machine_learning.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. # 机器学习框架
  2. # https://www.joinquant.com/view/community/detail/7253d06b938dc84af3e0c3c996d7d5bd?type=1
  3. # 1. 数据集制作
  4. from jqdata import *
  5. from jqlib.technical_analysis import *
  6. from jqfactor import get_factor_values, winsorize_med, standardlize, neutralize
  7. import datetime
  8. import pandas as pd
  9. import numpy as np
  10. from scipy import stats
  11. import statsmodels.api as sm
  12. from statsmodels import regression
  13. from six import StringIO
  14. from sklearn.decomposition import PCA
  15. from sklearn import svm
  16. from sklearn.model_selection import train_test_split
  17. from sklearn.grid_search import GridSearchCV
  18. from sklearn import metrics
  19. from tqdm import tqdm
  20. import matplotlib.dates as mdates
  21. import matplotlib.pyplot as plt
  22. import warnings
  23. import seaborn as sns
  24. import pickle
  25. warnings.filterwarnings("ignore")
  26. import pandas as pd
  27. import numpy as np
  28. import matplotlib.pyplot as plt
  29. import seaborn as sns
  30. from sklearn.metrics import (accuracy_score, precision_score, recall_score,
  31. f1_score, roc_auc_score, confusion_matrix,
  32. roc_curve, precision_recall_curve, auc, classification_report)
  33. import lightgbm as lgb
  34. jqfactors_list=['asset_impairment_loss_ttm', 'cash_flow_to_price_ratio', 'market_cap', 'interest_free_current_liability', 'EBITDA', 'financial_assets', 'gross_profit_ttm', 'net_working_capital', 'non_recurring_gain_loss', 'EBIT', 'sales_to_price_ratio', 'AR', 'ARBR', 'ATR6', 'DAVOL10', 'MAWVAD', 'TVMA6', 'PSY', 'VOL10', 'VDIFF', 'VEMA26', 'VMACD', 'VOL120', 'VOSC', 'VR', 'WVAD', 'arron_down_25', 'arron_up_25', 'BBIC', 'MASS', 'Rank1M', 'single_day_VPT', 'single_day_VPT_12', 'single_day_VPT_6', 'Volume1M', 'capital_reserve_fund_per_share', 'net_asset_per_share', 'net_operate_cash_flow_per_share', 'operating_profit_per_share', 'total_operating_revenue_per_share', 'surplus_reserve_fund_per_share', 'ACCA', 'account_receivable_turnover_days', 'account_receivable_turnover_rate', 'adjusted_profit_to_total_profit', 'super_quick_ratio', 'MLEV', 'debt_to_equity_ratio', 'debt_to_tangible_equity_ratio', 'equity_to_fixed_asset_ratio', 'fixed_asset_ratio', 'intangible_asset_ratio', 'invest_income_associates_to_total_profit', 'long_debt_to_asset_ratio', 'long_debt_to_working_capital_ratio', 'net_operate_cash_flow_to_total_liability', 'net_operating_cash_flow_coverage', 'non_current_asset_ratio', 'operating_profit_to_total_profit', 'roa_ttm', 'roe_ttm', 'Kurtosis120', 'Kurtosis20', 'Kurtosis60', 'sharpe_ratio_20', 'sharpe_ratio_60', 'Skewness120', 'Skewness20', 'Skewness60', 'Variance120', 'Variance20', 'liquidity', 'beta', 'book_to_price_ratio', 'cash_earnings_to_price_ratio', 'cube_of_size', 'earnings_to_price_ratio', 'earnings_yield', 'growth', 'momentum', 'natural_log_of_market_cap', 'boll_down', 'MFI14', 'price_no_fq']
  35. print(len(jqfactors_list))
  36. def get_period_date(peroid, start_date, end_date):
  37. stock_data = get_price('000001.XSHE', start_date, end_date, 'daily', fields=['close'])
  38. stock_data['date'] = stock_data.index
  39. period_stock_data = stock_data.resample(peroid, how='last')
  40. period_stock_data = period_stock_data.set_index('date').dropna()
  41. date = period_stock_data.index
  42. pydate_array = date.to_pydatetime()
  43. date_only_array = np.vectorize(lambda s: s.strftime('%Y-%m-%d'))(pydate_array)
  44. date_only_series = pd.Series(date_only_array)
  45. start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d")
  46. start_date = start_date - datetime.timedelta(days=1)
  47. start_date = start_date.strftime("%Y-%m-%d")
  48. date_list = date_only_series.values.tolist()
  49. date_list.insert(0, start_date)
  50. return date_list
  51. peroid = 'M'
  52. start_date = '2019-01-01'
  53. end_date = '2024-01-01'
  54. DAY = get_period_date(peroid, start_date, end_date)
  55. print(len(DAY))
  56. def delect_stop(stocks, beginDate, n=30 * 3):
  57. stockList = []
  58. beginDate = datetime.datetime.strptime(beginDate, "%Y-%m-%d")
  59. for stock in stocks:
  60. start_date = get_security_info(stock).start_date
  61. if start_date < (beginDate - datetime.timedelta(days=n)).date():
  62. stockList.append(stock)
  63. return stockList
  64. def get_stock(stockPool, begin_date):
  65. if stockPool == 'HS300':
  66. stockList = get_index_stocks('000300.XSHG', begin_date)
  67. elif stockPool == 'ZZ500':
  68. stockList = get_index_stocks('399905.XSHE', begin_date)
  69. elif stockPool == 'ZZ800':
  70. stockList = get_index_stocks('399906.XSHE', begin_date)
  71. elif stockPool == 'CYBZ':
  72. stockList = get_index_stocks('399006.XSHE', begin_date)
  73. elif stockPool == 'ZXBZ':
  74. stockList = get_index_stocks('399101.XSHE', begin_date)
  75. elif stockPool == 'A':
  76. stockList = get_index_stocks('000002.XSHG', begin_date) + get_index_stocks('399107.XSHE', begin_date)
  77. stockList = [stock for stock in stockList if not stock.startswith(('68', '4', '8'))]
  78. elif stockPool == 'AA':
  79. stockList = get_index_stocks('000985.XSHG', begin_date)
  80. stockList = [stock for stock in stockList if not stock.startswith(('3', '68', '4', '8'))]
  81. st_data = get_extras('is_st', stockList, count=1, end_date=begin_date)
  82. stockList = [stock for stock in stockList if not st_data[stock][0]]
  83. stockList = delect_stop(stockList, begin_date)
  84. return stockList
  85. def get_factor_data(securities_list, date):
  86. factor_data = get_factor_values(securities=securities_list, factors=jqfactors_list, count=1, end_date=date)
  87. df_jq_factor = pd.DataFrame(index=securities_list)
  88. for i in factor_data.keys():
  89. df_jq_factor[i] = factor_data[i].iloc[0, :]
  90. return df_jq_factor
  91. dateList = get_period_date(peroid, start_date, end_date)
  92. DF = pd.DataFrame()
  93. for date in tqdm(dateList[:-1]):
  94. stockList = get_stock('ZXBZ', date)
  95. factor_origl_data = get_factor_data(stockList, date)
  96. data_close = get_price(stockList, date, dateList[dateList.index(date) + 1], '1d', 'close')['close']
  97. factor_origl_data['pchg'] = data_close.iloc[-1] / data_close.iloc[1] - 1
  98. factor_origl_data = factor_origl_data.dropna()
  99. median_pchg = factor_origl_data['pchg'].median()
  100. factor_origl_data['label'] = np.where(factor_origl_data['pchg'] >= median_pchg, 1, 0)
  101. factor_origl_data = factor_origl_data.drop(columns=['pchg'])
  102. DF = pd.concat([DF, factor_origl_data], ignore_index=True)
  103. DF.to_csv(r'train_small.csv', index=False)
  104. # 2. 数据分析
  105. df = pd.read_csv(r'train_small.csv')
  106. plot_cols = jqfactors_list
  107. print(len(plot_cols))
  108. corr_matrix = df.corr()
  109. plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
  110. label_corr = corr_matrix['label'].sort_values(ascending=False)
  111. plt.figure(figsize=(12, 10))
  112. corr = df[plot_cols].corr()
  113. mask = np.triu(np.ones_like(corr, dtype=bool))
  114. sns.heatmap(corr, mask=mask, annot=False, fmt=".2f", cmap='RdBu_r', vmin=-1, vmax=1)
  115. plt.title('因子间相关性矩阵')
  116. plt.show()
  117. plt.figure(figsize=(14, 100))
  118. for i, col in enumerate(plot_cols, 1):
  119. plt.subplot(42, 2, i)
  120. sns.distplot(df[col].dropna(), bins=30, color='skyblue', kde=True)
  121. stats_text = f"均值: {df[col].mean():.2f}\n中位数: {df[col].median():.2f}\n标准差: {df[col].std():.2f}"
  122. plt.gca().text(0.95, 0.95, stats_text, transform=plt.gca().transAxes,
  123. verticalalignment='top', horizontalalignment='right',
  124. bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
  125. plt.title(f'{col} 分布')
  126. plt.xlabel('')
  127. plt.tight_layout()
  128. plt.show()
  129. # 3. 数据预处理
  130. from collections import defaultdict
  131. # 计算每个特征的缺失值数量
  132. missing_counts = df[plot_cols].isnull().sum().to_dict()
  133. # 计算特征间的相关系数矩阵
  134. corr_matrix = df[plot_cols].corr()
  135. # 创建图结构存储高度相关的特征对
  136. graph = defaultdict(list)
  137. threshold = 0.6 # 相关性阈值
  138. # 遍历上三角矩阵找到高度相关的特征对
  139. n = len(plot_cols)
  140. for i in range(n):
  141. for j in range(i + 1, n):
  142. col1, col2 = plot_cols[i], plot_cols[j]
  143. corr_value = corr_matrix.iloc[i, j]
  144. if not pd.isna(corr_value) and abs(corr_value) > threshold:
  145. graph[col1].append(col2)
  146. graph[col2].append(col1)
  147. # 使用DFS找到连通分量(高度相关的特征组)
  148. visited = set()
  149. components = []
  150. def dfs(node, comp):
  151. visited.add(node)
  152. comp.append(node)
  153. for neighbor in graph[node]:
  154. if neighbor not in visited:
  155. dfs(neighbor, comp)
  156. for col in plot_cols:
  157. if col not in visited:
  158. comp = []
  159. dfs(col, comp)
  160. components.append(comp)
  161. # 处理每个连通分量:保留缺失值最少的特征
  162. to_keep = []
  163. to_remove = []
  164. for comp in components:
  165. if len(comp) == 1: # 独立特征直接保留
  166. to_keep.append(comp[0])
  167. else:
  168. # 按缺失值数量排序(升序),相同缺失值时按特征名字母顺序排序
  169. comp_sorted = sorted(comp, key=lambda x: (missing_counts[x], x))
  170. keep_feature = comp_sorted[0] # 缺失值最少的特征
  171. to_keep.append(keep_feature)
  172. to_remove.extend(comp_sorted[1:]) # 组内其他特征移除
  173. print(f"\n最终保留特征数量: {len(to_keep)}")
  174. print(f"移除特征数量: {len(to_remove)}")
  175. print("\n移除的特征列表:", to_remove)
  176. print("\n保留的特征列表:", to_keep)
  177. # 可视化保留特征的相关矩阵(可选)
  178. plt.figure(figsize=(12, 10))
  179. corr_kept = df[to_keep].corr()
  180. mask = np.triu(np.ones_like(corr_kept, dtype=bool))
  181. sns.heatmap(corr_kept, mask=mask, annot=True, fmt=".2f", cmap='RdBu_r', vmin=-1, vmax=1)
  182. plt.title('保留特征间相关性矩阵')
  183. plt.show()
  184. # 4. 训练模型
  185. X = df[to_keep]
  186. y = df['label']
  187. lgb_train = lgb.Dataset(X, label=y)
  188. params = {
  189. 'objective': 'binary',
  190. 'metric': 'binary_logloss',
  191. 'boosting_type': 'gbdt',
  192. 'verbose': -1
  193. }
  194. model = lgb.train(params, lgb_train, num_boost_round=200)
  195. y_pred_proba = model.predict(X)
  196. y_pred = (y_pred_proba > 0.5).astype(int)
  197. precision, recall, _ = precision_recall_curve(y, y_pred_proba)
  198. prauc = auc(recall, precision)
  199. print("\n模型性能评估:")
  200. print("准确率 (Accuracy):", accuracy_score(y, y_pred))
  201. print("精确率 (Precision):", precision_score(y, y_pred))
  202. print("召回率 (Recall):", recall_score(y, y_pred))
  203. print("F1分数 (F1-score):", f1_score(y, y_pred))
  204. print("AUC分数:", roc_auc_score(y, y_pred_proba))
  205. print("PRAUC分数:", prauc) # 新增PRAUC输出
  206. plt.figure(figsize=(15, 12))
  207. plt.subplot(2, 2, 1)
  208. cm = confusion_matrix(y, y_pred)
  209. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
  210. xticklabels=['预测0', '预测1'],
  211. yticklabels=['实际0', '实际1'])
  212. plt.title('混淆矩阵')
  213. plt.ylabel('实际标签')
  214. plt.xlabel('预测标签')
  215. plt.subplot(2, 2, 2)
  216. fpr, tpr, _ = roc_curve(y, y_pred_proba)
  217. roc_auc = roc_auc_score(y, y_pred_proba)
  218. plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'AUC = {roc_auc:.3f}')
  219. plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
  220. plt.xlim([0.0, 1.0])
  221. plt.ylim([0.0, 1.05])
  222. plt.xlabel('假正率 (FPR)')
  223. plt.ylabel('真正率 (TPR)')
  224. plt.title('ROC曲线')
  225. plt.legend(loc="lower right")
  226. plt.subplot(2, 2, 3)
  227. importance = pd.Series(model.feature_importance(), index=to_keep)
  228. importance.sort_values().plot(kind='barh')
  229. plt.title('特征重要性')
  230. plt.xlabel('重要性分数')
  231. plt.ylabel('特征')
  232. plt.subplot(2, 2, 4)
  233. for label in [0, 1]:
  234. sns.kdeplot(y_pred_proba[y == label], label=f'真实标签={label}', shade=True)
  235. plt.title('预测概率分布')
  236. plt.xlabel('预测为正类的概率')
  237. plt.ylabel('密度')
  238. plt.legend()
  239. plt.axvline(0.5, color='red', linestyle='--')
  240. plt.tight_layout()
  241. plt.show()
  242. plt.figure(figsize=(12, 5))
  243. plt.subplot(1, 2, 1)
  244. plt.plot(recall, precision, color='darkblue', lw=2, label=f'PRAUC = {prauc:.3f}')
  245. plt.fill_between(recall, precision, alpha=0.2, color='darkblue')
  246. plt.xlabel('召回率 (Recall)')
  247. plt.ylabel('精确率 (Precision)')
  248. plt.title('PRAUC曲线')
  249. plt.legend(loc='upper right')
  250. plt.grid(True)
  251. plt.tight_layout()
  252. plt.show()
  253. with open('model_small.pkl', 'wb') as model_file:
  254. pickle.dump(model, model_file)
  255. # 5. 回测代码
  256. from jqdata import *
  257. from jqfactor import *
  258. import numpy as np
  259. import pandas as pd
  260. import pickle
  261. # 初始化函数
  262. def initialize(context):
  263. # 设定基准
  264. set_benchmark('399101.XSHE')
  265. # 用真实价格交易
  266. set_option('use_real_price', True)
  267. # 打开防未来函数
  268. set_option("avoid_future_data", True)
  269. # 将滑点设置为0
  270. set_slippage(FixedSlippage(0))
  271. # 设置交易成本万分之三,不同滑点影响可在归因分析中查看
  272. set_order_cost(OrderCost(open_tax=0, close_tax=0.001, open_commission=0.0003, close_commission=0.0003,
  273. close_today_commission=0, min_commission=5), type='stock')
  274. # 过滤order中低于error级别的日志
  275. log.set_level('order', 'error')
  276. # 初始化全局变量
  277. g.stock_num = 10
  278. g.hold_list = [] # 当前持仓的全部股票
  279. g.yesterday_HL_list = [] # 记录持仓中昨日涨停的股票
  280. g.model_small = pickle.loads(read_file('model_small.pkl'))
  281. # 因子列表
  282. g.factor_list = ['asset_impairment_loss_ttm', 'cash_flow_to_price_ratio', 'EBIT', 'net_working_capital', 'non_recurring_gain_loss', 'sales_to_price_ratio', 'AR', 'ARBR', 'ATR6', 'DAVOL10', 'MAWVAD', 'TVMA6', 'PSY', 'VOL10', 'VDIFF', 'VEMA26', 'VMACD', 'VOL120', 'VOSC', 'VR', 'arron_down_25', 'arron_up_25', 'BBIC', 'MASS', 'Rank1M', 'single_day_VPT', 'single_day_VPT_12', 'Volume1M', 'capital_reserve_fund_per_share', 'net_operate_cash_flow_per_share', 'operating_profit_per_share', 'total_operating_revenue_per_share', 'surplus_reserve_fund_per_share', 'ACCA', 'account_receivable_turnover_days', 'account_receivable_turnover_rate', 'adjusted_profit_to_total_profit', 'super_quick_ratio', 'MLEV', 'debt_to_equity_ratio', 'debt_to_tangible_equity_ratio', 'equity_to_fixed_asset_ratio', 'fixed_asset_ratio', 'intangible_asset_ratio', 'invest_income_associates_to_total_profit', 'long_debt_to_asset_ratio', 'long_debt_to_working_capital_ratio', 'net_operate_cash_flow_to_total_liability', 'net_operating_cash_flow_coverage', 'non_current_asset_ratio', 'operating_profit_to_total_profit', 'roa_ttm', 'Kurtosis120', 'Kurtosis20', 'Kurtosis60', 'sharpe_ratio_20', 'sharpe_ratio_60', 'Skewness120', 'Skewness20', 'Skewness60', 'Variance120', 'Variance20', 'beta', 'book_to_price_ratio', 'cash_earnings_to_price_ratio', 'cube_of_size', 'earnings_to_price_ratio', 'earnings_yield', 'growth', 'momentum', 'natural_log_of_market_cap', 'boll_down', 'MFI14', 'price_no_fq']
  283. run_daily(prepare_stock_list, '9:05')
  284. run_monthly(weekly_adjustment, 1, '9:30')
  285. run_daily(check_limit_up, '14:00')
  286. # 1-1 准备股票池
  287. def prepare_stock_list(context):
  288. # 获取已持有列表
  289. g.hold_list = []
  290. for position in list(context.portfolio.positions.values()):
  291. stock = position.security
  292. g.hold_list.append(stock)
  293. # 获取昨日涨停列表
  294. if g.hold_list != []:
  295. df = get_price(g.hold_list, end_date=context.previous_date, frequency='daily', fields=['close', 'high_limit'],
  296. count=1, panel=False, fill_paused=False)
  297. df = df[df['close'] == df['high_limit']]
  298. g.yesterday_HL_list = list(df.code)
  299. else:
  300. g.yesterday_HL_list = []
  301. # 1-2 选股模块
  302. def get_stock_list(context):
  303. yesterday = context.previous_date
  304. today = context.current_dt
  305. stocks = get_index_stocks('399101.XSHE', yesterday)
  306. initial_list = filter_kcbj_stock(stocks)
  307. initial_list = filter_st_stock(initial_list)
  308. initial_list = filter_paused_stock(initial_list)
  309. initial_list = filter_new_stock(context, initial_list)
  310. initial_list = filter_limitup_stock(context,initial_list)
  311. initial_list = filter_limitdown_stock(context,initial_list)
  312. factor_data = get_factor_values(initial_list, g.factor_list, end_date=yesterday, count=1)
  313. df_jq_factor_value = pd.DataFrame(index=initial_list, columns=g.factor_list)
  314. for factor in g.factor_list:
  315. df_jq_factor_value[factor] = list(factor_data[factor].T.iloc[:, 0])
  316. tar = g.model_small.predict(df_jq_factor_value)
  317. df = df_jq_factor_value
  318. df['total_score'] = list(tar)
  319. df = df.sort_values(by=['total_score'], ascending=False)
  320. lst = df.index.tolist()
  321. lst = lst[:min(g.stock_num, len(lst))]
  322. return lst
  323. # 1-3 整体调整持仓
  324. def weekly_adjustment(context):
  325. # 获取应买入列表
  326. target_list = get_stock_list(context)
  327. # 调仓卖出
  328. for stock in g.hold_list:
  329. if (stock not in target_list) and (stock not in g.yesterday_HL_list):
  330. log.info("卖出[%s]" % (stock))
  331. position = context.portfolio.positions[stock]
  332. close_position(position)
  333. else:
  334. log.info("已持有[%s]" % (stock))
  335. # 调仓买入
  336. position_count = len(context.portfolio.positions)
  337. target_num = len(target_list)
  338. if target_num > position_count:
  339. value = context.portfolio.cash / (target_num - position_count)
  340. for stock in target_list:
  341. if context.portfolio.positions[stock].total_amount == 0:
  342. if open_position(stock, value):
  343. if len(context.portfolio.positions) == target_num:
  344. break
  345. # 1-4 调整昨日涨停股票
  346. def check_limit_up(context):
  347. now_time = context.current_dt
  348. if g.yesterday_HL_list != []:
  349. # 对昨日涨停股票观察到尾盘如不涨停则提前卖出,如果涨停即使不在应买入列表仍暂时持有
  350. for stock in g.yesterday_HL_list:
  351. current_data = get_price(stock, end_date=now_time, frequency='1m', fields=['close', 'high_limit'],
  352. skip_paused=False, fq='pre', count=1, panel=False, fill_paused=True)
  353. if current_data.iloc[0, 0] < current_data.iloc[0, 1]:
  354. log.info("[%s]涨停打开,卖出" % (stock))
  355. position = context.portfolio.positions[stock]
  356. close_position(position)
  357. else:
  358. log.info("[%s]涨停,继续持有" % (stock))
  359. # 3-1 交易模块-自定义下单
  360. def order_target_value_(security, value):
  361. if value == 0:
  362. log.debug("Selling out %s" % (security))
  363. else:
  364. log.debug("Order %s to value %f" % (security, value))
  365. return order_target_value(security, value)
  366. # 3-2 交易模块-开仓
  367. def open_position(security, value):
  368. order = order_target_value_(security, value)
  369. if order != None and order.filled > 0:
  370. return True
  371. return False
  372. # 3-3 交易模块-平仓
  373. def close_position(position):
  374. security = position.security
  375. order = order_target_value_(security, 0) # 可能会因停牌失败
  376. if order != None:
  377. if order.status == OrderStatus.held and order.filled == order.amount:
  378. return True
  379. return False
  380. # 2-1 过滤停牌股票
  381. def filter_paused_stock(stock_list):
  382. current_data = get_current_data()
  383. return [stock for stock in stock_list if not current_data[stock].paused]
  384. # 2-2 过滤ST及其他具有退市标签的股票
  385. def filter_st_stock(stock_list):
  386. current_data = get_current_data()
  387. return [stock for stock in stock_list
  388. if not current_data[stock].is_st
  389. and 'ST' not in current_data[stock].name
  390. and '*' not in current_data[stock].name
  391. and '退' not in current_data[stock].name]
  392. # 2-3 过滤科创北交股票
  393. def filter_kcbj_stock(stock_list):
  394. for stock in stock_list[:]:
  395. if stock[0] == '4' or stock[0] == '8' or stock[:2] == '68' or stock[0] == '3':
  396. stock_list.remove(stock)
  397. return stock_list
  398. # 2-4 过滤涨停的股票
  399. def filter_limitup_stock(context, stock_list):
  400. last_prices = history(1, unit='1m', field='close', security_list=stock_list)
  401. current_data = get_current_data()
  402. return [stock for stock in stock_list if stock in context.portfolio.positions.keys()
  403. or last_prices[stock][-1] < current_data[stock].high_limit]
  404. # 2-5 过滤跌停的股票
  405. def filter_limitdown_stock(context, stock_list):
  406. last_prices = history(1, unit='1m', field='close', security_list=stock_list)
  407. current_data = get_current_data()
  408. return [stock for stock in stock_list if stock in context.portfolio.positions.keys()
  409. or last_prices[stock][-1] > current_data[stock].low_limit]
  410. # 2-6 过滤次新股
  411. def filter_new_stock(context, stock_list):
  412. yesterday = context.previous_date
  413. return [stock for stock in stock_list if
  414. not yesterday - get_security_info(stock).start_date < datetime.timedelta(days=375)]