Эх сурвалжийг харах

1. 新增交易训练工具:辅助顺势交易。
- 提供从交易记录中随机选择交易并显示K线图的功能
- 让用户根据K线图判断是否跟随交易。
- 记录用户决策和收益

2. 针对期货交易的一些工具的逻辑记录文档

maxfeng 1 сар өмнө
parent
commit
39b9aff47c

+ 122 - 0
Lib/future/tools_documentation.md

@@ -0,0 +1,122 @@
+# 期货交易工具逻辑文档
+
+## 1. 交易训练工具 (trading_training_tool.py)
+
+### 功能概述
+交易训练工具是一个用于训练交易决策能力的交互式工具。它从交易记录中随机选择一笔开仓交易,只显示历史K线数据和当天的部分信息(当天收盘价被替换为实际成交价),让用户判断是否开仓,然后显示包含未来20天数据的完整K线图,并记录用户的决策结果。
+
+### 核心逻辑流程
+
+#### 1.1 数据准备阶段
+1. **读取交易数据**
+   - 从CSV文件读取所有交易记录
+   - 支持多种编码格式(utf-8, gbk, gb2312等)
+   - 自动识别文件编码
+
+2. **加载历史处理记录**
+   - 读取已存在的训练结果文件(training_results.csv)
+   - 提取已处理的交易对ID列表,避免重复处理
+
+3. **提取开仓交易**
+   - 遍历所有交易记录,筛选出开仓交易(交易类型以"开"开头)
+   - 提取合约代码、交易日期、成交价、方向等关键信息
+   - 跳过已处理的交易对
+   - 计算对应的平仓盈亏(考虑连续交易对ID的情况)
+
+#### 1.2 随机选择机制
+- **完全随机选择**:使用random.shuffle打乱交易列表,然后随机选择
+- **避免重复**:通过已处理的交易对ID列表确保不重复选择
+- **显示统计**:显示剩余未处理的交易数量
+
+#### 1.3 K线数据处理
+1. **获取完整数据**
+   - 获取历史100天 + 未来20天的完整K线数据
+   - 包含开高低收价格和5/10/20/30日均线
+
+2. **部分K线图生成**
+   - 截取历史数据+当天数据
+   - 将当天收盘价替换为实际成交价
+   - 生成只有历史信息的K线图
+
+3. **完整K线图生成**
+   - 使用原始数据(不修改当天收盘价)
+   - 显示历史和未来20天的完整走势
+   - 用灰色背景标注未来数据区域
+
+#### 1.4 用户交互
+1. **显示交易信息**
+   - 合约代码
+   - 交易日期
+   - 交易方向(多头/空头)
+   - 成交价
+
+2. **获取用户决策**
+   - 提示用户输入'y'(开仓)或'n'(不开仓)
+   - 循环验证输入直到获得有效决策
+
+#### 1.5 结果记录
+1. **计算判定收益**
+   - 用户判定开仓:判定收益 = 实际平仓盈亏
+   - 用户判定不开仓:判定收益 = -实际平仓盈亏
+
+2. **记录格式**
+   - 包含原始交易信息(日期、时间、标的、类型、数量、价格)
+   - 平仓盈亏(可能合并连续交易对)
+   - 用户判定(开仓/不开仓)
+   - 判定收益
+   - 交易对ID和连续交易对ID
+
+### 特殊处理逻辑
+
+#### 连续交易对ID处理
+- 如果记录包含有效的连续交易对ID(非'N/A'),则合并所有同一连续交易对ID的平仓盈亏
+- 这允许处理跨期换月等连续持仓的情况
+
+#### 日期和时间处理
+- 支持多种日期格式(YYYY-MM-DD, DD/MM/YYYY等)
+- 夜盘交易处理:委托时间>=21:00的交易,实际交易日为下一个交易日
+- 使用JoinQuant的get_trade_days获取准确的交易日历
+
+#### 合约代码提取
+- 从"标的"列的括号中提取标准合约代码(如:"豆粕2501(DM2501)" -> "DM2501")
+- 自动添加交易所后缀(如.XDCE)
+
+### 配置参数
+所有参数集中在CONFIG字典中:
+- `csv_filename`: 输入交易记录文件名
+- `result_filename`: 输出结果文件名
+- `history_days`: 历史数据天数(默认100)
+- `future_days`: 未来数据天数(默认20)
+- `output_dir`: 图片输出目录
+- `show_plots`: 是否显示图片
+- `plot_dpi`: 图片分辨率
+- `random_seed`: 随机数种子(确保可重复性)
+
+### 输出文件
+1. **训练结果CSV**(training_results.csv)
+   - 追加模式写入,保留所有历史记录
+   - UTF-8编码,支持中文
+
+2. **K线图片**(training_images/目录)
+   - partial_*.png: 部分K线图
+   - full_*.png: 完整K线图
+
+### 使用场景
+- 交易决策训练:通过历史数据练习判断开仓时机
+- 策略验证:验证人工判断与策略信号的一致性
+- 交易复盘:分析决策质量,改进交易技能
+
+## 2. K线复原工具 (kline_reconstruction.py)
+
+### 功能概述
+(待补充详细逻辑说明)
+
+### 核心功能
+- 从交易记录中提取所有开仓交易
+- 为每个开仓交易生成对应的K线图
+- 包含均线和技术指标
+- 批量生成并打包成ZIP文件
+
+---
+
+*注:本文档主要介绍交易训练工具的详细逻辑,其他工具的逻辑说明将逐步补充。*

