# 机器学习框架 # https://www.joinquant.com/view/community/detail/7253d06b938dc84af3e0c3c996d7d5bd?type=1 # 1. 数据集制作 from jqdata import * from jqlib.technical_analysis import * from jqfactor import get_factor_values, winsorize_med, standardlize, neutralize import datetime import pandas as pd import numpy as np from scipy import stats import statsmodels.api as sm from statsmodels import regression from six import StringIO from sklearn.decomposition import PCA from sklearn import svm from sklearn.model_selection import train_test_split from sklearn.grid_search import GridSearchCV from sklearn import metrics from tqdm import tqdm import matplotlib.dates as mdates import matplotlib.pyplot as plt import warnings import seaborn as sns import pickle warnings.filterwarnings("ignore") import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix, roc_curve, precision_recall_curve, auc, classification_report) import lightgbm as lgb jqfactors_list=['asset_impairment_loss_ttm', 'cash_flow_to_price_ratio', 'market_cap', 'interest_free_current_liability', 'EBITDA', 'financial_assets', 'gross_profit_ttm', 'net_working_capital', 'non_recurring_gain_loss', 'EBIT', 'sales_to_price_ratio', 'AR', 'ARBR', 'ATR6', 'DAVOL10', 'MAWVAD', 'TVMA6', 'PSY', 'VOL10', 'VDIFF', 'VEMA26', 'VMACD', 'VOL120', 'VOSC', 'VR', 'WVAD', 'arron_down_25', 'arron_up_25', 'BBIC', 'MASS', 'Rank1M', 'single_day_VPT', 'single_day_VPT_12', 'single_day_VPT_6', 'Volume1M', 'capital_reserve_fund_per_share', 'net_asset_per_share', 'net_operate_cash_flow_per_share', 'operating_profit_per_share', 'total_operating_revenue_per_share', 'surplus_reserve_fund_per_share', 'ACCA', 'account_receivable_turnover_days', 'account_receivable_turnover_rate', 'adjusted_profit_to_total_profit', 'super_quick_ratio', 'MLEV', 'debt_to_equity_ratio', 'debt_to_tangible_equity_ratio', 'equity_to_fixed_asset_ratio', 'fixed_asset_ratio', 'intangible_asset_ratio', 'invest_income_associates_to_total_profit', 'long_debt_to_asset_ratio', 'long_debt_to_working_capital_ratio', 'net_operate_cash_flow_to_total_liability', 'net_operating_cash_flow_coverage', 'non_current_asset_ratio', 'operating_profit_to_total_profit', 'roa_ttm', 'roe_ttm', 'Kurtosis120', 'Kurtosis20', 'Kurtosis60', 'sharpe_ratio_20', 'sharpe_ratio_60', 'Skewness120', 'Skewness20', 'Skewness60', 'Variance120', 'Variance20', 'liquidity', 'beta', 'book_to_price_ratio', 'cash_earnings_to_price_ratio', 'cube_of_size', 'earnings_to_price_ratio', 'earnings_yield', 'growth', 'momentum', 'natural_log_of_market_cap', 'boll_down', 'MFI14', 'price_no_fq'] print(len(jqfactors_list)) def get_period_date(peroid, start_date, end_date): stock_data = get_price('000001.XSHE', start_date, end_date, 'daily', fields=['close']) stock_data['date'] = stock_data.index period_stock_data = stock_data.resample(peroid, how='last') period_stock_data = period_stock_data.set_index('date').dropna() date = period_stock_data.index pydate_array = date.to_pydatetime() date_only_array = np.vectorize(lambda s: s.strftime('%Y-%m-%d'))(pydate_array) date_only_series = pd.Series(date_only_array) start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") start_date = start_date - datetime.timedelta(days=1) start_date = start_date.strftime("%Y-%m-%d") date_list = date_only_series.values.tolist() date_list.insert(0, start_date) return date_list peroid = 'M' start_date = '2019-01-01' end_date = '2024-01-01' DAY = get_period_date(peroid, start_date, end_date) print(len(DAY)) def delect_stop(stocks, beginDate, n=30 * 3): stockList = [] beginDate = datetime.datetime.strptime(beginDate, "%Y-%m-%d") for stock in stocks: start_date = get_security_info(stock).start_date if start_date < (beginDate - datetime.timedelta(days=n)).date(): stockList.append(stock) return stockList def get_stock(stockPool, begin_date): if stockPool == 'HS300': stockList = get_index_stocks('000300.XSHG', begin_date) elif stockPool == 'ZZ500': stockList = get_index_stocks('399905.XSHE', begin_date) elif stockPool == 'ZZ800': stockList = get_index_stocks('399906.XSHE', begin_date) elif stockPool == 'CYBZ': stockList = get_index_stocks('399006.XSHE', begin_date) elif stockPool == 'ZXBZ': stockList = get_index_stocks('399101.XSHE', begin_date) elif stockPool == 'A': stockList = get_index_stocks('000002.XSHG', begin_date) + get_index_stocks('399107.XSHE', begin_date) stockList = [stock for stock in stockList if not stock.startswith(('68', '4', '8'))] elif stockPool == 'AA': stockList = get_index_stocks('000985.XSHG', begin_date) stockList = [stock for stock in stockList if not stock.startswith(('3', '68', '4', '8'))] st_data = get_extras('is_st', stockList, count=1, end_date=begin_date) stockList = [stock for stock in stockList if not st_data[stock][0]] stockList = delect_stop(stockList, begin_date) return stockList def get_factor_data(securities_list, date): factor_data = get_factor_values(securities=securities_list, factors=jqfactors_list, count=1, end_date=date) df_jq_factor = pd.DataFrame(index=securities_list) for i in factor_data.keys(): df_jq_factor[i] = factor_data[i].iloc[0, :] return df_jq_factor dateList = get_period_date(peroid, start_date, end_date) DF = pd.DataFrame() for date in tqdm(dateList[:-1]): stockList = get_stock('ZXBZ', date) factor_origl_data = get_factor_data(stockList, date) data_close = get_price(stockList, date, dateList[dateList.index(date) + 1], '1d', 'close')['close'] factor_origl_data['pchg'] = data_close.iloc[-1] / data_close.iloc[1] - 1 factor_origl_data = factor_origl_data.dropna() median_pchg = factor_origl_data['pchg'].median() factor_origl_data['label'] = np.where(factor_origl_data['pchg'] >= median_pchg, 1, 0) factor_origl_data = factor_origl_data.drop(columns=['pchg']) DF = pd.concat([DF, factor_origl_data], ignore_index=True) DF.to_csv(r'train_small.csv', index=False) # 2. 数据分析 df = pd.read_csv(r'train_small.csv') plot_cols = jqfactors_list print(len(plot_cols)) corr_matrix = df.corr() plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 label_corr = corr_matrix['label'].sort_values(ascending=False) plt.figure(figsize=(12, 10)) corr = df[plot_cols].corr() mask = np.triu(np.ones_like(corr, dtype=bool)) sns.heatmap(corr, mask=mask, annot=False, fmt=".2f", cmap='RdBu_r', vmin=-1, vmax=1) plt.title('因子间相关性矩阵') plt.show() plt.figure(figsize=(14, 100)) for i, col in enumerate(plot_cols, 1): plt.subplot(42, 2, i) sns.distplot(df[col].dropna(), bins=30, color='skyblue', kde=True) stats_text = f"均值: {df[col].mean():.2f}\n中位数: {df[col].median():.2f}\n标准差: {df[col].std():.2f}" plt.gca().text(0.95, 0.95, stats_text, transform=plt.gca().transAxes, verticalalignment='top', horizontalalignment='right', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) plt.title(f'{col} 分布') plt.xlabel('') plt.tight_layout() plt.show() # 3. 数据预处理 from collections import defaultdict # 计算每个特征的缺失值数量 missing_counts = df[plot_cols].isnull().sum().to_dict() # 计算特征间的相关系数矩阵 corr_matrix = df[plot_cols].corr() # 创建图结构存储高度相关的特征对 graph = defaultdict(list) threshold = 0.6 # 相关性阈值 # 遍历上三角矩阵找到高度相关的特征对 n = len(plot_cols) for i in range(n): for j in range(i + 1, n): col1, col2 = plot_cols[i], plot_cols[j] corr_value = corr_matrix.iloc[i, j] if not pd.isna(corr_value) and abs(corr_value) > threshold: graph[col1].append(col2) graph[col2].append(col1) # 使用DFS找到连通分量(高度相关的特征组) visited = set() components = [] def dfs(node, comp): visited.add(node) comp.append(node) for neighbor in graph[node]: if neighbor not in visited: dfs(neighbor, comp) for col in plot_cols: if col not in visited: comp = [] dfs(col, comp) components.append(comp) # 处理每个连通分量:保留缺失值最少的特征 to_keep = [] to_remove = [] for comp in components: if len(comp) == 1: # 独立特征直接保留 to_keep.append(comp[0]) else: # 按缺失值数量排序(升序),相同缺失值时按特征名字母顺序排序 comp_sorted = sorted(comp, key=lambda x: (missing_counts[x], x)) keep_feature = comp_sorted[0] # 缺失值最少的特征 to_keep.append(keep_feature) to_remove.extend(comp_sorted[1:]) # 组内其他特征移除 print(f"\n最终保留特征数量: {len(to_keep)}") print(f"移除特征数量: {len(to_remove)}") print("\n移除的特征列表:", to_remove) print("\n保留的特征列表:", to_keep) # 可视化保留特征的相关矩阵(可选) plt.figure(figsize=(12, 10)) corr_kept = df[to_keep].corr() mask = np.triu(np.ones_like(corr_kept, dtype=bool)) sns.heatmap(corr_kept, mask=mask, annot=True, fmt=".2f", cmap='RdBu_r', vmin=-1, vmax=1) plt.title('保留特征间相关性矩阵') plt.show() # 4. 训练模型 X = df[to_keep] y = df['label'] lgb_train = lgb.Dataset(X, label=y) params = { 'objective': 'binary', 'metric': 'binary_logloss', 'boosting_type': 'gbdt', 'verbose': -1 } model = lgb.train(params, lgb_train, num_boost_round=200) y_pred_proba = model.predict(X) y_pred = (y_pred_proba > 0.5).astype(int) precision, recall, _ = precision_recall_curve(y, y_pred_proba) prauc = auc(recall, precision) print("\n模型性能评估:") print("准确率 (Accuracy):", accuracy_score(y, y_pred)) print("精确率 (Precision):", precision_score(y, y_pred)) print("召回率 (Recall):", recall_score(y, y_pred)) print("F1分数 (F1-score):", f1_score(y, y_pred)) print("AUC分数:", roc_auc_score(y, y_pred_proba)) print("PRAUC分数:", prauc) # 新增PRAUC输出 plt.figure(figsize=(15, 12)) plt.subplot(2, 2, 1) cm = confusion_matrix(y, y_pred) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['预测0', '预测1'], yticklabels=['实际0', '实际1']) plt.title('混淆矩阵') plt.ylabel('实际标签') plt.xlabel('预测标签') plt.subplot(2, 2, 2) fpr, tpr, _ = roc_curve(y, y_pred_proba) roc_auc = roc_auc_score(y, y_pred_proba) plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'AUC = {roc_auc:.3f}') plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('假正率 (FPR)') plt.ylabel('真正率 (TPR)') plt.title('ROC曲线') plt.legend(loc="lower right") plt.subplot(2, 2, 3) importance = pd.Series(model.feature_importance(), index=to_keep) importance.sort_values().plot(kind='barh') plt.title('特征重要性') plt.xlabel('重要性分数') plt.ylabel('特征') plt.subplot(2, 2, 4) for label in [0, 1]: sns.kdeplot(y_pred_proba[y == label], label=f'真实标签={label}', shade=True) plt.title('预测概率分布') plt.xlabel('预测为正类的概率') plt.ylabel('密度') plt.legend() plt.axvline(0.5, color='red', linestyle='--') plt.tight_layout() plt.show() plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) plt.plot(recall, precision, color='darkblue', lw=2, label=f'PRAUC = {prauc:.3f}') plt.fill_between(recall, precision, alpha=0.2, color='darkblue') plt.xlabel('召回率 (Recall)') plt.ylabel('精确率 (Precision)') plt.title('PRAUC曲线') plt.legend(loc='upper right') plt.grid(True) plt.tight_layout() plt.show() with open('model_small.pkl', 'wb') as model_file: pickle.dump(model, model_file) # 5. 回测代码 from jqdata import * from jqfactor import * import numpy as np import pandas as pd import pickle # 初始化函数 def initialize(context): # 设定基准 set_benchmark('399101.XSHE') # 用真实价格交易 set_option('use_real_price', True) # 打开防未来函数 set_option("avoid_future_data", True) # 将滑点设置为0 set_slippage(FixedSlippage(0)) # 设置交易成本万分之三,不同滑点影响可在归因分析中查看 set_order_cost(OrderCost(open_tax=0, close_tax=0.001, open_commission=0.0003, close_commission=0.0003, close_today_commission=0, min_commission=5), type='stock') # 过滤order中低于error级别的日志 log.set_level('order', 'error') # 初始化全局变量 g.stock_num = 10 g.hold_list = [] # 当前持仓的全部股票 g.yesterday_HL_list = [] # 记录持仓中昨日涨停的股票 g.model_small = pickle.loads(read_file('model_small.pkl')) # 因子列表 g.factor_list = ['asset_impairment_loss_ttm', 'cash_flow_to_price_ratio', 'EBIT', 'net_working_capital', 'non_recurring_gain_loss', 'sales_to_price_ratio', 'AR', 'ARBR', 'ATR6', 'DAVOL10', 'MAWVAD', 'TVMA6', 'PSY', 'VOL10', 'VDIFF', 'VEMA26', 'VMACD', 'VOL120', 'VOSC', 'VR', 'arron_down_25', 'arron_up_25', 'BBIC', 'MASS', 'Rank1M', 'single_day_VPT', 'single_day_VPT_12', 'Volume1M', 'capital_reserve_fund_per_share', 'net_operate_cash_flow_per_share', 'operating_profit_per_share', 'total_operating_revenue_per_share', 'surplus_reserve_fund_per_share', 'ACCA', 'account_receivable_turnover_days', 'account_receivable_turnover_rate', 'adjusted_profit_to_total_profit', 'super_quick_ratio', 'MLEV', 'debt_to_equity_ratio', 'debt_to_tangible_equity_ratio', 'equity_to_fixed_asset_ratio', 'fixed_asset_ratio', 'intangible_asset_ratio', 'invest_income_associates_to_total_profit', 'long_debt_to_asset_ratio', 'long_debt_to_working_capital_ratio', 'net_operate_cash_flow_to_total_liability', 'net_operating_cash_flow_coverage', 'non_current_asset_ratio', 'operating_profit_to_total_profit', 'roa_ttm', 'Kurtosis120', 'Kurtosis20', 'Kurtosis60', 'sharpe_ratio_20', 'sharpe_ratio_60', 'Skewness120', 'Skewness20', 'Skewness60', 'Variance120', 'Variance20', 'beta', 'book_to_price_ratio', 'cash_earnings_to_price_ratio', 'cube_of_size', 'earnings_to_price_ratio', 'earnings_yield', 'growth', 'momentum', 'natural_log_of_market_cap', 'boll_down', 'MFI14', 'price_no_fq'] run_daily(prepare_stock_list, '9:05') run_monthly(weekly_adjustment, 1, '9:30') run_daily(check_limit_up, '14:00') # 1-1 准备股票池 def prepare_stock_list(context): # 获取已持有列表 g.hold_list = [] for position in list(context.portfolio.positions.values()): stock = position.security g.hold_list.append(stock) # 获取昨日涨停列表 if g.hold_list != []: df = get_price(g.hold_list, end_date=context.previous_date, frequency='daily', fields=['close', 'high_limit'], count=1, panel=False, fill_paused=False) df = df[df['close'] == df['high_limit']] g.yesterday_HL_list = list(df.code) else: g.yesterday_HL_list = [] # 1-2 选股模块 def get_stock_list(context): yesterday = context.previous_date today = context.current_dt stocks = get_index_stocks('399101.XSHE', yesterday) initial_list = filter_kcbj_stock(stocks) initial_list = filter_st_stock(initial_list) initial_list = filter_paused_stock(initial_list) initial_list = filter_new_stock(context, initial_list) initial_list = filter_limitup_stock(context,initial_list) initial_list = filter_limitdown_stock(context,initial_list) factor_data = get_factor_values(initial_list, g.factor_list, end_date=yesterday, count=1) df_jq_factor_value = pd.DataFrame(index=initial_list, columns=g.factor_list) for factor in g.factor_list: df_jq_factor_value[factor] = list(factor_data[factor].T.iloc[:, 0]) tar = g.model_small.predict(df_jq_factor_value) df = df_jq_factor_value df['total_score'] = list(tar) df = df.sort_values(by=['total_score'], ascending=False) lst = df.index.tolist() lst = lst[:min(g.stock_num, len(lst))] return lst # 1-3 整体调整持仓 def weekly_adjustment(context): # 获取应买入列表 target_list = get_stock_list(context) # 调仓卖出 for stock in g.hold_list: if (stock not in target_list) and (stock not in g.yesterday_HL_list): log.info("卖出[%s]" % (stock)) position = context.portfolio.positions[stock] close_position(position) else: log.info("已持有[%s]" % (stock)) # 调仓买入 position_count = len(context.portfolio.positions) target_num = len(target_list) if target_num > position_count: value = context.portfolio.cash / (target_num - position_count) for stock in target_list: if context.portfolio.positions[stock].total_amount == 0: if open_position(stock, value): if len(context.portfolio.positions) == target_num: break # 1-4 调整昨日涨停股票 def check_limit_up(context): now_time = context.current_dt if g.yesterday_HL_list != []: # 对昨日涨停股票观察到尾盘如不涨停则提前卖出,如果涨停即使不在应买入列表仍暂时持有 for stock in g.yesterday_HL_list: current_data = get_price(stock, end_date=now_time, frequency='1m', fields=['close', 'high_limit'], skip_paused=False, fq='pre', count=1, panel=False, fill_paused=True) if current_data.iloc[0, 0] < current_data.iloc[0, 1]: log.info("[%s]涨停打开,卖出" % (stock)) position = context.portfolio.positions[stock] close_position(position) else: log.info("[%s]涨停,继续持有" % (stock)) # 3-1 交易模块-自定义下单 def order_target_value_(security, value): if value == 0: log.debug("Selling out %s" % (security)) else: log.debug("Order %s to value %f" % (security, value)) return order_target_value(security, value) # 3-2 交易模块-开仓 def open_position(security, value): order = order_target_value_(security, value) if order != None and order.filled > 0: return True return False # 3-3 交易模块-平仓 def close_position(position): security = position.security order = order_target_value_(security, 0) # 可能会因停牌失败 if order != None: if order.status == OrderStatus.held and order.filled == order.amount: return True return False # 2-1 过滤停牌股票 def filter_paused_stock(stock_list): current_data = get_current_data() return [stock for stock in stock_list if not current_data[stock].paused] # 2-2 过滤ST及其他具有退市标签的股票 def filter_st_stock(stock_list): current_data = get_current_data() return [stock for stock in stock_list if not current_data[stock].is_st and 'ST' not in current_data[stock].name and '*' not in current_data[stock].name and '退' not in current_data[stock].name] # 2-3 过滤科创北交股票 def filter_kcbj_stock(stock_list): for stock in stock_list[:]: if stock[0] == '4' or stock[0] == '8' or stock[:2] == '68' or stock[0] == '3': stock_list.remove(stock) return stock_list # 2-4 过滤涨停的股票 def filter_limitup_stock(context, stock_list): last_prices = history(1, unit='1m', field='close', security_list=stock_list) current_data = get_current_data() return [stock for stock in stock_list if stock in context.portfolio.positions.keys() or last_prices[stock][-1] < current_data[stock].high_limit] # 2-5 过滤跌停的股票 def filter_limitdown_stock(context, stock_list): last_prices = history(1, unit='1m', field='close', security_list=stock_list) current_data = get_current_data() return [stock for stock in stock_list if stock in context.portfolio.positions.keys() or last_prices[stock][-1] > current_data[stock].low_limit] # 2-6 过滤次新股 def filter_new_stock(context, stock_list): yesterday = context.previous_date return [stock for stock in stock_list if not yesterday - get_security_info(stock).start_date < datetime.timedelta(days=375)]