# 交易训练工具 # 用于从交易记录CSV文件中随机选择交易,显示部分K线图让用户判断是否开仓,然后显示完整结果并记录 from jqdata import * import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches from datetime import datetime, timedelta, date import re import os import random import warnings warnings.filterwarnings('ignore') # 中文字体设置 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False # ========== 参数配置区域(硬编码参数都在这里) ========== CONFIG = { # CSV文件名 'csv_filename': 'transaction1.csv', # 结果记录文件名 'result_filename': 'training_results.csv', # 历史数据天数 'history_days': 100, # 未来数据天数 'future_days': 20, # 输出目录 'output_dir': 'training_images', # 是否显示图片(在某些环境下可能需要设置为False) 'show_plots': True, # 图片DPI 'plot_dpi': 150, # 随机种子(设置为None表示不固定种子) 'random_seed': 42 } # ===================================================== def _get_current_directory(): """ 获取当前文件所在目录 """ try: current_dir = os.path.dirname(os.path.abspath(__file__)) except NameError: current_dir = os.getcwd() if not os.path.exists(os.path.join(current_dir, CONFIG['csv_filename'])): parent_dir = os.path.dirname(current_dir) future_dir = os.path.join(parent_dir, 'future') if os.path.exists(os.path.join(future_dir, CONFIG['csv_filename'])): current_dir = future_dir return current_dir def read_transaction_data(csv_path): """ 读取交易记录CSV文件 """ encodings = ['utf-8-sig', 'utf-8', 'gbk', 'gb2312', 'gb18030', 'latin1'] for encoding in encodings: try: df = pd.read_csv(csv_path, encoding=encoding) print(f"成功使用 {encoding} 编码读取CSV文件,共 {len(df)} 条记录") return df except UnicodeDecodeError: continue except Exception as e: if encoding == encodings[-1]: print(f"读取CSV文件时出错: {str(e)}") raise return pd.DataFrame() def extract_contract_info(row): """ 从交易记录中提取合约信息和交易信息 """ try: # 提取合约编号 target_str = str(row['标的']) match = re.search(r'\(([^)]+)\)', target_str) if match: contract_code = match.group(1) else: return None, None, None, None, None, None, None, None, None # 提取日期 date_str = str(row['日期']).strip() date_formats = ['%Y-%m-%d', '%d/%m/%Y', '%Y/%m/%d', '%d-%m-%Y', '%Y%m%d'] trade_date = None for date_format in date_formats: try: trade_date = datetime.strptime(date_str, date_format).date() break except ValueError: continue if trade_date is None: return None, None, None, None, None, None, None, None, None # 提取委托时间 order_time_str = str(row['委托时间']).strip() try: time_parts = order_time_str.split(':') hour = int(time_parts[0]) # 如果委托时间 >= 21:00,需要找到下一个交易日 if hour >= 21: try: trade_days = get_trade_days(start_date=trade_date, count=2) if len(trade_days) >= 2: next_trade_day = trade_days[1] if isinstance(next_trade_day, datetime): actual_trade_date = next_trade_day.date() elif isinstance(next_trade_day, date): actual_trade_date = next_trade_day else: actual_trade_date = trade_date else: actual_trade_date = trade_date except: actual_trade_date = trade_date else: actual_trade_date = trade_date except: actual_trade_date = trade_date # 提取成交价 try: trade_price = float(row['成交价']) except: return None, None, None, None, None, None, None, None, None # 提取交易方向和类型 trade_type = str(row['交易类型']).strip() if '开多' in trade_type: direction = 'long' action = 'open' elif '开空' in trade_type: direction = 'short' action = 'open' elif '平多' in trade_type or '平空' in trade_type: action = 'close' direction = 'long' if '平多' in trade_type else 'short' else: return None, None, None, None, None, None, None, None, None # 提取交易对ID和连续交易对ID trade_pair_id = row.get('交易对ID', 'N/A') continuous_pair_id = row.get('连续交易对ID', 'N/A') return contract_code, actual_trade_date, trade_price, direction, action, order_time_str, trade_type, trade_pair_id, continuous_pair_id except Exception as e: print(f"提取合约信息时出错: {str(e)}") return None, None, None, None, None, None, None, None, None def get_trade_day_range(trade_date, days_before, days_after): """ 获取交易日范围 """ try: # 获取历史交易日 trade_days_before = get_trade_days(end_date=trade_date, count=days_before + 1) if len(trade_days_before) < days_before + 1: return None, None first_day = trade_days_before[0] if isinstance(first_day, datetime): start_date = first_day.date() elif isinstance(first_day, date): start_date = first_day else: start_date = first_day # 获取未来交易日 trade_days_after = get_trade_days(start_date=trade_date, count=days_after + 1) if len(trade_days_after) < days_after + 1: return None, None last_day = trade_days_after[-1] if isinstance(last_day, datetime): end_date = last_day.date() elif isinstance(last_day, date): end_date = last_day else: end_date = last_day return start_date, end_date except Exception as e: print(f"计算交易日范围时出错: {str(e)}") return None, None def get_kline_data_with_future(contract_code, trade_date, days_before=100, days_after=20): """ 获取包含历史和未来数据的K线数据 """ try: # 获取完整的数据范围 start_date, end_date = get_trade_day_range(trade_date, days_before, days_after) if start_date is None or end_date is None: return None, None, None # 获取K线数据 price_data = get_price( contract_code, start_date=start_date, end_date=end_date, frequency='1d', fields=['open', 'close', 'high', 'low'] ) if price_data is None or len(price_data) == 0: return None, None, None # 计算均线 price_data['ma5'] = price_data['close'].rolling(window=5).mean() price_data['ma10'] = price_data['close'].rolling(window=10).mean() price_data['ma20'] = price_data['close'].rolling(window=20).mean() price_data['ma30'] = price_data['close'].rolling(window=30).mean() # 找到交易日在数据中的位置 trade_date_normalized = pd.Timestamp(trade_date) trade_idx = None for i, idx in enumerate(price_data.index): if isinstance(idx, pd.Timestamp): if idx.date() == trade_date: trade_idx = i break return price_data, trade_idx except Exception as e: print(f"获取K线数据时出错: {str(e)}") return None, None, None def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, save_path=None): # contract_code 参数保留用于可能的扩展功能 """ 绘制部分K线图(仅显示历史数据和当天) """ try: # 截取历史数据和当天数据 partial_data = data.iloc[:trade_idx + 1].copy() # 根据交易方向修改当天的价格数据 if direction == 'long': # 做多时,用成交价替代最高价(表示买入点) partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price partial_data.iloc[-1, partial_data.columns.get_loc('high')] = trade_price else: # 做空时,用成交价替代最低价(表示卖出点) partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price partial_data.iloc[-1, partial_data.columns.get_loc('low')] = trade_price fig, ax = plt.subplots(figsize=(16, 10)) # 准备数据 dates = partial_data.index opens = partial_data['open'] highs = partial_data['high'] lows = partial_data['low'] closes = partial_data['close'] # 绘制K线 for i in range(len(partial_data)): # 检查是否是交易日 is_trade_day = (i == trade_idx) if is_trade_day: # 成交日根据涨跌用不同颜色 if closes.iloc[i] > opens.iloc[i]: # 上涨 color = '#FFD700' # 金黄色(黄红色混合) edge_color = '#FF8C00' # 深橙色 else: # 下跌 color = '#ADFF2F' # 黄绿色 edge_color = '#9ACD32' # 黄绿色深版 else: # 正常K线颜色 color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green' edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen' # 影线 ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1) # 实体 body_height = abs(closes.iloc[i] - opens.iloc[i]) if body_height == 0: body_height = 0.01 bottom = min(opens.iloc[i], closes.iloc[i]) rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height, linewidth=1, edgecolor=edge_color, facecolor=color, alpha=0.8) ax.add_patch(rect) # 绘制均线 ax.plot(range(len(partial_data)), partial_data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8) ax.plot(range(len(partial_data)), partial_data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8) ax.plot(range(len(partial_data)), partial_data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8) ax.plot(range(len(partial_data)), partial_data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8) # 获取当天的最高价(用于画连接线) day_high = highs.iloc[trade_idx] # 标注信息 date_label = trade_date.strftime('%Y-%m-%d') price_label = f'Price: {trade_price:.2f}' direction_label = f'Direction: {"Long" if direction == "long" else "Short"}' time_label = f'Time: {order_time}' # 将文本框移到左上角 annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}' text_box = ax.text(0.02, 0.98, annotation_text, fontsize=10, ha='left', va='top', transform=ax.transAxes, bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5), zorder=11, weight='bold') # 画黄色虚线连接文本框底部和交易日最高价 # 获取文本框在数据坐标系中的位置 fig.canvas.draw() # 需要先绘制一次才能获取准确位置 bbox = text_box.get_window_extent().transformed(ax.transData.inverted()) text_bottom_y = bbox.ymin # 从文本框底部到交易日最高价画虚线 ax.plot([trade_idx, trade_idx], [day_high, text_bottom_y], color='yellow', linestyle='--', linewidth=1.5, alpha=0.7, zorder=5) # 设置标题和标签 direction_text = "Long" if direction == "long" else "Short" ax.set_title(f'{direction_text} Position Decision\n' f'Historical Data + Trade Day Only', fontsize=14, fontweight='bold', pad=20) ax.set_xlabel('Time', fontsize=12) ax.set_ylabel('Price', fontsize=12) ax.grid(True, alpha=0.3) ax.legend(loc='lower left', fontsize=10) # 设置x轴标签 step = max(1, len(partial_data) // 10) tick_positions = range(0, len(partial_data), step) tick_labels = [] for pos in tick_positions: date_val = dates[pos] if isinstance(date_val, (date, datetime)): tick_labels.append(date_val.strftime('%Y-%m-%d')) else: tick_labels.append(str(date_val)) ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels, rotation=45, ha='right') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight') if CONFIG['show_plots']: plt.show() plt.close(fig) except Exception as e: print(f"绘制部分K线图时出错: {str(e)}") plt.close('all') raise def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, profit_loss, save_path=None): """ 绘制完整K线图(包含未来数据) """ try: fig, ax = plt.subplots(figsize=(16, 10)) # 准备数据 dates = data.index opens = data['open'] highs = data['high'] lows = data['low'] closes = data['close'] # 绘制K线 for i in range(len(data)): # 检查是否是交易日 is_trade_day = (i == trade_idx) if is_trade_day: # 成交日根据涨跌用不同颜色 if closes.iloc[i] > opens.iloc[i]: # 上涨 color = '#FFD700' # 金黄色(黄红色混合) edge_color = '#FF8C00' # 深橙色 else: # 下跌 color = '#ADFF2F' # 黄绿色 edge_color = '#9ACD32' # 黄绿色深版 else: # 正常K线颜色 color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green' edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen' # 影线 ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1) # 实体 body_height = abs(closes.iloc[i] - opens.iloc[i]) if body_height == 0: body_height = 0.01 bottom = min(opens.iloc[i], closes.iloc[i]) rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height, linewidth=1, edgecolor=edge_color, facecolor=color, alpha=0.8) ax.add_patch(rect) # 绘制均线 ax.plot(range(len(data)), data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8) ax.plot(range(len(data)), data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8) ax.plot(range(len(data)), data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8) ax.plot(range(len(data)), data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8) # 获取当天的最高价(用于画连接线) day_high = highs.iloc[trade_idx] # 添加未来区域背景 ax.axvspan(trade_idx + 0.5, len(data) - 0.5, alpha=0.1, color='gray', label='Future Data') # 标注信息 date_label = trade_date.strftime('%Y-%m-%d') price_label = f'Price: {trade_price:.2f}' direction_label = f'Direction: {"Long" if direction == "long" else "Short"}' time_label = f'Time: {order_time}' profit_label = f'P&L: {profit_loss:+.2f}' # 将文本框移到左上角 annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}\n{profit_label}' text_box = ax.text(0.02, 0.98, annotation_text, fontsize=10, ha='left', va='top', transform=ax.transAxes, bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5), zorder=11, weight='bold') # 画黄色虚线连接文本框底部和交易日最高价 # 获取文本框在数据坐标系中的位置 fig.canvas.draw() # 需要先绘制一次才能获取准确位置 bbox = text_box.get_window_extent().transformed(ax.transData.inverted()) text_bottom_y = bbox.ymin # 从文本框底部到交易日最高价画虚线 ax.plot([trade_idx, trade_idx], [day_high, text_bottom_y], color='yellow', linestyle='--', linewidth=1.5, alpha=0.7, zorder=5) # 设置标题和标签 contract_simple = contract_code.split('.')[0] direction_text = "Long" if direction == "long" else "Short" ax.set_title(f'{contract_simple} - {direction_text} Position Result\n' f'Complete Data with Future {CONFIG["future_days"]} Days', fontsize=14, fontweight='bold', pad=20) ax.set_xlabel('Time', fontsize=12) ax.set_ylabel('Price', fontsize=12) ax.grid(True, alpha=0.3) ax.legend(loc='lower left', fontsize=10) # 设置x轴标签 step = max(1, len(data) // 15) tick_positions = range(0, len(data), step) tick_labels = [] for pos in tick_positions: date_val = dates[pos] if isinstance(date_val, (date, datetime)): tick_labels.append(date_val.strftime('%Y-%m-%d')) else: tick_labels.append(str(date_val)) ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels, rotation=45, ha='right') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight') if CONFIG['show_plots']: plt.show() plt.close(fig) except Exception as e: print(f"绘制完整K线图时出错: {str(e)}") plt.close('all') raise def load_processed_results(result_path): """ 加载已处理的结果文件 """ if not os.path.exists(result_path): return pd.DataFrame(), set() try: # 简单读取CSV文件 df = pd.read_csv(result_path, header=0) # 确保必要的列存在 required_columns = ['交易对ID'] for col in required_columns: if col not in df.columns: print(f"警告:结果文件缺少必要列 '{col}'") return pd.DataFrame(), set() # 获取已处理的交易对ID processed_pairs = set(df['交易对ID'].dropna().unique()) return df, processed_pairs except Exception as e: # 详细打印错误信息 print(f"加载结果文件时出错: {str(e)}") print(f"错误类型: {type(e)}") return pd.DataFrame(), set() def calculate_profit_loss(df, trade_pair_id, continuous_pair_id): """ 计算平仓盈亏 """ try: if continuous_pair_id != 'N/A' and pd.notna(continuous_pair_id): # 合并所有同一连续交易对ID的平仓盈亏 close_trades = df[ (df['连续交易对ID'] == continuous_pair_id) & (df['交易类型'].str[0] == '平') ] total_profit = close_trades['平仓盈亏'].sum() else: # 只查找当前交易对ID的平仓交易 close_trades = df[ (df['交易对ID'] == trade_pair_id) & (df['交易类型'].str[0] == '平') ] if len(close_trades) > 0: total_profit = close_trades['平仓盈亏'].iloc[0] else: total_profit = 0 return total_profit except Exception as e: print(f"计算盈亏时出错: {str(e)}") return 0 def record_result(result_data, result_path): """ 记录训练结果 """ try: # 创建结果DataFrame result_df = pd.DataFrame([result_data]) # 如果文件已存在,读取现有格式并确保新数据格式一致 if os.path.exists(result_path): try: # 读取现有文件的列名 existing_df = pd.read_csv(result_path, nrows=0) # 只读取列名 existing_columns = existing_df.columns.tolist() # 如果新数据列与现有文件不一致,调整格式 if list(result_df.columns) != existing_columns: # 重新创建DataFrame,确保列顺序一致 aligned_data = {} for col in existing_columns: aligned_data[col] = result_data.get(col, 'N/A' if col == '连续交易总盈亏' else '') result_df = pd.DataFrame([aligned_data]) # 追加写入 result_df.to_csv(result_path, mode='a', header=False, index=False, encoding='utf-8-sig') except Exception: # 如果无法读取现有格式,直接覆盖 result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig') else: # 文件不存在,创建新文件 result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig') print(f"结果已记录到: {result_path}") except Exception as e: print(f"记录结果时出错: {str(e)}") def is_first_continuous_trade(transaction_df, trade_pair_id, continuous_pair_id): """ 判断是否为连续交易的第一笔交易 参数: transaction_df: 交易数据DataFrame trade_pair_id: 当前交易对ID continuous_pair_id: 连续交易对ID 返回: bool: 是否为连续交易的第一笔交易(或不是连续交易) """ # 如果不是连续交易,返回True if continuous_pair_id == 'N/A' or pd.isna(continuous_pair_id): return True # 获取同一连续交易组的所有交易 continuous_trades = transaction_df[transaction_df['连续交易对ID'] == continuous_pair_id] # 获取所有交易对ID并按时间排序 pair_ids = continuous_trades['交易对ID'].unique() # 获取每个交易对的开仓时间 pair_times = [] for pid in pair_ids: pair_records = continuous_trades[continuous_trades['交易对ID'] == pid] open_records = pair_records[pair_records['交易类型'].str.contains('开', na=False)] if len(open_records) > 0: # 获取第一个开仓记录的日期和时间 first_open = open_records.iloc[0] date_str = str(first_open['日期']).strip() time_str = str(first_open['委托时间']).strip() try: dt = pd.to_datetime(f"{date_str} {time_str}") pair_times.append((pid, dt)) except: pass # 按时间排序 pair_times.sort(key=lambda x: x[1]) # 检查当前交易对是否为第一个 if pair_times and pair_times[0][0] == trade_pair_id: return True return False def get_user_decision(): """ 获取用户的开仓决策和信心指数 返回: tuple: (是否开仓, 信心指数) - 是否开仓: bool - 信心指数: int (1-3) """ while True: decision = input("\n是否开仓?请输入 'y,信心指数' (开仓) 或 'n,信心指数' (不开仓)\n" + "例如: 'y,3' (开仓,高信心) 或 'n,1' (不开仓,低信心)\n" + "信心指数: 1=低, 2=中, 3=高 (默认为2): ").strip().lower() # 解析输入 parts = decision.split(',') decision_part = parts[0].strip() confidence = 2 # 默认信心指数 # 检查是否提供了信心指数 if len(parts) >= 2: try: confidence = int(parts[1].strip()) if confidence not in [1, 2, 3]: print("信心指数必须是 1、2 或 3,请重新输入") continue except ValueError: print("信心指数必须是数字 1、2 或 3,请重新输入") continue # 检查开仓决策 if decision_part in ['y', 'yes', '是', '开仓']: return True, confidence elif decision_part in ['n', 'no', '否', '不开仓']: return False, confidence else: print("请输入有效的选项: 'y' 或 'n' (可选择性添加信心指数,如 'y,3')") def main(): """ 主函数 """ print("=" * 60) print("交易训练工具") print("=" * 60) # 设置随机种子 if CONFIG['random_seed'] is not None: random.seed(CONFIG['random_seed']) np.random.seed(CONFIG['random_seed']) # 获取当前目录 current_dir = _get_current_directory() csv_path = os.path.join(current_dir, CONFIG['csv_filename']) result_path = os.path.join(current_dir, CONFIG['result_filename']) output_dir = os.path.join(current_dir, CONFIG['output_dir']) # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 1. 读取交易数据 print("\n=== 步骤1: 读取交易数据 ===") transaction_df = read_transaction_data(csv_path) if len(transaction_df) == 0: print("未能读取交易数据,退出") return # 2. 加载已处理的结果 print("\n=== 步骤2: 加载已处理记录 ===") _, processed_pairs = load_processed_results(result_path) print(f"已处理 {len(processed_pairs)} 个交易对") # 3. 提取所有开仓交易 print("\n=== 步骤3: 提取开仓交易 ===") open_trades = [] for idx, row in transaction_df.iterrows(): contract_code, trade_date, trade_price, direction, action, order_time, trade_type, trade_pair_id, continuous_pair_id = extract_contract_info(row) if contract_code is None or action != 'open': continue # 跳过已处理的交易对 if trade_pair_id in processed_pairs: continue # 检查是否为连续交易的第一笔交易(如果不是第一笔,跳过) if not is_first_continuous_trade(transaction_df, trade_pair_id, continuous_pair_id): continue # 查找对应的平仓交易 profit_loss = calculate_profit_loss(transaction_df, trade_pair_id, continuous_pair_id) # 如果是连续交易,获取连续交易总盈亏 continuous_total_profit = 'N/A' if continuous_pair_id != 'N/A' and pd.notna(continuous_pair_id): continuous_trades = transaction_df[transaction_df['连续交易对ID'] == continuous_pair_id] try: close_profit_loss_str = continuous_trades['平仓盈亏'].astype(str).str.replace(',', '') close_profit_loss_numeric = pd.to_numeric(close_profit_loss_str, errors='coerce').fillna(0) continuous_total_profit = close_profit_loss_numeric.sum() except: continuous_total_profit = 0 open_trades.append({ 'index': idx, 'contract_code': contract_code, 'trade_date': trade_date, 'trade_price': trade_price, 'direction': direction, 'order_time': order_time, 'trade_type': trade_type, 'trade_pair_id': trade_pair_id, 'continuous_pair_id': continuous_pair_id, 'profit_loss': profit_loss, 'continuous_total_profit': continuous_total_profit, 'original_row': row }) print(f"找到 {len(open_trades)} 个未处理的开仓交易(已过滤非首笔连续交易)") if len(open_trades) == 0: print("没有未处理的开仓交易,退出") return # 4. 随机选择一个交易(按标的类型分组随机抽取,避免同类连续出现) print("\n=== 步骤4: 随机选择交易 ===") # 按标的类型分组(提取合约代码的核心字母部分) def get_contract_type(contract_code): """提取合约类型,如'M2405'提取为'M','AG2406'提取为'AG'""" import re match = re.match(r'^([A-Za-z]+)', contract_code.split('.')[0]) return match.group(1) if match else 'UNKNOWN' # 按合约类型分组 trades_by_type = {} for trade in open_trades: contract_type = get_contract_type(trade['contract_code']) if contract_type not in trades_by_type: trades_by_type[contract_type] = [] trades_by_type[contract_type].append(trade) # 打乱每个组内的顺序 for contract_type in trades_by_type: random.shuffle(trades_by_type[contract_type]) # 从各组中轮流抽取,确保类型分散 selected_trade = None available_types = list(trades_by_type.keys()) # 随机打乱类型顺序,然后从第一个有交易的类型中抽取 random.shuffle(available_types) for contract_type in available_types: if trades_by_type[contract_type]: selected_trade = trades_by_type[contract_type].pop(0) break if selected_trade is None: # 如果上述方法失败,回退到简单随机选择 random.shuffle(open_trades) selected_trade = random.choice(open_trades) # print(f"选中交易: {selected_trade['contract_code']} - {selected_trade['trade_date']} - {selected_trade['direction']}") print(f"剩余未处理交易: {len(open_trades) - 1} 个") # 5. 获取K线数据 print("\n=== 步骤5: 获取K线数据 ===") kline_data, trade_idx = get_kline_data_with_future( selected_trade['contract_code'], selected_trade['trade_date'], CONFIG['history_days'], CONFIG['future_days'] ) if kline_data is None or trade_idx is None: print("获取K线数据失败,退出") return # 6. 显示部分K线图 print("\n=== 步骤6: 显示部分K线图 ===") partial_image_name = f"partial_{selected_trade['contract_code']}_{selected_trade['trade_date']}_{selected_trade['direction']}.png" partial_image_path = os.path.join(output_dir, partial_image_name) plot_partial_kline( kline_data, trade_idx, selected_trade['trade_price'], selected_trade['direction'], selected_trade['contract_code'], selected_trade['trade_date'], selected_trade['order_time'], partial_image_path ) # 7. 获取用户决策和信心指数 user_decision, confidence_level = get_user_decision() # 8. 显示完整K线图 print("\n=== 步骤7: 显示完整K线图 ===") full_image_name = f"full_{selected_trade['contract_code']}_{selected_trade['trade_date']}_{selected_trade['direction']}.png" full_image_path = os.path.join(output_dir, full_image_name) plot_full_kline( kline_data, trade_idx, selected_trade['trade_price'], selected_trade['direction'], selected_trade['contract_code'], selected_trade['trade_date'], selected_trade['order_time'], selected_trade['profit_loss'], full_image_path ) # 在完整K线图之后显示交易信息 print(f"\n交易信息:") print(f"合约: {selected_trade['contract_code']}") print(f"日期: {selected_trade['trade_date']}") print(f"方向: {'多头' if selected_trade['direction'] == 'long' else '空头'}") print(f"成交价: {selected_trade['trade_price']}") # 9. 记录结果 print("\n=== 步骤8: 记录结果 ===") # 计算判定收益(使用连续交易总盈亏或普通盈亏) if selected_trade['continuous_total_profit'] != 'N/A': # 连续交易使用连续交易总盈亏 decision_profit = selected_trade['continuous_total_profit'] if user_decision else -selected_trade['continuous_total_profit'] profit_to_show = selected_trade['continuous_total_profit'] else: # 普通交易使用单笔盈亏 decision_profit = selected_trade['profit_loss'] if user_decision else -selected_trade['profit_loss'] profit_to_show = selected_trade['profit_loss'] result_data = { '日期': selected_trade['original_row']['日期'], '委托时间': selected_trade['original_row']['委托时间'], '标的': selected_trade['original_row']['标的'], '交易类型': selected_trade['original_row']['交易类型'], '成交数量': selected_trade['original_row']['成交数量'], '成交价': selected_trade['original_row']['成交价'], '平仓盈亏': selected_trade['profit_loss'], '用户判定': '开仓' if user_decision else '不开仓', '信心指数': confidence_level, '判定收益': decision_profit, '交易对ID': selected_trade['trade_pair_id'], '连续交易对ID': selected_trade['continuous_pair_id'], '连续交易总盈亏': selected_trade['continuous_total_profit'] } record_result(result_data, result_path) print(f"\n=== 训练完成 ===") print(f"用户判定: {'开仓' if user_decision else '不开仓'}") print(f"信心指数: {confidence_level} ({'低' if confidence_level == 1 else '中' if confidence_level == 2 else '高'})") if selected_trade['continuous_total_profit'] != 'N/A': print(f"连续交易总盈亏: {profit_to_show:+.2f}") else: print(f"实际盈亏: {profit_to_show:+.2f}") print(f"判定收益: {decision_profit:+.2f}") print(f"结果已保存到: {result_path}") if __name__ == "__main__": main()