| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- """
- 初始化trend_info表的脚本
- 从CSV文件中读取数据,按照规则生成存储趋势信息
- 根据规则:
- 1. 从"name"解析出对应的字段:
- 1.1 首先获得"name"的长度,如果是13,"category"一定是1;如果长度是4或5,"category"一定是2;如果长度是8"category"可能是0或1
- 1.2 如果是13,那么最后5个字符就是"extra_info",剩下的就是8个字符
- 1.3 如果长度是8,分析最后两个字符,如果是"上涨"或"下跌","category"是0,否则是1
- 1.4 针对8个字符的部分,每两个字符为一组,如果"category"是0,对应的分别是"time_range", "amplitude", "speed_type", "trend_type",然后根据内容找到id
- 1.5 针对8个字符的部分,每两个字符为一组,如果"category"是1,对应的分别是"time_range", "position", "amplitude", "trend_type",然后根据内容找到id
- """
- import os
- import csv
- import logging
- from pathlib import Path
- from app import db, create_app
- from app.models.dimension import (
- TrendInfo, DimTimeRange, DimAmplitude,
- DimPosition, DimSpeedType, DimTrendType, CandleInfo
- )
- import pandas as pd
- # 配置日志
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
- logger = logging.getLogger(__name__)
- def load_csv_data():
- """加载CSV文件数据"""
- csv_path = Path(__file__).parent.parent / 'config' / 'trend_info.csv'
- data = []
-
- if not csv_path.exists():
- logger.error(f"CSV文件不存在: {csv_path}")
- raise FileNotFoundError(f"CSV文件不存在: {csv_path}")
-
- try:
- with open(csv_path, 'r', encoding='utf-8') as file:
- reader = csv.DictReader(file)
- for row in reader:
- data.append(row)
-
- logger.info(f"从CSV文件成功加载了{len(data)}条记录")
- return data
- except Exception as e:
- logger.error(f"加载CSV文件时出错: {e}")
- raise
- def init_trend_info():
- """初始化trend_info表"""
- # 检查表中是否已有数据
- if db.session.query(TrendInfo).count() > 0:
- logger.info("trend_info表已有数据,跳过初始化")
- return
-
- # 加载CSV数据
- try:
- data = load_csv_data()
- except Exception as e:
- logger.error(f"加载CSV数据失败: {e}")
- return
-
- # 提前查询所有需要的维度数据
- time_ranges = {tr.name: tr for tr in db.session.query(DimTimeRange).all()}
- amplitudes = {a.name: a for a in db.session.query(DimAmplitude).all()}
- positions = {p.name: p for p in db.session.query(DimPosition).all()}
- speed_types = {s.name: s for s in db.session.query(DimSpeedType).all()}
- trend_types = {t.name: t for t in db.session.query(DimTrendType).all()}
-
- # 创建所有trend_info记录
- trend_records = []
- errors = []
-
- for i, row in enumerate(data, 1):
- try:
- name = row['name'].strip()
- if not name:
- logger.warning(f"第{i}行缺少name值,跳过")
- continue
-
- # 根据name的长度和特征确定category及其他字段
- category, time_range_id, amplitude_id, position_id, speed_type_id, trend_type_id, extra_info = parse_name(
- name, time_ranges, amplitudes, positions, speed_types, trend_types)
-
- # 创建TrendInfo对象
- trend_info = TrendInfo(
- category=category,
- name=name,
- time_range_id=time_range_id,
- position_id=position_id,
- amplitude_id=amplitude_id,
- speed_type_id=speed_type_id,
- trend_type_id=trend_type_id,
- extra_info=extra_info
- )
-
- trend_records.append(trend_info)
- except Exception as e:
- error_msg = f"处理第{i}行时出错: {e}, 行数据: {row}"
- logger.error(error_msg)
- errors.append(error_msg)
-
- if errors:
- logger.warning(f"初始化过程中有{len(errors)}个错误,请检查日志获取详情")
-
- try:
- # 批量添加记录并提交
- db.session.add_all(trend_records)
- db.session.commit()
- logger.info(f"成功初始化 {len(trend_records)} 条trend_info记录")
- except Exception as e:
- db.session.rollback()
- logger.error(f"提交数据到数据库时出错: {e}")
- raise
- def parse_name(name, time_ranges, amplitudes, positions, speed_types, trend_types):
- """
- 从name解析出category和其他字段的ID
-
- Args:
- name: 趋势名称
- time_ranges: 时间范围映射表 {name: object}
- amplitudes: 幅度范围映射表 {name: object}
- positions: 位置范围映射表 {name: object}
- speed_types: 速度类型映射表 {name: object}
- trend_types: 趋势类型映射表 {name: object}
-
- Returns:
- tuple: (category, time_range_id, amplitude_id, position_id, speed_type_id, trend_type_id, extra_info)
- """
- name_length = len(name)
-
- # 初始化所有字段为None
- category = None
- time_range_id = None
- amplitude_id = None
- position_id = None
- speed_type_id = None
- trend_type_id = None
- extra_info = None
-
- # 根据长度判断category
- if name_length == 13:
- # 长度为13,一定是category=1,带extra_info
- category = 1
- extra_info = name[-5:]
- main_part = name[:8]
- elif name_length in [4, 5]:
- # 长度为4或5,一定是category=2,其他字段为空
- category = 2
- return category, None, None, None, None, None, None
- elif name_length == 8:
- # 长度为8,需要进一步分析
- main_part = name
- if name.endswith('上涨') or name.endswith('下跌'):
- category = 0
- else:
- category = 1
- else:
- # 其他长度,出错
- raise ValueError(f"无法处理的名称长度: {name_length}, 名称: {name}")
-
- # 拆分main_part为4个两字符的部分
- parts = [main_part[i:i+2] for i in range(0, 8, 2)]
-
- if category == 0:
- # 对于category=0,顺序是:时间范围、幅度范围、速度类型、趋势类型
- time_range_name, amplitude_name, speed_type_name, trend_type_name = parts
-
- if time_range_name in time_ranges:
- time_range_id = time_ranges[time_range_name].id
- else:
- raise ValueError(f"找不到时间范围: {time_range_name}")
-
- if amplitude_name in amplitudes:
- amplitude_id = amplitudes[amplitude_name].id
- else:
- raise ValueError(f"找不到幅度范围: {amplitude_name}")
-
- if speed_type_name in speed_types:
- speed_type_id = speed_types[speed_type_name].id
- else:
- raise ValueError(f"找不到速度类型: {speed_type_name}")
-
- if trend_type_name in trend_types:
- trend_type_id = trend_types[trend_type_name].id
- else:
- raise ValueError(f"找不到趋势类型: {trend_type_name}")
-
- elif category == 1:
- # 对于category=1,顺序是:时间范围、位置范围、幅度范围、趋势类型(震荡)
- time_range_name, position_name, amplitude_name, trend_type_name = parts
-
- if time_range_name in time_ranges:
- time_range_id = time_ranges[time_range_name].id
- else:
- raise ValueError(f"找不到时间范围: {time_range_name}")
-
- if position_name in positions:
- position_id = positions[position_name].id
- else:
- raise ValueError(f"找不到位置范围: {position_name}")
-
- if amplitude_name in amplitudes:
- amplitude_id = amplitudes[amplitude_name].id
- else:
- raise ValueError(f"找不到幅度范围: {amplitude_name}")
-
- # 对于category=1,趋势类型固定为"震荡"
- if '震荡' in trend_types:
- trend_type_id = trend_types['震荡'].id
- else:
- raise ValueError("找不到趋势类型: 震荡")
-
- return category, time_range_id, amplitude_id, position_id, speed_type_id, trend_type_id, extra_info
- if __name__ == "__main__":
- logger.info("开始执行trend_info表初始化")
- app = create_app()
- with app.app_context():
- try:
- init_trend_info()
- logger.info("trend_info表初始化完成")
- except Exception as e:
- logger.error(f"初始化trend_info表时发生错误: {e}")
- raise
|