future_ma_cross_analysis.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import pandas as pd
  2. import numpy as np
  3. from jqdata import *
  4. import datetime
  5. import matplotlib.pyplot as plt
  6. def get_all_key_info(the_type='futures'):
  7. """Get all futures main contracts information"""
  8. main = get_all_securities(types=[the_type]).reset_index()
  9. the_main = main[main.display_name.str.endswith('主力合约')]
  10. the_main.rename(columns={'index': 'code'}, inplace=True)
  11. all_future_list = the_main['code'].unique()
  12. return the_main, all_future_list
  13. def get_period_range(start_year, end_year):
  14. """Get trading period range between years"""
  15. now = datetime.datetime.now()
  16. start_date = datetime.datetime(start_year, 1, 1)
  17. end_date = datetime.datetime(end_year, 12, 31)
  18. if end_year == now.year and now < end_date:
  19. end_date = now
  20. trade_days = get_trade_days(start_date=start_date, end_date=end_date)
  21. actual_start_date = trade_days[0]
  22. actual_end_date = trade_days[-1]
  23. print(f'Analysis period: {actual_start_date} to {actual_end_date}')
  24. return actual_start_date, actual_end_date
  25. def get_future_data(future_code, start_date, end_date):
  26. """Get price data for a single future contract"""
  27. data = get_price(future_code,
  28. start_date=start_date,
  29. end_date=end_date,
  30. frequency='daily',
  31. fields=['open', 'close', 'high', 'low', 'volume'],
  32. skip_paused=False,
  33. panel=False)
  34. if data is None or len(data) == 0:
  35. return None
  36. # Create a copy to avoid SettingWithCopyWarning
  37. data = data.copy()
  38. # Calculate 5-day MA
  39. data['MA5'] = data['close'].rolling(window=5).mean()
  40. return data
  41. def analyze_ma_crosses(data):
  42. """Analyze MA crosses and calculate statistics"""
  43. if data is None or len(data) < 5: # Need at least 5 days for MA5
  44. return pd.DataFrame()
  45. # Create a copy of the data to avoid warnings
  46. data = data.copy()
  47. # Initialize cross detection
  48. data['cross_up'] = (data['open'] < data['MA5']) & (data['close'] > data['MA5'])
  49. data['cross_down'] = (data['open'] > data['MA5']) & (data['close'] < data['MA5'])
  50. # Get indices where crosses occur
  51. cross_dates = data.index[data['cross_up'] | data['cross_down']].tolist()
  52. results = []
  53. for i in range(len(cross_dates) - 1):
  54. current_date = cross_dates[i]
  55. next_date = cross_dates[i + 1]
  56. # Get cross type
  57. is_up_cross = data.loc[current_date, 'cross_up']
  58. # Calculate trading days between crosses
  59. trading_days = len(data.loc[current_date:next_date].index) - 1
  60. # Calculate price change
  61. start_price = data.loc[current_date, 'close']
  62. end_price = data.loc[next_date, 'close']
  63. price_change_pct = (end_price - start_price) / start_price * 100
  64. results.append({
  65. 'cross_date': current_date,
  66. 'next_cross_date': next_date,
  67. 'cross_type': 'Upward' if is_up_cross else 'Downward',
  68. 'trading_days': trading_days,
  69. 'price_change_pct': price_change_pct
  70. })
  71. return pd.DataFrame(results)
  72. def analyze_all_futures(start_year, end_year):
  73. """Analyze MA crosses for all futures contracts"""
  74. start_date, end_date = get_period_range(start_year, end_year)
  75. all_future_df, all_future_list = get_all_key_info()
  76. all_results = []
  77. for future in all_future_list:
  78. print(f'Analyzing {future}...')
  79. data = get_future_data(future, start_date, end_date)
  80. if data is not None and len(data) >= 5:
  81. results = analyze_ma_crosses(data)
  82. if not results.empty:
  83. results['future_code'] = future
  84. all_results.append(results)
  85. if not all_results:
  86. return pd.DataFrame()
  87. combined_results = pd.concat(all_results, ignore_index=True)
  88. return combined_results
  89. def generate_statistics(results):
  90. """Generate summary statistics for the analysis"""
  91. if results.empty:
  92. print("No results to analyze")
  93. return
  94. # Group by future code and cross type
  95. stats = results.groupby(['future_code', 'cross_type']).agg({
  96. 'trading_days': ['count', 'mean', 'std', 'min', 'max'],
  97. 'price_change_pct': ['mean', 'std', 'min', 'max']
  98. }).round(2)
  99. return stats
  100. def plot_results(results):
  101. """Plot distribution of trading days and price changes"""
  102. if results.empty:
  103. print("No results to plot")
  104. return
  105. fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
  106. # Plot trading days distribution
  107. for cross_type in ['Upward', 'Downward']:
  108. data = results[results['cross_type'] == cross_type]['trading_days']
  109. ax1.hist(data, bins=30, alpha=0.5, label=cross_type)
  110. ax1.set_title('Distribution of Trading Days Between Crosses')
  111. ax1.set_xlabel('Trading Days')
  112. ax1.set_ylabel('Frequency')
  113. ax1.legend()
  114. # Plot price change distribution
  115. for cross_type in ['Upward', 'Downward']:
  116. data = results[results['cross_type'] == cross_type]['price_change_pct']
  117. ax2.hist(data, bins=30, alpha=0.5, label=cross_type)
  118. ax2.set_title('Distribution of Price Changes')
  119. ax2.set_xlabel('Price Change (%)')
  120. ax2.set_ylabel('Frequency')
  121. ax2.legend()
  122. plt.tight_layout()
  123. plt.show()
  124. def main():
  125. # Set analysis period (e.g., last 2 years)
  126. current_year = datetime.datetime.now().year
  127. results = analyze_all_futures(current_year - 1, current_year)
  128. if not results.empty:
  129. # Generate and display statistics
  130. stats = generate_statistics(results)
  131. print("\nSummary Statistics:")
  132. print(stats)
  133. # Plot distributions
  134. plot_results(results)
  135. # Export results to CSV
  136. results.to_csv('ma_cross_analysis_results.csv', index=False)
  137. stats.to_csv('ma_cross_analysis_stats.csv')
  138. if __name__ == "__main__":
  139. main()