فهرست منبع

增加了一个回测代码

maxfeng 3 ماه پیش
والد
کامیت
3d715c5e86
1فایلهای تغییر یافته به همراه476 افزوده شده و 0 حذف شده
  1. 476 0
      Lib/stock/machine_learning.py

+ 476 - 0
Lib/stock/machine_learning.py

@@ -0,0 +1,476 @@
+# 机器学习框架
+# https://www.joinquant.com/view/community/detail/7253d06b938dc84af3e0c3c996d7d5bd?type=1
+
+# 1. 数据集制作
+from jqdata import *
+from jqlib.technical_analysis import *
+from jqfactor import get_factor_values, winsorize_med, standardlize, neutralize
+import datetime
+import pandas as pd
+import numpy as np
+from scipy import stats
+import statsmodels.api as sm
+from statsmodels import regression
+from six import StringIO
+from sklearn.decomposition import PCA
+from sklearn import svm
+from sklearn.model_selection import train_test_split
+from sklearn.grid_search import GridSearchCV
+from sklearn import metrics
+from tqdm import tqdm
+import matplotlib.dates as mdates
+import matplotlib.pyplot as plt
+import warnings
+import seaborn as sns
+import pickle
+warnings.filterwarnings("ignore")
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+from sklearn.metrics import (accuracy_score, precision_score, recall_score, 
+                            f1_score, roc_auc_score, confusion_matrix, 
+                            roc_curve, precision_recall_curve, auc, classification_report)
+import lightgbm as lgb
+
+jqfactors_list=['asset_impairment_loss_ttm', 'cash_flow_to_price_ratio', 'market_cap', 'interest_free_current_liability', 'EBITDA', 'financial_assets', 'gross_profit_ttm', 'net_working_capital', 'non_recurring_gain_loss', 'EBIT', 'sales_to_price_ratio', 'AR', 'ARBR', 'ATR6', 'DAVOL10', 'MAWVAD', 'TVMA6', 'PSY', 'VOL10', 'VDIFF', 'VEMA26', 'VMACD', 'VOL120', 'VOSC', 'VR', 'WVAD', 'arron_down_25', 'arron_up_25', 'BBIC', 'MASS', 'Rank1M', 'single_day_VPT', 'single_day_VPT_12', 'single_day_VPT_6', 'Volume1M', 'capital_reserve_fund_per_share', 'net_asset_per_share', 'net_operate_cash_flow_per_share', 'operating_profit_per_share', 'total_operating_revenue_per_share', 'surplus_reserve_fund_per_share', 'ACCA', 'account_receivable_turnover_days', 'account_receivable_turnover_rate', 'adjusted_profit_to_total_profit', 'super_quick_ratio', 'MLEV', 'debt_to_equity_ratio', 'debt_to_tangible_equity_ratio', 'equity_to_fixed_asset_ratio', 'fixed_asset_ratio', 'intangible_asset_ratio', 'invest_income_associates_to_total_profit', 'long_debt_to_asset_ratio', 'long_debt_to_working_capital_ratio', 'net_operate_cash_flow_to_total_liability', 'net_operating_cash_flow_coverage', 'non_current_asset_ratio', 'operating_profit_to_total_profit', 'roa_ttm', 'roe_ttm', 'Kurtosis120', 'Kurtosis20', 'Kurtosis60', 'sharpe_ratio_20', 'sharpe_ratio_60', 'Skewness120', 'Skewness20', 'Skewness60', 'Variance120', 'Variance20', 'liquidity', 'beta', 'book_to_price_ratio', 'cash_earnings_to_price_ratio', 'cube_of_size', 'earnings_to_price_ratio', 'earnings_yield', 'growth', 'momentum', 'natural_log_of_market_cap', 'boll_down', 'MFI14', 'price_no_fq']
+print(len(jqfactors_list))
+
+def get_period_date(peroid, start_date, end_date):
+    stock_data = get_price('000001.XSHE', start_date, end_date, 'daily', fields=['close'])
+    stock_data['date'] = stock_data.index
+    period_stock_data = stock_data.resample(peroid, how='last')
+    period_stock_data = period_stock_data.set_index('date').dropna()
+    date = period_stock_data.index
+    pydate_array = date.to_pydatetime()
+    date_only_array = np.vectorize(lambda s: s.strftime('%Y-%m-%d'))(pydate_array)
+    date_only_series = pd.Series(date_only_array)
+    start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d")
+    start_date = start_date - datetime.timedelta(days=1)
+    start_date = start_date.strftime("%Y-%m-%d")
+    date_list = date_only_series.values.tolist()
+    date_list.insert(0, start_date)
+    return date_list
+
+peroid = 'M'
+start_date = '2019-01-01'
+end_date = '2024-01-01'
+DAY = get_period_date(peroid, start_date, end_date)
+print(len(DAY))
+
+def delect_stop(stocks, beginDate, n=30 * 3):
+    stockList = []
+    beginDate = datetime.datetime.strptime(beginDate, "%Y-%m-%d")
+    for stock in stocks:
+        start_date = get_security_info(stock).start_date
+        if start_date < (beginDate - datetime.timedelta(days=n)).date():
+            stockList.append(stock)
+    return stockList
+
+def get_stock(stockPool, begin_date):
+    if stockPool == 'HS300':
+        stockList = get_index_stocks('000300.XSHG', begin_date)
+    elif stockPool == 'ZZ500':
+        stockList = get_index_stocks('399905.XSHE', begin_date)
+    elif stockPool == 'ZZ800':
+        stockList = get_index_stocks('399906.XSHE', begin_date)
+    elif stockPool == 'CYBZ':
+        stockList = get_index_stocks('399006.XSHE', begin_date)
+    elif stockPool == 'ZXBZ':
+        stockList = get_index_stocks('399101.XSHE', begin_date)
+    elif stockPool == 'A':
+        stockList = get_index_stocks('000002.XSHG', begin_date) + get_index_stocks('399107.XSHE', begin_date)
+        stockList = [stock for stock in stockList if not stock.startswith(('68', '4', '8'))]
+    elif stockPool == 'AA':
+        stockList = get_index_stocks('000985.XSHG', begin_date)
+        stockList = [stock for stock in stockList if not stock.startswith(('3', '68', '4', '8'))]
+    st_data = get_extras('is_st', stockList, count=1, end_date=begin_date)
+    stockList = [stock for stock in stockList if not st_data[stock][0]]
+    stockList = delect_stop(stockList, begin_date)
+    return stockList
+
+def get_factor_data(securities_list, date):
+    factor_data = get_factor_values(securities=securities_list, factors=jqfactors_list, count=1, end_date=date)
+    df_jq_factor = pd.DataFrame(index=securities_list)
+    for i in factor_data.keys():
+        df_jq_factor[i] = factor_data[i].iloc[0, :]
+    return df_jq_factor
+
+dateList = get_period_date(peroid, start_date, end_date)
+DF = pd.DataFrame()
+
+for date in tqdm(dateList[:-1]):
+    stockList = get_stock('ZXBZ', date)
+    factor_origl_data = get_factor_data(stockList, date)
+    data_close = get_price(stockList, date, dateList[dateList.index(date) + 1], '1d', 'close')['close']
+    factor_origl_data['pchg'] = data_close.iloc[-1] / data_close.iloc[1] - 1
+    factor_origl_data = factor_origl_data.dropna()
+    median_pchg = factor_origl_data['pchg'].median()
+    factor_origl_data['label'] = np.where(factor_origl_data['pchg'] >= median_pchg, 1, 0)
+    factor_origl_data = factor_origl_data.drop(columns=['pchg'])
+    DF = pd.concat([DF, factor_origl_data], ignore_index=True)
+
+DF.to_csv(r'train_small.csv', index=False)
+
+# 2. 数据分析
+df = pd.read_csv(r'train_small.csv')
+plot_cols = jqfactors_list
+print(len(plot_cols))
+corr_matrix = df.corr()
+plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
+label_corr = corr_matrix['label'].sort_values(ascending=False)
+plt.figure(figsize=(12, 10))
+corr = df[plot_cols].corr()
+mask = np.triu(np.ones_like(corr, dtype=bool))
+sns.heatmap(corr, mask=mask, annot=False, fmt=".2f", cmap='RdBu_r', vmin=-1, vmax=1)
+plt.title('因子间相关性矩阵')
+plt.show()
+
+plt.figure(figsize=(14, 100))
+for i, col in enumerate(plot_cols, 1):
+    plt.subplot(42, 2, i)
+    sns.distplot(df[col].dropna(), bins=30, color='skyblue', kde=True)
+    stats_text = f"均值: {df[col].mean():.2f}\n中位数: {df[col].median():.2f}\n标准差: {df[col].std():.2f}"
+    plt.gca().text(0.95, 0.95, stats_text, transform=plt.gca().transAxes, 
+                  verticalalignment='top', horizontalalignment='right',
+                  bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
+    plt.title(f'{col} 分布')
+    plt.xlabel('')
+plt.tight_layout()
+plt.show()
+
+# 3. 数据预处理
+from collections import defaultdict
+# 计算每个特征的缺失值数量
+missing_counts = df[plot_cols].isnull().sum().to_dict()
+
+# 计算特征间的相关系数矩阵
+corr_matrix = df[plot_cols].corr()
+
+# 创建图结构存储高度相关的特征对
+graph = defaultdict(list)
+threshold = 0.6  # 相关性阈值
+
+# 遍历上三角矩阵找到高度相关的特征对
+n = len(plot_cols)
+for i in range(n):
+    for j in range(i + 1, n):
+        col1, col2 = plot_cols[i], plot_cols[j]
+        corr_value = corr_matrix.iloc[i, j]
+
+        if not pd.isna(corr_value) and abs(corr_value) > threshold:
+            graph[col1].append(col2)
+            graph[col2].append(col1)
+
+# 使用DFS找到连通分量(高度相关的特征组)
+visited = set()
+components = []
+
+def dfs(node, comp):
+    visited.add(node)
+    comp.append(node)
+    for neighbor in graph[node]:
+        if neighbor not in visited:
+            dfs(neighbor, comp)
+
+for col in plot_cols:
+    if col not in visited:
+        comp = []
+        dfs(col, comp)
+        components.append(comp)
+
+# 处理每个连通分量:保留缺失值最少的特征
+to_keep = []
+to_remove = []
+
+for comp in components:
+    if len(comp) == 1:  # 独立特征直接保留
+        to_keep.append(comp[0])
+    else:
+        # 按缺失值数量排序(升序),相同缺失值时按特征名字母顺序排序
+        comp_sorted = sorted(comp, key=lambda x: (missing_counts[x], x))
+        keep_feature = comp_sorted[0]  # 缺失值最少的特征
+        to_keep.append(keep_feature)
+        to_remove.extend(comp_sorted[1:])  # 组内其他特征移除
+
+
+
+print(f"\n最终保留特征数量: {len(to_keep)}")
+print(f"移除特征数量: {len(to_remove)}")
+print("\n移除的特征列表:", to_remove)
+print("\n保留的特征列表:", to_keep)
+# 可视化保留特征的相关矩阵(可选)
+plt.figure(figsize=(12, 10))
+corr_kept = df[to_keep].corr()
+mask = np.triu(np.ones_like(corr_kept, dtype=bool))
+sns.heatmap(corr_kept, mask=mask, annot=True, fmt=".2f", cmap='RdBu_r', vmin=-1, vmax=1)
+plt.title('保留特征间相关性矩阵')
+plt.show()
+
+# 4. 训练模型
+X = df[to_keep]
+y = df['label']
+lgb_train = lgb.Dataset(X, label=y)
+params = {
+    'objective': 'binary',
+    'metric': 'binary_logloss',
+    'boosting_type': 'gbdt',
+    'verbose': -1
+}
+model = lgb.train(params, lgb_train, num_boost_round=200)
+y_pred_proba = model.predict(X)
+y_pred = (y_pred_proba > 0.5).astype(int)
+precision, recall, _ = precision_recall_curve(y, y_pred_proba)
+prauc = auc(recall, precision)
+print("\n模型性能评估:")
+print("准确率 (Accuracy):", accuracy_score(y, y_pred))
+print("精确率 (Precision):", precision_score(y, y_pred))
+print("召回率 (Recall):", recall_score(y, y_pred))
+print("F1分数 (F1-score):", f1_score(y, y_pred))
+print("AUC分数:", roc_auc_score(y, y_pred_proba))
+print("PRAUC分数:", prauc)  # 新增PRAUC输出
+plt.figure(figsize=(15, 12))
+plt.subplot(2, 2, 1)
+cm = confusion_matrix(y, y_pred)
+sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
+            xticklabels=['预测0', '预测1'], 
+            yticklabels=['实际0', '实际1'])
+plt.title('混淆矩阵')
+plt.ylabel('实际标签')
+plt.xlabel('预测标签')
+plt.subplot(2, 2, 2)
+fpr, tpr, _ = roc_curve(y, y_pred_proba)
+roc_auc = roc_auc_score(y, y_pred_proba)
+plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'AUC = {roc_auc:.3f}')
+plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
+plt.xlim([0.0, 1.0])
+plt.ylim([0.0, 1.05])
+plt.xlabel('假正率 (FPR)')
+plt.ylabel('真正率 (TPR)')
+plt.title('ROC曲线')
+plt.legend(loc="lower right")
+plt.subplot(2, 2, 3)
+importance = pd.Series(model.feature_importance(), index=to_keep)
+importance.sort_values().plot(kind='barh')
+plt.title('特征重要性')
+plt.xlabel('重要性分数')
+plt.ylabel('特征')
+plt.subplot(2, 2, 4)
+for label in [0, 1]:
+    sns.kdeplot(y_pred_proba[y == label], label=f'真实标签={label}', shade=True)
+plt.title('预测概率分布')
+plt.xlabel('预测为正类的概率')
+plt.ylabel('密度')
+plt.legend()
+plt.axvline(0.5, color='red', linestyle='--')
+plt.tight_layout()
+plt.show()
+plt.figure(figsize=(12, 5))
+plt.subplot(1, 2, 1)
+plt.plot(recall, precision, color='darkblue', lw=2, label=f'PRAUC = {prauc:.3f}')
+plt.fill_between(recall, precision, alpha=0.2, color='darkblue')
+plt.xlabel('召回率 (Recall)')
+plt.ylabel('精确率 (Precision)')
+plt.title('PRAUC曲线')
+plt.legend(loc='upper right')
+plt.grid(True)
+plt.tight_layout()
+plt.show()
+with open('model_small.pkl', 'wb') as model_file:
+    pickle.dump(model, model_file)
+
+# 5. 回测代码
+from jqdata import *
+from jqfactor import *
+import numpy as np
+import pandas as pd
+import pickle
+
+
+
+# 初始化函数
+def initialize(context):
+    # 设定基准
+    set_benchmark('399101.XSHE')
+    # 用真实价格交易
+    set_option('use_real_price', True)
+    # 打开防未来函数
+    set_option("avoid_future_data", True)
+    # 将滑点设置为0
+    set_slippage(FixedSlippage(0))
+    # 设置交易成本万分之三,不同滑点影响可在归因分析中查看
+    set_order_cost(OrderCost(open_tax=0, close_tax=0.001, open_commission=0.0003, close_commission=0.0003,
+                             close_today_commission=0, min_commission=5), type='stock')
+    # 过滤order中低于error级别的日志
+    log.set_level('order', 'error')
+    # 初始化全局变量
+
+    g.stock_num = 10
+    g.hold_list = []  # 当前持仓的全部股票
+    g.yesterday_HL_list = []  # 记录持仓中昨日涨停的股票
+
+
+    g.model_small = pickle.loads(read_file('model_small.pkl'))
+
+    # 因子列表
+    g.factor_list =  ['asset_impairment_loss_ttm', 'cash_flow_to_price_ratio', 'EBIT', 'net_working_capital', 'non_recurring_gain_loss', 'sales_to_price_ratio', 'AR', 'ARBR', 'ATR6', 'DAVOL10', 'MAWVAD', 'TVMA6', 'PSY', 'VOL10', 'VDIFF', 'VEMA26', 'VMACD', 'VOL120', 'VOSC', 'VR', 'arron_down_25', 'arron_up_25', 'BBIC', 'MASS', 'Rank1M', 'single_day_VPT', 'single_day_VPT_12', 'Volume1M', 'capital_reserve_fund_per_share', 'net_operate_cash_flow_per_share', 'operating_profit_per_share', 'total_operating_revenue_per_share', 'surplus_reserve_fund_per_share', 'ACCA', 'account_receivable_turnover_days', 'account_receivable_turnover_rate', 'adjusted_profit_to_total_profit', 'super_quick_ratio', 'MLEV', 'debt_to_equity_ratio', 'debt_to_tangible_equity_ratio', 'equity_to_fixed_asset_ratio', 'fixed_asset_ratio', 'intangible_asset_ratio', 'invest_income_associates_to_total_profit', 'long_debt_to_asset_ratio', 'long_debt_to_working_capital_ratio', 'net_operate_cash_flow_to_total_liability', 'net_operating_cash_flow_coverage', 'non_current_asset_ratio', 'operating_profit_to_total_profit', 'roa_ttm', 'Kurtosis120', 'Kurtosis20', 'Kurtosis60', 'sharpe_ratio_20', 'sharpe_ratio_60', 'Skewness120', 'Skewness20', 'Skewness60', 'Variance120', 'Variance20', 'beta', 'book_to_price_ratio', 'cash_earnings_to_price_ratio', 'cube_of_size', 'earnings_to_price_ratio', 'earnings_yield', 'growth', 'momentum', 'natural_log_of_market_cap', 'boll_down', 'MFI14', 'price_no_fq']
+    run_daily(prepare_stock_list, '9:05')
+    run_monthly(weekly_adjustment, 1, '9:30')
+    run_daily(check_limit_up, '14:00') 
+
+
+
+
+# 1-1 准备股票池
+def prepare_stock_list(context):
+    # 获取已持有列表
+    g.hold_list = []
+    for position in list(context.portfolio.positions.values()):
+        stock = position.security
+        g.hold_list.append(stock)
+    # 获取昨日涨停列表
+    if g.hold_list != []:
+        df = get_price(g.hold_list, end_date=context.previous_date, frequency='daily', fields=['close', 'high_limit'],
+                       count=1, panel=False, fill_paused=False)
+        df = df[df['close'] == df['high_limit']]
+        g.yesterday_HL_list = list(df.code)
+    else:
+        g.yesterday_HL_list = []
+
+# 1-2 选股模块
+def get_stock_list(context):
+    yesterday = context.previous_date
+    today = context.current_dt
+    stocks = get_index_stocks('399101.XSHE', yesterday)
+    initial_list = filter_kcbj_stock(stocks)
+    initial_list = filter_st_stock(initial_list)
+    initial_list = filter_paused_stock(initial_list)
+    initial_list = filter_new_stock(context, initial_list)
+    initial_list = filter_limitup_stock(context,initial_list)
+    initial_list = filter_limitdown_stock(context,initial_list)
+    factor_data = get_factor_values(initial_list, g.factor_list, end_date=yesterday, count=1)
+    df_jq_factor_value = pd.DataFrame(index=initial_list, columns=g.factor_list)
+    for factor in g.factor_list:
+        df_jq_factor_value[factor] = list(factor_data[factor].T.iloc[:, 0])
+    tar = g.model_small.predict(df_jq_factor_value)
+    df = df_jq_factor_value
+    df['total_score'] = list(tar)
+    df = df.sort_values(by=['total_score'], ascending=False)  
+    lst = df.index.tolist()
+    lst = lst[:min(g.stock_num, len(lst))]
+    return lst
+
+
+# 1-3 整体调整持仓
+def weekly_adjustment(context):
+
+        # 获取应买入列表
+        target_list = get_stock_list(context)
+        # 调仓卖出
+        for stock in g.hold_list:
+            if (stock not in target_list) and (stock not in g.yesterday_HL_list):
+                log.info("卖出[%s]" % (stock))
+                position = context.portfolio.positions[stock]
+                close_position(position)
+            else:
+                log.info("已持有[%s]" % (stock))
+        # 调仓买入
+        position_count = len(context.portfolio.positions)
+        target_num = len(target_list)
+        if target_num > position_count:
+            value = context.portfolio.cash / (target_num - position_count)
+            for stock in target_list:
+                if context.portfolio.positions[stock].total_amount == 0:
+                    if open_position(stock, value):
+                        if len(context.portfolio.positions) == target_num:
+                            break
+
+
+
+# 1-4 调整昨日涨停股票
+def check_limit_up(context):
+    now_time = context.current_dt
+    if g.yesterday_HL_list != []:
+        # 对昨日涨停股票观察到尾盘如不涨停则提前卖出,如果涨停即使不在应买入列表仍暂时持有
+        for stock in g.yesterday_HL_list:
+            current_data = get_price(stock, end_date=now_time, frequency='1m', fields=['close', 'high_limit'],
+                                     skip_paused=False, fq='pre', count=1, panel=False, fill_paused=True)
+            if current_data.iloc[0, 0] < current_data.iloc[0, 1]:
+                log.info("[%s]涨停打开,卖出" % (stock))
+                position = context.portfolio.positions[stock]
+                close_position(position)
+            else:
+                log.info("[%s]涨停,继续持有" % (stock))
+
+# 3-1 交易模块-自定义下单
+def order_target_value_(security, value):
+    if value == 0:
+        log.debug("Selling out %s" % (security))
+    else:
+        log.debug("Order %s to value %f" % (security, value))
+    return order_target_value(security, value)
+
+
+# 3-2 交易模块-开仓
+def open_position(security, value):
+    order = order_target_value_(security, value)
+    if order != None and order.filled > 0:
+        return True
+    return False
+
+
+# 3-3 交易模块-平仓
+def close_position(position):
+    security = position.security
+    order = order_target_value_(security, 0)  # 可能会因停牌失败
+    if order != None:
+        if order.status == OrderStatus.held and order.filled == order.amount:
+            return True
+    return False
+
+
+# 2-1 过滤停牌股票
+def filter_paused_stock(stock_list):
+    current_data = get_current_data()
+    return [stock for stock in stock_list if not current_data[stock].paused]
+
+
+# 2-2 过滤ST及其他具有退市标签的股票
+def filter_st_stock(stock_list):
+    current_data = get_current_data()
+    return [stock for stock in stock_list
+            if not current_data[stock].is_st
+            and 'ST' not in current_data[stock].name
+            and '*' not in current_data[stock].name
+            and '退' not in current_data[stock].name]
+
+
+# 2-3 过滤科创北交股票
+def filter_kcbj_stock(stock_list):
+    for stock in stock_list[:]:
+        if stock[0] == '4' or stock[0] == '8' or stock[:2] == '68' or stock[0] == '3':
+            stock_list.remove(stock)
+    return stock_list
+
+
+# 2-4 过滤涨停的股票
+def filter_limitup_stock(context, stock_list):
+    last_prices = history(1, unit='1m', field='close', security_list=stock_list)
+    current_data = get_current_data()
+    return [stock for stock in stock_list if stock in context.portfolio.positions.keys()
+            or last_prices[stock][-1] < current_data[stock].high_limit]
+
+
+# 2-5 过滤跌停的股票
+def filter_limitdown_stock(context, stock_list):
+    last_prices = history(1, unit='1m', field='close', security_list=stock_list)
+    current_data = get_current_data()
+    return [stock for stock in stock_list if stock in context.portfolio.positions.keys()
+            or last_prices[stock][-1] > current_data[stock].low_limit]
+
+
+# 2-6 过滤次新股
+def filter_new_stock(context, stock_list):
+    yesterday = context.previous_date
+    return [stock for stock in stock_list if
+            not yesterday - get_security_info(stock).start_date < datetime.timedelta(days=375)]