# K线复原工具 # 用于从交易记录CSV文件中提取开仓记录,获取对应的K线数据并绘制包含均线的K线图 from jqdata import * import pandas as pd import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as patches from datetime import datetime, timedelta, date import re import os import warnings warnings.filterwarnings('ignore') # 中文字体设置(虽然图片内文字用英文,但保留设置以防需要) plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False def _get_current_directory(): """ 获取当前文件所在目录,兼容 Jupyter notebook 环境 返回: str: 当前目录路径 """ try: # 在普通 Python 脚本中,使用 __file__ current_dir = os.path.dirname(os.path.abspath(__file__)) except NameError: # 在 Jupyter notebook 环境中,__file__ 不存在,使用当前工作目录 current_dir = os.getcwd() # 如果当前目录不是 future 目录,尝试查找 if not os.path.exists(os.path.join(current_dir, 'transaction.csv')): # 尝试查找 future 目录 if 'future' not in current_dir: # 尝试向上查找 future 目录 parent_dir = os.path.dirname(current_dir) future_dir = os.path.join(parent_dir, 'future') if os.path.exists(os.path.join(future_dir, 'transaction.csv')): current_dir = future_dir return current_dir def read_and_filter_open_positions(csv_path): """ 读取CSV文件并筛选出开仓记录 参数: csv_path (str): CSV文件路径 返回: pandas.DataFrame: 包含开仓记录的DataFrame """ try: df = pd.read_csv(csv_path, encoding='utf-8-sig') # 筛选交易类型第一个字符为"开"的行 open_positions = df[df['交易类型'].str[0] == '开'].copy() print(f"从CSV文件中读取到 {len(df)} 条记录") print(f"筛选出 {len(open_positions)} 条开仓记录") return open_positions except Exception as e: print(f"读取CSV文件时出错: {str(e)}") return pd.DataFrame() def extract_contract_code_and_date(row): """ 从标的列提取合约编号,从日期列提取日期 参数: row (pandas.Series): DataFrame的一行数据 返回: tuple: (contract_code, trade_date) 或 (None, None) 如果提取失败 """ try: # 提取合约编号:从"标的"列中提取括号内的内容 target_str = str(row['标的']) match = re.search(r'\(([^)]+)\)', target_str) if match: contract_code = match.group(1) else: print(f"无法从标的 '{target_str}' 中提取合约编号") return None, None # 提取日期 date_str = str(row['日期']) try: trade_date = datetime.strptime(date_str, '%Y-%m-%d').date() except: print(f"日期格式错误: {date_str}") return None, None return contract_code, trade_date except Exception as e: print(f"提取合约编号和日期时出错: {str(e)}") return None, None def calculate_trade_days_range(trade_date, days_before=60, days_after=10): """ 计算交易日范围:往前days_before个交易日,往后days_after个交易日 参数: trade_date (date): 开仓日期 days_before (int): 往前交易日数量,默认60 days_after (int): 往后交易日数量,默认10 返回: tuple: (start_date, end_date) 或 (None, None) 如果计算失败 """ try: # 往前找:从trade_date往前找days_before个交易日 # get_trade_days(end_date=trade_date, count=n) 返回包括trade_date在内的n个交易日 # 所以需要count=days_before+1,第一个就是days_before个交易日前的日期 trade_days_before = get_trade_days(end_date=trade_date, count=days_before + 1) if len(trade_days_before) < days_before + 1: print(f"无法获取足够的往前交易日,只获取到 {len(trade_days_before)} 个") return None, None start_date = trade_days_before[0].date() # 往后找:从trade_date往后找days_after个交易日 # get_trade_days(start_date=trade_date, count=n) 返回包括trade_date在内的n个交易日 # 所以需要count=days_after+1,最后一个就是days_after个交易日后的日期 trade_days_after = get_trade_days(start_date=trade_date, count=days_after + 1) if len(trade_days_after) < days_after + 1: print(f"无法获取足够的往后交易日,只获取到 {len(trade_days_after)} 个") return None, None end_date = trade_days_after[-1].date() return start_date, end_date except Exception as e: print(f"计算交易日范围时出错: {str(e)}") return None, None def get_kline_data(contract_code, start_date, end_date): """ 获取指定合约在时间范围内的K线数据 参数: contract_code (str): 合约编号,如 'JD2502.XDCE' start_date (date): 开始日期 end_date (date): 结束日期 返回: pandas.DataFrame: 包含OHLC数据的DataFrame,如果获取失败返回None """ try: # 使用get_price获取K线数据 price_data = get_price( contract_code, start_date=start_date, end_date=end_date, frequency='1d', fields=['open', 'close', 'high', 'low'] ) if price_data is None or len(price_data) == 0: print(f"未获取到 {contract_code} 在 {start_date} 至 {end_date} 的数据") return None return price_data except Exception as e: print(f"获取K线数据时出错: {str(e)}") return None def calculate_moving_averages(data): """ 计算5K, 10K, 20K, 30K均线 参数: data (pandas.DataFrame): 包含close列的DataFrame 返回: pandas.DataFrame: 添加了均线列的DataFrame """ data = data.copy() # 计算均线 data['ma5'] = data['close'].rolling(window=5).mean() data['ma10'] = data['close'].rolling(window=10).mean() data['ma20'] = data['close'].rolling(window=20).mean() data['ma30'] = data['close'].rolling(window=30).mean() return data def filter_data_with_ma(data): """ 过滤掉任何一条均线为空的日期 参数: data (pandas.DataFrame): 包含均线列的DataFrame 返回: pandas.DataFrame: 过滤后的DataFrame """ # 过滤掉任何一条均线为空的日期 filtered_data = data.dropna(subset=['ma5', 'ma10', 'ma20', 'ma30']) return filtered_data def plot_kline_chart(data, contract_code, trade_date, save_path): """ 绘制K线图(包含均线和开仓日期标注) 参数: data (pandas.DataFrame): 包含OHLC和均线数据的DataFrame contract_code (str): 合约编号 trade_date (date): 开仓日期 save_path (str): 保存路径 """ try: # 创建图表 fig, ax = plt.subplots(figsize=(16, 10)) # 准备数据 dates = data.index opens = data['open'] highs = data['high'] lows = data['low'] closes = data['close'] # 找到开仓日期在数据中的位置 trade_date_idx = None for i, date_idx in enumerate(dates): if isinstance(date_idx, date): if date_idx == trade_date: trade_date_idx = i break elif isinstance(date_idx, datetime): if date_idx.date() == trade_date: trade_date_idx = i break # 绘制K线 for i in range(len(data)): date_idx = dates[i] open_price = opens.iloc[i] high_price = highs.iloc[i] low_price = lows.iloc[i] close_price = closes.iloc[i] # K线颜色:红涨绿跌 color = 'red' if close_price > open_price else 'green' edge_color = 'darkred' if close_price > open_price else 'darkgreen' # 绘制影线(最高价到最低价的竖线) ax.plot([i, i], [low_price, high_price], color='black', linewidth=1) # 绘制实体(开盘价到收盘价的矩形) body_height = abs(close_price - open_price) if body_height == 0: body_height = 0.01 # 避免高度为0 bottom = min(open_price, close_price) # 使用矩形绘制K线实体 rect = patches.Rectangle((i-0.4, bottom), 0.8, body_height, linewidth=1, edgecolor=edge_color, facecolor=color, alpha=0.8) ax.add_patch(rect) # 绘制均线 ax.plot(range(len(data)), data['ma5'], label='MA5', color='blue', linewidth=1.5, alpha=0.8) ax.plot(range(len(data)), data['ma10'], label='MA10', color='orange', linewidth=1.5, alpha=0.8) ax.plot(range(len(data)), data['ma20'], label='MA20', color='purple', linewidth=1.5, alpha=0.8) ax.plot(range(len(data)), data['ma30'], label='MA30', color='brown', linewidth=1.5, alpha=0.8) # 标注开仓日期位置 if trade_date_idx is not None: trade_price = closes.iloc[trade_date_idx] ax.plot(trade_date_idx, trade_price, marker='*', markersize=15, color='yellow', markeredgecolor='black', markeredgewidth=1.5, label='Open Position', zorder=10) # 添加垂直线 ax.axvline(x=trade_date_idx, color='yellow', linestyle='--', linewidth=2, alpha=0.7, zorder=5) # 设置图表标题和标签(使用英文) contract_simple = contract_code.split('.')[0] # 提取合约编号的简约部分 ax.set_title(f'{contract_simple} ({contract_code}) K-Line Chart\n' f'Period: {dates[0].strftime("%Y-%m-%d")} to {dates[-1].strftime("%Y-%m-%d")} ' f'({len(data)} bars)', fontsize=14, fontweight='bold', pad=20) ax.set_xlabel('Time', fontsize=12) ax.set_ylabel('Price', fontsize=12) ax.grid(True, alpha=0.3) ax.legend(loc='upper left', fontsize=10) # 设置x轴标签 step = max(1, len(data) // 10) # 显示约10个时间标签 tick_positions = range(0, len(data), step) tick_labels = [] for pos in tick_positions: date_val = dates[pos] if isinstance(date_val, date): tick_labels.append(date_val.strftime('%Y-%m-%d')) elif isinstance(date_val, datetime): tick_labels.append(date_val.strftime('%Y-%m-%d')) else: tick_labels.append(str(date_val)) ax.set_xticks(tick_positions) ax.set_xticklabels(tick_labels, rotation=45, ha='right') # 添加统计信息(使用英文) max_price = highs.max() min_price = lows.min() latest_close = closes.iloc[-1] first_close = closes.iloc[0] total_change = (latest_close - first_close) / first_close * 100 stats_text = (f'High: {max_price:.2f}\n' f'Low: {min_price:.2f}\n' f'Latest Close: {latest_close:.2f}\n' f'Total Change: {total_change:+.2f}%') ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8), fontsize=10) # 调整布局并保存 plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') plt.close() print(f"K线图已保存到: {save_path}") except Exception as e: print(f"绘制K线图时出错: {str(e)}") raise def reconstruct_kline_from_transactions(csv_path=None, output_dir=None): """ 主函数:从交易记录中复原K线图 参数: csv_path (str): CSV文件路径,默认为 'Lib/future/transaction.csv' output_dir (str): 输出目录,默认为 'Lib/future/K' """ # 设置默认路径 if csv_path is None: # 获取当前文件所在目录 # 在 Jupyter notebook 中,__file__ 不存在,使用当前工作目录 try: current_dir = os.path.dirname(os.path.abspath(__file__)) except NameError: # 在 Jupyter notebook 环境中,使用当前工作目录 current_dir = os.getcwd() # 如果当前目录不是 future 目录,尝试查找 if not os.path.exists(os.path.join(current_dir, 'transaction.csv')): # 尝试查找 future 目录 if 'future' in current_dir: pass # 已经在 future 目录中 else: # 尝试向上查找 future 目录 parent_dir = os.path.dirname(current_dir) future_dir = os.path.join(parent_dir, 'future') if os.path.exists(os.path.join(future_dir, 'transaction.csv')): current_dir = future_dir csv_path = os.path.join(current_dir, 'transaction.csv') if output_dir is None: # 获取当前文件所在目录 try: current_dir = os.path.dirname(os.path.abspath(__file__)) except NameError: # 在 Jupyter notebook 环境中,使用当前工作目录 current_dir = os.getcwd() # 如果当前目录不是 future 目录,尝试查找 if not os.path.exists(os.path.join(current_dir, 'transaction.csv')): # 尝试查找 future 目录 if 'future' in current_dir: pass # 已经在 future 目录中 else: # 尝试向上查找 future 目录 parent_dir = os.path.dirname(current_dir) future_dir = os.path.join(parent_dir, 'future') if os.path.exists(os.path.join(future_dir, 'transaction.csv')): current_dir = future_dir output_dir = os.path.join(current_dir, 'K') # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) print(f"输出目录: {output_dir}") # 1. 读取和筛选开仓记录 print("\n=== 步骤1: 读取和筛选开仓记录 ===") open_positions = read_and_filter_open_positions(csv_path) if len(open_positions) == 0: print("未找到开仓记录,退出") return # 2. 处理每条开仓记录 print(f"\n=== 步骤2: 处理 {len(open_positions)} 条开仓记录 ===") success_count = 0 fail_count = 0 for idx, row in open_positions.iterrows(): print(f"\n--- 处理第 {idx + 1}/{len(open_positions)} 条记录 ---") try: # 提取合约编号和日期 contract_code, trade_date = extract_contract_code_and_date(row) if contract_code is None or trade_date is None: print(f"跳过:无法提取合约编号或日期") fail_count += 1 continue print(f"合约编号: {contract_code}, 开仓日期: {trade_date}") # 计算交易日范围 start_date, end_date = calculate_trade_days_range(trade_date, days_before=60, days_after=10) if start_date is None or end_date is None: print(f"跳过:无法计算交易日范围") fail_count += 1 continue print(f"数据范围: {start_date} 至 {end_date}") # 获取K线数据 kline_data = get_kline_data(contract_code, start_date, end_date) if kline_data is None or len(kline_data) == 0: print(f"跳过:无法获取K线数据") fail_count += 1 continue print(f"获取到 {len(kline_data)} 条K线数据") # 计算均线 kline_data = calculate_moving_averages(kline_data) # 过滤数据 filtered_data = filter_data_with_ma(kline_data) if len(filtered_data) == 0: print(f"跳过:过滤后无有效数据") fail_count += 1 continue print(f"过滤后剩余 {len(filtered_data)} 条有效数据") # 生成文件名 contract_simple = contract_code.split('.')[0] # 提取合约编号的简约部分 filename = f"{contract_simple}_{trade_date.strftime('%Y%m%d')}.png" save_path = os.path.join(output_dir, filename) # 绘制K线图 plot_kline_chart(filtered_data, contract_code, trade_date, save_path) success_count += 1 print(f"✓ 成功处理") except Exception as e: print(f"✗ 处理时出错: {str(e)}") fail_count += 1 continue # 输出统计信息 print(f"\n=== 处理完成 ===") print(f"成功: {success_count} 条") print(f"失败: {fail_count} 条") print(f"总计: {len(open_positions)} 条") # 使用示例 if __name__ == "__main__": print("=" * 60) print("K线复原工具") print("=" * 60) reconstruct_kline_from_transactions() print("\n=== 完成 ===")