|
|
@@ -0,0 +1,720 @@
|
|
|
+# 交易训练工具
|
|
|
+# 用于从交易记录CSV文件中随机选择交易,显示部分K线图让用户判断是否开仓,然后显示完整结果并记录
|
|
|
+
|
|
|
+from jqdata import *
|
|
|
+import pandas as pd
|
|
|
+import numpy as np
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import matplotlib.patches as patches
|
|
|
+from datetime import datetime, timedelta, date
|
|
|
+import re
|
|
|
+import os
|
|
|
+import random
|
|
|
+import warnings
|
|
|
+warnings.filterwarnings('ignore')
|
|
|
+
|
|
|
+# 中文字体设置
|
|
|
+plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
|
|
|
+plt.rcParams['axes.unicode_minus'] = False
|
|
|
+
|
|
|
+# ========== 参数配置区域(硬编码参数都在这里) ==========
|
|
|
+CONFIG = {
|
|
|
+ # CSV文件名
|
|
|
+ 'csv_filename': 'transaction1.csv',
|
|
|
+
|
|
|
+ # 结果记录文件名
|
|
|
+ 'result_filename': 'training_results.csv',
|
|
|
+
|
|
|
+ # 历史数据天数
|
|
|
+ 'history_days': 100,
|
|
|
+
|
|
|
+ # 未来数据天数
|
|
|
+ 'future_days': 20,
|
|
|
+
|
|
|
+ # 输出目录
|
|
|
+ 'output_dir': 'training_images',
|
|
|
+
|
|
|
+ # 是否显示图片(在某些环境下可能需要设置为False)
|
|
|
+ 'show_plots': True,
|
|
|
+
|
|
|
+ # 图片DPI
|
|
|
+ 'plot_dpi': 150,
|
|
|
+
|
|
|
+ # 随机种子(设置为None表示不固定种子)
|
|
|
+ 'random_seed': 42
|
|
|
+}
|
|
|
+# =====================================================
|
|
|
+
|
|
|
+
|
|
|
+def _get_current_directory():
|
|
|
+ """
|
|
|
+ 获取当前文件所在目录
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+ except NameError:
|
|
|
+ current_dir = os.getcwd()
|
|
|
+ if not os.path.exists(os.path.join(current_dir, CONFIG['csv_filename'])):
|
|
|
+ parent_dir = os.path.dirname(current_dir)
|
|
|
+ future_dir = os.path.join(parent_dir, 'future')
|
|
|
+ if os.path.exists(os.path.join(future_dir, CONFIG['csv_filename'])):
|
|
|
+ current_dir = future_dir
|
|
|
+ return current_dir
|
|
|
+
|
|
|
+
|
|
|
+def read_transaction_data(csv_path):
|
|
|
+ """
|
|
|
+ 读取交易记录CSV文件
|
|
|
+ """
|
|
|
+ encodings = ['utf-8-sig', 'utf-8', 'gbk', 'gb2312', 'gb18030', 'latin1']
|
|
|
+
|
|
|
+ for encoding in encodings:
|
|
|
+ try:
|
|
|
+ df = pd.read_csv(csv_path, encoding=encoding)
|
|
|
+ print(f"成功使用 {encoding} 编码读取CSV文件,共 {len(df)} 条记录")
|
|
|
+ return df
|
|
|
+ except UnicodeDecodeError:
|
|
|
+ continue
|
|
|
+ except Exception as e:
|
|
|
+ if encoding == encodings[-1]:
|
|
|
+ print(f"读取CSV文件时出错: {str(e)}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ return pd.DataFrame()
|
|
|
+
|
|
|
+
|
|
|
+def extract_contract_info(row):
|
|
|
+ """
|
|
|
+ 从交易记录中提取合约信息和交易信息
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 提取合约编号
|
|
|
+ target_str = str(row['标的'])
|
|
|
+ match = re.search(r'\(([^)]+)\)', target_str)
|
|
|
+ if match:
|
|
|
+ contract_code = match.group(1)
|
|
|
+ else:
|
|
|
+ return None, None, None, None, None, None, None, None, None
|
|
|
+
|
|
|
+ # 提取日期
|
|
|
+ date_str = str(row['日期']).strip()
|
|
|
+ date_formats = ['%Y-%m-%d', '%d/%m/%Y', '%Y/%m/%d', '%d-%m-%Y', '%Y%m%d']
|
|
|
+
|
|
|
+ trade_date = None
|
|
|
+ for date_format in date_formats:
|
|
|
+ try:
|
|
|
+ trade_date = datetime.strptime(date_str, date_format).date()
|
|
|
+ break
|
|
|
+ except ValueError:
|
|
|
+ continue
|
|
|
+
|
|
|
+ if trade_date is None:
|
|
|
+ return None, None, None, None, None, None, None, None, None
|
|
|
+
|
|
|
+ # 提取委托时间
|
|
|
+ order_time_str = str(row['委托时间']).strip()
|
|
|
+ try:
|
|
|
+ time_parts = order_time_str.split(':')
|
|
|
+ hour = int(time_parts[0])
|
|
|
+
|
|
|
+ # 如果委托时间 >= 21:00,需要找到下一个交易日
|
|
|
+ if hour >= 21:
|
|
|
+ try:
|
|
|
+ trade_days = get_trade_days(start_date=trade_date, count=2)
|
|
|
+ if len(trade_days) >= 2:
|
|
|
+ next_trade_day = trade_days[1]
|
|
|
+ if isinstance(next_trade_day, datetime):
|
|
|
+ actual_trade_date = next_trade_day.date()
|
|
|
+ elif isinstance(next_trade_day, date):
|
|
|
+ actual_trade_date = next_trade_day
|
|
|
+ else:
|
|
|
+ actual_trade_date = trade_date
|
|
|
+ else:
|
|
|
+ actual_trade_date = trade_date
|
|
|
+ except:
|
|
|
+ actual_trade_date = trade_date
|
|
|
+ else:
|
|
|
+ actual_trade_date = trade_date
|
|
|
+ except:
|
|
|
+ actual_trade_date = trade_date
|
|
|
+
|
|
|
+ # 提取成交价
|
|
|
+ try:
|
|
|
+ trade_price = float(row['成交价'])
|
|
|
+ except:
|
|
|
+ return None, None, None, None, None, None, None, None, None
|
|
|
+
|
|
|
+ # 提取交易方向和类型
|
|
|
+ trade_type = str(row['交易类型']).strip()
|
|
|
+ if '开多' in trade_type:
|
|
|
+ direction = 'long'
|
|
|
+ action = 'open'
|
|
|
+ elif '开空' in trade_type:
|
|
|
+ direction = 'short'
|
|
|
+ action = 'open'
|
|
|
+ elif '平多' in trade_type or '平空' in trade_type:
|
|
|
+ action = 'close'
|
|
|
+ direction = 'long' if '平多' in trade_type else 'short'
|
|
|
+ else:
|
|
|
+ return None, None, None, None, None, None, None, None, None
|
|
|
+
|
|
|
+ # 提取交易对ID和连续交易对ID
|
|
|
+ trade_pair_id = row.get('交易对ID', 'N/A')
|
|
|
+ continuous_pair_id = row.get('连续交易对ID', 'N/A')
|
|
|
+
|
|
|
+ return contract_code, actual_trade_date, trade_price, direction, action, order_time_str, trade_type, trade_pair_id, continuous_pair_id
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"提取合约信息时出错: {str(e)}")
|
|
|
+ return None, None, None, None, None, None, None, None, None
|
|
|
+
|
|
|
+
|
|
|
+def get_trade_day_range(trade_date, days_before, days_after):
|
|
|
+ """
|
|
|
+ 获取交易日范围
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 获取历史交易日
|
|
|
+ trade_days_before = get_trade_days(end_date=trade_date, count=days_before + 1)
|
|
|
+ if len(trade_days_before) < days_before + 1:
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ first_day = trade_days_before[0]
|
|
|
+ if isinstance(first_day, datetime):
|
|
|
+ start_date = first_day.date()
|
|
|
+ elif isinstance(first_day, date):
|
|
|
+ start_date = first_day
|
|
|
+ else:
|
|
|
+ start_date = first_day
|
|
|
+
|
|
|
+ # 获取未来交易日
|
|
|
+ trade_days_after = get_trade_days(start_date=trade_date, count=days_after + 1)
|
|
|
+ if len(trade_days_after) < days_after + 1:
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ last_day = trade_days_after[-1]
|
|
|
+ if isinstance(last_day, datetime):
|
|
|
+ end_date = last_day.date()
|
|
|
+ elif isinstance(last_day, date):
|
|
|
+ end_date = last_day
|
|
|
+ else:
|
|
|
+ end_date = last_day
|
|
|
+
|
|
|
+ return start_date, end_date
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"计算交易日范围时出错: {str(e)}")
|
|
|
+ return None, None
|
|
|
+
|
|
|
+
|
|
|
+def get_kline_data_with_future(contract_code, trade_date, days_before=100, days_after=20):
|
|
|
+ """
|
|
|
+ 获取包含历史和未来数据的K线数据
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 获取完整的数据范围
|
|
|
+ start_date, end_date = get_trade_day_range(trade_date, days_before, days_after)
|
|
|
+ if start_date is None or end_date is None:
|
|
|
+ return None, None, None
|
|
|
+
|
|
|
+ # 获取K线数据
|
|
|
+ price_data = get_price(
|
|
|
+ contract_code,
|
|
|
+ start_date=start_date,
|
|
|
+ end_date=end_date,
|
|
|
+ frequency='1d',
|
|
|
+ fields=['open', 'close', 'high', 'low']
|
|
|
+ )
|
|
|
+
|
|
|
+ if price_data is None or len(price_data) == 0:
|
|
|
+ return None, None, None
|
|
|
+
|
|
|
+ # 计算均线
|
|
|
+ price_data['ma5'] = price_data['close'].rolling(window=5).mean()
|
|
|
+ price_data['ma10'] = price_data['close'].rolling(window=10).mean()
|
|
|
+ price_data['ma20'] = price_data['close'].rolling(window=20).mean()
|
|
|
+ price_data['ma30'] = price_data['close'].rolling(window=30).mean()
|
|
|
+
|
|
|
+ # 找到交易日在数据中的位置
|
|
|
+ trade_date_normalized = pd.Timestamp(trade_date)
|
|
|
+ trade_idx = None
|
|
|
+
|
|
|
+ for i, idx in enumerate(price_data.index):
|
|
|
+ if isinstance(idx, pd.Timestamp):
|
|
|
+ if idx.date() == trade_date:
|
|
|
+ trade_idx = i
|
|
|
+ break
|
|
|
+
|
|
|
+ return price_data, trade_idx
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"获取K线数据时出错: {str(e)}")
|
|
|
+ return None, None, None
|
|
|
+
|
|
|
+
|
|
|
+def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, save_path=None):
|
|
|
+ """
|
|
|
+ 绘制部分K线图(仅显示历史数据和当天)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 截取历史数据和当天数据
|
|
|
+ partial_data = data.iloc[:trade_idx + 1].copy()
|
|
|
+
|
|
|
+ # 修改当天的收盘价为成交价
|
|
|
+ partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price
|
|
|
+
|
|
|
+ fig, ax = plt.subplots(figsize=(16, 10))
|
|
|
+
|
|
|
+ # 准备数据
|
|
|
+ dates = partial_data.index
|
|
|
+ opens = partial_data['open']
|
|
|
+ highs = partial_data['high']
|
|
|
+ lows = partial_data['low']
|
|
|
+ closes = partial_data['close']
|
|
|
+
|
|
|
+ # 绘制K线
|
|
|
+ for i in range(len(partial_data)):
|
|
|
+ color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
|
|
|
+ edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
|
|
|
+
|
|
|
+ # 影线
|
|
|
+ ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
|
|
|
+
|
|
|
+ # 实体
|
|
|
+ body_height = abs(closes.iloc[i] - opens.iloc[i])
|
|
|
+ if body_height == 0:
|
|
|
+ body_height = 0.01
|
|
|
+ bottom = min(opens.iloc[i], closes.iloc[i])
|
|
|
+
|
|
|
+ rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height,
|
|
|
+ linewidth=1, edgecolor=edge_color,
|
|
|
+ facecolor=color, alpha=0.8)
|
|
|
+ ax.add_patch(rect)
|
|
|
+
|
|
|
+ # 绘制均线
|
|
|
+ ax.plot(range(len(partial_data)), partial_data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8)
|
|
|
+ ax.plot(range(len(partial_data)), partial_data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8)
|
|
|
+ ax.plot(range(len(partial_data)), partial_data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
|
|
|
+ ax.plot(range(len(partial_data)), partial_data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
|
|
|
+
|
|
|
+ # 标注开仓位置
|
|
|
+ ax.plot(trade_idx, trade_price, marker='*', markersize=20,
|
|
|
+ color='yellow', markeredgecolor='black', markeredgewidth=2,
|
|
|
+ label='Open Position', zorder=10)
|
|
|
+
|
|
|
+ # 添加垂直线
|
|
|
+ ax.axvline(x=trade_idx, color='yellow', linestyle='--',
|
|
|
+ linewidth=2, alpha=0.7, zorder=5)
|
|
|
+
|
|
|
+ # 标注信息
|
|
|
+ date_label = trade_date.strftime('%Y-%m-%d')
|
|
|
+ price_label = f'Price: {trade_price:.2f}'
|
|
|
+ direction_label = f'Direction: {"Long" if direction == "long" else "Short"}'
|
|
|
+ time_label = f'Time: {order_time}'
|
|
|
+
|
|
|
+ # 计算文本位置
|
|
|
+ price_range = highs.max() - lows.min()
|
|
|
+ y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
|
|
|
+ text_y = trade_price + y_offset
|
|
|
+
|
|
|
+ if text_y > highs.max():
|
|
|
+ text_y = trade_price - price_range * 0.08
|
|
|
+
|
|
|
+ annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}'
|
|
|
+ ax.text(trade_idx, text_y, annotation_text,
|
|
|
+ fontsize=10, ha='center', va='bottom',
|
|
|
+ bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
|
|
|
+ zorder=11, weight='bold')
|
|
|
+
|
|
|
+ # 设置标题和标签
|
|
|
+ contract_simple = contract_code.split('.')[0]
|
|
|
+ direction_text = "Long" if direction == "long" else "Short"
|
|
|
+ ax.set_title(f'{contract_simple} - {direction_text} Position Decision\n'
|
|
|
+ f'Historical Data + Trade Day Only',
|
|
|
+ fontsize=14, fontweight='bold', pad=20)
|
|
|
+
|
|
|
+ ax.set_xlabel('Time', fontsize=12)
|
|
|
+ ax.set_ylabel('Price', fontsize=12)
|
|
|
+ ax.grid(True, alpha=0.3)
|
|
|
+ ax.legend(loc='lower left', fontsize=10)
|
|
|
+
|
|
|
+ # 设置x轴标签
|
|
|
+ step = max(1, len(partial_data) // 10)
|
|
|
+ tick_positions = range(0, len(partial_data), step)
|
|
|
+ tick_labels = []
|
|
|
+ for pos in tick_positions:
|
|
|
+ date_val = dates[pos]
|
|
|
+ if isinstance(date_val, (date, datetime)):
|
|
|
+ tick_labels.append(date_val.strftime('%Y-%m-%d'))
|
|
|
+ else:
|
|
|
+ tick_labels.append(str(date_val))
|
|
|
+
|
|
|
+ ax.set_xticks(tick_positions)
|
|
|
+ ax.set_xticklabels(tick_labels, rotation=45, ha='right')
|
|
|
+
|
|
|
+ plt.tight_layout()
|
|
|
+
|
|
|
+ if save_path:
|
|
|
+ plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight')
|
|
|
+
|
|
|
+ if CONFIG['show_plots']:
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ plt.close(fig)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"绘制部分K线图时出错: {str(e)}")
|
|
|
+ plt.close('all')
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, profit_loss, save_path=None):
|
|
|
+ """
|
|
|
+ 绘制完整K线图(包含未来数据)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ fig, ax = plt.subplots(figsize=(16, 10))
|
|
|
+
|
|
|
+ # 准备数据
|
|
|
+ dates = data.index
|
|
|
+ opens = data['open']
|
|
|
+ highs = data['high']
|
|
|
+ lows = data['low']
|
|
|
+ closes = data['close']
|
|
|
+
|
|
|
+ # 绘制K线
|
|
|
+ for i in range(len(data)):
|
|
|
+ color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
|
|
|
+ edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
|
|
|
+
|
|
|
+ # 影线
|
|
|
+ ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
|
|
|
+
|
|
|
+ # 实体
|
|
|
+ body_height = abs(closes.iloc[i] - opens.iloc[i])
|
|
|
+ if body_height == 0:
|
|
|
+ body_height = 0.01
|
|
|
+ bottom = min(opens.iloc[i], closes.iloc[i])
|
|
|
+
|
|
|
+ rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height,
|
|
|
+ linewidth=1, edgecolor=edge_color,
|
|
|
+ facecolor=color, alpha=0.8)
|
|
|
+ ax.add_patch(rect)
|
|
|
+
|
|
|
+ # 绘制均线
|
|
|
+ ax.plot(range(len(data)), data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8)
|
|
|
+ ax.plot(range(len(data)), data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8)
|
|
|
+ ax.plot(range(len(data)), data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
|
|
|
+ ax.plot(range(len(data)), data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
|
|
|
+
|
|
|
+ # 标注开仓位置
|
|
|
+ ax.plot(trade_idx, trade_price, marker='*', markersize=20,
|
|
|
+ color='yellow', markeredgecolor='black', markeredgewidth=2,
|
|
|
+ label='Open Position', zorder=10)
|
|
|
+
|
|
|
+ # 添加垂直线分隔历史和未来
|
|
|
+ ax.axvline(x=trade_idx, color='yellow', linestyle='--',
|
|
|
+ linewidth=2, alpha=0.7, zorder=5)
|
|
|
+
|
|
|
+ # 添加未来区域背景
|
|
|
+ ax.axvspan(trade_idx + 0.5, len(data) - 0.5, alpha=0.1, color='gray', label='Future Data')
|
|
|
+
|
|
|
+ # 标注信息
|
|
|
+ date_label = trade_date.strftime('%Y-%m-%d')
|
|
|
+ price_label = f'Price: {trade_price:.2f}'
|
|
|
+ direction_label = f'Direction: {"Long" if direction == "long" else "Short"}'
|
|
|
+ time_label = f'Time: {order_time}'
|
|
|
+ profit_label = f'P&L: {profit_loss:+.2f}'
|
|
|
+
|
|
|
+ # 计算文本位置
|
|
|
+ price_range = highs.max() - lows.min()
|
|
|
+ y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
|
|
|
+ text_y = trade_price + y_offset
|
|
|
+
|
|
|
+ if text_y > highs.max():
|
|
|
+ text_y = trade_price - price_range * 0.08
|
|
|
+
|
|
|
+ annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}\n{profit_label}'
|
|
|
+ ax.text(trade_idx, text_y, annotation_text,
|
|
|
+ fontsize=10, ha='center', va='bottom',
|
|
|
+ bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
|
|
|
+ zorder=11, weight='bold')
|
|
|
+
|
|
|
+ # 设置标题和标签
|
|
|
+ contract_simple = contract_code.split('.')[0]
|
|
|
+ direction_text = "Long" if direction == "long" else "Short"
|
|
|
+ ax.set_title(f'{contract_simple} - {direction_text} Position Result\n'
|
|
|
+ f'Complete Data with Future {CONFIG["future_days"]} Days',
|
|
|
+ fontsize=14, fontweight='bold', pad=20)
|
|
|
+
|
|
|
+ ax.set_xlabel('Time', fontsize=12)
|
|
|
+ ax.set_ylabel('Price', fontsize=12)
|
|
|
+ ax.grid(True, alpha=0.3)
|
|
|
+ ax.legend(loc='lower left', fontsize=10)
|
|
|
+
|
|
|
+ # 设置x轴标签
|
|
|
+ step = max(1, len(data) // 15)
|
|
|
+ tick_positions = range(0, len(data), step)
|
|
|
+ tick_labels = []
|
|
|
+ for pos in tick_positions:
|
|
|
+ date_val = dates[pos]
|
|
|
+ if isinstance(date_val, (date, datetime)):
|
|
|
+ tick_labels.append(date_val.strftime('%Y-%m-%d'))
|
|
|
+ else:
|
|
|
+ tick_labels.append(str(date_val))
|
|
|
+
|
|
|
+ ax.set_xticks(tick_positions)
|
|
|
+ ax.set_xticklabels(tick_labels, rotation=45, ha='right')
|
|
|
+
|
|
|
+ plt.tight_layout()
|
|
|
+
|
|
|
+ if save_path:
|
|
|
+ plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight')
|
|
|
+
|
|
|
+ if CONFIG['show_plots']:
|
|
|
+ plt.show()
|
|
|
+
|
|
|
+ plt.close(fig)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"绘制完整K线图时出错: {str(e)}")
|
|
|
+ plt.close('all')
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+def load_processed_results(result_path):
|
|
|
+ """
|
|
|
+ 加载已处理的结果文件
|
|
|
+ """
|
|
|
+ if not os.path.exists(result_path):
|
|
|
+ return pd.DataFrame(), set()
|
|
|
+
|
|
|
+ try:
|
|
|
+ df = pd.read_csv(result_path)
|
|
|
+ # 获取已处理的交易对ID
|
|
|
+ processed_pairs = set(df['交易对ID'].unique())
|
|
|
+ return df, processed_pairs
|
|
|
+ except Exception as e:
|
|
|
+ print(f"加载结果文件时出错: {str(e)}")
|
|
|
+ return pd.DataFrame(), set()
|
|
|
+
|
|
|
+
|
|
|
+def calculate_profit_loss(df, trade_pair_id, continuous_pair_id):
|
|
|
+ """
|
|
|
+ 计算平仓盈亏
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ if continuous_pair_id != 'N/A' and pd.notna(continuous_pair_id):
|
|
|
+ # 合并所有同一连续交易对ID的平仓盈亏
|
|
|
+ close_trades = df[
|
|
|
+ (df['连续交易对ID'] == continuous_pair_id) &
|
|
|
+ (df['交易类型'].str[0] == '平')
|
|
|
+ ]
|
|
|
+ total_profit = close_trades['平仓盈亏'].sum()
|
|
|
+ else:
|
|
|
+ # 只查找当前交易对ID的平仓交易
|
|
|
+ close_trades = df[
|
|
|
+ (df['交易对ID'] == trade_pair_id) &
|
|
|
+ (df['交易类型'].str[0] == '平')
|
|
|
+ ]
|
|
|
+ if len(close_trades) > 0:
|
|
|
+ total_profit = close_trades['平仓盈亏'].iloc[0]
|
|
|
+ else:
|
|
|
+ total_profit = 0
|
|
|
+
|
|
|
+ return total_profit
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"计算盈亏时出错: {str(e)}")
|
|
|
+ return 0
|
|
|
+
|
|
|
+
|
|
|
+def record_result(result_data, result_path):
|
|
|
+ """
|
|
|
+ 记录训练结果
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 创建结果DataFrame
|
|
|
+ result_df = pd.DataFrame([result_data])
|
|
|
+
|
|
|
+ # 如果文件已存在,追加写入;否则创建新文件
|
|
|
+ if os.path.exists(result_path):
|
|
|
+ result_df.to_csv(result_path, mode='a', header=False, index=False, encoding='utf-8-sig')
|
|
|
+ else:
|
|
|
+ result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig')
|
|
|
+
|
|
|
+ print(f"结果已记录到: {result_path}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"记录结果时出错: {str(e)}")
|
|
|
+
|
|
|
+
|
|
|
+def get_user_decision():
|
|
|
+ """
|
|
|
+ 获取用户的开仓决策
|
|
|
+ """
|
|
|
+ while True:
|
|
|
+ decision = input("\n是否开仓?请输入 'y' (开仓) 或 'n' (不开仓): ").strip().lower()
|
|
|
+ if decision in ['y', 'yes', '是', '开仓']:
|
|
|
+ return True
|
|
|
+ elif decision in ['n', 'no', '否', '不开仓']:
|
|
|
+ return False
|
|
|
+ else:
|
|
|
+ print("请输入有效的选项: 'y' 或 'n'")
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ """
|
|
|
+ 主函数
|
|
|
+ """
|
|
|
+ print("=" * 60)
|
|
|
+ print("交易训练工具")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ # 设置随机种子
|
|
|
+ if CONFIG['random_seed'] is not None:
|
|
|
+ random.seed(CONFIG['random_seed'])
|
|
|
+ np.random.seed(CONFIG['random_seed'])
|
|
|
+
|
|
|
+ # 获取当前目录
|
|
|
+ current_dir = _get_current_directory()
|
|
|
+ csv_path = os.path.join(current_dir, CONFIG['csv_filename'])
|
|
|
+ result_path = os.path.join(current_dir, CONFIG['result_filename'])
|
|
|
+ output_dir = os.path.join(current_dir, CONFIG['output_dir'])
|
|
|
+
|
|
|
+ # 创建输出目录
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
+
|
|
|
+ # 1. 读取交易数据
|
|
|
+ print("\n=== 步骤1: 读取交易数据 ===")
|
|
|
+ transaction_df = read_transaction_data(csv_path)
|
|
|
+ if len(transaction_df) == 0:
|
|
|
+ print("未能读取交易数据,退出")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 2. 加载已处理的结果
|
|
|
+ print("\n=== 步骤2: 加载已处理记录 ===")
|
|
|
+ existing_results, processed_pairs = load_processed_results(result_path)
|
|
|
+ print(f"已处理 {len(processed_pairs)} 个交易对")
|
|
|
+
|
|
|
+ # 3. 提取所有开仓交易
|
|
|
+ print("\n=== 步骤3: 提取开仓交易 ===")
|
|
|
+ open_trades = []
|
|
|
+
|
|
|
+ for idx, row in transaction_df.iterrows():
|
|
|
+ contract_code, trade_date, trade_price, direction, action, order_time, trade_type, trade_pair_id, continuous_pair_id = extract_contract_info(row)
|
|
|
+
|
|
|
+ if contract_code is None or action != 'open':
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 跳过已处理的交易对
|
|
|
+ if trade_pair_id in processed_pairs:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 查找对应的平仓交易
|
|
|
+ profit_loss = calculate_profit_loss(transaction_df, trade_pair_id, continuous_pair_id)
|
|
|
+
|
|
|
+ open_trades.append({
|
|
|
+ 'index': idx,
|
|
|
+ 'contract_code': contract_code,
|
|
|
+ 'trade_date': trade_date,
|
|
|
+ 'trade_price': trade_price,
|
|
|
+ 'direction': direction,
|
|
|
+ 'order_time': order_time,
|
|
|
+ 'trade_type': trade_type,
|
|
|
+ 'trade_pair_id': trade_pair_id,
|
|
|
+ 'continuous_pair_id': continuous_pair_id,
|
|
|
+ 'profit_loss': profit_loss,
|
|
|
+ 'original_row': row
|
|
|
+ })
|
|
|
+
|
|
|
+ print(f"找到 {len(open_trades)} 个未处理的开仓交易")
|
|
|
+
|
|
|
+ if len(open_trades) == 0:
|
|
|
+ print("没有未处理的开仓交易,退出")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 4. 随机选择一个交易(完全随机)
|
|
|
+ print("\n=== 步骤4: 随机选择交易 ===")
|
|
|
+
|
|
|
+ # 打乱交易列表顺序确保完全随机
|
|
|
+ random.shuffle(open_trades)
|
|
|
+ selected_trade = random.choice(open_trades)
|
|
|
+
|
|
|
+ print(f"选中交易: {selected_trade['contract_code']} - {selected_trade['trade_date']} - {selected_trade['direction']}")
|
|
|
+ print(f"剩余未处理交易: {len(open_trades) - 1} 个")
|
|
|
+
|
|
|
+ # 5. 获取K线数据
|
|
|
+ print("\n=== 步骤5: 获取K线数据 ===")
|
|
|
+ kline_data, trade_idx = get_kline_data_with_future(
|
|
|
+ selected_trade['contract_code'],
|
|
|
+ selected_trade['trade_date'],
|
|
|
+ CONFIG['history_days'],
|
|
|
+ CONFIG['future_days']
|
|
|
+ )
|
|
|
+
|
|
|
+ if kline_data is None or trade_idx is None:
|
|
|
+ print("获取K线数据失败,退出")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 6. 显示部分K线图
|
|
|
+ print("\n=== 步骤6: 显示部分K线图 ===")
|
|
|
+ partial_image_path = os.path.join(output_dir, f"partial_{selected_trade['trade_pair_id']}.png")
|
|
|
+ plot_partial_kline(
|
|
|
+ kline_data, trade_idx, selected_trade['trade_price'],
|
|
|
+ selected_trade['direction'], selected_trade['contract_code'],
|
|
|
+ selected_trade['trade_date'], selected_trade['order_time'],
|
|
|
+ partial_image_path
|
|
|
+ )
|
|
|
+
|
|
|
+ # 7. 获取用户决策
|
|
|
+ print(f"\n交易信息:")
|
|
|
+ print(f"合约: {selected_trade['contract_code']}")
|
|
|
+ print(f"日期: {selected_trade['trade_date']}")
|
|
|
+ print(f"方向: {'多头' if selected_trade['direction'] == 'long' else '空头'}")
|
|
|
+ print(f"成交价: {selected_trade['trade_price']}")
|
|
|
+
|
|
|
+ user_decision = get_user_decision()
|
|
|
+
|
|
|
+ # 8. 显示完整K线图
|
|
|
+ print("\n=== 步骤7: 显示完整K线图 ===")
|
|
|
+ full_image_path = os.path.join(output_dir, f"full_{selected_trade['trade_pair_id']}.png")
|
|
|
+ plot_full_kline(
|
|
|
+ kline_data, trade_idx, selected_trade['trade_price'],
|
|
|
+ selected_trade['direction'], selected_trade['contract_code'],
|
|
|
+ selected_trade['trade_date'], selected_trade['order_time'],
|
|
|
+ selected_trade['profit_loss'],
|
|
|
+ full_image_path
|
|
|
+ )
|
|
|
+
|
|
|
+ # 9. 记录结果
|
|
|
+ print("\n=== 步骤8: 记录结果 ===")
|
|
|
+
|
|
|
+ # 计算判定收益
|
|
|
+ decision_profit = selected_trade['profit_loss'] if user_decision else -selected_trade['profit_loss']
|
|
|
+
|
|
|
+ result_data = {
|
|
|
+ '日期': selected_trade['original_row']['日期'],
|
|
|
+ '委托时间': selected_trade['original_row']['委托时间'],
|
|
|
+ '标的': selected_trade['original_row']['标的'],
|
|
|
+ '交易类型': selected_trade['original_row']['交易类型'],
|
|
|
+ '成交数量': selected_trade['original_row']['成交数量'],
|
|
|
+ '成交价': selected_trade['original_row']['成交价'],
|
|
|
+ '平仓盈亏': selected_trade['profit_loss'],
|
|
|
+ '用户判定': '开仓' if user_decision else '不开仓',
|
|
|
+ '判定收益': decision_profit,
|
|
|
+ '交易对ID': selected_trade['trade_pair_id'],
|
|
|
+ '连续交易对ID': selected_trade['continuous_pair_id']
|
|
|
+ }
|
|
|
+
|
|
|
+ record_result(result_data, result_path)
|
|
|
+
|
|
|
+ print(f"\n=== 训练完成 ===")
|
|
|
+ print(f"用户判定: {'开仓' if user_decision else '不开仓'}")
|
|
|
+ print(f"实际盈亏: {selected_trade['profit_loss']:+.2f}")
|
|
|
+ print(f"判定收益: {decision_profit:+.2f}")
|
|
|
+ print(f"结果已保存到: {result_path}")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ main()
|