||
- # 交易训练工具
- # 用于从交易记录CSV文件中随机选择交易,显示部分K线图让用户判断是否开仓,然后显示完整结果并记录
- from jqdata import *
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- import matplotlib.patches as patches
- from datetime import datetime, timedelta, date
- import re
- import os
- import random
- import warnings
- warnings.filterwarnings('ignore')
- # 中文字体设置
- plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
- plt.rcParams['axes.unicode_minus'] = False
- # ========== 参数配置区域(硬编码参数都在这里) ==========
- CONFIG = {
- # CSV文件名
- 'csv_filename': 'transaction1.csv',
- # 结果记录文件名
- 'result_filename': 'training_results.csv',
- # 历史数据天数
- 'history_days': 100,
- # 未来数据天数
- 'future_days': 20,
- # 输出目录
- 'output_dir': 'training_images',
- # 是否显示图片(在某些环境下可能需要设置为False)
- 'show_plots': True,
- # 图片DPI
- 'plot_dpi': 150,
- # 随机种子(设置为None表示不固定种子)
- 'random_seed': 42
- }
- # =====================================================
- def _get_current_directory():
- """
- 获取当前文件所在目录
- """
- try:
- current_dir = os.path.dirname(os.path.abspath(__file__))
- except NameError:
- current_dir = os.getcwd()
- if not os.path.exists(os.path.join(current_dir, CONFIG['csv_filename'])):
- parent_dir = os.path.dirname(current_dir)
- future_dir = os.path.join(parent_dir, 'future')
- if os.path.exists(os.path.join(future_dir, CONFIG['csv_filename'])):
- current_dir = future_dir
- return current_dir
- def read_transaction_data(csv_path):
- """
- 读取交易记录CSV文件
- """
- encodings = ['utf-8-sig', 'utf-8', 'gbk', 'gb2312', 'gb18030', 'latin1']
- for encoding in encodings:
- try:
- df = pd.read_csv(csv_path, encoding=encoding)
- print(f"成功使用 {encoding} 编码读取CSV文件,共 {len(df)} 条记录")
- return df
- except UnicodeDecodeError:
- continue
- except Exception as e:
- if encoding == encodings[-1]:
- print(f"读取CSV文件时出错: {str(e)}")
- raise
- return pd.DataFrame()
- def extract_contract_info(row):
- """
- 从交易记录中提取合约信息和交易信息
- """
- try:
- # 提取合约编号
- target_str = str(row['标的'])
- match = re.search(r'\(([^)]+)\)', target_str)
- if match:
- contract_code = match.group(1)
- else:
- return None, None, None, None, None, None, None, None, None
- # 提取日期
- date_str = str(row['日期']).strip()
- date_formats = ['%Y-%m-%d', '%d/%m/%Y', '%Y/%m/%d', '%d-%m-%Y', '%Y%m%d']
- trade_date = None
- for date_format in date_formats:
- try:
- trade_date = datetime.strptime(date_str, date_format).date()
- break
- except ValueError:
- continue
- if trade_date is None:
- return None, None, None, None, None, None, None, None, None
- # 提取委托时间
- order_time_str = str(row['委托时间']).strip()
- try:
- time_parts = order_time_str.split(':')
- hour = int(time_parts[0])
- # 如果委托时间 >= 21:00,需要找到下一个交易日
- if hour >= 21:
- try:
- trade_days = get_trade_days(start_date=trade_date, count=2)
- if len(trade_days) >= 2:
- next_trade_day = trade_days[1]
- if isinstance(next_trade_day, datetime):
- actual_trade_date = next_trade_day.date()
- elif isinstance(next_trade_day, date):
- actual_trade_date = next_trade_day
- else:
- actual_trade_date = trade_date
- else:
- actual_trade_date = trade_date
- except:
- actual_trade_date = trade_date
- else:
- actual_trade_date = trade_date
- except:
- actual_trade_date = trade_date
- # 提取成交价
- try:
- trade_price = float(row['成交价'])
- except:
- return None, None, None, None, None, None, None, None, None
- # 提取交易方向和类型
- trade_type = str(row['交易类型']).strip()
- if '开多' in trade_type:
- direction = 'long'
- action = 'open'
- elif '开空' in trade_type:
- direction = 'short'
- action = 'open'
- elif '平多' in trade_type or '平空' in trade_type:
- action = 'close'
- direction = 'long' if '平多' in trade_type else 'short'
- else:
- return None, None, None, None, None, None, None, None, None
- # 提取交易对ID和连续交易对ID
- trade_pair_id = row.get('交易对ID', 'N/A')
- continuous_pair_id = row.get('连续交易对ID', 'N/A')
- return contract_code, actual_trade_date, trade_price, direction, action, order_time_str, trade_type, trade_pair_id, continuous_pair_id
- except Exception as e:
- print(f"提取合约信息时出错: {str(e)}")
- return None, None, None, None, None, None, None, None, None
- def get_trade_day_range(trade_date, days_before, days_after):
- """
- 获取交易日范围
- """
- try:
- # 获取历史交易日
- trade_days_before = get_trade_days(end_date=trade_date, count=days_before + 1)
- if len(trade_days_before) < days_before + 1:
- return None, None
- first_day = trade_days_before[0]
- if isinstance(first_day, datetime):
- start_date = first_day.date()
- elif isinstance(first_day, date):
- start_date = first_day
- else:
- start_date = first_day
- # 获取未来交易日
- trade_days_after = get_trade_days(start_date=trade_date, count=days_after + 1)
- if len(trade_days_after) < days_after + 1:
- return None, None
- last_day = trade_days_after[-1]
- if isinstance(last_day, datetime):
- end_date = last_day.date()
- elif isinstance(last_day, date):
- end_date = last_day
- else:
- end_date = last_day
- return start_date, end_date
- except Exception as e:
- print(f"计算交易日范围时出错: {str(e)}")
- return None, None
- def get_kline_data_with_future(contract_code, trade_date, days_before=100, days_after=20):
- """
- 获取包含历史和未来数据的K线数据
- """
- try:
- # 获取完整的数据范围
- start_date, end_date = get_trade_day_range(trade_date, days_before, days_after)
- if start_date is None or end_date is None:
- return None, None, None
- # 获取K线数据
- price_data = get_price(
- contract_code,
- start_date=start_date,
- end_date=end_date,
- frequency='1d',
- fields=['open', 'close', 'high', 'low']
- )
- if price_data is None or len(price_data) == 0:
- return None, None, None
- # 计算均线
- price_data['ma5'] = price_data['close'].rolling(window=5).mean()
- price_data['ma10'] = price_data['close'].rolling(window=10).mean()
- price_data['ma20'] = price_data['close'].rolling(window=20).mean()
- price_data['ma30'] = price_data['close'].rolling(window=30).mean()
- # 找到交易日在数据中的位置
- trade_date_normalized = pd.Timestamp(trade_date)
- trade_idx = None
- for i, idx in enumerate(price_data.index):
- if isinstance(idx, pd.Timestamp):
- if idx.date() == trade_date:
- trade_idx = i
- break
- return price_data, trade_idx
- except Exception as e:
- print(f"获取K线数据时出错: {str(e)}")
- return None, None, None
- def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, save_path=None):
- """
- 绘制部分K线图(仅显示历史数据和当天)
- """
- try:
- # 截取历史数据和当天数据
- partial_data = data.iloc[:trade_idx + 1].copy()
- # 修改当天的收盘价为成交价
- partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price
- fig, ax = plt.subplots(figsize=(16, 10))
- # 准备数据
- dates = partial_data.index
- opens = partial_data['open']
- highs = partial_data['high']
- lows = partial_data['low']
- closes = partial_data['close']
- # 绘制K线
- for i in range(len(partial_data)):
- color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
- edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
- # 影线
- ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
- # 实体
- body_height = abs(closes.iloc[i] - opens.iloc[i])
- if body_height == 0:
- body_height = 0.01
- bottom = min(opens.iloc[i], closes.iloc[i])
- rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height,
- linewidth=1, edgecolor=edge_color,
- facecolor=color, alpha=0.8)
- ax.add_patch(rect)
- # 绘制均线
- ax.plot(range(len(partial_data)), partial_data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8)
- ax.plot(range(len(partial_data)), partial_data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8)
- ax.plot(range(len(partial_data)), partial_data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
- ax.plot(range(len(partial_data)), partial_data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
- # 标注开仓位置
- ax.plot(trade_idx, trade_price, marker='*', markersize=20,
- color='yellow', markeredgecolor='black', markeredgewidth=2,
- label='Open Position', zorder=10)
- # 添加垂直线
- ax.axvline(x=trade_idx, color='yellow', linestyle='--',
- linewidth=2, alpha=0.7, zorder=5)
- # 标注信息
- date_label = trade_date.strftime('%Y-%m-%d')
- price_label = f'Price: {trade_price:.2f}'
- direction_label = f'Direction: {"Long" if direction == "long" else "Short"}'
- time_label = f'Time: {order_time}'
- # 计算文本位置
- price_range = highs.max() - lows.min()
- y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
- text_y = trade_price + y_offset
- if text_y > highs.max():
- text_y = trade_price - price_range * 0.08
- annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}'
- ax.text(trade_idx, text_y, annotation_text,
- fontsize=10, ha='center', va='bottom',
- bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
- zorder=11, weight='bold')
- # 设置标题和标签
- contract_simple = contract_code.split('.')[0]
- direction_text = "Long" if direction == "long" else "Short"
- ax.set_title(f'{contract_simple} - {direction_text} Position Decision\n'
- f'Historical Data + Trade Day Only',
- fontsize=14, fontweight='bold', pad=20)
- ax.set_xlabel('Time', fontsize=12)
- ax.set_ylabel('Price', fontsize=12)
- ax.grid(True, alpha=0.3)
- ax.legend(loc='lower left', fontsize=10)
- # 设置x轴标签
- step = max(1, len(partial_data) // 10)
- tick_positions = range(0, len(partial_data), step)
- tick_labels = []
- for pos in tick_positions:
- date_val = dates[pos]
- if isinstance(date_val, (date, datetime)):
- tick_labels.append(date_val.strftime('%Y-%m-%d'))
- else:
- tick_labels.append(str(date_val))
- ax.set_xticks(tick_positions)
- ax.set_xticklabels(tick_labels, rotation=45, ha='right')
- plt.tight_layout()
- if save_path:
- plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight')
- if CONFIG['show_plots']:
- plt.show()
- plt.close(fig)
- except Exception as e:
- print(f"绘制部分K线图时出错: {str(e)}")
- plt.close('all')
- raise
- def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, profit_loss, save_path=None):
- """
- 绘制完整K线图(包含未来数据)
- """
- try:
- fig, ax = plt.subplots(figsize=(16, 10))
- # 准备数据
- dates = data.index
- opens = data['open']
- highs = data['high']
- lows = data['low']
- closes = data['close']
- # 绘制K线
- for i in range(len(data)):
- color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
- edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
- # 影线
- ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
- # 实体
- body_height = abs(closes.iloc[i] - opens.iloc[i])
- if body_height == 0:
- body_height = 0.01
- bottom = min(opens.iloc[i], closes.iloc[i])
- rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height,
- linewidth=1, edgecolor=edge_color,
- facecolor=color, alpha=0.8)
- ax.add_patch(rect)
- # 绘制均线
- ax.plot(range(len(data)), data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8)
- ax.plot(range(len(data)), data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8)
- ax.plot(range(len(data)), data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
- ax.plot(range(len(data)), data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
- # 标注开仓位置
- ax.plot(trade_idx, trade_price, marker='*', markersize=20,
- color='yellow', markeredgecolor='black', markeredgewidth=2,
- label='Open Position', zorder=10)
- # 添加垂直线分隔历史和未来
- ax.axvline(x=trade_idx, color='yellow', linestyle='--',
- linewidth=2, alpha=0.7, zorder=5)
- # 添加未来区域背景
- ax.axvspan(trade_idx + 0.5, len(data) - 0.5, alpha=0.1, color='gray', label='Future Data')
- # 标注信息
- date_label = trade_date.strftime('%Y-%m-%d')
- price_label = f'Price: {trade_price:.2f}'
- direction_label = f'Direction: {"Long" if direction == "long" else "Short"}'
- time_label = f'Time: {order_time}'
- profit_label = f'P&L: {profit_loss:+.2f}'
- # 计算文本位置
- price_range = highs.max() - lows.min()
- y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
- text_y = trade_price + y_offset
- if text_y > highs.max():
- text_y = trade_price - price_range * 0.08
- annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}\n{profit_label}'
- ax.text(trade_idx, text_y, annotation_text,
- fontsize=10, ha='center', va='bottom',
- bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
- zorder=11, weight='bold')
- # 设置标题和标签
- contract_simple = contract_code.split('.')[0]
- direction_text = "Long" if direction == "long" else "Short"
- ax.set_title(f'{contract_simple} - {direction_text} Position Result\n'
- f'Complete Data with Future {CONFIG["future_days"]} Days',
- fontsize=14, fontweight='bold', pad=20)
- ax.set_xlabel('Time', fontsize=12)
- ax.set_ylabel('Price', fontsize=12)
- ax.grid(True, alpha=0.3)
- ax.legend(loc='lower left', fontsize=10)
- # 设置x轴标签
- step = max(1, len(data) // 15)
- tick_positions = range(0, len(data), step)
- tick_labels = []
- for pos in tick_positions:
- date_val = dates[pos]
- if isinstance(date_val, (date, datetime)):
- tick_labels.append(date_val.strftime('%Y-%m-%d'))
- else:
- tick_labels.append(str(date_val))
- ax.set_xticks(tick_positions)
- ax.set_xticklabels(tick_labels, rotation=45, ha='right')
- plt.tight_layout()
- if save_path:
- plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight')
- if CONFIG['show_plots']:
- plt.show()
- plt.close(fig)
- except Exception as e:
- print(f"绘制完整K线图时出错: {str(e)}")
- plt.close('all')
- raise
- def load_processed_results(result_path):
- """
- 加载已处理的结果文件
- """
- if not os.path.exists(result_path):
- return pd.DataFrame(), set()
- try:
- df = pd.read_csv(result_path)
- # 获取已处理的交易对ID
- processed_pairs = set(df['交易对ID'].unique())
- return df, processed_pairs
- except Exception as e:
- print(f"加载结果文件时出错: {str(e)}")
- return pd.DataFrame(), set()
- def calculate_profit_loss(df, trade_pair_id, continuous_pair_id):
- """
- 计算平仓盈亏
- """
- try:
- if continuous_pair_id != 'N/A' and pd.notna(continuous_pair_id):
- # 合并所有同一连续交易对ID的平仓盈亏
- close_trades = df[
- (df['连续交易对ID'] == continuous_pair_id) &
- (df['交易类型'].str[0] == '平')
- ]
- total_profit = close_trades['平仓盈亏'].sum()
- else:
- # 只查找当前交易对ID的平仓交易
- close_trades = df[
- (df['交易对ID'] == trade_pair_id) &
- (df['交易类型'].str[0] == '平')
- ]
- if len(close_trades) > 0:
- total_profit = close_trades['平仓盈亏'].iloc[0]
- else:
- total_profit = 0
- return total_profit
- except Exception as e:
- print(f"计算盈亏时出错: {str(e)}")
- return 0
- def record_result(result_data, result_path):
- """
- 记录训练结果
- """
- try:
- # 创建结果DataFrame
- result_df = pd.DataFrame([result_data])
- # 如果文件已存在,追加写入;否则创建新文件
- if os.path.exists(result_path):
- result_df.to_csv(result_path, mode='a', header=False, index=False, encoding='utf-8-sig')
- else:
- result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig')
- print(f"结果已记录到: {result_path}")
- except Exception as e:
- print(f"记录结果时出错: {str(e)}")
- def get_user_decision():
- """
- 获取用户的开仓决策
- """
- while True:
- decision = input("\n是否开仓?请输入 'y' (开仓) 或 'n' (不开仓): ").strip().lower()
- if decision in ['y', 'yes', '是', '开仓']:
- return True
- elif decision in ['n', 'no', '否', '不开仓']:
- return False
- else:
- print("请输入有效的选项: 'y' 或 'n'")
- def main():
- """
- 主函数
- """
- print("=" * 60)
- print("交易训练工具")
- print("=" * 60)
- # 设置随机种子
- if CONFIG['random_seed'] is not None:
- random.seed(CONFIG['random_seed'])
- np.random.seed(CONFIG['random_seed'])
- # 获取当前目录
- current_dir = _get_current_directory()
- csv_path = os.path.join(current_dir, CONFIG['csv_filename'])
- result_path = os.path.join(current_dir, CONFIG['result_filename'])
- output_dir = os.path.join(current_dir, CONFIG['output_dir'])
- # 创建输出目录
- os.makedirs(output_dir, exist_ok=True)
- # 1. 读取交易数据
- print("\n=== 步骤1: 读取交易数据 ===")
- transaction_df = read_transaction_data(csv_path)
- if len(transaction_df) == 0:
- print("未能读取交易数据,退出")
- return
- # 2. 加载已处理的结果
- print("\n=== 步骤2: 加载已处理记录 ===")
- existing_results, processed_pairs = load_processed_results(result_path)
- print(f"已处理 {len(processed_pairs)} 个交易对")
- # 3. 提取所有开仓交易
- print("\n=== 步骤3: 提取开仓交易 ===")
- open_trades = []
- for idx, row in transaction_df.iterrows():
- contract_code, trade_date, trade_price, direction, action, order_time, trade_type, trade_pair_id, continuous_pair_id = extract_contract_info(row)
- if contract_code is None or action != 'open':
- continue
- # 跳过已处理的交易对
- if trade_pair_id in processed_pairs:
- continue
- # 查找对应的平仓交易
- profit_loss = calculate_profit_loss(transaction_df, trade_pair_id, continuous_pair_id)
- open_trades.append({
- 'index': idx,
- 'contract_code': contract_code,
- 'trade_date': trade_date,
- 'trade_price': trade_price,
- 'direction': direction,
- 'order_time': order_time,
- 'trade_type': trade_type,
- 'trade_pair_id': trade_pair_id,
- 'continuous_pair_id': continuous_pair_id,
- 'profit_loss': profit_loss,
- 'original_row': row
- })
- print(f"找到 {len(open_trades)} 个未处理的开仓交易")
- if len(open_trades) == 0:
- print("没有未处理的开仓交易,退出")
- return
- # 4. 随机选择一个交易(完全随机)
- print("\n=== 步骤4: 随机选择交易 ===")
- # 打乱交易列表顺序确保完全随机
- random.shuffle(open_trades)
- selected_trade = random.choice(open_trades)
- print(f"选中交易: {selected_trade['contract_code']} - {selected_trade['trade_date']} - {selected_trade['direction']}")
- print(f"剩余未处理交易: {len(open_trades) - 1} 个")
- # 5. 获取K线数据
- print("\n=== 步骤5: 获取K线数据 ===")
- kline_data, trade_idx = get_kline_data_with_future(
- selected_trade['contract_code'],
- selected_trade['trade_date'],
- CONFIG['history_days'],
- CONFIG['future_days']
- )
- if kline_data is None or trade_idx is None:
- print("获取K线数据失败,退出")
- return
- # 6. 显示部分K线图
- print("\n=== 步骤6: 显示部分K线图 ===")
- partial_image_path = os.path.join(output_dir, f"partial_{selected_trade['trade_pair_id']}.png")
- plot_partial_kline(
- kline_data, trade_idx, selected_trade['trade_price'],
- selected_trade['direction'], selected_trade['contract_code'],
- selected_trade['trade_date'], selected_trade['order_time'],
- partial_image_path
- )
- # 7. 获取用户决策
- print(f"\n交易信息:")
- print(f"合约: {selected_trade['contract_code']}")
- print(f"日期: {selected_trade['trade_date']}")
- print(f"方向: {'多头' if selected_trade['direction'] == 'long' else '空头'}")
- print(f"成交价: {selected_trade['trade_price']}")
- user_decision = get_user_decision()
- # 8. 显示完整K线图
- print("\n=== 步骤7: 显示完整K线图 ===")
- full_image_path = os.path.join(output_dir, f"full_{selected_trade['trade_pair_id']}.png")
- plot_full_kline(
- kline_data, trade_idx, selected_trade['trade_price'],
- selected_trade['direction'], selected_trade['contract_code'],
- selected_trade['trade_date'], selected_trade['order_time'],
- selected_trade['profit_loss'],
- full_image_path
- )
- # 9. 记录结果
- print("\n=== 步骤8: 记录结果 ===")
- # 计算判定收益
- decision_profit = selected_trade['profit_loss'] if user_decision else -selected_trade['profit_loss']
- result_data = {
- '日期': selected_trade['original_row']['日期'],
- '委托时间': selected_trade['original_row']['委托时间'],
- '标的': selected_trade['original_row']['标的'],
- '交易类型': selected_trade['original_row']['交易类型'],
- '成交数量': selected_trade['original_row']['成交数量'],
- '成交价': selected_trade['original_row']['成交价'],
- '平仓盈亏': selected_trade['profit_loss'],
- '用户判定': '开仓' if user_decision else '不开仓',
- '判定收益': decision_profit,
- '交易对ID': selected_trade['trade_pair_id'],
- '连续交易对ID': selected_trade['continuous_pair_id']
- }
- record_result(result_data, result_path)
- print(f"\n=== 训练完成 ===")
- print(f"用户判定: {'开仓' if user_decision else '不开仓'}")
- print(f"实际盈亏: {selected_trade['profit_loss']:+.2f}")
- print(f"判定收益: {decision_profit:+.2f}")
- print(f"结果已保存到: {result_path}")
- if __name__ == "__main__":
- main()
|