|
|
@@ -0,0 +1,486 @@
|
|
|
+# K线复原工具
|
|
|
+# 用于从交易记录CSV文件中提取开仓记录,获取对应的K线数据并绘制包含均线的K线图
|
|
|
+
|
|
|
+from jqdata import *
|
|
|
+import pandas as pd
|
|
|
+import numpy as np
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import matplotlib.patches as patches
|
|
|
+from datetime import datetime, timedelta, date
|
|
|
+import re
|
|
|
+import os
|
|
|
+import warnings
|
|
|
+warnings.filterwarnings('ignore')
|
|
|
+
|
|
|
+# 中文字体设置(虽然图片内文字用英文,但保留设置以防需要)
|
|
|
+plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
|
|
|
+plt.rcParams['axes.unicode_minus'] = False
|
|
|
+
|
|
|
+
|
|
|
+def _get_current_directory():
|
|
|
+ """
|
|
|
+ 获取当前文件所在目录,兼容 Jupyter notebook 环境
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ str: 当前目录路径
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 在普通 Python 脚本中,使用 __file__
|
|
|
+ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+ except NameError:
|
|
|
+ # 在 Jupyter notebook 环境中,__file__ 不存在,使用当前工作目录
|
|
|
+ current_dir = os.getcwd()
|
|
|
+ # 如果当前目录不是 future 目录,尝试查找
|
|
|
+ if not os.path.exists(os.path.join(current_dir, 'transaction.csv')):
|
|
|
+ # 尝试查找 future 目录
|
|
|
+ if 'future' not in current_dir:
|
|
|
+ # 尝试向上查找 future 目录
|
|
|
+ parent_dir = os.path.dirname(current_dir)
|
|
|
+ future_dir = os.path.join(parent_dir, 'future')
|
|
|
+ if os.path.exists(os.path.join(future_dir, 'transaction.csv')):
|
|
|
+ current_dir = future_dir
|
|
|
+ return current_dir
|
|
|
+
|
|
|
+
|
|
|
+def read_and_filter_open_positions(csv_path):
|
|
|
+ """
|
|
|
+ 读取CSV文件并筛选出开仓记录
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ csv_path (str): CSV文件路径
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ pandas.DataFrame: 包含开仓记录的DataFrame
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ df = pd.read_csv(csv_path, encoding='utf-8-sig')
|
|
|
+
|
|
|
+ # 筛选交易类型第一个字符为"开"的行
|
|
|
+ open_positions = df[df['交易类型'].str[0] == '开'].copy()
|
|
|
+
|
|
|
+ print(f"从CSV文件中读取到 {len(df)} 条记录")
|
|
|
+ print(f"筛选出 {len(open_positions)} 条开仓记录")
|
|
|
+
|
|
|
+ return open_positions
|
|
|
+ except Exception as e:
|
|
|
+ print(f"读取CSV文件时出错: {str(e)}")
|
|
|
+ return pd.DataFrame()
|
|
|
+
|
|
|
+
|
|
|
+def extract_contract_code_and_date(row):
|
|
|
+ """
|
|
|
+ 从标的列提取合约编号,从日期列提取日期
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ row (pandas.Series): DataFrame的一行数据
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ tuple: (contract_code, trade_date) 或 (None, None) 如果提取失败
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 提取合约编号:从"标的"列中提取括号内的内容
|
|
|
+ target_str = str(row['标的'])
|
|
|
+ match = re.search(r'\(([^)]+)\)', target_str)
|
|
|
+ if match:
|
|
|
+ contract_code = match.group(1)
|
|
|
+ else:
|
|
|
+ print(f"无法从标的 '{target_str}' 中提取合约编号")
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ # 提取日期
|
|
|
+ date_str = str(row['日期'])
|
|
|
+ try:
|
|
|
+ trade_date = datetime.strptime(date_str, '%Y-%m-%d').date()
|
|
|
+ except:
|
|
|
+ print(f"日期格式错误: {date_str}")
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ return contract_code, trade_date
|
|
|
+ except Exception as e:
|
|
|
+ print(f"提取合约编号和日期时出错: {str(e)}")
|
|
|
+ return None, None
|
|
|
+
|
|
|
+
|
|
|
+def calculate_trade_days_range(trade_date, days_before=60, days_after=10):
|
|
|
+ """
|
|
|
+ 计算交易日范围:往前days_before个交易日,往后days_after个交易日
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ trade_date (date): 开仓日期
|
|
|
+ days_before (int): 往前交易日数量,默认60
|
|
|
+ days_after (int): 往后交易日数量,默认10
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ tuple: (start_date, end_date) 或 (None, None) 如果计算失败
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 往前找:从trade_date往前找days_before个交易日
|
|
|
+ # get_trade_days(end_date=trade_date, count=n) 返回包括trade_date在内的n个交易日
|
|
|
+ # 所以需要count=days_before+1,第一个就是days_before个交易日前的日期
|
|
|
+ trade_days_before = get_trade_days(end_date=trade_date, count=days_before + 1)
|
|
|
+ if len(trade_days_before) < days_before + 1:
|
|
|
+ print(f"无法获取足够的往前交易日,只获取到 {len(trade_days_before)} 个")
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ start_date = trade_days_before[0].date()
|
|
|
+
|
|
|
+ # 往后找:从trade_date往后找days_after个交易日
|
|
|
+ # get_trade_days(start_date=trade_date, count=n) 返回包括trade_date在内的n个交易日
|
|
|
+ # 所以需要count=days_after+1,最后一个就是days_after个交易日后的日期
|
|
|
+ trade_days_after = get_trade_days(start_date=trade_date, count=days_after + 1)
|
|
|
+ if len(trade_days_after) < days_after + 1:
|
|
|
+ print(f"无法获取足够的往后交易日,只获取到 {len(trade_days_after)} 个")
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ end_date = trade_days_after[-1].date()
|
|
|
+
|
|
|
+ return start_date, end_date
|
|
|
+ except Exception as e:
|
|
|
+ print(f"计算交易日范围时出错: {str(e)}")
|
|
|
+ return None, None
|
|
|
+
|
|
|
+
|
|
|
+def get_kline_data(contract_code, start_date, end_date):
|
|
|
+ """
|
|
|
+ 获取指定合约在时间范围内的K线数据
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ contract_code (str): 合约编号,如 'JD2502.XDCE'
|
|
|
+ start_date (date): 开始日期
|
|
|
+ end_date (date): 结束日期
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ pandas.DataFrame: 包含OHLC数据的DataFrame,如果获取失败返回None
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 使用get_price获取K线数据
|
|
|
+ price_data = get_price(
|
|
|
+ contract_code,
|
|
|
+ start_date=start_date,
|
|
|
+ end_date=end_date,
|
|
|
+ frequency='1d',
|
|
|
+ fields=['open', 'close', 'high', 'low']
|
|
|
+ )
|
|
|
+
|
|
|
+ if price_data is None or len(price_data) == 0:
|
|
|
+ print(f"未获取到 {contract_code} 在 {start_date} 至 {end_date} 的数据")
|
|
|
+ return None
|
|
|
+
|
|
|
+ return price_data
|
|
|
+ except Exception as e:
|
|
|
+ print(f"获取K线数据时出错: {str(e)}")
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def calculate_moving_averages(data):
|
|
|
+ """
|
|
|
+ 计算5K, 10K, 20K, 30K均线
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ data (pandas.DataFrame): 包含close列的DataFrame
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ pandas.DataFrame: 添加了均线列的DataFrame
|
|
|
+ """
|
|
|
+ data = data.copy()
|
|
|
+
|
|
|
+ # 计算均线
|
|
|
+ data['ma5'] = data['close'].rolling(window=5).mean()
|
|
|
+ data['ma10'] = data['close'].rolling(window=10).mean()
|
|
|
+ data['ma20'] = data['close'].rolling(window=20).mean()
|
|
|
+ data['ma30'] = data['close'].rolling(window=30).mean()
|
|
|
+
|
|
|
+ return data
|
|
|
+
|
|
|
+
|
|
|
+def filter_data_with_ma(data):
|
|
|
+ """
|
|
|
+ 过滤掉任何一条均线为空的日期
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ data (pandas.DataFrame): 包含均线列的DataFrame
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ pandas.DataFrame: 过滤后的DataFrame
|
|
|
+ """
|
|
|
+ # 过滤掉任何一条均线为空的日期
|
|
|
+ filtered_data = data.dropna(subset=['ma5', 'ma10', 'ma20', 'ma30'])
|
|
|
+
|
|
|
+ return filtered_data
|
|
|
+
|
|
|
+
|
|
|
+def plot_kline_chart(data, contract_code, trade_date, save_path):
|
|
|
+ """
|
|
|
+ 绘制K线图(包含均线和开仓日期标注)
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ data (pandas.DataFrame): 包含OHLC和均线数据的DataFrame
|
|
|
+ contract_code (str): 合约编号
|
|
|
+ trade_date (date): 开仓日期
|
|
|
+ save_path (str): 保存路径
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 创建图表
|
|
|
+ fig, ax = plt.subplots(figsize=(16, 10))
|
|
|
+
|
|
|
+ # 准备数据
|
|
|
+ dates = data.index
|
|
|
+ opens = data['open']
|
|
|
+ highs = data['high']
|
|
|
+ lows = data['low']
|
|
|
+ closes = data['close']
|
|
|
+
|
|
|
+ # 找到开仓日期在数据中的位置
|
|
|
+ trade_date_idx = None
|
|
|
+ for i, date_idx in enumerate(dates):
|
|
|
+ if isinstance(date_idx, date):
|
|
|
+ if date_idx == trade_date:
|
|
|
+ trade_date_idx = i
|
|
|
+ break
|
|
|
+ elif isinstance(date_idx, datetime):
|
|
|
+ if date_idx.date() == trade_date:
|
|
|
+ trade_date_idx = i
|
|
|
+ break
|
|
|
+
|
|
|
+ # 绘制K线
|
|
|
+ for i in range(len(data)):
|
|
|
+ date_idx = dates[i]
|
|
|
+ open_price = opens.iloc[i]
|
|
|
+ high_price = highs.iloc[i]
|
|
|
+ low_price = lows.iloc[i]
|
|
|
+ close_price = closes.iloc[i]
|
|
|
+
|
|
|
+ # K线颜色:红涨绿跌
|
|
|
+ color = 'red' if close_price > open_price else 'green'
|
|
|
+ edge_color = 'darkred' if close_price > open_price else 'darkgreen'
|
|
|
+
|
|
|
+ # 绘制影线(最高价到最低价的竖线)
|
|
|
+ ax.plot([i, i], [low_price, high_price], color='black', linewidth=1)
|
|
|
+
|
|
|
+ # 绘制实体(开盘价到收盘价的矩形)
|
|
|
+ body_height = abs(close_price - open_price)
|
|
|
+ if body_height == 0:
|
|
|
+ body_height = 0.01 # 避免高度为0
|
|
|
+ bottom = min(open_price, close_price)
|
|
|
+
|
|
|
+ # 使用矩形绘制K线实体
|
|
|
+ rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height,
|
|
|
+ linewidth=1, edgecolor=edge_color,
|
|
|
+ facecolor=color, alpha=0.8)
|
|
|
+ ax.add_patch(rect)
|
|
|
+
|
|
|
+ # 绘制均线
|
|
|
+ ax.plot(range(len(data)), data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8)
|
|
|
+ ax.plot(range(len(data)), data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8)
|
|
|
+ ax.plot(range(len(data)), data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8)
|
|
|
+ ax.plot(range(len(data)), data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8)
|
|
|
+
|
|
|
+ # 标注开仓日期位置
|
|
|
+ if trade_date_idx is not None:
|
|
|
+ trade_price = closes.iloc[trade_date_idx]
|
|
|
+ ax.plot(trade_date_idx, trade_price, marker='*', markersize=15,
|
|
|
+ color='yellow', markeredgecolor='black', markeredgewidth=1.5,
|
|
|
+ label='Open Position', zorder=10)
|
|
|
+ # 添加垂直线
|
|
|
+ ax.axvline(x=trade_date_idx, color='yellow', linestyle='--',
|
|
|
+ linewidth=2, alpha=0.7, zorder=5)
|
|
|
+
|
|
|
+ # 设置图表标题和标签(使用英文)
|
|
|
+ contract_simple = contract_code.split('.')[0] # 提取合约编号的简约部分
|
|
|
+ ax.set_title(f'{contract_simple} ({contract_code}) K-Line Chart\n'
|
|
|
+ f'Period: {dates[0].strftime("%Y-%m-%d")} to {dates[-1].strftime("%Y-%m-%d")} '
|
|
|
+ f'({len(data)} bars)',
|
|
|
+ fontsize=14, fontweight='bold', pad=20)
|
|
|
+
|
|
|
+ ax.set_xlabel('Time', fontsize=12)
|
|
|
+ ax.set_ylabel('Price', fontsize=12)
|
|
|
+ ax.grid(True, alpha=0.3)
|
|
|
+ ax.legend(loc='upper left', fontsize=10)
|
|
|
+
|
|
|
+ # 设置x轴标签
|
|
|
+ step = max(1, len(data) // 10) # 显示约10个时间标签
|
|
|
+ tick_positions = range(0, len(data), step)
|
|
|
+ tick_labels = []
|
|
|
+ for pos in tick_positions:
|
|
|
+ date_val = dates[pos]
|
|
|
+ if isinstance(date_val, date):
|
|
|
+ tick_labels.append(date_val.strftime('%Y-%m-%d'))
|
|
|
+ elif isinstance(date_val, datetime):
|
|
|
+ tick_labels.append(date_val.strftime('%Y-%m-%d'))
|
|
|
+ else:
|
|
|
+ tick_labels.append(str(date_val))
|
|
|
+
|
|
|
+ ax.set_xticks(tick_positions)
|
|
|
+ ax.set_xticklabels(tick_labels, rotation=45, ha='right')
|
|
|
+
|
|
|
+ # 添加统计信息(使用英文)
|
|
|
+ max_price = highs.max()
|
|
|
+ min_price = lows.min()
|
|
|
+ latest_close = closes.iloc[-1]
|
|
|
+ first_close = closes.iloc[0]
|
|
|
+ total_change = (latest_close - first_close) / first_close * 100
|
|
|
+
|
|
|
+ stats_text = (f'High: {max_price:.2f}\n'
|
|
|
+ f'Low: {min_price:.2f}\n'
|
|
|
+ f'Latest Close: {latest_close:.2f}\n'
|
|
|
+ f'Total Change: {total_change:+.2f}%')
|
|
|
+
|
|
|
+ ax.text(0.02, 0.98, stats_text, transform=ax.transAxes,
|
|
|
+ verticalalignment='top', bbox=dict(boxstyle='round',
|
|
|
+ facecolor='wheat', alpha=0.8), fontsize=10)
|
|
|
+
|
|
|
+ # 调整布局并保存
|
|
|
+ plt.tight_layout()
|
|
|
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
|
|
+ plt.close()
|
|
|
+
|
|
|
+ print(f"K线图已保存到: {save_path}")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"绘制K线图时出错: {str(e)}")
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+def reconstruct_kline_from_transactions(csv_path=None, output_dir=None):
|
|
|
+ """
|
|
|
+ 主函数:从交易记录中复原K线图
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ csv_path (str): CSV文件路径,默认为 'Lib/future/transaction.csv'
|
|
|
+ output_dir (str): 输出目录,默认为 'Lib/future/K'
|
|
|
+ """
|
|
|
+ # 设置默认路径
|
|
|
+ if csv_path is None:
|
|
|
+ # 获取当前文件所在目录
|
|
|
+ # 在 Jupyter notebook 中,__file__ 不存在,使用当前工作目录
|
|
|
+ try:
|
|
|
+ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+ except NameError:
|
|
|
+ # 在 Jupyter notebook 环境中,使用当前工作目录
|
|
|
+ current_dir = os.getcwd()
|
|
|
+ # 如果当前目录不是 future 目录,尝试查找
|
|
|
+ if not os.path.exists(os.path.join(current_dir, 'transaction.csv')):
|
|
|
+ # 尝试查找 future 目录
|
|
|
+ if 'future' in current_dir:
|
|
|
+ pass # 已经在 future 目录中
|
|
|
+ else:
|
|
|
+ # 尝试向上查找 future 目录
|
|
|
+ parent_dir = os.path.dirname(current_dir)
|
|
|
+ future_dir = os.path.join(parent_dir, 'future')
|
|
|
+ if os.path.exists(os.path.join(future_dir, 'transaction.csv')):
|
|
|
+ current_dir = future_dir
|
|
|
+ csv_path = os.path.join(current_dir, 'transaction.csv')
|
|
|
+
|
|
|
+ if output_dir is None:
|
|
|
+ # 获取当前文件所在目录
|
|
|
+ try:
|
|
|
+ current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
+ except NameError:
|
|
|
+ # 在 Jupyter notebook 环境中,使用当前工作目录
|
|
|
+ current_dir = os.getcwd()
|
|
|
+ # 如果当前目录不是 future 目录,尝试查找
|
|
|
+ if not os.path.exists(os.path.join(current_dir, 'transaction.csv')):
|
|
|
+ # 尝试查找 future 目录
|
|
|
+ if 'future' in current_dir:
|
|
|
+ pass # 已经在 future 目录中
|
|
|
+ else:
|
|
|
+ # 尝试向上查找 future 目录
|
|
|
+ parent_dir = os.path.dirname(current_dir)
|
|
|
+ future_dir = os.path.join(parent_dir, 'future')
|
|
|
+ if os.path.exists(os.path.join(future_dir, 'transaction.csv')):
|
|
|
+ current_dir = future_dir
|
|
|
+ output_dir = os.path.join(current_dir, 'K')
|
|
|
+
|
|
|
+ # 确保输出目录存在
|
|
|
+ os.makedirs(output_dir, exist_ok=True)
|
|
|
+ print(f"输出目录: {output_dir}")
|
|
|
+
|
|
|
+ # 1. 读取和筛选开仓记录
|
|
|
+ print("\n=== 步骤1: 读取和筛选开仓记录 ===")
|
|
|
+ open_positions = read_and_filter_open_positions(csv_path)
|
|
|
+
|
|
|
+ if len(open_positions) == 0:
|
|
|
+ print("未找到开仓记录,退出")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 2. 处理每条开仓记录
|
|
|
+ print(f"\n=== 步骤2: 处理 {len(open_positions)} 条开仓记录 ===")
|
|
|
+ success_count = 0
|
|
|
+ fail_count = 0
|
|
|
+
|
|
|
+ for idx, row in open_positions.iterrows():
|
|
|
+ print(f"\n--- 处理第 {idx + 1}/{len(open_positions)} 条记录 ---")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 提取合约编号和日期
|
|
|
+ contract_code, trade_date = extract_contract_code_and_date(row)
|
|
|
+ if contract_code is None or trade_date is None:
|
|
|
+ print(f"跳过:无法提取合约编号或日期")
|
|
|
+ fail_count += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ print(f"合约编号: {contract_code}, 开仓日期: {trade_date}")
|
|
|
+
|
|
|
+ # 计算交易日范围
|
|
|
+ start_date, end_date = calculate_trade_days_range(trade_date, days_before=60, days_after=10)
|
|
|
+ if start_date is None or end_date is None:
|
|
|
+ print(f"跳过:无法计算交易日范围")
|
|
|
+ fail_count += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ print(f"数据范围: {start_date} 至 {end_date}")
|
|
|
+
|
|
|
+ # 获取K线数据
|
|
|
+ kline_data = get_kline_data(contract_code, start_date, end_date)
|
|
|
+ if kline_data is None or len(kline_data) == 0:
|
|
|
+ print(f"跳过:无法获取K线数据")
|
|
|
+ fail_count += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ print(f"获取到 {len(kline_data)} 条K线数据")
|
|
|
+
|
|
|
+ # 计算均线
|
|
|
+ kline_data = calculate_moving_averages(kline_data)
|
|
|
+
|
|
|
+ # 过滤数据
|
|
|
+ filtered_data = filter_data_with_ma(kline_data)
|
|
|
+ if len(filtered_data) == 0:
|
|
|
+ print(f"跳过:过滤后无有效数据")
|
|
|
+ fail_count += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ print(f"过滤后剩余 {len(filtered_data)} 条有效数据")
|
|
|
+
|
|
|
+ # 生成文件名
|
|
|
+ contract_simple = contract_code.split('.')[0] # 提取合约编号的简约部分
|
|
|
+ filename = f"{contract_simple}_{trade_date.strftime('%Y%m%d')}.png"
|
|
|
+ save_path = os.path.join(output_dir, filename)
|
|
|
+
|
|
|
+ # 绘制K线图
|
|
|
+ plot_kline_chart(filtered_data, contract_code, trade_date, save_path)
|
|
|
+
|
|
|
+ success_count += 1
|
|
|
+ print(f"✓ 成功处理")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"✗ 处理时出错: {str(e)}")
|
|
|
+ fail_count += 1
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 输出统计信息
|
|
|
+ print(f"\n=== 处理完成 ===")
|
|
|
+ print(f"成功: {success_count} 条")
|
|
|
+ print(f"失败: {fail_count} 条")
|
|
|
+ print(f"总计: {len(open_positions)} 条")
|
|
|
+
|
|
|
+
|
|
|
+# 使用示例
|
|
|
+if __name__ == "__main__":
|
|
|
+ print("=" * 60)
|
|
|
+ print("K线复原工具")
|
|
|
+ print("=" * 60)
|
|
|
+
|
|
|
+ reconstruct_kline_from_transactions()
|
|
|
+
|
|
|
+ print("\n=== 完成 ===")
|
|
|
+
|