Browse Source

优化交易训练工具的绘图功能:根据交易方向调整当天K线颜色,改进文本框位置和连接线绘制逻辑,确保信息展示更清晰。同时,更新随机选择交易的逻辑,以合约类型分组随机抽取,避免同类交易连续出现。

maxfeng 1 tháng trước cách đây
mục cha
commit
6c4f45d12a
1 tập tin đã thay đổi với 106 bổ sung44 xóa
  1. 106 44
      Lib/future/trading_training_tool.py

+ 106 - 44
Lib/future/trading_training_tool.py

@@ -253,6 +253,7 @@ def get_kline_data_with_future(contract_code, trade_date, days_before=100, days_
 
 
 def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, save_path=None):
+    # contract_code 参数保留用于可能的扩展功能
     """
     绘制部分K线图(仅显示历史数据和当天)
     """
@@ -260,8 +261,15 @@ def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, t
         # 截取历史数据和当天数据
         partial_data = data.iloc[:trade_idx + 1].copy()
 
-        # 修改当天的收盘价为成交价
-        partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price
+        # 根据交易方向修改当天的价格数据
+        if direction == 'long':
+            # 做多时,用成交价替代最高价(表示买入点)
+            partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price
+            partial_data.iloc[-1, partial_data.columns.get_loc('high')] = trade_price
+        else:
+            # 做空时,用成交价替代最低价(表示卖出点)
+            partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price
+            partial_data.iloc[-1, partial_data.columns.get_loc('low')] = trade_price
 
         fig, ax = plt.subplots(figsize=(16, 10))
 
@@ -274,8 +282,21 @@ def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, t
 
         # 绘制K线
         for i in range(len(partial_data)):
-            color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
-            edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
+            # 检查是否是交易日
+            is_trade_day = (i == trade_idx)
+
+            if is_trade_day:
+                # 成交日根据涨跌用不同颜色
+                if closes.iloc[i] > opens.iloc[i]:  # 上涨
+                    color = '#FFD700'  # 金黄色(黄红色混合)
+                    edge_color = '#FF8C00'  # 深橙色
+                else:  # 下跌
+                    color = '#ADFF2F'  # 黄绿色
+                    edge_color = '#9ACD32'  # 黄绿色深版
+            else:
+                # 正常K线颜色
+                color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
+                edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
 
             # 影线
             ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
@@ -306,28 +327,26 @@ def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, t
         direction_label = f'Direction: {"Long" if direction == "long" else "Short"}'
         time_label = f'Time: {order_time}'
 
-        # 计算文本位置
-        price_range = highs.max() - lows.min()
-        y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
-        text_y = trade_price + y_offset
-
-        if text_y > highs.max():
-            text_y = trade_price - price_range * 0.08
-
+        # 将文本框移到左上角
         annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}'