+ 720 - 0
Lib/future/trading_training_tool.py

@@ -0,0 +1,720 @@
+# 交易训练工具
+# 用于从交易记录CSV文件中随机选择交易,显示部分K线图让用户判断是否开仓,然后显示完整结果并记录
+
+from jqdata import *
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.patches as patches
+from datetime import datetime, timedelta, date
+import re
+import os
+import random
+import warnings
+warnings.filterwarnings('ignore')
+
+# 中文字体设置
+plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
+plt.rcParams['axes.unicode_minus'] = False
+
+# ========== 参数配置区域(硬编码参数都在这里) ==========
+CONFIG = {
+    # CSV文件名
+    'csv_filename': 'transaction1.csv',
+
+    # 结果记录文件名
+    'result_filename': 'training_results.csv',
+
+    # 历史数据天数
+    'history_days': 100,
+
+    # 未来数据天数
+    'future_days': 20,
+
+    # 输出目录
+    'output_dir': 'training_images',
+
+    # 是否显示图片(在某些环境下可能需要设置为False)
+    'show_plots': True,
+
+    # 图片DPI
+    'plot_dpi': 150,
+
+    # 随机种子(设置为None表示不固定种子)
+    'random_seed': 42
+}
+# =====================================================
+
+
+def _get_current_directory():
+    """
+    获取当前文件所在目录
+    """
+    try:
+        current_dir = os.path.dirname(os.path.abspath(__file__))
+    except NameError:
+        current_dir = os.getcwd()
+        if not os.path.exists(os.path.join(current_dir, CONFIG['csv_filename'])):
+            parent_dir = os.path.dirname(current_dir)
+            future_dir = os.path.join(parent_dir, 'future')
+            if os.path.exists(os.path.join(future_dir, CONFIG['csv_filename'])):
+                current_dir = future_dir
+    return current_dir
+
+
+def read_transaction_data(csv_path):
+    """
+    读取交易记录CSV文件
+    """
+    encodings = ['utf-8-sig', 'utf-8', 'gbk', 'gb2312', 'gb18030', 'latin1']
+
+    for encoding in encodings:
+        try:
+            df = pd.read_csv(csv_path, encoding=encoding)
+            print(f"成功使用 {encoding} 编码读取CSV文件,共 {len(df)} 条记录")
+            return df
+        except UnicodeDecodeError:
+            continue
+        except Exception as e:
+            if encoding == encodings[-1]:
+                print(f"读取CSV文件时出错: {str(e)}")
+                raise
+
+    return pd.DataFrame()
+
+
+def extract_contract_info(row):
+    """
+    从交易记录中提取合约信息和交易信息
+    """
+    try:
+        # 提取合约编号
+        target_str = str(row['标的'])
+        match = re.search(r'\(([^)]+)\)', target_str)
+        if match:
+            contract_code = match.group(1)
+        else:
+            return None, None, None, None, None, None, None, None, None
+
+        # 提取日期
+        date_str = str(row['日期']).strip()
+        date_formats = ['%Y-%m-%d', '%d/%m/%Y', '%Y/%m/%d', '%d-%m-%Y', '%Y%m%d']
+
+        trade_date = None
+        for date_format in date_formats:
+            try:
+                trade_date = datetime.strptime(date_str, date_format).date()
+                break
+            except ValueError:
+                continue
+
+        if trade_date is None:
+            return None, None, None, None, None, None, None, None, None
+
+        # 提取委托时间
+        order_time_str = str(row['委托时间']).strip()
+        try:
+            time_parts = order_time_str.split(':')
+            hour = int(time_parts[0])
+
+            # 如果委托时间 >= 21:00,需要找到下一个交易日
+            if hour >= 21:
+                try:
+                    trade_days = get_trade_days(start_date=trade_date, count=2)
+                    if len(trade_days) >= 2:
+                        next_trade_day = trade_days[1]
+                        if isinstance(next_trade_day, datetime):
+                            actual_trade_date = next_trade_day.date()
+                        elif isinstance(next_trade_day, date):
+                            actual_trade_date = next_trade_day
+                        else:
+                            actual_trade_date = trade_date
+                    else:
+                        actual_trade_date = trade_date
+                except:
+                    actual_trade_date = trade_date
+            else:
+                actual_trade_date = trade_date
+        except:
+            actual_trade_date = trade_date
+
+        # 提取成交价
+        try:
+            trade_price = float(row['成交价'])
+        except:
+            return None, None, None, None, None, None, None, None, None
+
+        # 提取交易方向和类型
+        trade_type = str(row['交易类型']).strip()
+        if '开多' in trade_type:
+            direction = 'long'
+            action = 'open'
+        elif '开空' in trade_type:
+            direction = 'short'
+            action = 'open'
+        elif '平多' in trade_type or '平空' in trade_type:
+            action = 'close'
+            direction = 'long' if '平多' in trade_type else 'short'
+        else:
+            return None, None, None, None, None, None, None, None, None
+
+        # 提取交易对ID和连续交易对ID
+        trade_pair_id = row.get('交易对ID', 'N/A')
+        continuous_pair_id = row.get('连续交易对ID', 'N/A')
+
+        return contract_code, actual_trade_date, trade_price, direction, action, order_time_str, trade_type, trade_pair_id, continuous_pair_id
+
+    except Exception as e:
+        print(f"提取合约信息时出错: {str(e)}")
+        return None, None, None, None, None, None, None, None, None
+
+
+def get_trade_day_range(trade_date, days_before, days_after):
+    """
+    获取交易日范围
+    """
+    try:
+        # 获取历史交易日
+        trade_days_before = get_trade_days(end_date=trade_date, count=days_before + 1)
+        if len(trade_days_before) < days_before + 1:
+            return None, None
+
+        first_day = trade_days_before[0]
+        if isinstance(first_day, datetime):
+            start_date = first_day.date()
+        elif isinstance(first_day, date):
+            start_date = first_day
+        else:
+            start_date = first_day
+
+        # 获取未来交易日
+        trade_days_after = get_trade_days(start_date=trade_date, count=days_after + 1)
+        if len(trade_days_after) < days_after + 1:
+            return None, None
+
+        last_day = trade_days_after[-1]
+        if isinstance(last_day, datetime):
+            end_date = last_day.date()
+        elif isinstance(last_day, date):
+            end_date = last_day
+        else:
+            end_date = last_day
+
+        return start_date, end_date
+
+    except Exception as e:
+        print(f"计算交易日范围时出错: {str(e)}")
+        return None, None
+
+
+def get_kline_data_with_future(contract_code, trade_date, days_before=100, days_after=20):
+    """
+    获取包含历史和未来数据的K线数据
+    """
+    try:
+        # 获取完整的数据范围
+        start_date, end_date = get_trade_day_range(trade_date, days_before, days_after)
+        if start_date is None or end_date is None:
+            return None, None, None
+
+        # 获取K线数据
+        price_data = get_price(
+            contract_code,
+            start_date=start_date,
+            end_date=end_date,
+            frequency='1d',
+            fields=['open', 'close', 'high', 'low']
+        )
+
+        if price_data is None or len(price_data) == 0:
+            return None, None, None
+
+        # 计算均线
+        price_data['ma5'] = price_data['close'].rolling(window=5).mean()
+        price_data['ma10'] = price_data['close'].rolling(window=10).mean()
+        price_data['ma20'] = price_data['close'].rolling(window=20).mean()
+        price_data['ma30'] = price_data['close'].rolling(window=30).mean()
+
+        # 找到交易日在数据中的位置
+        trade_date_normalized = pd.Timestamp(trade_date)
+        trade_idx = None
+
+        for i, idx in enumerate(price_data.index):
+            if isinstance(idx, pd.Timestamp):
+                if idx.date() == trade_date:
+                    trade_idx = i
+                    break
+
+        return price_data, trade_idx
+
+    except Exception as e:
+        print(f"获取K线数据时出错: {str(e)}")
+        return None, None, None
+
+
+def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, save_path=None):
+    """
+    绘制部分K线图(仅显示历史数据和当天)
+    """
+    try:
+        # 截取历史数据和当天数据
+        partial_data = data.iloc[:trade_idx + 1].copy()
+
+        # 修改当天的收盘价为成交价
+        partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price
+
+        fig, ax = plt.subplots(figsize=(16, 10))
+
+        # 准备数据
+        dates = partial_data.index
+        opens = partial_data['open']
+        highs = partial_data['high']
+        lows = partial_data['low']
+        closes = partial_data['close']
+
+        # 绘制K线
+        for i in range(len(partial_data)):
+            color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
+            edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
+
+            # 影线
+            ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
+
+            # 实体
+            body_height = abs(closes.iloc[i] - opens.iloc[i])
+            if body_height == 0:
+                body_height = 0.01
+            bottom = min(opens.iloc[i], closes.iloc[i])
+
+            rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height,
+                                   linewidth=1, edgecolor=edge_color,
+                                   facecolor=color, alpha=0.8)
+            ax.add_patch(rect)
+
+        # 绘制均线
+        ax.plot(range(len(partial_data)), partial_data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8)
+        ax.plot(range(len(partial_data)), partial_data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8)
+        ax.plot(range(len(partial_data)), partial_data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
+        ax.plot(range(len(partial_data)), partial_data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
+
+        # 标注开仓位置
+        ax.plot(trade_idx, trade_price, marker='*', markersize=20,
+               color='yellow', markeredgecolor='black', markeredgewidth=2,
+               label='Open Position', zorder=10)
+
+        # 添加垂直线
+        ax.axvline(x=trade_idx, color='yellow', linestyle='--',
+                  linewidth=2, alpha=0.7, zorder=5)
+
+        # 标注信息
+        date_label = trade_date.strftime('%Y-%m-%d')
+        price_label = f'Price: {trade_price:.2f}'
+        direction_label = f'Direction: {"Long" if direction == "long" else "Short"}'
+        time_label = f'Time: {order_time}'
+
+        # 计算文本位置
+        price_range = highs.max() - lows.min()
+        y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
+        text_y = trade_price + y_offset
+
+        if text_y > highs.max():
+            text_y = trade_price - price_range * 0.08
+
+        annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}'
+        ax.text(trade_idx, text_y, annotation_text,
+               fontsize=10, ha='center', va='bottom',
+               bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
+               zorder=11, weight='bold')
+
+        # 设置标题和标签
+        contract_simple = contract_code.split('.')[0]
+        direction_text = "Long" if direction == "long" else "Short"
+        ax.set_title(f'{contract_simple} - {direction_text} Position Decision\n'
+                    f'Historical Data + Trade Day Only',
+                    fontsize=14, fontweight='bold', pad=20)
+
+        ax.set_xlabel('Time', fontsize=12)
+        ax.set_ylabel('Price', fontsize=12)
+        ax.grid(True, alpha=0.3)
+        ax.legend(loc='lower left', fontsize=10)
+
+        # 设置x轴标签
+        step = max(1, len(partial_data) // 10)
+        tick_positions = range(0, len(partial_data), step)
+        tick_labels = []
+        for pos in tick_positions:
+            date_val = dates[pos]
+            if isinstance(date_val, (date, datetime)):
+                tick_labels.append(date_val.strftime('%Y-%m-%d'))
+            else:
+                tick_labels.append(str(date_val))
+
+        ax.set_xticks(tick_positions)
+        ax.set_xticklabels(tick_labels, rotation=45, ha='right')
+
+        plt.tight_layout()
+
+        if save_path:
+            plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight')
+
+        if CONFIG['show_plots']:
+            plt.show()
+
+        plt.close(fig)
+
+    except Exception as e:
+        print(f"绘制部分K线图时出错: {str(e)}")
+        plt.close('all')
+        raise
+
+
+def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, profit_loss, save_path=None):
+    """
+    绘制完整K线图(包含未来数据)
+    """
+    try:
+        fig, ax = plt.subplots(figsize=(16, 10))
+
+        # 准备数据
+        dates = data.index
+        opens = data['open']
+        highs = data['high']
+        lows = data['low']
+        closes = data['close']
+
+        # 绘制K线
+        for i in range(len(data)):
+            color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
+            edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
+
+            # 影线
+            ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
+
+            # 实体
+            body_height = abs(closes.iloc[i] - opens.iloc[i])
+            if body_height == 0:
+                body_height = 0.01
+            bottom = min(opens.iloc[i], closes.iloc[i])
+
+            rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height,
+                                   linewidth=1, edgecolor=edge_color,
+                                   facecolor=color, alpha=0.8)
+            ax.add_patch(rect)
+
+        # 绘制均线
+        ax.plot(range(len(data)), data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8)
+        ax.plot(range(len(data)), data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8)
+        ax.plot(range(len(data)), data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
+        ax.plot(range(len(data)), data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
+
+        # 标注开仓位置
+        ax.plot(trade_idx, trade_price, marker='*', markersize=20,
+               color='yellow', markeredgecolor='black', markeredgewidth=2,
+               label='Open Position', zorder=10)
+
+        # 添加垂直线分隔历史和未来
+        ax.axvline(x=trade_idx, color='yellow', linestyle='--',
+                  linewidth=2, alpha=0.7, zorder=5)
+
+        # 添加未来区域背景
+        ax.axvspan(trade_idx + 0.5, len(data) - 0.5, alpha=0.1, color='gray', label='Future Data')
+
+        # 标注信息
+        date_label = trade_date.strftime('%Y-%m-%d')
+        price_label = f'Price: {trade_price:.2f}'
+        direction_label = f'Direction: {"Long" if direction == "long" else "Short"}'
+        time_label = f'Time: {order_time}'
+        profit_label = f'P&L: {profit_loss:+.2f}'
+
+        # 计算文本位置
+        price_range = highs.max() - lows.min()
+        y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
+        text_y = trade_price + y_offset
+
+        if text_y > highs.max():
+            text_y = trade_price - price_range * 0.08
+
+        annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}\n{profit_label}'
+        ax.text(trade_idx, text_y, annotation_text,
+               fontsize=10, ha='center', va='bottom',
+               bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
+               zorder=11, weight='bold')
+
+        # 设置标题和标签
+        contract_simple = contract_code.split('.')[0]
+        direction_text = "Long" if direction == "long" else "Short"
+        ax.set_title(f'{contract_simple} - {direction_text} Position Result\n'
+                    f'Complete Data with Future {CONFIG["future_days"]} Days',
+                    fontsize=14, fontweight='bold', pad=20)
+
+        ax.set_xlabel('Time', fontsize=12)
+        ax.set_ylabel('Price', fontsize=12)
+        ax.grid(True, alpha=0.3)
+        ax.legend(loc='lower left', fontsize=10)
+
+        # 设置x轴标签
+        step = max(1, len(data) // 15)
+        tick_positions = range(0, len(data), step)
+        tick_labels = []
+        for pos in tick_positions:
+            date_val = dates[pos]
+            if isinstance(date_val, (date, datetime)):
+                tick_labels.append(date_val.strftime('%Y-%m-%d'))
+            else:
+                tick_labels.append(str(date_val))
+
+        ax.set_xticks(tick_positions)
+        ax.set_xticklabels(tick_labels, rotation=45, ha='right')
+
+        plt.tight_layout()
+
+        if save_path:
+            plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight')
+
+        if CONFIG['show_plots']:
+            plt.show()
+
+        plt.close(fig)
+
+    except Exception as e:
+        print(f"绘制完整K线图时出错: {str(e)}")
+        plt.close('all')
+        raise
+
+
+def load_processed_results(result_path):
+    """
+    加载已处理的结果文件
+    """
+    if not os.path.exists(result_path):
+        return pd.DataFrame(), set()
+
+    try:
+        df = pd.read_csv(result_path)
+        # 获取已处理的交易对ID
+        processed_pairs = set(df['交易对ID'].unique())
+        return df, processed_pairs
+    except Exception as e:
+        print(f"加载结果文件时出错: {str(e)}")
+        return pd.DataFrame(), set()
+
+
+def calculate_profit_loss(df, trade_pair_id, continuous_pair_id):
+    """
+    计算平仓盈亏
+    """
+    try:
+        if continuous_pair_id != 'N/A' and pd.notna(continuous_pair_id):
+            # 合并所有同一连续交易对ID的平仓盈亏
+            close_trades = df[
+                (df['连续交易对ID'] == continuous_pair_id) &
+                (df['交易类型'].str[0] == '平')
+            ]
+            total_profit = close_trades['平仓盈亏'].sum()
+        else:
+            # 只查找当前交易对ID的平仓交易
+            close_trades = df[
+                (df['交易对ID'] == trade_pair_id) &
+                (df['交易类型'].str[0] == '平')
+            ]
+            if len(close_trades) > 0:
+                total_profit = close_trades['平仓盈亏'].iloc[0]
+            else:
+                total_profit = 0
+
+        return total_profit
+
+    except Exception as e:
+        print(f"计算盈亏时出错: {str(e)}")
+        return 0
+
+
+def record_result(result_data, result_path):
+    """
+    记录训练结果
+    """
+    try:
+        # 创建结果DataFrame
+        result_df = pd.DataFrame([result_data])
+
+        # 如果文件已存在,追加写入;否则创建新文件
+        if os.path.exists(result_path):
+            result_df.to_csv(result_path, mode='a', header=False, index=False, encoding='utf-8-sig')
+        else:
+            result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig')
+
+        print(f"结果已记录到: {result_path}")
+
+    except Exception as e:
+        print(f"记录结果时出错: {str(e)}")
+
+
+def get_user_decision():
+    """
+    获取用户的开仓决策
+    """
+    while True:
+        decision = input("\n是否开仓?请输入 'y' (开仓) 或 'n' (不开仓): ").strip().lower()
+        if decision in ['y', 'yes', '是', '开仓']:
+            return True
+        elif decision in ['n', 'no', '否', '不开仓']:
+            return False
+        else:
+            print("请输入有效的选项: 'y' 或 'n'")
+
+
+def main():
+    """
+    主函数
+    """
+    print("=" * 60)
+    print("交易训练工具")
+    print("=" * 60)
+
+    # 设置随机种子
+    if CONFIG['random_seed'] is not None:
+        random.seed(CONFIG['random_seed'])
+        np.random.seed(CONFIG['random_seed'])
+
+    # 获取当前目录
+    current_dir = _get_current_directory()
+    csv_path = os.path.join(current_dir, CONFIG['csv_filename'])
+    result_path = os.path.join(current_dir, CONFIG['result_filename'])
+    output_dir = os.path.join(current_dir, CONFIG['output_dir'])
+
+    # 创建输出目录
+    os.makedirs(output_dir, exist_ok=True)
+
+    # 1. 读取交易数据
+    print("\n=== 步骤1: 读取交易数据 ===")
+    transaction_df = read_transaction_data(csv_path)
+    if len(transaction_df) == 0:
+        print("未能读取交易数据,退出")
+        return
+
+    # 2. 加载已处理的结果
+    print("\n=== 步骤2: 加载已处理记录 ===")
+    existing_results, processed_pairs = load_processed_results(result_path)
+    print(f"已处理 {len(processed_pairs)} 个交易对")
+
+    # 3. 提取所有开仓交易
+    print("\n=== 步骤3: 提取开仓交易 ===")
+    open_trades = []
+
+    for idx, row in transaction_df.iterrows():
+        contract_code, trade_date, trade_price, direction, action, order_time, trade_type, trade_pair_id, continuous_pair_id = extract_contract_info(row)
+
+        if contract_code is None or action != 'open':
+            continue
+
+        # 跳过已处理的交易对
+        if trade_pair_id in processed_pairs:
+            continue
+
+        # 查找对应的平仓交易
+        profit_loss = calculate_profit_loss(transaction_df, trade_pair_id, continuous_pair_id)
+
+        open_trades.append({
+            'index': idx,
+            'contract_code': contract_code,
+            'trade_date': trade_date,
+            'trade_price': trade_price,
+            'direction': direction,
+            'order_time': order_time,
+            'trade_type': trade_type,
+            'trade_pair_id': trade_pair_id,
+            'continuous_pair_id': continuous_pair_id,
+            'profit_loss': profit_loss,
+            'original_row': row
+        })
+
+    print(f"找到 {len(open_trades)} 个未处理的开仓交易")
+
+    if len(open_trades) == 0:
+        print("没有未处理的开仓交易,退出")
+        return
+
+    # 4. 随机选择一个交易(完全随机)
+    print("\n=== 步骤4: 随机选择交易 ===")
+
+    # 打乱交易列表顺序确保完全随机
+    random.shuffle(open_trades)
+    selected_trade = random.choice(open_trades)
+
+    print(f"选中交易: {selected_trade['contract_code']} - {selected_trade['trade_date']} - {selected_trade['direction']}")
+    print(f"剩余未处理交易: {len(open_trades) - 1} 个")
+
+    # 5. 获取K线数据
+    print("\n=== 步骤5: 获取K线数据 ===")
+    kline_data, trade_idx = get_kline_data_with_future(
+        selected_trade['contract_code'],
+        selected_trade['trade_date'],
+        CONFIG['history_days'],
+        CONFIG['future_days']
+    )
+
+    if kline_data is None or trade_idx is None:
+        print("获取K线数据失败,退出")
+        return
+
+    # 6. 显示部分K线图
+    print("\n=== 步骤6: 显示部分K线图 ===")
+    partial_image_path = os.path.join(output_dir, f"partial_{selected_trade['trade_pair_id']}.png")
+    plot_partial_kline(
+        kline_data, trade_idx, selected_trade['trade_price'],
+        selected_trade['direction'], selected_trade['contract_code'],
+        selected_trade['trade_date'], selected_trade['order_time'],
+        partial_image_path
+    )
+
+    # 7. 获取用户决策
+    print(f"\n交易信息:")
+    print(f"合约: {selected_trade['contract_code']}")
+    print(f"日期: {selected_trade['trade_date']}")
+    print(f"方向: {'多头' if selected_trade['direction'] == 'long' else '空头'}")
+    print(f"成交价: {selected_trade['trade_price']}")
+
+    user_decision = get_user_decision()
+
+    # 8. 显示完整K线图
+    print("\n=== 步骤7: 显示完整K线图 ===")
+    full_image_path = os.path.join(output_dir, f"full_{selected_trade['trade_pair_id']}.png")
+    plot_full_kline(
+        kline_data, trade_idx, selected_trade['trade_price'],
+        selected_trade['direction'], selected_trade['contract_code'],
+        selected_trade['trade_date'], selected_trade['order_time'],
+        selected_trade['profit_loss'],
+        full_image_path
+    )
+
+    # 9. 记录结果
+    print("\n=== 步骤8: 记录结果 ===")
+
+    # 计算判定收益
+    decision_profit = selected_trade['profit_loss'] if user_decision else -selected_trade['profit_loss']
+
+    result_data = {
+        '日期': selected_trade['original_row']['日期'],
+        '委托时间': selected_trade['original_row']['委托时间'],
+        '标的': selected_trade['original_row']['标的'],
+        '交易类型': selected_trade['original_row']['交易类型'],
+        '成交数量': selected_trade['original_row']['成交数量'],
+        '成交价': selected_trade['original_row']['成交价'],
+        '平仓盈亏': selected_trade['profit_loss'],
+        '用户判定': '开仓' if user_decision else '不开仓',
+        '判定收益': decision_profit,
+        '交易对ID': selected_trade['trade_pair_id'],
+        '连续交易对ID': selected_trade['continuous_pair_id']
+    }
+
+    record_result(result_data, result_path)
+
+    print(f"\n=== 训练完成 ===")
+    print(f"用户判定: {'开仓' if user_decision else '不开仓'}")
+    print(f"实际盈亏: {selected_trade['profit_loss']:+.2f}")
+    print(f"判定收益: {decision_profit:+.2f}")
+    print(f"结果已保存到: {result_path}")
+
+
+if __name__ == "__main__":
+    main()