trading_training_tool.py 30 KB


  1. # 交易训练工具
  2. # 用于从交易记录CSV文件中随机选择交易,显示部分K线图让用户判断是否开仓,然后显示完整结果并记录
  3. from jqdata import *
  4. import pandas as pd
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import matplotlib.patches as patches
  8. from datetime import datetime, timedelta, date
  9. import re
  10. import os
  11. import random
  12. import warnings
  13. warnings.filterwarnings('ignore')
  14. # 中文字体设置
  15. plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
  16. plt.rcParams['axes.unicode_minus'] = False
  17. # ========== 参数配置区域(硬编码参数都在这里) ==========
  18. CONFIG = {
  19. # CSV文件名
  20. 'csv_filename': 'transaction1.csv',
  21. # 结果记录文件名
  22. 'result_filename': 'training_results.csv',
  23. # 历史数据天数
  24. 'history_days': 100,
  25. # 未来数据天数
  26. 'future_days': 20,
  27. # 输出目录
  28. 'output_dir': 'training_images',
  29. # 是否显示图片(在某些环境下可能需要设置为False)
  30. 'show_plots': True,
  31. # 图片DPI
  32. 'plot_dpi': 150,
  33. # 随机种子(设置为None表示不固定种子)
  34. 'random_seed': 42
  35. }
  36. # =====================================================
  37. def _get_current_directory():
  38. """
  39. 获取当前文件所在目录
  40. """
  41. try:
  42. current_dir = os.path.dirname(os.path.abspath(__file__))
  43. except NameError:
  44. current_dir = os.getcwd()
  45. if not os.path.exists(os.path.join(current_dir, CONFIG['csv_filename'])):
  46. parent_dir = os.path.dirname(current_dir)
  47. future_dir = os.path.join(parent_dir, 'future')
  48. if os.path.exists(os.path.join(future_dir, CONFIG['csv_filename'])):
  49. current_dir = future_dir
  50. return current_dir
  51. def read_transaction_data(csv_path):
  52. """
  53. 读取交易记录CSV文件
  54. """
  55. encodings = ['utf-8-sig', 'utf-8', 'gbk', 'gb2312', 'gb18030', 'latin1']
  56. for encoding in encodings:
  57. try:
  58. df = pd.read_csv(csv_path, encoding=encoding)
  59. print(f"成功使用 {encoding} 编码读取CSV文件,共 {len(df)} 条记录")
  60. return df
  61. except UnicodeDecodeError:
  62. continue
  63. except Exception as e:
  64. if encoding == encodings[-1]:
  65. print(f"读取CSV文件时出错: {str(e)}")
  66. raise
  67. return pd.DataFrame()
  68. def extract_contract_info(row):
  69. """
  70. 从交易记录中提取合约信息和交易信息
  71. """
  72. try:
  73. # 提取合约编号
  74. target_str = str(row['标的'])
  75. match = re.search(r'\(([^)]+)\)', target_str)
  76. if match:
  77. contract_code = match.group(1)
  78. else:
  79. return None, None, None, None, None, None, None, None, None
  80. # 提取日期
  81. date_str = str(row['日期']).strip()
  82. date_formats = ['%Y-%m-%d', '%d/%m/%Y', '%Y/%m/%d', '%d-%m-%Y', '%Y%m%d']
  83. trade_date = None
  84. for date_format in date_formats:
  85. try:
  86. trade_date = datetime.strptime(date_str, date_format).date()
  87. break
  88. except ValueError:
  89. continue
  90. if trade_date is None:
  91. return None, None, None, None, None, None, None, None, None
  92. # 提取委托时间
  93. order_time_str = str(row['委托时间']).strip()
  94. try:
  95. time_parts = order_time_str.split(':')
  96. hour = int(time_parts[0])
  97. # 如果委托时间 >= 21:00,需要找到下一个交易日
  98. if hour >= 21:
  99. try:
  100. trade_days = get_trade_days(start_date=trade_date, count=2)
  101. if len(trade_days) >= 2:
  102. next_trade_day = trade_days[1]
  103. if isinstance(next_trade_day, datetime):
  104. actual_trade_date = next_trade_day.date()
  105. elif isinstance(next_trade_day, date):
  106. actual_trade_date = next_trade_day
  107. else:
  108. actual_trade_date = trade_date
  109. else:
  110. actual_trade_date = trade_date
  111. except:
  112. actual_trade_date = trade_date
  113. else:
  114. actual_trade_date = trade_date
  115. except:
  116. actual_trade_date = trade_date
  117. # 提取成交价
  118. try:
  119. trade_price = float(row['成交价'])
  120. except:
  121. return None, None, None, None, None, None, None, None, None
  122. # 提取交易方向和类型
  123. trade_type = str(row['交易类型']).strip()
  124. if '开多' in trade_type:
  125. direction = 'long'
  126. action = 'open'
  127. elif '开空' in trade_type:
  128. direction = 'short'
  129. action = 'open'
  130. elif '平多' in trade_type or '平空' in trade_type:
  131. action = 'close'
  132. direction = 'long' if '平多' in trade_type else 'short'
  133. else:
  134. return None, None, None, None, None, None, None, None, None
  135. # 提取交易对ID和连续交易对ID
  136. trade_pair_id = row.get('交易对ID', 'N/A')
  137. continuous_pair_id = row.get('连续交易对ID', 'N/A')
  138. return contract_code, actual_trade_date, trade_price, direction, action, order_time_str, trade_type, trade_pair_id, continuous_pair_id
  139. except Exception as e:
  140. print(f"提取合约信息时出错: {str(e)}")
  141. return None, None, None, None, None, None, None, None, None
  142. def get_trade_day_range(trade_date, days_before, days_after):
  143. """
  144. 获取交易日范围
  145. """
  146. try:
  147. # 获取历史交易日
  148. trade_days_before = get_trade_days(end_date=trade_date, count=days_before + 1)
  149. if len(trade_days_before) < days_before + 1:
  150. return None, None
  151. first_day = trade_days_before[0]
  152. if isinstance(first_day, datetime):
  153. start_date = first_day.date()
  154. elif isinstance(first_day, date):
  155. start_date = first_day
  156. else:
  157. start_date = first_day
  158. # 获取未来交易日
  159. trade_days_after = get_trade_days(start_date=trade_date, count=days_after + 1)
  160. if len(trade_days_after) < days_after + 1:
  161. return None, None
  162. last_day = trade_days_after[-1]
  163. if isinstance(last_day, datetime):
  164. end_date = last_day.date()
  165. elif isinstance(last_day, date):
  166. end_date = last_day
  167. else:
  168. end_date = last_day
  169. return start_date, end_date
  170. except Exception as e:
  171. print(f"计算交易日范围时出错: {str(e)}")
  172. return None, None
  173. def get_kline_data_with_future(contract_code, trade_date, days_before=100, days_after=20):
  174. """
  175. 获取包含历史和未来数据的K线数据
  176. """
  177. try:
  178. # 获取完整的数据范围
  179. start_date, end_date = get_trade_day_range(trade_date, days_before, days_after)
  180. if start_date is None or end_date is None:
  181. return None, None, None
  182. # 获取K线数据
  183. price_data = get_price(
  184. contract_code,
  185. start_date=start_date,
  186. end_date=end_date,
  187. frequency='1d',
  188. fields=['open', 'close', 'high', 'low']
  189. )
  190. if price_data is None or len(price_data) == 0:
  191. return None, None, None
  192. # 计算均线
  193. price_data['ma5'] = price_data['close'].rolling(window=5).mean()
  194. price_data['ma10'] = price_data['close'].rolling(window=10).mean()
  195. price_data['ma20'] = price_data['close'].rolling(window=20).mean()
  196. price_data['ma30'] = price_data['close'].rolling(window=30).mean()
  197. # 找到交易日在数据中的位置
  198. trade_date_normalized = pd.Timestamp(trade_date)
  199. trade_idx = None
  200. for i, idx in enumerate(price_data.index):
  201. if isinstance(idx, pd.Timestamp):
  202. if idx.date() == trade_date:
  203. trade_idx = i
  204. break
  205. return price_data, trade_idx
  206. except Exception as e:
  207. print(f"获取K线数据时出错: {str(e)}")
  208. return None, None, None
  209. def plot_partial_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, save_path=None):
  210. """
  211. 绘制部分K线图(仅显示历史数据和当天)
  212. """
  213. try:
  214. # 截取历史数据和当天数据
  215. partial_data = data.iloc[:trade_idx + 1].copy()
  216. # 修改当天的收盘价为成交价
  217. partial_data.iloc[-1, partial_data.columns.get_loc('close')] = trade_price
  218. fig, ax = plt.subplots(figsize=(16, 10))
  219. # 准备数据
  220. dates = partial_data.index
  221. opens = partial_data['open']
  222. highs = partial_data['high']
  223. lows = partial_data['low']
  224. closes = partial_data['close']
  225. # 绘制K线
  226. for i in range(len(partial_data)):
  227. color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
  228. edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
  229. # 影线
  230. ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
  231. # 实体
  232. body_height = abs(closes.iloc[i] - opens.iloc[i])
  233. if body_height == 0:
  234. body_height = 0.01
  235. bottom = min(opens.iloc[i], closes.iloc[i])
  236. rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height,
  237. linewidth=1, edgecolor=edge_color,
  238. facecolor=color, alpha=0.8)
  239. ax.add_patch(rect)
  240. # 绘制均线
  241. ax.plot(range(len(partial_data)), partial_data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8)
  242. ax.plot(range(len(partial_data)), partial_data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8)
  243. ax.plot(range(len(partial_data)), partial_data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
  244. ax.plot(range(len(partial_data)), partial_data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
  245. # 获取当天的最高价(用于画连接线)
  246. day_high = highs.iloc[trade_idx]
  247. # 标注信息
  248. date_label = trade_date.strftime('%Y-%m-%d')
  249. price_label = f'Price: {trade_price:.2f}'
  250. direction_label = f'Direction: {"Long" if direction == "long" else "Short"}'
  251. time_label = f'Time: {order_time}'
  252. # 计算文本位置
  253. price_range = highs.max() - lows.min()
  254. y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
  255. text_y = trade_price + y_offset
  256. if text_y > highs.max():
  257. text_y = trade_price - price_range * 0.08
  258. annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}'
  259. ax.text(trade_idx, text_y, annotation_text,
  260. fontsize=10, ha='center', va='bottom',
  261. bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
  262. zorder=11, weight='bold')
  263. # 画黄色虚线连接当天最高价和文本框
  264. ax.plot([trade_idx, trade_idx], [day_high, text_y],
  265. color='yellow', linestyle='--', linewidth=1.5, alpha=0.7, zorder=5)
  266. # 设置标题和标签
  267. contract_simple = contract_code.split('.')[0]
  268. direction_text = "Long" if direction == "long" else "Short"
  269. ax.set_title(f'{contract_simple} - {direction_text} Position Decision\n'
  270. f'Historical Data + Trade Day Only',
  271. fontsize=14, fontweight='bold', pad=20)
  272. ax.set_xlabel('Time', fontsize=12)
  273. ax.set_ylabel('Price', fontsize=12)
  274. ax.grid(True, alpha=0.3)
  275. ax.legend(loc='lower left', fontsize=10)
  276. # 设置x轴标签
  277. step = max(1, len(partial_data) // 10)
  278. tick_positions = range(0, len(partial_data), step)
  279. tick_labels = []
  280. for pos in tick_positions:
  281. date_val = dates[pos]
  282. if isinstance(date_val, (date, datetime)):
  283. tick_labels.append(date_val.strftime('%Y-%m-%d'))
  284. else:
  285. tick_labels.append(str(date_val))
  286. ax.set_xticks(tick_positions)
  287. ax.set_xticklabels(tick_labels, rotation=45, ha='right')
  288. plt.tight_layout()
  289. if save_path:
  290. plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight')
  291. if CONFIG['show_plots']:
  292. plt.show()
  293. plt.close(fig)
  294. except Exception as e:
  295. print(f"绘制部分K线图时出错: {str(e)}")
  296. plt.close('all')
  297. raise
  298. def plot_full_kline(data, trade_idx, trade_price, direction, contract_code, trade_date, order_time, profit_loss, save_path=None):
  299. """
  300. 绘制完整K线图(包含未来数据)
  301. """
  302. try:
  303. fig, ax = plt.subplots(figsize=(16, 10))
  304. # 准备数据
  305. dates = data.index
  306. opens = data['open']
  307. highs = data['high']
  308. lows = data['low']
  309. closes = data['close']
  310. # 绘制K线
  311. for i in range(len(data)):
  312. color = 'red' if closes.iloc[i] > opens.iloc[i] else 'green'
  313. edge_color = 'darkred' if closes.iloc[i] > opens.iloc[i] else 'darkgreen'
  314. # 影线
  315. ax.plot([i, i], [lows.iloc[i], highs.iloc[i]], color='black', linewidth=1)
  316. # 实体
  317. body_height = abs(closes.iloc[i] - opens.iloc[i])
  318. if body_height == 0:
  319. body_height = 0.01
  320. bottom = min(opens.iloc[i], closes.iloc[i])
  321. rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height,
  322. linewidth=1, edgecolor=edge_color,
  323. facecolor=color, alpha=0.8)
  324. ax.add_patch(rect)
  325. # 绘制均线
  326. ax.plot(range(len(data)), data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8)
  327. ax.plot(range(len(data)), data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8)
  328. ax.plot(range(len(data)), data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
  329. ax.plot(range(len(data)), data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
  330. # 获取当天的最高价(用于画连接线)
  331. day_high = highs.iloc[trade_idx]
  332. # 添加未来区域背景
  333. ax.axvspan(trade_idx + 0.5, len(data) - 0.5, alpha=0.1, color='gray', label='Future Data')
  334. # 标注信息
  335. date_label = trade_date.strftime('%Y-%m-%d')
  336. price_label = f'Price: {trade_price:.2f}'
  337. direction_label = f'Direction: {"Long" if direction == "long" else "Short"}'
  338. time_label = f'Time: {order_time}'
  339. profit_label = f'P&L: {profit_loss:+.2f}'
  340. # 计算文本位置
  341. price_range = highs.max() - lows.min()
  342. y_offset = max(price_range * 0.08, (highs.max() - trade_price) * 0.3)
  343. text_y = trade_price + y_offset
  344. if text_y > highs.max():
  345. text_y = trade_price - price_range * 0.08
  346. annotation_text = f'{date_label}\n{price_label}\n{direction_label}\n{time_label}\n{profit_label}'
  347. ax.text(trade_idx, text_y, annotation_text,
  348. fontsize=10, ha='center', va='bottom',
  349. bbox=dict(boxstyle='round,pad=0.6', facecolor='yellow', alpha=0.9, edgecolor='black', linewidth=1.5),
  350. zorder=11, weight='bold')
  351. # 画黄色虚线连接当天最高价和文本框
  352. ax.plot([trade_idx, trade_idx], [day_high, text_y],
  353. color='yellow', linestyle='--', linewidth=1.5, alpha=0.7, zorder=5)
  354. # 设置标题和标签
  355. contract_simple = contract_code.split('.')[0]
  356. direction_text = "Long" if direction == "long" else "Short"
  357. ax.set_title(f'{contract_simple} - {direction_text} Position Result\n'
  358. f'Complete Data with Future {CONFIG["future_days"]} Days',
  359. fontsize=14, fontweight='bold', pad=20)
  360. ax.set_xlabel('Time', fontsize=12)
  361. ax.set_ylabel('Price', fontsize=12)
  362. ax.grid(True, alpha=0.3)
  363. ax.legend(loc='lower left', fontsize=10)
  364. # 设置x轴标签
  365. step = max(1, len(data) // 15)
  366. tick_positions = range(0, len(data), step)
  367. tick_labels = []
  368. for pos in tick_positions:
  369. date_val = dates[pos]
  370. if isinstance(date_val, (date, datetime)):
  371. tick_labels.append(date_val.strftime('%Y-%m-%d'))
  372. else:
  373. tick_labels.append(str(date_val))
  374. ax.set_xticks(tick_positions)
  375. ax.set_xticklabels(tick_labels, rotation=45, ha='right')
  376. plt.tight_layout()
  377. if save_path:
  378. plt.savefig(save_path, dpi=CONFIG['plot_dpi'], bbox_inches='tight')
  379. if CONFIG['show_plots']:
  380. plt.show()
  381. plt.close(fig)
  382. except Exception as e:
  383. print(f"绘制完整K线图时出错: {str(e)}")
  384. plt.close('all')
  385. raise
  386. def load_processed_results(result_path):
  387. """
  388. 加载已处理的结果文件
  389. """
  390. if not os.path.exists(result_path):
  391. return pd.DataFrame(), set()
  392. try:
  393. # 简单读取CSV文件
  394. df = pd.read_csv(result_path, header=0)
  395. # 确保必要的列存在
  396. required_columns = ['交易对ID']
  397. for col in required_columns:
  398. if col not in df.columns:
  399. print(f"警告:结果文件缺少必要列 '{col}'")
  400. return pd.DataFrame(), set()
  401. # 获取已处理的交易对ID
  402. processed_pairs = set(df['交易对ID'].dropna().unique())
  403. return df, processed_pairs
  404. except Exception as e:
  405. # 详细打印错误信息
  406. print(f"加载结果文件时出错: {str(e)}")
  407. print(f"错误类型: {type(e)}")
  408. # 尝试打印问题行的信息
  409. if "line 40" in str(e):
  410. print("\n=== 尝试定位问题行 ===")
  411. try:
  412. with open(result_path, 'r', encoding='utf-8-sig') as f:
  413. lines = f.readlines()
  414. if len(lines) > 40:
  415. print(f"第41行内容: {lines[40]}")
  416. if len(lines) > 41:
  417. print(f"第42行内容: {lines[41]}")
  418. except:
  419. print("无法读取文件内容进行调试")
  420. return pd.DataFrame(), set()
  421. def calculate_profit_loss(df, trade_pair_id, continuous_pair_id):
  422. """
  423. 计算平仓盈亏
  424. """
  425. try:
  426. if continuous_pair_id != 'N/A' and pd.notna(continuous_pair_id):
  427. # 合并所有同一连续交易对ID的平仓盈亏
  428. close_trades = df[
  429. (df['连续交易对ID'] == continuous_pair_id) &
  430. (df['交易类型'].str[0] == '平')
  431. ]
  432. total_profit = close_trades['平仓盈亏'].sum()
  433. else:
  434. # 只查找当前交易对ID的平仓交易
  435. close_trades = df[
  436. (df['交易对ID'] == trade_pair_id) &
  437. (df['交易类型'].str[0] == '平')
  438. ]
  439. if len(close_trades) > 0:
  440. total_profit = close_trades['平仓盈亏'].iloc[0]
  441. else:
  442. total_profit = 0
  443. return total_profit
  444. except Exception as e:
  445. print(f"计算盈亏时出错: {str(e)}")
  446. return 0
  447. def record_result(result_data, result_path):
  448. """
  449. 记录训练结果
  450. """
  451. try:
  452. # 创建结果DataFrame
  453. result_df = pd.DataFrame([result_data])
  454. # 如果文件已存在,读取现有格式并确保新数据格式一致
  455. if os.path.exists(result_path):
  456. try:
  457. # 读取现有文件的列名
  458. existing_df = pd.read_csv(result_path, nrows=0) # 只读取列名
  459. existing_columns = existing_df.columns.tolist()
  460. # 如果新数据列与现有文件不一致,调整格式
  461. if list(result_df.columns) != existing_columns:
  462. # 重新创建DataFrame,确保列顺序一致
  463. aligned_data = {}
  464. for col in existing_columns:
  465. aligned_data[col] = result_data.get(col, 'N/A' if col == '连续交易总盈亏' else '')
  466. result_df = pd.DataFrame([aligned_data])
  467. # 追加写入
  468. result_df.to_csv(result_path, mode='a', header=False, index=False, encoding='utf-8-sig')
  469. except Exception:
  470. # 如果无法读取现有格式,直接覆盖
  471. result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig')
  472. else:
  473. # 文件不存在,创建新文件
  474. result_df.to_csv(result_path, mode='w', header=True, index=False, encoding='utf-8-sig')
  475. print(f"结果已记录到: {result_path}")
  476. except Exception as e:
  477. print(f"记录结果时出错: {str(e)}")
  478. def is_first_continuous_trade(transaction_df, trade_pair_id, continuous_pair_id):
  479. """
  480. 判断是否为连续交易的第一笔交易
  481. 参数:
  482. transaction_df: 交易数据DataFrame
  483. trade_pair_id: 当前交易对ID
  484. continuous_pair_id: 连续交易对ID
  485. 返回:
  486. bool: 是否为连续交易的第一笔交易(或不是连续交易)
  487. """
  488. # 如果不是连续交易,返回True
  489. if continuous_pair_id == 'N/A' or pd.isna(continuous_pair_id):
  490. return True
  491. # 获取同一连续交易组的所有交易
  492. continuous_trades = transaction_df[transaction_df['连续交易对ID'] == continuous_pair_id]
  493. # 获取所有交易对ID并按时间排序
  494. pair_ids = continuous_trades['交易对ID'].unique()
  495. # 获取每个交易对的开仓时间
  496. pair_times = []
  497. for pid in pair_ids:
  498. pair_records = continuous_trades[continuous_trades['交易对ID'] == pid]
  499. open_records = pair_records[pair_records['交易类型'].str.contains('开', na=False)]
  500. if len(open_records) > 0:
  501. # 获取第一个开仓记录的日期和时间
  502. first_open = open_records.iloc[0]
  503. date_str = str(first_open['日期']).strip()
  504. time_str = str(first_open['委托时间']).strip()
  505. try:
  506. dt = pd.to_datetime(f"{date_str} {time_str}")
  507. pair_times.append((pid, dt))
  508. except:
  509. pass
  510. # 按时间排序
  511. pair_times.sort(key=lambda x: x[1])
  512. # 检查当前交易对是否为第一个
  513. if pair_times and pair_times[0][0] == trade_pair_id:
  514. return True
  515. return False
  516. def get_user_decision():
  517. """
  518. 获取用户的开仓决策
  519. """
  520. while True:
  521. decision = input("\n是否开仓?请输入 'y' (开仓) 或 'n' (不开仓): ").strip().lower()
  522. if decision in ['y', 'yes', '是', '开仓']:
  523. return True
  524. elif decision in ['n', 'no', '否', '不开仓']:
  525. return False
  526. else:
  527. print("请输入有效的选项: 'y' 或 'n'")
  528. def main():
  529. """
  530. 主函数
  531. """
  532. print("=" * 60)
  533. print("交易训练工具")
  534. print("=" * 60)
  535. # 设置随机种子
  536. if CONFIG['random_seed'] is not None:
  537. random.seed(CONFIG['random_seed'])
  538. np.random.seed(CONFIG['random_seed'])
  539. # 获取当前目录
  540. current_dir = _get_current_directory()
  541. csv_path = os.path.join(current_dir, CONFIG['csv_filename'])
  542. result_path = os.path.join(current_dir, CONFIG['result_filename'])
  543. output_dir = os.path.join(current_dir, CONFIG['output_dir'])
  544. # 创建输出目录
  545. os.makedirs(output_dir, exist_ok=True)
  546. # 1. 读取交易数据
  547. print("\n=== 步骤1: 读取交易数据 ===")
  548. transaction_df = read_transaction_data(csv_path)
  549. if len(transaction_df) == 0:
  550. print("未能读取交易数据,退出")
  551. return
  552. # 2. 加载已处理的结果
  553. print("\n=== 步骤2: 加载已处理记录 ===")
  554. existing_results, processed_pairs = load_processed_results(result_path)
  555. print(f"已处理 {len(processed_pairs)} 个交易对")
  556. # existing_results 保留用于后续可能的数据分析功能
  557. # 3. 提取所有开仓交易
  558. print("\n=== 步骤3: 提取开仓交易 ===")
  559. open_trades = []
  560. for idx, row in transaction_df.iterrows():
  561. contract_code, trade_date, trade_price, direction, action, order_time, trade_type, trade_pair_id, continuous_pair_id = extract_contract_info(row)
  562. if contract_code is None or action != 'open':
  563. continue
  564. # 跳过已处理的交易对
  565. if trade_pair_id in processed_pairs:
  566. continue
  567. # 检查是否为连续交易的第一笔交易(如果不是第一笔,跳过)
  568. if not is_first_continuous_trade(transaction_df, trade_pair_id, continuous_pair_id):
  569. continue
  570. # 查找对应的平仓交易
  571. profit_loss = calculate_profit_loss(transaction_df, trade_pair_id, continuous_pair_id)
  572. # 如果是连续交易,获取连续交易总盈亏
  573. continuous_total_profit = 'N/A'
  574. if continuous_pair_id != 'N/A' and pd.notna(continuous_pair_id):
  575. continuous_trades = transaction_df[transaction_df['连续交易对ID'] == continuous_pair_id]
  576. try:
  577. close_profit_loss_str = continuous_trades['平仓盈亏'].astype(str).str.replace(',', '')
  578. close_profit_loss_numeric = pd.to_numeric(close_profit_loss_str, errors='coerce').fillna(0)
  579. continuous_total_profit = close_profit_loss_numeric.sum()
  580. except:
  581. continuous_total_profit = 0
  582. open_trades.append({
  583. 'index': idx,
  584. 'contract_code': contract_code,
  585. 'trade_date': trade_date,
  586. 'trade_price': trade_price,
  587. 'direction': direction,
  588. 'order_time': order_time,
  589. 'trade_type': trade_type,
  590. 'trade_pair_id': trade_pair_id,
  591. 'continuous_pair_id': continuous_pair_id,
  592. 'profit_loss': profit_loss,
  593. 'continuous_total_profit': continuous_total_profit,
  594. 'original_row': row
  595. })
  596. print(f"找到 {len(open_trades)} 个未处理的开仓交易(已过滤非首笔连续交易)")
  597. if len(open_trades) == 0:
  598. print("没有未处理的开仓交易,退出")
  599. return
  600. # 4. 随机选择一个交易(完全随机)
  601. print("\n=== 步骤4: 随机选择交易 ===")
  602. # 打乱交易列表顺序确保完全随机
  603. random.shuffle(open_trades)
  604. selected_trade = random.choice(open_trades)
  605. print(f"选中交易: {selected_trade['contract_code']} - {selected_trade['trade_date']} - {selected_trade['direction']}")
  606. print(f"剩余未处理交易: {len(open_trades) - 1} 个")
  607. # 5. 获取K线数据
  608. print("\n=== 步骤5: 获取K线数据 ===")
  609. kline_data, trade_idx = get_kline_data_with_future(
  610. selected_trade['contract_code'],
  611. selected_trade['trade_date'],
  612. CONFIG['history_days'],
  613. CONFIG['future_days']
  614. )
  615. if kline_data is None or trade_idx is None:
  616. print("获取K线数据失败,退出")
  617. return
  618. # 6. 显示部分K线图
  619. print("\n=== 步骤6: 显示部分K线图 ===")
  620. partial_image_path = os.path.join(output_dir, f"partial_{selected_trade['trade_pair_id']}.png")
  621. plot_partial_kline(
  622. kline_data, trade_idx, selected_trade['trade_price'],
  623. selected_trade['direction'], selected_trade['contract_code'],
  624. selected_trade['trade_date'], selected_trade['order_time'],
  625. partial_image_path
  626. )
  627. # 7. 获取用户决策
  628. print(f"\n交易信息:")
  629. print(f"合约: {selected_trade['contract_code']}")
  630. print(f"日期: {selected_trade['trade_date']}")
  631. print(f"方向: {'多头' if selected_trade['direction'] == 'long' else '空头'}")
  632. print(f"成交价: {selected_trade['trade_price']}")
  633. user_decision = get_user_decision()
  634. # 8. 显示完整K线图
  635. print("\n=== 步骤7: 显示完整K线图 ===")
  636. full_image_path = os.path.join(output_dir, f"full_{selected_trade['trade_pair_id']}.png")
  637. plot_full_kline(
  638. kline_data, trade_idx, selected_trade['trade_price'],
  639. selected_trade['direction'], selected_trade['contract_code'],
  640. selected_trade['trade_date'], selected_trade['order_time'],
  641. selected_trade['profit_loss'],
  642. full_image_path
  643. )
  644. # 9. 记录结果
  645. print("\n=== 步骤8: 记录结果 ===")
  646. # 计算判定收益(使用连续交易总盈亏或普通盈亏)
  647. if selected_trade['continuous_total_profit'] != 'N/A':
  648. # 连续交易使用连续交易总盈亏
  649. decision_profit = selected_trade['continuous_total_profit'] if user_decision else -selected_trade['continuous_total_profit']
  650. profit_to_show = selected_trade['continuous_total_profit']
  651. else:
  652. # 普通交易使用单笔盈亏
  653. decision_profit = selected_trade['profit_loss'] if user_decision else -selected_trade['profit_loss']
  654. profit_to_show = selected_trade['profit_loss']
  655. result_data = {
  656. '日期': selected_trade['original_row']['日期'],
  657. '委托时间': selected_trade['original_row']['委托时间'],
  658. '标的': selected_trade['original_row']['标的'],
  659. '交易类型': selected_trade['original_row']['交易类型'],
  660. '成交数量': selected_trade['original_row']['成交数量'],
  661. '成交价': selected_trade['original_row']['成交价'],
  662. '平仓盈亏': selected_trade['profit_loss'],
  663. '用户判定': '开仓' if user_decision else '不开仓',
  664. '判定收益': decision_profit,
  665. '交易对ID': selected_trade['trade_pair_id'],
  666. '连续交易对ID': selected_trade['continuous_pair_id'],
  667. '连续交易总盈亏': selected_trade['continuous_total_profit']
  668. }
  669. record_result(result_data, result_path)
  670. print(f"\n=== 训练完成 ===")
  671. print(f"用户判定: {'开仓' if user_decision else '不开仓'}")
  672. if selected_trade['continuous_total_profit'] != 'N/A':
  673. print(f"连续交易总盈亏: {profit_to_show:+.2f}")
  674. else:
  675. print(f"实际盈亏: {profit_to_show:+.2f}")
  676. print(f"判定收益: {decision_profit:+.2f}")
  677. print(f"结果已保存到: {result_path}")
  678. if __name__ == "__main__":
  679. main()