future_ma_cross_analysis.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import pandas as pd
  2. import numpy as np
  3. from jqdata import *
  4. import datetime
  5. import matplotlib.pyplot as plt
  6. def get_all_key_info(the_type='futures'):
  7. """获取所有期货主力合约信息并进行日期过滤"""
  8. main = get_all_securities(types=[the_type]).reset_index()
  9. the_main = main[main.display_name.str.endswith('主力合约')].copy() # 使用copy()避免警告
  10. the_main.rename(columns={'index': 'code'}, inplace=True)
  11. # 将start_date转换为datetime格式
  12. the_main['start_date'] = pd.to_datetime(the_main['start_date'])
  13. return the_main
  14. def get_future_data_with_ma(future_code, start_date, end_date):
  15. """获取单个期货合约的价格数据并计算多条移动平均线"""
  16. try:
  17. # if future_code != "A9999.XDCE":
  18. # print(f"跳过: {future_code}的数据")
  19. # return None
  20. # else:
  21. print(f"获取期货: {future_code}的数据")
  22. data = get_price(future_code,
  23. start_date=start_date,
  24. end_date=end_date,
  25. frequency='daily',
  26. fields=['open', 'close', 'high', 'low', 'volume'],
  27. skip_paused=False,
  28. panel=False)
  29. if data is None or len(data) == 0:
  30. return None
  31. # 重置索引,使日期成为一列,避免复制警告
  32. data = data.reset_index()
  33. # 计算多条移动平均线
  34. data['MA5'] = data['close'].rolling(window=5).mean()
  35. data['MA10'] = data['close'].rolling(window=10).mean()
  36. data['MA20'] = data['close'].rolling(window=20).mean()
  37. data['MA30'] = data['close'].rolling(window=30).mean()
  38. data.to_csv("A9999.csv", index=False, encoding='utf-8-sig')
  39. return data
  40. except Exception as e:
  41. print(f"获取{future_code}数据时出错: {str(e)}")
  42. return None
  43. def check_multi_ma_cross(row):
  44. """
  45. 检查单日K线是否向上或向下穿越了至少3条均线。
  46. 这个新版本精确计算被K线实体穿越的均线数量。
  47. 返回: 'up', 'down', 或 None
  48. """
  49. date = row['index']
  50. open_price = row['open']
  51. close_price = row['close']
  52. ma_values = [row['MA5'], row['MA10'], row['MA20'], row['MA30']]
  53. # print(f"date: {date}, open_price: {open_price}, close_price: {close_price}")
  54. # print(f"ma_values: {ma_values}")
  55. # 如果任何均线为NaN则跳过,确保数据完整性
  56. if pd.isna(row['MA5']) or pd.isna(row['MA10']) or pd.isna(row['MA20']) or pd.isna(row['MA30']):
  57. # print(f"date: {date}, MA5: {row['MA5']}, MA10: {row['MA10']}, MA20: {row['MA20']}, MA30: {row['MA30']}")
  58. # print("ma里有NaN跳过")
  59. return None
  60. # 如果开盘价和收盘价相等,不可能有穿越发生
  61. if open_price == close_price:
  62. # print("开盘价和收盘价相等,跳过")
  63. return None
  64. crossed_mas_count = 0
  65. # print("开始检查穿越情况")
  66. # 情况一:上涨(阳线),检查上穿
  67. if close_price > open_price:
  68. # print(f"收盘价大于开盘价: open_price: {open_price}, close_price: {close_price}")
  69. for ma in ma_values:
  70. # print(f"检查{ma}和收盘价和开盘价的关系")
  71. # 精确判断:开盘价在均线下方,且收盘价在均线上方
  72. if open_price < ma and close_price > ma:
  73. crossed_mas_count += 1
  74. if crossed_mas_count >= 3:
  75. return 'up'
  76. # 情况二:下跌(阴线),检查下穿
  77. elif open_price > close_price:
  78. for ma in ma_values:
  79. # 精确判断:开盘价在均线上方,且收盘价在均线下方
  80. if open_price > ma and close_price < ma:
  81. crossed_mas_count += 1
  82. if crossed_mas_count >= 3:
  83. return 'down'
  84. # 如果以上条件都不满足(包括穿越数量不足),则返回None
  85. # print("以上情况均不满足,跳过")
  86. return None
  87. def analyze_multi_ma_crosses(data, future_code, future_name):
  88. """分析多均线穿越并计算未来收益率"""
  89. if data is None or len(data) < 30: # 需要至少30天数据来计算30日均线
  90. return pd.DataFrame()
  91. results = []
  92. for i in range(len(data)):
  93. row = data.iloc[i]
  94. cross_type = check_multi_ma_cross(row)
  95. if cross_type is not None:
  96. # 使用包含日期的'index'列
  97. cross_date = row['index']
  98. open_price = row['open']
  99. close_price = row['close']
  100. # 计算当日收益率(收盘价和开盘价的变化率)
  101. intraday_return = (close_price - open_price) / open_price * 100
  102. # 计算未来收益率
  103. future_5d_return = None
  104. future_20d_return = None
  105. future_30d_return = None
  106. # 后5日收益率
  107. if i + 5 < len(data):
  108. future_5d_price = data.iloc[i + 5]['close']
  109. future_5d_return = (future_5d_price - close_price) / close_price * 100
  110. # 后10日收益率
  111. if i + 10 < len(data):
  112. future_10d_price = data.iloc[i + 10]['close']
  113. future_10d_return = (future_5d_price - close_price) / close_price * 100
  114. # 后20日收益率
  115. if i + 20 < len(data):
  116. future_20d_price = data.iloc[i + 20]['close']
  117. future_20d_return = (future_20d_price - close_price) / close_price * 100
  118. # 后30日收益率
  119. if i + 30 < len(data):
  120. future_30d_price = data.iloc[i + 30]['close']
  121. future_30d_return = (future_30d_price - close_price) / close_price * 100
  122. # 【【【 这是关键的修复点 】】】
  123. # 从 'row' 中获取当天的均线值,而不是从 'data' 中获取整个列
  124. results.append({
  125. '代码': future_code,
  126. '名称': future_name,
  127. '日期': cross_date,
  128. '方向': '上穿' if cross_type == 'up' else '下穿',
  129. '开盘价': round(open_price, 2),
  130. '收盘价': round(close_price, 2),
  131. 'MA5': round(row['MA5'], 2),
  132. 'MA10': round(row['MA10'], 2),
  133. 'MA20': round(row['MA20'], 2),
  134. 'MA30': round(row['MA30'], 2),
  135. '当日变化率': round(intraday_return, 2),
  136. '后5日变化率': round(future_5d_return, 2) if future_5d_return is not None else None,
  137. '后10日变化率': round(future_10d_return, 2) if future_10d_return is not None else None,
  138. '后20日变化率': round(future_20d_return, 2) if future_20d_return is not None else None,
  139. '后30日变化率': round(future_30d_return, 2) if future_30d_return is not None else None
  140. })
  141. # 将结果转换为DataFrame并确保列的顺序
  142. result_df = pd.DataFrame(results, columns=[
  143. '日期', '代码', '名称', '开盘价', '收盘价', '当日变化率', '方向',
  144. 'MA5', 'MA10', 'MA20', 'MA30', '后5日变化率', '后10日变化率',
  145. '后20日变化率', '后30日变化率'
  146. ])
  147. return result_df
  148. def analyze_all_futures_multi_ma(start_date, end_date):
  149. """分析所有期货合约的多均线穿越情况"""
  150. all_future_df = get_all_key_info()
  151. # 过滤在分析结束日期之前有数据且在分析结束日期之前开始交易的期货
  152. valid_futures = all_future_df[
  153. (all_future_df['start_date'] <= pd.to_datetime(end_date))
  154. ].copy()
  155. print(f"找到{len(valid_futures)}个期货合约需要分析")
  156. # 创建代码到名称的映射
  157. code_to_name = dict(zip(valid_futures['code'], valid_futures['display_name']))
  158. all_results = []
  159. total_futures = len(valid_futures)
  160. for idx, row in valid_futures.iterrows():
  161. future = row['code']
  162. future_start_date = row['start_date']
  163. # 调整开始日期为分析开始日期或期货开始日期中的较晚者
  164. effective_start_date = max(pd.to_datetime(start_date), future_start_date)
  165. print(f'正在分析 {future} ({idx+1}/{total_futures}) 从 {effective_start_date.strftime("%Y-%m-%d")} 开始...')
  166. try:
  167. data = get_future_data_with_ma(future, effective_start_date, end_date)
  168. if data is not None and len(data) >= 30:
  169. future_name = code_to_name.get(future, future)
  170. results = analyze_multi_ma_crosses(data, future, future_name)
  171. if not results.empty:
  172. all_results.append(results)
  173. print(f' 为{future}找到{len(results)}次多均线穿越')
  174. else:
  175. print(f' {future}未找到多均线穿越')
  176. else:
  177. print(f' {future}数据不足 (获得{len(data) if data is not None else 0}天数据)')
  178. except Exception as e:
  179. print(f' 分析{future}时出错: {str(e)}')
  180. continue
  181. if not all_results:
  182. print("未找到多均线穿越结果")
  183. return pd.DataFrame()
  184. combined_results = pd.concat(all_results, ignore_index=True)
  185. return combined_results
  186. def generate_summary_stats(results):
  187. """生成多均线穿越分析的汇总统计"""
  188. if results.empty:
  189. print("没有结果可供分析")
  190. return
  191. print(f"\n=== 多均线穿越分析结果汇总 ===")
  192. print(f"总共发现 {len(results)} 次多均线穿越事件")
  193. print(f"涉及 {results['代码'].nunique()} 个不同的期货品种")
  194. # 按穿越方向统计
  195. cross_type_stats = results['方向'].value_counts()
  196. print(f"\n穿越方向分布:")
  197. for cross_type, count in cross_type_stats.items():
  198. print(f" {cross_type}: {count} 次")
  199. # 收益率统计
  200. print(f"\n收益率统计:")
  201. for col in ['当日变化率', '后5日变化率',
  202. '后20日变化率', '后30日变化率']:
  203. valid_data = results[col].dropna()
  204. if len(valid_data) > 0:
  205. print(f" {col}:")
  206. print(f" 平均值: {valid_data.mean():.2f}%")
  207. print(f" 中位数: {valid_data.median():.2f}%")
  208. print(f" 标准差: {valid_data.std():.2f}%")
  209. print(f" 最小值: {valid_data.min():.2f}%")
  210. print(f" 最大值: {valid_data.max():.2f}%")
  211. # 按品种统计
  212. print(f"\n各品种穿越次数统计:")
  213. variety_stats = results.groupby('代码').size().sort_values(ascending=False)
  214. for code, count in variety_stats.head(10).items():
  215. name = results[results['代码'] == code]['名称'].iloc[0]
  216. print(f" {code} ({name}): {count} 次")
  217. def main():
  218. """运行多均线穿越分析的主函数"""
  219. # 设置分析期间
  220. start_date = datetime.datetime(2024, 6, 1)
  221. end_date = datetime.datetime(2025, 6, 1)
  222. print(f"开始分析期货多均线穿越情况...")
  223. print(f"分析期间: {start_date.strftime('%Y-%m-%d')} 到 {end_date.strftime('%Y-%m-%d')}")
  224. # 分析所有期货
  225. results = analyze_all_futures_multi_ma(start_date, end_date)
  226. if not results.empty:
  227. # 生成并显示统计信息
  228. generate_summary_stats(results)
  229. # 导出结果到CSV文件
  230. output_filename = 'multi_ma_cross_analysis_results.csv'
  231. results.to_csv(output_filename, index=False, encoding='utf-8-sig')
  232. print(f"\n结果已保存到文件: {output_filename}")
  233. # 显示前几行结果
  234. print(f"\n前10条结果预览:")
  235. print(results.head(10).to_string(index=False))
  236. return results
  237. else:
  238. print("未找到符合条件的多均线穿越事件")
  239. return pd.DataFrame()
  240. if __name__ == "__main__":
  241. results = main()