-        ax.text(trade_idx, text_y, annotation_text,
-               fontsize=10, ha='center', va='bottom',
+        text_box = ax.text(0.02, 0.98, annotation_text,
+               fontsize=10, ha='left', va='top', transform=ax.transAxes,
                bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
                zorder=11, weight='bold')
 
-        # 画黄色虚线连接当天最高价和文本框
-        ax.plot([trade_idx, trade_idx], [day_high, text_y],
+        # 画黄色虚线连接文本框底部和交易日最高价
+        # 获取文本框在数据坐标系中的位置
+        fig.canvas.draw()  # 需要先绘制一次才能获取准确位置
+        bbox = text_box.get_window_extent().transformed(ax.transData.inverted())
+        text_bottom_y = bbox.ymin
+
+        # 从文本框底部到交易日最高价画虚线
+        ax.plot([trade_idx, trade_idx], [day_high, text_bottom_y],
                color='yellow', linestyle='--', linewidth=1.5, alpha=0.7, zorder=5)
 
         # 设置标题和标签
-        contract_simple = contract_code.split('.')[0]
         direction_text = "Long" if direction == "long" else "Short"
-        ax.set_title(f'{contract_simple} - {direction_text} Position Decision\n'
+        ax.set_title(f'{direction_text} Position Decision\n'
                     f'Historical Data + Trade Day Only',
                     fontsize=14, fontweight='bold', pad=20)
 
@@ -382,8 +401,21 @@ def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trad
 
         # 绘制K线
         for i in range(len(data)):
-            color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
-            edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
+            # 检查是否是交易日
+            is_trade_day = (i == trade_idx)
+
+            if is_trade_day:
+                # 成交日根据涨跌用不同颜色
+                if closes.iloc[i] > opens.iloc[i]:  # 上涨
+                    color = '#FFD700'  # 金黄色(黄红色混合)
+                    edge_color = '#FF8C00'  # 深橙色
+                else:  # 下跌
+                    color = '#ADFF2F'  # 黄绿色
+                    edge_color = '#9ACD32'  # 黄绿色深版
+            else:
+                # 正常K线颜色
+                color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
+                edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
 
             # 影线
             ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
@@ -418,22 +450,21 @@ def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trad
         time_label = f'Time: {order_time}'
         profit_label = f'P&L: {profit_loss:+.2f}'
 
-        # 计算文本位置
-        price_range = highs.max() - lows.min()
-        y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
-        text_y = trade_price + y_offset
-
-        if text_y > highs.max():
-            text_y = trade_price - price_range * 0.08
-
+        # 将文本框移到左上角
         annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}\n{profit_label}'
-        ax.text(trade_idx, text_y, annotation_text,
-               fontsize=10, ha='center', va='bottom',
+        text_box = ax.text(0.02, 0.98, annotation_text,
+               fontsize=10, ha='left', va='top', transform=ax.transAxes,
                bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
                zorder=11, weight='bold')
 
-        # 画黄色虚线连接当天最高价和文本框
-        ax.plot([trade_idx, trade_idx], [day_high, text_y],
+        # 画黄色虚线连接文本框底部和交易日最高价
+        # 获取文本框在数据坐标系中的位置
+        fig.canvas.draw()  # 需要先绘制一次才能获取准确位置
+        bbox = text_box.get_window_extent().transformed(ax.transData.inverted())
+        text_bottom_y = bbox.ymin
+
+        # 从文本框底部到交易日最高价画虚线
+        ax.plot([trade_idx, trade_idx], [day_high, text_bottom_y],
                color='yellow', linestyle='--', linewidth=1.5, alpha=0.7, zorder=5)
 
         # 设置标题和标签
@@ -682,9 +713,8 @@ def main():
 
     # 2. 加载已处理的结果
     print("\n=== 步骤2: 加载已处理记录 ===")
-    existing_results, processed_pairs = load_processed_results(result_path)
+    _, processed_pairs = load_processed_results(result_path)
     print(f"已处理 {len(processed_pairs)} 个交易对")
-    # existing_results 保留用于后续可能的数据分析功能
 
     # 3. 提取所有开仓交易
     print("\n=== 步骤3: 提取开仓交易 ===")
@@ -739,12 +769,43 @@ def main():
         print("没有未处理的开仓交易,退出")
         return
 
-    # 4. 随机选择一个交易(完全随机
+    # 4. 随机选择一个交易(按标的类型分组随机抽取,避免同类连续出现
     print("\n=== 步骤4: 随机选择交易 ===")
 
-    # 打乱交易列表顺序确保完全随机
-    random.shuffle(open_trades)
-    selected_trade = random.choice(open_trades)
+    # 按标的类型分组(提取合约代码的核心字母部分)
+    def get_contract_type(contract_code):
+        """提取合约类型,如'M2405'提取为'M','AG2406'提取为'AG'"""
+        import re
+        match = re.match(r'^([A-Za-z]+)', contract_code.split('.')[0])
+        return match.group(1) if match else 'UNKNOWN'
+
+    # 按合约类型分组
+    trades_by_type = {}
+    for trade in open_trades:
+        contract_type = get_contract_type(trade['contract_code'])
+        if contract_type not in trades_by_type:
+            trades_by_type[contract_type] = []
+        trades_by_type[contract_type].append(trade)
+
+    # 打乱每个组内的顺序
+    for contract_type in trades_by_type:
+        random.shuffle(trades_by_type[contract_type])
+
+    # 从各组中轮流抽取,确保类型分散
+    selected_trade = None
+    available_types = list(trades_by_type.keys())
+
+    # 随机打乱类型顺序,然后从第一个有交易的类型中抽取
+    random.shuffle(available_types)
+    for contract_type in available_types:
+        if trades_by_type[contract_type]:
+            selected_trade = trades_by_type[contract_type].pop(0)
+            break
+
+    if selected_trade is None:
+        # 如果上述方法失败,回退到简单随机选择
+        random.shuffle(open_trades)
+        selected_trade = random.choice(open_trades)
 
     print(f"选中交易: {selected_trade['contract_code']} - {selected_trade['trade_date']} - {selected_trade['direction']}")
     print(f"剩余未处理交易: {len(open_trades) - 1} 个")
@@ -773,12 +834,6 @@ def main():
     )
 
     # 7. 获取用户决策
-    print(f"\n交易信息:")
-    print(f"合约: {selected_trade['contract_code']}")
-    print(f"日期: {selected_trade['trade_date']}")
-    print(f"方向: {'多头' if selected_trade['direction'] == 'long' else '空头'}")
-    print(f"成交价: {selected_trade['trade_price']}")
-
     user_decision = get_user_decision()
 
     # 8. 显示完整K线图
@@ -792,6 +847,13 @@ def main():
         full_image_path
     )
 
+    # 在完整K线图之后显示交易信息
+    print(f"\n交易信息:")
+    print(f"合约: {selected_trade['contract_code']}")
+    print(f"日期: {selected_trade['trade_date']}")
+    print(f"方向: {'多头' if selected_trade['direction'] == 'long' else '空头'}")
+    print(f"成交价: {selected_trade['trade_price']}")
+
     # 9. 记录结果
     print("\n=== 步骤8: 记录结果 ===")