kmeans.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # 克隆自聚宽文章:https://www.joinquant.com/post/61157
  2. # 标题:致敬市场(13)——K-means 指数增强
  3. # 作者:Gyro^.^
  4. import numpy as np
  5. import pandas as pd
  6. import datetime as dt
  7. from sklearn.cluster import KMeans
  8. def initialize(context):
  9. # setting system
  10. log.set_level('order', 'error')
  11. set_option('use_real_price', True)
  12. set_option('avoid_future_data', True)
  13. run_monthly(iUpdate, 1, 'before_open')
  14. run_daily(iTrader, 'every_bar')
  15. run_daily(iReport, 'after_close')
  16. def iUpdate(context):
  17. # paramers
  18. nday = 243
  19. n_cluster = 24
  20. n_class = 5
  21. # all funds
  22. dt_now = context.current_dt.date()
  23. all_fund = get_all_securities('fund', dt_now)
  24. # onlist 1-year
  25. dt_1y = dt_now - dt.timedelta(days=365)
  26. funds = all_fund[all_fund.start_date < dt_1y].index.tolist()
  27. # filter, liquity
  28. hm = history(nday, '1d', 'money', funds).mean()
  29. funds = hm[hm > 1e6].index.tolist()
  30. # history return
  31. h = history(nday, '1d', 'close', funds).dropna(axis=1)
  32. r = np.log(h).diff()[1:]
  33. # annual return
  34. ar = 100*nday*r.mean()
  35. # K-means
  36. cluster = KMeans(n_clusters=n_cluster).fit(r.T)
  37. # labels
  38. c = pd.Series(cluster.labels_, r.columns)
  39. # choice
  40. df = pd.DataFrame(columns=['Cluster', 'Return', 'Name'])
  41. for k in range(n_cluster):
  42. # class
  43. choice = []
  44. for s in c.index:
  45. if c[s] == k:
  46. choice.append(s)
  47. log.info(k, len(choice))
  48. # high return
  49. cr = ar[choice].sort_values(ascending=False).head(n_class)
  50. for s in cr.index:
  51. if s in context.portfolio.positions:
  52. df.loc[s] = [k, cr[s], get_security_info(s).display_name]
  53. break
  54. else:
  55. s = cr.index[0]
  56. df.loc[s] = [k, cr[s], get_security_info(s).display_name]
  57. # report
  58. pd.set_option('display.max_rows', None)
  59. log.info('Funds', len(df), '\n', df)
  60. # results
  61. g.funds = df.index.tolist()
  62. g.position_size = 0.95/n_cluster * context.portfolio.total_value
  63. def iTrader(context):
  64. # load data
  65. choice = g.funds
  66. position_size = g.position_size
  67. lm_value = 0.7*position_size
  68. hm_value = 1.5*position_size
  69. cash_size = 0.05*context.portfolio.total_value
  70. cdata = get_current_data()
  71. # sell
  72. for s in context.portfolio.positions:
  73. if cdata[s].paused or \
  74. cdata[s].last_price >= cdata[s].high_limit or cdata[s].last_price <= cdata[s].low_limit:
  75. continue # 过滤三停
  76. if s not in choice:
  77. log.info('sell', s, cdata[s].name)
  78. order_target(s, 0, MarketOrderStyle(0.99*cdata[s].last_price))
  79. # buy
  80. for s in choice:
  81. if cdata[s].paused or \
  82. cdata[s].last_price >= cdata[s].high_limit or cdata[s].last_price <= cdata[s].low_limit:
  83. continue # 过滤三停
  84. if context.portfolio.available_cash < cash_size:
  85. break # 现金耗尽
  86. if s not in context.portfolio.positions:
  87. log.info('buy', s, cdata[s].name)
  88. order_target_value(s, position_size, MarketOrderStyle(1.01*cdata[s].last_price))
  89. def iReport(context):
  90. # table of positions
  91. ptable = pd.DataFrame(columns=['amount', 'value', 'weight', 'name'])
  92. tvalue = context.portfolio.total_value
  93. cdata = get_current_data()
  94. for s in context.portfolio.positions:
  95. ps = context.portfolio.positions[s]
  96. ptable.loc[s] = [ps.total_amount, int(ps.value), 100*ps.value/tvalue, cdata[s].name]
  97. ptable = ptable.sort_values(by='weight', ascending=False)
  98. # report portfolio
  99. pd.set_option('display.max_rows', None)
  100. log.info(' positions', len(ptable), '\n', ptable.head())
  101. log.info(' total value %.2f, cash %.2f', \
  102. context.portfolio.total_value/10000, context.portfolio.available_cash/10000)
  103. # end