Parcourir la source

1. 优化结果分析工具:
- 新增连续交易总盈亏计算和首笔连续交易判断功能
- 更新绘图功能以连接当天最高价和文本框

2. 优化交易训练工具:
- 新增首笔连续交易判断功能
- 更新绘图功能以连接当天最高价和文本框

maxfeng il y a 1 mois
Parent
commit
c3db52931b
2 fichiers modifiés avec 177 ajouts et 34 suppressions
  1. 141 25
      Lib/future/trading_training_tool.py
  2. 36 9
      Lib/future/transaction_pair_analysis.py

+ 141 - 25
Lib/future/trading_training_tool.py

@@ -297,14 +297,8 @@ def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, t
         ax.plot(range(len(partial_data)), partial_data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
         ax.plot(range(len(partial_data)), partial_data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
 
-        # 标注开仓位置
-        ax.plot(trade_idx, trade_price, marker='*', markersize=20,
-               color='yellow', markeredgecolor='black', markeredgewidth=2,
-               label='Open Position', zorder=10)
-
-        # 添加垂直线
-        ax.axvline(x=trade_idx, color='yellow', linestyle='--',
-                  linewidth=2, alpha=0.7, zorder=5)
+        # 获取当天的最高价(用于画连接线)
+        day_high = highs.iloc[trade_idx]
 
         # 标注信息
         date_label = trade_date.strftime('%Y-%m-%d')
@@ -326,6 +320,10 @@ def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, t
                bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
                zorder=11, weight='bold')
 
+        # 画黄色虚线连接当天最高价和文本框
+        ax.plot([trade_idx, trade_idx], [day_high, text_y],
+               color='yellow', linestyle='--', linewidth=1.5, alpha=0.7, zorder=5)
+
         # 设置标题和标签
         contract_simple = contract_code.split('.')[0]
         direction_text = "Long" if direction == "long" else "Short"
@@ -407,14 +405,8 @@ def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trad
         ax.plot(range(len(data)), data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
         ax.plot(range(len(data)), data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
 
-        # 标注开仓位置
-        ax.plot(trade_idx, trade_price, marker='*', markersize=20,
-               color='yellow', markeredgecolor='black', markeredgewidth=2,
-               label='Open Position', zorder=10)
-
-        # 添加垂直线分隔历史和未来
-        ax.axvline(x=trade_idx, color='yellow', linestyle='--',
-                  linewidth=2, alpha=0.7, zorder=5)
+        # 获取当天的最高价(用于画连接线)
+        day_high = highs.iloc[trade_idx]
 
         # 添加未来区域背景
         ax.axvspan(trade_idx + 0.5, len(data) - 0.5, alpha=0.1, color='gray', label='Future Data')
@@ -440,6 +432,10 @@ def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trad
                bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
                zorder=11, weight='bold')
 
+        # 画黄色虚线连接当天最高价和文本框
+        ax.plot([trade_idx, trade_idx], [day_high, text_y],
+               color='yellow', linestyle='--', linewidth=1.5, alpha=0.7, zorder=5)
+
         # 设置标题和标签
         contract_simple = contract_code.split('.')[0]
         direction_text = "Long" if direction == "long" else "Short"
@@ -490,12 +486,38 @@ def load_processed_results(result_path):
         return pd.DataFrame(), set()
 
     try:
-        df = pd.read_csv(result_path)
+        # 简单读取CSV文件
+        df = pd.read_csv(result_path, header=0)
+
+        # 确保必要的列存在
+        required_columns = ['交易对ID']
+        for col in required_columns:
+            if col not in df.columns:
+                print(f"警告:结果文件缺少必要列 '{col}'")
+                return pd.DataFrame(), set()
+
         # 获取已处理的交易对ID
-        processed_pairs = set(df['交易对ID'].unique())
+        processed_pairs = set(df['交易对ID'].dropna().unique())
         return df, processed_pairs
+
     except Exception as e:
+        # 详细打印错误信息
         print(f"加载结果文件时出错: {str(e)}")
+        print(f"错误类型: {type(e)}")
+
+        # 尝试打印问题行的信息
+        if "line 40" in str(e):
+            print("\n=== 尝试定位问题行 ===")
+            try:
+                with open(result_path, 'r', encoding='utf-8-sig') as f:
+                    lines = f.readlines()
+                    if len(lines) > 40:
+                        print(f"第41行内容: {lines[40]}")
+                        if len(lines) > 41:
+                            print(f"第42行内容: {lines[41]}")
+            except:
+                print("无法读取文件内容进行调试")
+
         return pd.DataFrame(), set()
 
 
@@ -537,10 +559,28 @@ def record_result(result_data, result_path):
         # 创建结果DataFrame
         result_df = pd.DataFrame([result_data])
 
-        # 如果文件已存在,追加写入;否则创建新文件
+        # 如果文件已存在,读取现有格式并确保新数据格式一致
         if os.path.exists(result_path):
-            result_df.to_csv(result_path, mode='a', header=False, index=False, encoding='utf-8-sig')
+            try:
+                # 读取现有文件的列名
+                existing_df = pd.read_csv(result_path, nrows=0)  # 只读取列名
+                existing_columns = existing_df.columns.tolist()
+
+                # 如果新数据列与现有文件不一致,调整格式
+                if list(result_df.columns) != existing_columns:
+                    # 重新创建DataFrame,确保列顺序一致
+                    aligned_data = {}
+                    for col in existing_columns:
+                        aligned_data[col] = result_data.get(col, 'N/A' if col == '连续交易总盈亏' else '')
+                    result_df = pd.DataFrame([aligned_data])
+
+                # 追加写入
+                result_df.to_csv(result_path, mode='a', header=False, index=False, encoding='utf-8-sig')
+            except Exception:
+                # 如果无法读取现有格式,直接覆盖
+                result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig')
         else:
+            # 文件不存在,创建新文件
             result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig')
 
         print(f"结果已记录到: {result_path}")
@@ -549,6 +589,54 @@ def record_result(result_data, result_path):
         print(f"记录结果时出错: {str(e)}")
 
 
+def is_first_continuous_trade(transaction_df, trade_pair_id, continuous_pair_id):
+    """
+    判断是否为连续交易的第一笔交易
+
+    参数:
+        transaction_df: 交易数据DataFrame
+        trade_pair_id: 当前交易对ID
+        continuous_pair_id: 连续交易对ID
+
+    返回:
+        bool: 是否为连续交易的第一笔交易(或不是连续交易)
+    """
+    # 如果不是连续交易,返回True
+    if continuous_pair_id == 'N/A' or pd.isna(continuous_pair_id):
+        return True
+
+    # 获取同一连续交易组的所有交易
+    continuous_trades = transaction_df[transaction_df['连续交易对ID'] == continuous_pair_id]
+
+    # 获取所有交易对ID并按时间排序
+    pair_ids = continuous_trades['交易对ID'].unique()
+
+    # 获取每个交易对的开仓时间
+    pair_times = []
+    for pid in pair_ids:
+        pair_records = continuous_trades[continuous_trades['交易对ID'] == pid]
+        open_records = pair_records[pair_records['交易类型'].str.contains('开', na=False)]
+        if len(open_records) > 0:
+            # 获取第一个开仓记录的日期和时间
+            first_open = open_records.iloc[0]
+            date_str = str(first_open['日期']).strip()
+            time_str = str(first_open['委托时间']).strip()
+            try:
+                dt = pd.to_datetime(f"{date_str} {time_str}")
+                pair_times.append((pid, dt))
+            except:
+                pass
+
+    # 按时间排序
+    pair_times.sort(key=lambda x: x[1])
+
+    # 检查当前交易对是否为第一个
+    if pair_times and pair_times[0][0] == trade_pair_id:
+        return True
+
+    return False
+
+
 def get_user_decision():
     """
     获取用户的开仓决策
@@ -596,6 +684,7 @@ def main():
     print("\n=== 步骤2: 加载已处理记录 ===")
     existing_results, processed_pairs = load_processed_results(result_path)
     print(f"已处理 {len(processed_pairs)} 个交易对")
+    # existing_results 保留用于后续可能的数据分析功能
 
     # 3. 提取所有开仓交易
     print("\n=== 步骤3: 提取开仓交易 ===")
@@ -611,9 +700,24 @@ def main():
         if trade_pair_id in processed_pairs:
             continue
 
+        # 检查是否为连续交易的第一笔交易(如果不是第一笔,跳过)
+        if not is_first_continuous_trade(transaction_df, trade_pair_id, continuous_pair_id):
+            continue
+
         # 查找对应的平仓交易
         profit_loss = calculate_profit_loss(transaction_df, trade_pair_id, continuous_pair_id)
 
+        # 如果是连续交易,获取连续交易总盈亏
+        continuous_total_profit = 'N/A'
+        if continuous_pair_id != 'N/A' and pd.notna(continuous_pair_id):
+            continuous_trades = transaction_df[transaction_df['连续交易对ID'] == continuous_pair_id]
+            try:
+                close_profit_loss_str = continuous_trades['平仓盈亏'].astype(str).str.replace(',', '')
+                close_profit_loss_numeric = pd.to_numeric(close_profit_loss_str, errors='coerce').fillna(0)
+                continuous_total_profit = close_profit_loss_numeric.sum()
+            except:
+                continuous_total_profit = 0
+
         open_trades.append({
             'index': idx,
             'contract_code': contract_code,
@@ -625,10 +729,11 @@ def main():
             'trade_pair_id': trade_pair_id,
             'continuous_pair_id': continuous_pair_id,
             'profit_loss': profit_loss,
+            'continuous_total_profit': continuous_total_profit,
             'original_row': row
         })
 
-    print(f"找到 {len(open_trades)} 个未处理的开仓交易")
+    print(f"找到 {len(open_trades)} 个未处理的开仓交易(已过滤非首笔连续交易)")
 
     if len(open_trades) == 0:
         print("没有未处理的开仓交易,退出")
@@ -690,8 +795,15 @@ def main():
     # 9. 记录结果
     print("\n=== 步骤8: 记录结果 ===")
 
-    # 计算判定收益
-    decision_profit = selected_trade['profit_loss'] if user_decision else -selected_trade['profit_loss']
+    # 计算判定收益(使用连续交易总盈亏或普通盈亏)
+    if selected_trade['continuous_total_profit'] != 'N/A':
+        # 连续交易使用连续交易总盈亏
+        decision_profit = selected_trade['continuous_total_profit'] if user_decision else -selected_trade['continuous_total_profit']
+        profit_to_show = selected_trade['continuous_total_profit']
+    else:
+        # 普通交易使用单笔盈亏
+        decision_profit = selected_trade['profit_loss'] if user_decision else -selected_trade['profit_loss']
+        profit_to_show = selected_trade['profit_loss']
 
     result_data = {
         '日期': selected_trade['original_row']['日期'],
@@ -704,14 +816,18 @@ def main():
         '用户判定': '开仓' if user_decision else '不开仓',
         '判定收益': decision_profit,
         '交易对ID': selected_trade['trade_pair_id'],
-        '连续交易对ID': selected_trade['continuous_pair_id']
+        '连续交易对ID': selected_trade['continuous_pair_id'],
+        '连续交易总盈亏': selected_trade['continuous_total_profit']
     }
 
     record_result(result_data, result_path)
 
     print(f"\n=== 训练完成 ===")
     print(f"用户判定: {'开仓' if user_decision else '不开仓'}")
-    print(f"实际盈亏: {selected_trade['profit_loss']:+.2f}")
+    if selected_trade['continuous_total_profit'] != 'N/A':
+        print(f"连续交易总盈亏: {profit_to_show:+.2f}")
+    else:
+        print(f"实际盈亏: {profit_to_show:+.2f}")
     print(f"判定收益: {decision_profit:+.2f}")
     print(f"结果已保存到: {result_path}")
 

+ 36 - 9
Lib/future/transaction_pair_analysis.py

@@ -608,35 +608,39 @@ def identify_continuous_trade_pairs(df):
 def save_result(df, output_path):
     """
     保存配对结果到CSV文件
-    
+
     参数:
         df (pandas.DataFrame): 包含交易对ID的DataFrame
         output_path (str): 输出文件路径
     """
     df = df.copy()
-    
+
     # 添加"开仓时间"列
     # 对于每个交易对ID,找到对应的开仓记录的"最后更新时间"
     df['开仓时间'] = ''
-    
+
     # 添加"交易盈亏"列,根据相同的交易对ID对平仓盈亏进行求和
     df['交易盈亏'] = ''
-    
+
+    # 添加"连续交易总盈亏"列
+    df['连续交易总盈亏'] = 'N/A'
+
+    # 先计算每个交易对的盈亏
     for pair_id in df['交易对ID'].unique():
         if pair_id and pair_id.startswith('P'):
             # 找到该交易对的所有记录
             pair_mask = df['交易对ID'] == pair_id
             pair_records = df[pair_mask]
-            
+
             # 找到开仓记录(仓位操作为"开仓")
             open_record = pair_records[pair_records['仓位操作'] == '开仓']
-            
+
             if len(open_record) > 0:
                 # 获取开仓记录的最后更新时间
                 open_time = open_record.iloc[0]['最后更新时间']
                 # 将开仓时间填充到该交易对的所有记录中
                 df.loc[pair_mask, '开仓时间'] = open_time
-            
+
             # 计算该交易对的总盈亏(对平仓盈亏求和)
             try:
                 # 提取平仓盈亏列,转换为数值
@@ -650,12 +654,32 @@ def save_result(df, output_path):
             except Exception as e:
                 # 如果计算失败,设为0
                 df.loc[pair_mask, '交易盈亏'] = 0
-    
+
+    # 计算连续交易总盈亏
+    for continuous_id in df['连续交易对ID'].unique():
+        if continuous_id != 'N/A' and pd.notna(continuous_id):
+            # 找到该连续交易组的所有记录
+            continuous_mask = df['连续交易对ID'] == continuous_id
+            continuous_records = df[continuous_mask]
+
+            # 计算该连续交易组的总盈亏
+            try:
+                # 提取平仓盈亏列,转换为数值
+                close_profit_loss_str = continuous_records['平仓盈亏'].astype(str).str.replace(',', '')
+                # 尝试转换为数值,无法转换的设为0
+                close_profit_loss_numeric = pd.to_numeric(close_profit_loss_str, errors='coerce').fillna(0)
+                total_continuous_profit = close_profit_loss_numeric.sum()
+                # 将连续交易总盈亏填充到该组的所有记录中
+                df.loc[continuous_mask, '连续交易总盈亏'] = total_continuous_profit
+            except Exception as e:
+                # 如果计算失败,设为0
+                df.loc[continuous_mask, '连续交易总盈亏'] = 0
+
     # 移除中间处理列
     columns_to_remove = ['标的_完整', '交易类型_标准', '仓位操作', '方向', '成交数量_数值', '交易时间']
     output_columns = [col for col in df.columns if col not in columns_to_remove]
 
-    # 调整列顺序,确保交易对ID、连续交易对ID、开仓时间和交易盈亏在最后
+    # 调整列顺序,确保交易对ID、连续交易对ID、开仓时间、交易盈亏连续交易盈亏在最后
     if '交易对ID' in output_columns:
         output_columns.remove('交易对ID')
     if '连续交易对ID' in output_columns:
@@ -664,10 +688,13 @@ def save_result(df, output_path):
         output_columns.remove('开仓时间')
     if '交易盈亏' in output_columns:
         output_columns.remove('交易盈亏')
+    if '连续交易总盈亏' in output_columns:
+        output_columns.remove('连续交易总盈亏')
     output_columns.append('交易对ID')
     output_columns.append('连续交易对ID')
     output_columns.append('开仓时间')
     output_columns.append('交易盈亏')
+    output_columns.append('连续交易总盈亏')
     
     # 按交易对ID和日期升序排序
     # 创建排序辅助列:未配对的排在最后,其他按ID数字排序