# 交易训练工具 # 用于从交易记录CSV文件中随机选择交易,显示部分K线图让用户判断是否开仓,然后显示完整结果并记录 from jqdata import * import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches from datetime import datetime, timedelta, date import re import os import random import warnings warnings.filterwarnings('ignore') # 中文字体设置 plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False # ========== 参数配置区域(硬编码参数都在这里) ========== CONFIG = { # CSV文件名 'csv_filename': 'transaction1.csv', # 结果记录文件名 'result_filename': 'training_results.csv', # 历史数据天数 'history_days': 100, # 未来数据天数 'future_days': 20, # 输出目录 'output_dir': 'training_images', # 是否显示图片(在某些环境下可能需要设置为False) 'show_plots': True, # 图片DPI 'plot_dpi': 150, # 随机种子(设置为None表示不固定种子) 'random_seed': 42 } # ===================================================== def _get_current_directory(): """ 获取当前文件所在目录 """ try: current_dir = os.path.dirname(os.path.abspath(__file__)) except NameError: current_dir = os.getcwd() if not os.path.exists(os.path.join(current_dir, CONFIG['csv_filename'])): parent_dir = os.path.dirname(current_dir) future_dir = os.path.join(parent_dir, 'future') if os.path.exists(os.path.join(future_dir, CONFIG['csv_filename'])): current_dir = future_dir return current_dir def read_transaction_data(csv_path): """ 读取交易记录CSV文件 """ encodings = ['utf-8-sig', 'utf-8', 'gbk', 'gb2312', 'gb18030', 'latin1'] for encoding in encodings: try: df = pd.read_csv(csv_path, encoding=encoding) print(f"成功使用 {encoding} 编码读取CSV文件,共 {len(df)} 条记录") return df except UnicodeDecodeError: continue except Exception as e: if encoding == encodings[-1]: print(f"读取CSV文件时出错: {str(e)}") raise return pd.DataFrame() def extract_contract_info(row): """ 从交易记录中提取合约信息和交易信息 """ try: # 提取合约编号 target_str = str(row['标的']) match = re.search(r'\(([^)]+)\)', target_str) if match: contract_code = match.group(1) else: return None, None, None, None, None, None, None, None, None # 提取日期 date_str = str(row['日期']).strip() date_formats = ['%Y-%m-%d', '%d/%m/%Y', '%Y/%m/%d', '%d-%m-%Y', '%Y%m%d'] trade_date = None for date_format in date_formats: try: trade_date = datetime.strptime(date_str, date_format).date() break except ValueError: continue if trade_date is None: return None, None, None, None, None, None, None, None, None # 提取委托时间 order_time_str = str(row['委托时间']).strip() try: time_parts = order_time_str.split(':') hour = int(time_parts[0]) # 如果委托时间 >= 21:00,需要找到下一个交易日 if hour >= 21: try: trade_days = get_trade_days(start_date=trade_date, count=2) if len(trade_days) >= 2: next_trade_day = trade_days[1] if isinstance(next_trade_day, datetime): actual_trade_date = next_trade_day.date() elif isinstance(next_trade_day, date): actual_trade_date = next_trade_day else: actual_trade_date = trade_date else: actual_trade_date = trade_date except: actual_trade_date = trade_date else: actual_trade_date = trade_date except: actual_trade_date = trade_date # 提取成交价 try: trade_price = float(row['成交价']) except: return None, None, None, None, None, None, None, None, None # 提取交易方向和类型 trade_type = str(row['交易类型']).strip() if '开多' in trade_type: direction = 'long' action = 'open' elif '开空' in trade_type: direction = 'short' action = 'open' elif '平多' in trade_type or '平空' in trade_type: action = 'close' direction = 'long' if '平多' in trade_type else 'short' else: return None, None, None, None, None, None, None, None, None # 提取交易对ID和连续交易对ID trade_pair_id = row.get('交易对ID', 'N/A') continuous_pair_id = row.get('连续交易对ID', 'N/A') return contract_code, actual_trade_date, trade_price, direction, action, order_time_str, trade_type, trade_pair_id, continuous_pair_id except Exception as e: print(f"提取合约信息时出错: {str(e)}") return None, None, None, None, None, None, None, None, None def get_trade_day_range(trade_date, days_before, days_after): """ 获取交易日范围 """ try: # 获取历史交易日 trade_days_before = get_trade_days(end_date=trade_date, count=days_before + 1) if len(trade_days_before) < days_before + 1: return None, None first_day = trade_days_before[0] if isinstance(first_day, datetime): start_date = first_day.date() elif isinstance(first_day, date): start_date = first_day else: start_date = first_day # 获取未来交易日 trade_days_after = get_trade_days(start_date=trade_date, count=days_after + 1) if len(trade_days_after) < days_after + 1: return None, None last_day = trade_days_after[-1] if isinstance(last_day, datetime): end_date = last_day.date() elif isinstance(last_day, date): end_date = last_day else: end_date = last_day return start_date, end_date except Exception as e: print(f"计算交易日范围时出错: {str(e)}") return None, None def get_kline_data_with_future(contract_code, trade_date, days_before=100, days_after=20): """ 获取包含历史和未来数据的K线数据 """ try: # 获取完整的数据范围 start_date, end_date = get_trade_day_range(trade_date, days_before, days_after) if start_date is None or end_date is None: return None, None, None # 获取K线数据 price_data = get_price( contract_code, start_date=start_date, end_date=end_date, frequency='1d', fields=['open', 'close', 'high', 'low'] ) if price_data is None or len(price_data) == 0: return None, None, None # 计算均线 price_data['ma5'] = price_data['close'].rolling(window=5).mean() price_data['ma10'] = price_data['close'].rolling(window=10).mean() price_data['ma20'] = price_data['close'].rolling(window=20).mean() price_data['ma30'] = price_data['close'].rolling(window=30).mean() # 找到交易日在数据中的位置 trade_date_normalized = pd.Timestamp(trade_date) trade_idx = None for i, idx in enumerate(price_data.index): if isinstance(idx, pd.Timestamp): if idx.date() == trade_date: trade_idx = i break return price_data, trade_idx except Exception as e: print(f"获取K线数据时出错: {str(e)}") return None, None, None def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, save_path=None): """ 绘制部分K线图(仅显示历史数据和当天) """ try: # 截取历史数据和当天数据 partial_data = data.iloc[:trade_idx + 1].copy() # 修改当天的收盘价为成交价 partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price fig, ax = plt.subplots(figsize=(16, 10)) # 准备数据 dates = partial_data.index opens = partial_data['open'] highs = partial_data['high'] lows = partial_data['low'] closes = partial_data['close'] # 绘制K线 for i in range(len(partial_data)): color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green' edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen' # 影线 ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1) # 实体 body_height = abs(closes.iloc[i] - opens.iloc[i]) if body_height == 0: body_height = 0.01 bottom = min(opens.iloc[i], closes.iloc[i]) rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height, linewidth=1, edgecolor=edge_color, facecolor=color, alpha=0.8) ax.add_patch(rect) # 绘制均线 ax.plot(range(len(partial_data)), partial_data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8) ax.plot(range(len(partial_data)), partial_data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8) ax.plot(range(len(partial_data)), partial_data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8) ax.plot(range(len(partial_data)), partial_data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8) # 标注开仓位置 ax.plot(trade_idx, trade_price, marker='*', markersize=20, color='yellow', markeredgecolor='black', markeredgewidth=2, label='Open Position', zorder=10) # 添加垂直线 ax.axvline(x=trade_idx, color='yellow', linestyle='--', linewidth=2, alpha=0.7, zorder=5) # 标注信息 date_label = trade_date.strftime('%Y-%m-%d') price_label = f'Price: {trade_price:.2f}' direction_label = f'Direction: {"Long" if direction == "long" else "Short"}' time_label = f'Time: {order_time}' # 计算文本位置 price_range = highs.max() - lows.min() y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3) text_y = trade_price + y_offset if text_y > highs.max(): text_y = trade_price - price_range * 0.08 annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}' ax.text(trade_idx, text_y, annotation_text, fontsize=10, ha='center', va='bottom', bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5), zorder=11, weight='bold') # 设置标题和标签 contract_simple = contract_code.split('.')[0] direction_text = "Long" if direction == "long" else "Short" ax.set_title(f'{contract_simple} - {direction_text} Position Decision\n' f'Historical Data + Trade Day Only', fontsize=14, fontweight='bold', pad=20) ax.set_xlabel('Time', fontsize=12) ax.set_ylabel('Price', fontsize=12) ax.grid(True, alpha=0.3) ax.legend(loc='lower left', fontsize=10) # 设置x轴标签 step = max(1, len(partial_data) // 10) tick_positions = range(0, len(partial_data), step) tick_labels = [] for pos in tick_positions: date_val = dates[pos] if isinstance(date_val, (date, datetime)): tick_labels.append(date_val.strftime('%Y-%m-%d')) else: tick_labels.append(str(date_val)) ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels, rotation=45, ha='right') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight') if CONFIG['show_plots']: plt.show() plt.close(fig) except Exception as e: print(f"绘制部分K线图时出错: {str(e)}") plt.close('all') raise def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, profit_loss, save_path=None): """ 绘制完整K线图(包含未来数据) """ try: fig, ax = plt.subplots(figsize=(16, 10)) # 准备数据 dates = data.index opens = data['open'] highs = data['high'] lows = data['low'] closes = data['close'] # 绘制K线 for i in range(len(data)): color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green' edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen' # 影线 ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1) # 实体 body_height = abs(closes.iloc[i] - opens.iloc[i]) if body_height == 0: body_height = 0.01 bottom = min(opens.iloc[i], closes.iloc[i]) rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height, linewidth=1, edgecolor=edge_color, facecolor=color, alpha=0.8) ax.add_patch(rect) # 绘制均线 ax.plot(range(len(data)), data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8) ax.plot(range(len(data)), data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8) ax.plot(range(len(data)), data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8) ax.plot(range(len(data)), data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8) # 标注开仓位置 ax.plot(trade_idx, trade_price, marker='*', markersize=20, color='yellow', markeredgecolor='black', markeredgewidth=2, label='Open Position', zorder=10) # 添加垂直线分隔历史和未来 ax.axvline(x=trade_idx, color='yellow', linestyle='--', linewidth=2, alpha=0.7, zorder=5) # 添加未来区域背景 ax.axvspan(trade_idx + 0.5, len(data) - 0.5, alpha=0.1, color='gray', label='Future Data') # 标注信息 date_label = trade_date.strftime('%Y-%m-%d') price_label = f'Price: {trade_price:.2f}' direction_label = f'Direction: {"Long" if direction == "long" else "Short"}' time_label = f'Time: {order_time}' profit_label = f'P&L: {profit_loss:+.2f}' # 计算文本位置 price_range = highs.max() - lows.min() y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3) text_y = trade_price + y_offset if text_y > highs.max(): text_y = trade_price - price_range * 0.08 annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}\n{profit_label}' ax.text(trade_idx, text_y, annotation_text, fontsize=10, ha='center', va='bottom', bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5), zorder=11, weight='bold') # 设置标题和标签 contract_simple = contract_code.split('.')[0] direction_text = "Long" if direction == "long" else "Short" ax.set_title(f'{contract_simple} - {direction_text} Position Result\n' f'Complete Data with Future {CONFIG["future_days"]} Days', fontsize=14, fontweight='bold', pad=20) ax.set_xlabel('Time', fontsize=12) ax.set_ylabel('Price', fontsize=12) ax.grid(True, alpha=0.3) ax.legend(loc='lower left', fontsize=10) # 设置x轴标签 step = max(1, len(data) // 15) tick_positions = range(0, len(data), step) tick_labels = [] for pos in tick_positions: date_val = dates[pos] if isinstance(date_val, (date, datetime)): tick_labels.append(date_val.strftime('%Y-%m-%d')) else: tick_labels.append(str(date_val)) ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels, rotation=45, ha='right') plt.tight_layout() if save_path: plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight') if CONFIG['show_plots']: plt.show() plt.close(fig) except Exception as e: print(f"绘制完整K线图时出错: {str(e)}") plt.close('all') raise def load_processed_results(result_path): """ 加载已处理的结果文件 """ if not os.path.exists(result_path): return pd.DataFrame(), set() try: df = pd.read_csv(result_path) # 获取已处理的交易对ID processed_pairs = set(df['交易对ID'].unique()) return df, processed_pairs except Exception as e: print(f"加载结果文件时出错: {str(e)}") return pd.DataFrame(), set() def calculate_profit_loss(df, trade_pair_id, continuous_pair_id): """ 计算平仓盈亏 """ try: if continuous_pair_id != 'N/A' and pd.notna(continuous_pair_id): # 合并所有同一连续交易对ID的平仓盈亏 close_trades = df[ (df['连续交易对ID'] == continuous_pair_id) & (df['交易类型'].str[0] == '平') ] total_profit = close_trades['平仓盈亏'].sum() else: # 只查找当前交易对ID的平仓交易 close_trades = df[ (df['交易对ID'] == trade_pair_id) & (df['交易类型'].str[0] == '平') ] if len(close_trades) > 0: total_profit = close_trades['平仓盈亏'].iloc[0] else: total_profit = 0 return total_profit except Exception as e: print(f"计算盈亏时出错: {str(e)}") return 0 def record_result(result_data, result_path): """ 记录训练结果 """ try: # 创建结果DataFrame result_df = pd.DataFrame([result_data]) # 如果文件已存在,追加写入;否则创建新文件 if os.path.exists(result_path): result_df.to_csv(result_path, mode='a', header=False, index=False, encoding='utf-8-sig') else: result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig') print(f"结果已记录到: {result_path}") except Exception as e: print(f"记录结果时出错: {str(e)}") def get_user_decision(): """ 获取用户的开仓决策 """ while True: decision = input("\n是否开仓?请输入 'y' (开仓) 或 'n' (不开仓): ").strip().lower() if decision in ['y', 'yes', '是', '开仓']: return True elif decision in ['n', 'no', '否', '不开仓']: return False else: print("请输入有效的选项: 'y' 或 'n'") def main(): """ 主函数 """ print("=" * 60) print("交易训练工具") print("=" * 60) # 设置随机种子 if CONFIG['random_seed'] is not None: random.seed(CONFIG['random_seed']) np.random.seed(CONFIG['random_seed']) # 获取当前目录 current_dir = _get_current_directory() csv_path = os.path.join(current_dir, CONFIG['csv_filename']) result_path = os.path.join(current_dir, CONFIG['result_filename']) output_dir = os.path.join(current_dir, CONFIG['output_dir']) # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 1. 读取交易数据 print("\n=== 步骤1: 读取交易数据 ===") transaction_df = read_transaction_data(csv_path) if len(transaction_df) == 0: print("未能读取交易数据,退出") return # 2. 加载已处理的结果 print("\n=== 步骤2: 加载已处理记录 ===") existing_results, processed_pairs = load_processed_results(result_path) print(f"已处理 {len(processed_pairs)} 个交易对") # 3. 提取所有开仓交易 print("\n=== 步骤3: 提取开仓交易 ===") open_trades = [] for idx, row in transaction_df.iterrows(): contract_code, trade_date, trade_price, direction, action, order_time, trade_type, trade_pair_id, continuous_pair_id = extract_contract_info(row) if contract_code is None or action != 'open': continue # 跳过已处理的交易对 if trade_pair_id in processed_pairs: continue # 查找对应的平仓交易 profit_loss = calculate_profit_loss(transaction_df, trade_pair_id, continuous_pair_id) open_trades.append({ 'index': idx, 'contract_code': contract_code, 'trade_date': trade_date, 'trade_price': trade_price, 'direction': direction, 'order_time': order_time, 'trade_type': trade_type, 'trade_pair_id': trade_pair_id, 'continuous_pair_id': continuous_pair_id, 'profit_loss': profit_loss, 'original_row': row }) print(f"找到 {len(open_trades)} 个未处理的开仓交易") if len(open_trades) == 0: print("没有未处理的开仓交易,退出") return # 4. 随机选择一个交易(完全随机) print("\n=== 步骤4: 随机选择交易 ===") # 打乱交易列表顺序确保完全随机 random.shuffle(open_trades) selected_trade = random.choice(open_trades) print(f"选中交易: {selected_trade['contract_code']} - {selected_trade['trade_date']} - {selected_trade['direction']}") print(f"剩余未处理交易: {len(open_trades) - 1} 个") # 5. 获取K线数据 print("\n=== 步骤5: 获取K线数据 ===") kline_data, trade_idx = get_kline_data_with_future( selected_trade['contract_code'], selected_trade['trade_date'], CONFIG['history_days'], CONFIG['future_days'] ) if kline_data is None or trade_idx is None: print("获取K线数据失败,退出") return # 6. 显示部分K线图 print("\n=== 步骤6: 显示部分K线图 ===") partial_image_path = os.path.join(output_dir, f"partial_{selected_trade['trade_pair_id']}.png") plot_partial_kline( kline_data, trade_idx, selected_trade['trade_price'], selected_trade['direction'], selected_trade['contract_code'], selected_trade['trade_date'], selected_trade['order_time'], partial_image_path ) # 7. 获取用户决策 print(f"\n交易信息:") print(f"合约: {selected_trade['contract_code']}") print(f"日期: {selected_trade['trade_date']}") print(f"方向: {'多头' if selected_trade['direction'] == 'long' else '空头'}") print(f"成交价: {selected_trade['trade_price']}") user_decision = get_user_decision() # 8. 显示完整K线图 print("\n=== 步骤7: 显示完整K线图 ===") full_image_path = os.path.join(output_dir, f"full_{selected_trade['trade_pair_id']}.png") plot_full_kline( kline_data, trade_idx, selected_trade['trade_price'], selected_trade['direction'], selected_trade['contract_code'], selected_trade['trade_date'], selected_trade['order_time'], selected_trade['profit_loss'], full_image_path ) # 9. 记录结果 print("\n=== 步骤8: 记录结果 ===") # 计算判定收益 decision_profit = selected_trade['profit_loss'] if user_decision else -selected_trade['profit_loss'] result_data = { '日期': selected_trade['original_row']['日期'], '委托时间': selected_trade['original_row']['委托时间'], '标的': selected_trade['original_row']['标的'], '交易类型': selected_trade['original_row']['交易类型'], '成交数量': selected_trade['original_row']['成交数量'], '成交价': selected_trade['original_row']['成交价'], '平仓盈亏': selected_trade['profit_loss'], '用户判定': '开仓' if user_decision else '不开仓', '判定收益': decision_profit, '交易对ID': selected_trade['trade_pair_id'], '连续交易对ID': selected_trade['continuous_pair_id'] } record_result(result_data, result_path) print(f"\n=== 训练完成 ===") print(f"用户判定: {'开仓' if user_decision else '不开仓'}") print(f"实际盈亏: {selected_trade['profit_loss']:+.2f}") print(f"判定收益: {decision_profit:+.2f}") print(f"结果已保存到: {result_path}") if __name__ == "__main__": main()