| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 |
- # 机器学习框架
- # https://www.joinquant.com/view/community/detail/7253d06b938dc84af3e0c3c996d7d5bd?type=1
- # 1. 数据集制作
- from jqdata import *
- from jqlib.technical_analysis import *
- from jqfactor import get_factor_values, winsorize_med, standardlize, neutralize
- import datetime
- import pandas as pd
- import numpy as np
- from scipy import stats
- import statsmodels.api as sm
- from statsmodels import regression
- from six import StringIO
- from sklearn.decomposition import PCA
- from sklearn import svm
- from sklearn.model_selection import train_test_split
- from sklearn.grid_search import GridSearchCV
- from sklearn import metrics
- from tqdm import tqdm
- import matplotlib.dates as mdates
- import matplotlib.pyplot as plt
- import warnings
- import seaborn as sns
- import pickle
- warnings.filterwarnings("ignore")
- import pandas as pd
- import numpy as np
- import matplotlib.pyplot as plt
- import seaborn as sns
- from sklearn.metrics import (accuracy_score, precision_score, recall_score,
- f1_score, roc_auc_score, confusion_matrix,
- roc_curve, precision_recall_curve, auc, classification_report)
- import lightgbm as lgb
- jqfactors_list=['asset_impairment_loss_ttm', 'cash_flow_to_price_ratio', 'market_cap', 'interest_free_current_liability', 'EBITDA', 'financial_assets', 'gross_profit_ttm', 'net_working_capital', 'non_recurring_gain_loss', 'EBIT', 'sales_to_price_ratio', 'AR', 'ARBR', 'ATR6', 'DAVOL10', 'MAWVAD', 'TVMA6', 'PSY', 'VOL10', 'VDIFF', 'VEMA26', 'VMACD', 'VOL120', 'VOSC', 'VR', 'WVAD', 'arron_down_25', 'arron_up_25', 'BBIC', 'MASS', 'Rank1M', 'single_day_VPT', 'single_day_VPT_12', 'single_day_VPT_6', 'Volume1M', 'capital_reserve_fund_per_share', 'net_asset_per_share', 'net_operate_cash_flow_per_share', 'operating_profit_per_share', 'total_operating_revenue_per_share', 'surplus_reserve_fund_per_share', 'ACCA', 'account_receivable_turnover_days', 'account_receivable_turnover_rate', 'adjusted_profit_to_total_profit', 'super_quick_ratio', 'MLEV', 'debt_to_equity_ratio', 'debt_to_tangible_equity_ratio', 'equity_to_fixed_asset_ratio', 'fixed_asset_ratio', 'intangible_asset_ratio', 'invest_income_associates_to_total_profit', 'long_debt_to_asset_ratio', 'long_debt_to_working_capital_ratio', 'net_operate_cash_flow_to_total_liability', 'net_operating_cash_flow_coverage', 'non_current_asset_ratio', 'operating_profit_to_total_profit', 'roa_ttm', 'roe_ttm', 'Kurtosis120', 'Kurtosis20', 'Kurtosis60', 'sharpe_ratio_20', 'sharpe_ratio_60', 'Skewness120', 'Skewness20', 'Skewness60', 'Variance120', 'Variance20', 'liquidity', 'beta', 'book_to_price_ratio', 'cash_earnings_to_price_ratio', 'cube_of_size', 'earnings_to_price_ratio', 'earnings_yield', 'growth', 'momentum', 'natural_log_of_market_cap', 'boll_down', 'MFI14', 'price_no_fq']
- print(len(jqfactors_list))
- def get_period_date(peroid, start_date, end_date):
- stock_data = get_price('000001.XSHE', start_date, end_date, 'daily', fields=['close'])
- stock_data['date'] = stock_data.index
- period_stock_data = stock_data.resample(peroid, how='last')
- period_stock_data = period_stock_data.set_index('date').dropna()
- date = period_stock_data.index
- pydate_array = date.to_pydatetime()
- date_only_array = np.vectorize(lambda s: s.strftime('%Y-%m-%d'))(pydate_array)
- date_only_series = pd.Series(date_only_array)
- start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d")
- start_date = start_date - datetime.timedelta(days=1)
- start_date = start_date.strftime("%Y-%m-%d")
- date_list = date_only_series.values.tolist()
- date_list.insert(0, start_date)
- return date_list
- peroid = 'M'
- start_date = '2019-01-01'
- end_date = '2024-01-01'
- DAY = get_period_date(peroid, start_date, end_date)
- print(len(DAY))
- def delect_stop(stocks, beginDate, n=30 * 3):
- stockList = []
- beginDate = datetime.datetime.strptime(beginDate, "%Y-%m-%d")
- for stock in stocks:
- start_date = get_security_info(stock).start_date
- if start_date < (beginDate - datetime.timedelta(days=n)).date():
- stockList.append(stock)
- return stockList
- def get_stock(stockPool, begin_date):
- if stockPool == 'HS300':
- stockList = get_index_stocks('000300.XSHG', begin_date)
- elif stockPool == 'ZZ500':
- stockList = get_index_stocks('399905.XSHE', begin_date)
- elif stockPool == 'ZZ800':
- stockList = get_index_stocks('399906.XSHE', begin_date)
- elif stockPool == 'CYBZ':
- stockList = get_index_stocks('399006.XSHE', begin_date)
- elif stockPool == 'ZXBZ':
- stockList = get_index_stocks('399101.XSHE', begin_date)
- elif stockPool == 'A':
- stockList = get_index_stocks('000002.XSHG', begin_date) + get_index_stocks('399107.XSHE', begin_date)
- stockList = [stock for stock in stockList if not stock.startswith(('68', '4', '8'))]
- elif stockPool == 'AA':
- stockList = get_index_stocks('000985.XSHG', begin_date)
- stockList = [stock for stock in stockList if not stock.startswith(('3', '68', '4', '8'))]
- st_data = get_extras('is_st', stockList, count=1, end_date=begin_date)
- stockList = [stock for stock in stockList if not st_data[stock][0]]
- stockList = delect_stop(stockList, begin_date)
- return stockList
- def get_factor_data(securities_list, date):
- factor_data = get_factor_values(securities=securities_list, factors=jqfactors_list, count=1, end_date=date)
- df_jq_factor = pd.DataFrame(index=securities_list)
- for i in factor_data.keys():
- df_jq_factor[i] = factor_data[i].iloc[0, :]
- return df_jq_factor
- dateList = get_period_date(peroid, start_date, end_date)
- DF = pd.DataFrame()
- for date in tqdm(dateList[:-1]):
- stockList = get_stock('ZXBZ', date)
- factor_origl_data = get_factor_data(stockList, date)
- data_close = get_price(stockList, date, dateList[dateList.index(date) + 1], '1d', 'close')['close']
- factor_origl_data['pchg'] = data_close.iloc[-1] / data_close.iloc[1] - 1
- factor_origl_data = factor_origl_data.dropna()
- median_pchg = factor_origl_data['pchg'].median()
- factor_origl_data['label'] = np.where(factor_origl_data['pchg'] >= median_pchg, 1, 0)
- factor_origl_data = factor_origl_data.drop(columns=['pchg'])
- DF = pd.concat([DF, factor_origl_data], ignore_index=True)
- DF.to_csv(r'train_small.csv', index=False)
- # 2. 数据分析
- df = pd.read_csv(r'train_small.csv')
- plot_cols = jqfactors_list
- print(len(plot_cols))
- corr_matrix = df.corr()
- plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
- label_corr = corr_matrix['label'].sort_values(ascending=False)
- plt.figure(figsize=(12, 10))
- corr = df[plot_cols].corr()
- mask = np.triu(np.ones_like(corr, dtype=bool))
- sns.heatmap(corr, mask=mask, annot=False, fmt=".2f", cmap='RdBu_r', vmin=-1, vmax=1)
- plt.title('因子间相关性矩阵')
- plt.show()
- plt.figure(figsize=(14, 100))
- for i, col in enumerate(plot_cols, 1):
- plt.subplot(42, 2, i)
- sns.distplot(df[col].dropna(), bins=30, color='skyblue', kde=True)
- stats_text = f"均值: {df[col].mean():.2f}\n中位数: {df[col].median():.2f}\n标准差: {df[col].std():.2f}"
- plt.gca().text(0.95, 0.95, stats_text, transform=plt.gca().transAxes,
- verticalalignment='top', horizontalalignment='right',
- bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
- plt.title(f'{col} 分布')
- plt.xlabel('')
- plt.tight_layout()
- plt.show()
- # 3. 数据预处理
- from collections import defaultdict
- # 计算每个特征的缺失值数量
- missing_counts = df[plot_cols].isnull().sum().to_dict()
- # 计算特征间的相关系数矩阵
- corr_matrix = df[plot_cols].corr()
- # 创建图结构存储高度相关的特征对
- graph = defaultdict(list)
- threshold = 0.6 # 相关性阈值
- # 遍历上三角矩阵找到高度相关的特征对
- n = len(plot_cols)
- for i in range(n):
- for j in range(i + 1, n):
- col1, col2 = plot_cols[i], plot_cols[j]
- corr_value = corr_matrix.iloc[i, j]
- if not pd.isna(corr_value) and abs(corr_value) > threshold:
- graph[col1].append(col2)
- graph[col2].append(col1)
- # 使用DFS找到连通分量(高度相关的特征组)
- visited = set()
- components = []
- def dfs(node, comp):
- visited.add(node)
- comp.append(node)
- for neighbor in graph[node]:
- if neighbor not in visited:
- dfs(neighbor, comp)
- for col in plot_cols:
- if col not in visited:
- comp = []
- dfs(col, comp)
- components.append(comp)
- # 处理每个连通分量:保留缺失值最少的特征
- to_keep = []
- to_remove = []
- for comp in components:
- if len(comp) == 1: # 独立特征直接保留
- to_keep.append(comp[0])
- else:
- # 按缺失值数量排序(升序),相同缺失值时按特征名字母顺序排序
- comp_sorted = sorted(comp, key=lambda x: (missing_counts[x], x))
- keep_feature = comp_sorted[0] # 缺失值最少的特征
- to_keep.append(keep_feature)
- to_remove.extend(comp_sorted[1:]) # 组内其他特征移除
- print(f"\n最终保留特征数量: {len(to_keep)}")
- print(f"移除特征数量: {len(to_remove)}")
- print("\n移除的特征列表:", to_remove)
- print("\n保留的特征列表:", to_keep)
- # 可视化保留特征的相关矩阵(可选)
- plt.figure(figsize=(12, 10))
- corr_kept = df[to_keep].corr()
- mask = np.triu(np.ones_like(corr_kept, dtype=bool))
- sns.heatmap(corr_kept, mask=mask, annot=True, fmt=".2f", cmap='RdBu_r', vmin=-1, vmax=1)
- plt.title('保留特征间相关性矩阵')
- plt.show()
- # 4. 训练模型
- X = df[to_keep]
- y = df['label']
- lgb_train = lgb.Dataset(X, label=y)
- params = {
- 'objective': 'binary',
- 'metric': 'binary_logloss',
- 'boosting_type': 'gbdt',
- 'verbose': -1
- }
- model = lgb.train(params, lgb_train, num_boost_round=200)
- y_pred_proba = model.predict(X)
- y_pred = (y_pred_proba > 0.5).astype(int)
- precision, recall, _ = precision_recall_curve(y, y_pred_proba)
- prauc = auc(recall, precision)
- print("\n模型性能评估:")
- print("准确率 (Accuracy):", accuracy_score(y, y_pred))
- print("精确率 (Precision):", precision_score(y, y_pred))
- print("召回率 (Recall):", recall_score(y, y_pred))
- print("F1分数 (F1-score):", f1_score(y, y_pred))
- print("AUC分数:", roc_auc_score(y, y_pred_proba))
- print("PRAUC分数:", prauc) # 新增PRAUC输出
- plt.figure(figsize=(15, 12))
- plt.subplot(2, 2, 1)
- cm = confusion_matrix(y, y_pred)
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
- xticklabels=['预测0', '预测1'],
- yticklabels=['实际0', '实际1'])
- plt.title('混淆矩阵')
- plt.ylabel('实际标签')
- plt.xlabel('预测标签')
- plt.subplot(2, 2, 2)
- fpr, tpr, _ = roc_curve(y, y_pred_proba)
- roc_auc = roc_auc_score(y, y_pred_proba)
- plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'AUC = {roc_auc:.3f}')
- plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
- plt.xlim([0.0, 1.0])
- plt.ylim([0.0, 1.05])
- plt.xlabel('假正率 (FPR)')
- plt.ylabel('真正率 (TPR)')
- plt.title('ROC曲线')
- plt.legend(loc="lower right")
- plt.subplot(2, 2, 3)
- importance = pd.Series(model.feature_importance(), index=to_keep)
- importance.sort_values().plot(kind='barh')
- plt.title('特征重要性')
- plt.xlabel('重要性分数')
- plt.ylabel('特征')
- plt.subplot(2, 2, 4)
- for label in [0, 1]:
- sns.kdeplot(y_pred_proba[y == label], label=f'真实标签={label}', shade=True)
- plt.title('预测概率分布')
- plt.xlabel('预测为正类的概率')
- plt.ylabel('密度')
- plt.legend()
- plt.axvline(0.5, color='red', linestyle='--')
- plt.tight_layout()
- plt.show()
- plt.figure(figsize=(12, 5))
- plt.subplot(1, 2, 1)
- plt.plot(recall, precision, color='darkblue', lw=2, label=f'PRAUC = {prauc:.3f}')
- plt.fill_between(recall, precision, alpha=0.2, color='darkblue')
- plt.xlabel('召回率 (Recall)')
- plt.ylabel('精确率 (Precision)')
- plt.title('PRAUC曲线')
- plt.legend(loc='upper right')
- plt.grid(True)
- plt.tight_layout()
- plt.show()
- with open('model_small.pkl', 'wb') as model_file:
- pickle.dump(model, model_file)
- # 5. 回测代码
- from jqdata import *
- from jqfactor import *
- import numpy as np
- import pandas as pd
- import pickle
- # 初始化函数
- def initialize(context):
- # 设定基准
- set_benchmark('399101.XSHE')
- # 用真实价格交易
- set_option('use_real_price', True)
- # 打开防未来函数
- set_option("avoid_future_data", True)
- # 将滑点设置为0
- set_slippage(FixedSlippage(0))
- # 设置交易成本万分之三,不同滑点影响可在归因分析中查看
- set_order_cost(OrderCost(open_tax=0, close_tax=0.001, open_commission=0.0003, close_commission=0.0003,
- close_today_commission=0, min_commission=5), type='stock')
- # 过滤order中低于error级别的日志
- log.set_level('order', 'error')
- # 初始化全局变量
- g.stock_num = 10
- g.hold_list = [] # 当前持仓的全部股票
- g.yesterday_HL_list = [] # 记录持仓中昨日涨停的股票
- g.model_small = pickle.loads(read_file('model_small.pkl'))
- # 因子列表
- g.factor_list = ['asset_impairment_loss_ttm', 'cash_flow_to_price_ratio', 'EBIT', 'net_working_capital', 'non_recurring_gain_loss', 'sales_to_price_ratio', 'AR', 'ARBR', 'ATR6', 'DAVOL10', 'MAWVAD', 'TVMA6', 'PSY', 'VOL10', 'VDIFF', 'VEMA26', 'VMACD', 'VOL120', 'VOSC', 'VR', 'arron_down_25', 'arron_up_25', 'BBIC', 'MASS', 'Rank1M', 'single_day_VPT', 'single_day_VPT_12', 'Volume1M', 'capital_reserve_fund_per_share', 'net_operate_cash_flow_per_share', 'operating_profit_per_share', 'total_operating_revenue_per_share', 'surplus_reserve_fund_per_share', 'ACCA', 'account_receivable_turnover_days', 'account_receivable_turnover_rate', 'adjusted_profit_to_total_profit', 'super_quick_ratio', 'MLEV', 'debt_to_equity_ratio', 'debt_to_tangible_equity_ratio', 'equity_to_fixed_asset_ratio', 'fixed_asset_ratio', 'intangible_asset_ratio', 'invest_income_associates_to_total_profit', 'long_debt_to_asset_ratio', 'long_debt_to_working_capital_ratio', 'net_operate_cash_flow_to_total_liability', 'net_operating_cash_flow_coverage', 'non_current_asset_ratio', 'operating_profit_to_total_profit', 'roa_ttm', 'Kurtosis120', 'Kurtosis20', 'Kurtosis60', 'sharpe_ratio_20', 'sharpe_ratio_60', 'Skewness120', 'Skewness20', 'Skewness60', 'Variance120', 'Variance20', 'beta', 'book_to_price_ratio', 'cash_earnings_to_price_ratio', 'cube_of_size', 'earnings_to_price_ratio', 'earnings_yield', 'growth', 'momentum', 'natural_log_of_market_cap', 'boll_down', 'MFI14', 'price_no_fq']
- run_daily(prepare_stock_list, '9:05')
- run_monthly(weekly_adjustment, 1, '9:30')
- run_daily(check_limit_up, '14:00')
- # 1-1 准备股票池
- def prepare_stock_list(context):
- # 获取已持有列表
- g.hold_list = []
- for position in list(context.portfolio.positions.values()):
- stock = position.security
- g.hold_list.append(stock)
- # 获取昨日涨停列表
- if g.hold_list != []:
- df = get_price(g.hold_list, end_date=context.previous_date, frequency='daily', fields=['close', 'high_limit'],
- count=1, panel=False, fill_paused=False)
- df = df[df['close'] == df['high_limit']]
- g.yesterday_HL_list = list(df.code)
- else:
- g.yesterday_HL_list = []
- # 1-2 选股模块
- def get_stock_list(context):
- yesterday = context.previous_date
- today = context.current_dt
- stocks = get_index_stocks('399101.XSHE', yesterday)
- initial_list = filter_kcbj_stock(stocks)
- initial_list = filter_st_stock(initial_list)
- initial_list = filter_paused_stock(initial_list)
- initial_list = filter_new_stock(context, initial_list)
- initial_list = filter_limitup_stock(context,initial_list)
- initial_list = filter_limitdown_stock(context,initial_list)
- factor_data = get_factor_values(initial_list, g.factor_list, end_date=yesterday, count=1)
- df_jq_factor_value = pd.DataFrame(index=initial_list, columns=g.factor_list)
- for factor in g.factor_list:
- df_jq_factor_value[factor] = list(factor_data[factor].T.iloc[:, 0])
- tar = g.model_small.predict(df_jq_factor_value)
- df = df_jq_factor_value
- df['total_score'] = list(tar)
- df = df.sort_values(by=['total_score'], ascending=False)
- lst = df.index.tolist()
- lst = lst[:min(g.stock_num, len(lst))]
- return lst
- # 1-3 整体调整持仓
- def weekly_adjustment(context):
- # 获取应买入列表
- target_list = get_stock_list(context)
- # 调仓卖出
- for stock in g.hold_list:
- if (stock not in target_list) and (stock not in g.yesterday_HL_list):
- log.info("卖出[%s]" % (stock))
- position = context.portfolio.positions[stock]
- close_position(position)
- else:
- log.info("已持有[%s]" % (stock))
- # 调仓买入
- position_count = len(context.portfolio.positions)
- target_num = len(target_list)
- if target_num > position_count:
- value = context.portfolio.cash / (target_num - position_count)
- for stock in target_list:
- if context.portfolio.positions[stock].total_amount == 0:
- if open_position(stock, value):
- if len(context.portfolio.positions) == target_num:
- break
- # 1-4 调整昨日涨停股票
- def check_limit_up(context):
- now_time = context.current_dt
- if g.yesterday_HL_list != []:
- # 对昨日涨停股票观察到尾盘如不涨停则提前卖出,如果涨停即使不在应买入列表仍暂时持有
- for stock in g.yesterday_HL_list:
- current_data = get_price(stock, end_date=now_time, frequency='1m', fields=['close', 'high_limit'],
- skip_paused=False, fq='pre', count=1, panel=False, fill_paused=True)
- if current_data.iloc[0, 0] < current_data.iloc[0, 1]:
- log.info("[%s]涨停打开,卖出" % (stock))
- position = context.portfolio.positions[stock]
- close_position(position)
- else:
- log.info("[%s]涨停,继续持有" % (stock))
- # 3-1 交易模块-自定义下单
- def order_target_value_(security, value):
- if value == 0:
- log.debug("Selling out %s" % (security))
- else:
- log.debug("Order %s to value %f" % (security, value))
- return order_target_value(security, value)
- # 3-2 交易模块-开仓
- def open_position(security, value):
- order = order_target_value_(security, value)
- if order != None and order.filled > 0:
- return True
- return False
- # 3-3 交易模块-平仓
- def close_position(position):
- security = position.security
- order = order_target_value_(security, 0) # 可能会因停牌失败
- if order != None:
- if order.status == OrderStatus.held and order.filled == order.amount:
- return True
- return False
- # 2-1 过滤停牌股票
- def filter_paused_stock(stock_list):
- current_data = get_current_data()
- return [stock for stock in stock_list if not current_data[stock].paused]
- # 2-2 过滤ST及其他具有退市标签的股票
- def filter_st_stock(stock_list):
- current_data = get_current_data()
- return [stock for stock in stock_list
- if not current_data[stock].is_st
- and 'ST' not in current_data[stock].name
- and '*' not in current_data[stock].name
- and '退' not in current_data[stock].name]
- # 2-3 过滤科创北交股票
- def filter_kcbj_stock(stock_list):
- for stock in stock_list[:]:
- if stock[0] == '4' or stock[0] == '8' or stock[:2] == '68' or stock[0] == '3':
- stock_list.remove(stock)
- return stock_list
- # 2-4 过滤涨停的股票
- def filter_limitup_stock(context, stock_list):
- last_prices = history(1, unit='1m', field='close', security_list=stock_list)
- current_data = get_current_data()
- return [stock for stock in stock_list if stock in context.portfolio.positions.keys()
- or last_prices[stock][-1] < current_data[stock].high_limit]
- # 2-5 过滤跌停的股票
- def filter_limitdown_stock(context, stock_list):
- last_prices = history(1, unit='1m', field='close', security_list=stock_list)
- current_data = get_current_data()
- return [stock for stock in stock_list if stock in context.portfolio.positions.keys()
- or last_prices[stock][-1] > current_data[stock].low_limit]
- # 2-6 过滤次新股
- def filter_new_stock(context, stock_list):
- yesterday = context.previous_date
- return [stock for stock in stock_list if
- not yesterday - get_security_info(stock).start_date < datetime.timedelta(days=375)]
|