From ecd5c81d00ea034cc8e44d41abf34ff05fee8333 Mon Sep 17 00:00:00 2001 From: zengbin93 <1257391203@qq.com> Date: Sun, 1 Jan 2023 20:45:31 +0800 Subject: [PATCH] =?UTF-8?q?V0.9.6=20=E9=87=8D=E8=A6=81=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=20(#116)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 0.9.6 start coding * 0.9.6 Event / Factor 对象优化 * 0.9.6 Event / Factor 对象优化 * 0.9.6 Event / Factor 对象优化 * 0.9.6 update word writer * 0.9.6 优化 dummy backtest 过程 * 0.9.6 update * 0.9.6 update * 0.9.6 update * 0.9.6 update * 0.9.6 删除 ts_backtest.py * 0.9.6 优化 BI 对象的缓存机制 * 0.9.6 新增大小级别中枢共振信号函数 * 0.9.6 新增 CzscSignals 作为信号计算的基类 * 0.9.6 update * 0.9.6 新增简单持仓对象 * 0.9.6 update * 0.9.6 update * 0.9.6 update * 0.9.6 update * 0.9.6 优化缓存机制 * 0.9.6 新增 qmt 工具集 * 0.9.6 update * 0.9.6 update * 0.9.6 update * 0.9.6 update * 0.9.6 update * 0.9.6 update * 0.9.6 更新单元测试 * 0.9.6 fix ta * 0.9.6 fix ta --- czsc/__init__.py | 4 +- czsc/cmd.py | 132 +--- czsc/connectors/__init__.py | 7 + czsc/connectors/qmt_connector.py | 124 ++++ czsc/data/ts.py | 15 +- czsc/gm_utils.py | 906 ---------------------------- czsc/objects.py | 344 +++++++++-- czsc/sensors/__init__.py | 1 - czsc/sensors/utils.py | 199 +----- czsc/signals/__init__.py | 2 + czsc/signals/bar.py | 26 +- czsc/signals/cxt.py | 53 ++ czsc/signals/tas.py | 353 +++++++---- czsc/traders/__init__.py | 5 +- czsc/traders/advanced.py | 3 + czsc/traders/base.py | 375 ++++++++++++ czsc/traders/dummy.py | 216 +++++++ czsc/traders/ts_backtest.py | 254 -------- czsc/utils/word_writer.py | 12 +- examples/__init__.py | 16 + examples/gm_backtest.py | 16 +- examples/gm_check_point.py | 12 +- examples/gm_realtime.py | 12 +- examples/quick_start.py | 2 +- examples/strategies/cat_sma.py | 132 ++-- examples/strategies/check_signal.py | 80 +-- examples/strategies/qmt_cat_sma.py | 225 +++++++ examples/ts_check_signal_acc.py | 15 +- examples/ts_continue_simulator.py | 2 +- examples/ts_dummy_trader.py | 2 +- examples/ts_fast_backtest.py | 67 -- examples/ts_signals_analyze.py | 74 --- examples/ts_stocks_sensors.py | 26 +- test/test_advanced_trader.py | 3 +- test/test_objects.py | 9 + test/test_trader_base.py | 362 +++++++++++ 36 files changed, 2118 insertions(+), 1968 deletions(-) create mode 100644 czsc/connectors/__init__.py create mode 100644 czsc/connectors/qmt_connector.py delete mode 100644 czsc/gm_utils.py create mode 100644 czsc/traders/base.py create mode 100644 czsc/traders/dummy.py delete mode 100644 czsc/traders/ts_backtest.py create mode 100644 examples/strategies/qmt_cat_sma.py delete mode 100644 examples/ts_fast_backtest.py delete mode 100644 examples/ts_signals_analyze.py create mode 100644 test/test_trader_base.py diff --git a/czsc/__init__.py b/czsc/__init__.py index 51200b118..e7a4f4812 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -17,10 +17,10 @@ from czsc.utils.cache import home_path, get_dir_size, empty_cache_path -__version__ = "0.9.5" +__version__ = "0.9.6" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" -__date__ = "20221212" +__date__ = "20221221" if envs.get_welcome(): diff --git a/czsc/cmd.py b/czsc/cmd.py index 21d201de7..c73ee036d 100644 --- a/czsc/cmd.py +++ b/czsc/cmd.py @@ -7,10 +7,7 @@ https://click.palletsprojects.com/en/8.0.x/quickstart/ """ -import os import click -import glob -import pandas as pd from loguru import logger @@ -28,153 +25,46 @@ def aphorism(): print_one() -@czsc.command() -@click.option('-f', '--file_strategy', type=str, required=True, help="Python择时策略文件路径") -def backtest(file_strategy): - """使用 TradeSimulator 进行择时策略回测/仿真研究""" - from collections import Counter - from czsc.traders.ts_simulator import TradeSimulator - from czsc.utils import get_py_namespace - - py = get_py_namespace(file_strategy) - results_path = os.path.join(py['results_path'], "backtest") - os.makedirs(results_path, exist_ok=True) - - strategy = py['trader_strategy'] - dc = py['dc'] - symbols = py['symbols'] - - ts = TradeSimulator(dc, strategy, res_path=results_path) - counter = Counter([x.split("#")[1] for x in symbols]) - for asset, c in counter.items(): - ts_codes = [x.split("#")[0] for x in symbols if x.endswith(asset)] - ts.update_traders(ts_codes, asset) - - @czsc.command() @click.option('-f', '--file_strategy', type=str, required=True, help="Python择时策略文件路径") def dummy(file_strategy): """使用 CzscDummyTrader 进行快速的择时策略研究""" - import shutil - from datetime import datetime - from czsc.traders.advanced import CzscDummyTrader - from czsc.sensors.utils import generate_symbol_signals - from czsc.utils import get_py_namespace, dill_dump, dill_load - - py = get_py_namespace(file_strategy) - signals_path = os.path.join(py['results_path'], "signals") - results_path = os.path.join(py['results_path'], f"DEXP{datetime.now().strftime('%Y%m%d%H%M')}") - os.makedirs(signals_path, exist_ok=True) - os.makedirs(results_path, exist_ok=True) - shutil.copy(file_strategy, os.path.join(results_path, 'strategy.py')) - - strategy = py['trader_strategy'] - dc = py['dc'] - symbols = py['symbols'] - - for symbol in symbols: - file_dfs = os.path.join(signals_path, f"{symbol}_signals.pkl") - - try: - # 可以直接生成信号,也可以直接读取信号 - if os.path.exists(file_dfs): - dfs = pd.read_pickle(file_dfs) - else: - ts_code, asset = symbol.split('#') - dfs = generate_symbol_signals(dc, ts_code, asset, "20170101", "20221001", strategy, 'hfq') - dfs.to_pickle(file_dfs) - - cdt = CzscDummyTrader(dfs, strategy) - dill_dump(cdt, os.path.join(results_path, f"{symbol}.cdt")) - - res = cdt.results - if "long_performance" in res.keys(): - logger.info(f"{res['long_performance']}") - - if "short_performance" in res.keys(): - logger.info(f"{res['short_performance']}") - except: - logger.exception(f"fail on {symbol}") - - # 汇总结果 - tactic = strategy('symbol') - files = glob.glob(f"{results_path}/*.cdt") - if tactic.get("long_pos", None): - lpf = pd.DataFrame([dill_load(file).results['long_performance'] for file in files]) - lpf.to_excel(os.path.join(results_path, f'{strategy.__doc__}多头回测结果.xlsx'), index=False) - - if tactic.get("short_pos", None): - spf = pd.DataFrame([dill_load(file).results['short_performance'] for file in files]) - spf.to_excel(os.path.join(results_path, f'{strategy.__doc__}空头回测结果.xlsx'), index=False) + from czsc.traders.dummy import DummyBacktest + dbt = DummyBacktest(file_strategy) + dbt.replay() + dbt.execute() + dbt.report() @czsc.command() @click.option('-f', '--file_strategy', type=str, required=True, help="Python择时策略文件路径") def replay(file_strategy): """执行择时策略在某个品种上的交易回放""" - from czsc.data import freq_cn2ts - from czsc.utils import BarGenerator - from czsc.traders.utils import trade_replay - from czsc.utils import get_py_namespace - - py = get_py_namespace(file_strategy) - strategy = py['trader_strategy'] - dc = py['dc'] - replay_params = py.get('replay_params', None) - - if not replay_params: - logger.warning(f"{file_strategy} 中没有设置策略回放参数,将使用默认参数执行") - - # 获取单个品种的基础周期K线 - tactic = strategy("000001.SZ") - base_freq = tactic['base_freq'] - symbol = replay_params.get('symbol', "000001.SZ#E") - ts_code, asset = symbol.split("#") - sdt = replay_params.get('sdt', '20150101') - edt = replay_params.get('edt', '20220101') - bars = dc.pro_bar_minutes(ts_code, sdt, edt, freq_cn2ts[base_freq], asset, adj="hfq") - logger.info(f"交易回放参数 | {symbol} - sdt:{sdt} - edt: {edt}") - - # 设置回放快照文件保存目录 - res_path = os.path.join(py['results_path'], f"replay_{symbol}") - os.makedirs(res_path, exist_ok=True) - - # 拆分基础周期K线,一部分用来初始化BarGenerator,随后的K线是回放区间 - start_date = pd.to_datetime(replay_params.get('mdt', '20200101')) - bg = BarGenerator(base_freq, freqs=tactic['freqs']) - bars1 = [x for x in bars if x.dt <= start_date] - bars2 = [x for x in bars if x.dt > start_date] - for bar in bars1: - bg.update(bar) - - trade_replay(bg, bars2, strategy, res_path) + from czsc.traders.dummy import DummyBacktest + dbt = DummyBacktest(file_strategy) + dbt.replay() @czsc.command() -@click.option('-f', '--file_strategy', type=str, required=True, help="Python择时策略文件路径") -@click.option('-d', '--delta_days', type=int, required=False, default=5, help="两次相同信号之间的间隔天数") +@click.option('-f', '--file_strategy', type=str, required=True, help="Python信号检查文件路径") +@click.option('-d', '--delta_days', type=int, required=False, default=1, help="两次相同信号之间的间隔天数") def check(file_strategy, delta_days): """执行择时策略中使用的信号在某个品种上的校验""" - from czsc.data import freq_cn2ts from czsc.sensors.utils import check_signals_acc from czsc.utils import get_py_namespace py = get_py_namespace(file_strategy) strategy = py['trader_strategy'] - dc = py['dc'] check_params = py.get('check_params', None) if not check_params: logger.warning(f"{file_strategy} 中没有设置策略回放参数,将使用默认参数执行") # 获取单个品种的基础周期K线 - tactic = strategy("000001.SZ") - base_freq = tactic['base_freq'] symbol = check_params.get('symbol', "000001.SZ#E") - ts_code, asset = symbol.split("#") sdt = check_params.get('sdt', '20200101') edt = check_params.get('edt', '20220101') - bars = dc.pro_bar_minutes(ts_code, sdt, edt, freq_cn2ts[base_freq], asset, adj="hfq") + bars = py['read_bars'](symbol, sdt, edt) logger.info(f"信号检查参数 | {symbol} - sdt: {sdt} - edt: {edt}") check_signals_acc(bars, strategy=strategy, delta_days=delta_days) diff --git a/czsc/connectors/__init__.py b/czsc/connectors/__init__.py new file mode 100644 index 000000000..062fab7e7 --- /dev/null +++ b/czsc/connectors/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2022/12/31 16:02 +describe: 常用第三方交易框架的连接器 +""" diff --git a/czsc/connectors/qmt_connector.py b/czsc/connectors/qmt_connector.py new file mode 100644 index 000000000..4b70715cf --- /dev/null +++ b/czsc/connectors/qmt_connector.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2022/12/31 16:03 +describe: +""" +import pandas as pd +from typing import List +from czsc.objects import Freq, RawBar +from xtquant import xtdata + + +def format_stock_kline(kline: pd.DataFrame, freq: Freq) -> List[RawBar]: + """QMT A股市场K线数据转换 + + :param kline: QMT 数据接口返回的K线数据 + time open high low close volume amount \ + 0 2022-12-01 10:15:00 13.22 13.22 13.16 13.18 20053 26432861.0 + 1 2022-12-01 10:20:00 13.18 13.19 13.15 13.15 32667 43002512.0 + 2 2022-12-01 10:25:00 13.16 13.18 13.13 13.16 32466 42708049.0 + 3 2022-12-01 10:30:00 13.16 13.19 13.13 13.18 15606 20540461.0 + 4 2022-12-01 10:35:00 13.20 13.25 13.19 13.22 29959 39626170.0 + symbol + 0 000001.SZ + 1 000001.SZ + 2 000001.SZ + 3 000001.SZ + 4 000001.SZ + :param freq: K线周期 + :return: 转换好的K线数据 + """ + bars = [] + dt_key = 'time' + kline = kline.sort_values(dt_key, ascending=True, ignore_index=True) + records = kline.to_dict('records') + + for i, record in enumerate(records): + # 将每一根K线转换成 RawBar 对象 + bar = RawBar(symbol=record['symbol'], dt=pd.to_datetime(record[dt_key]), id=i, freq=freq, + open=record['open'], close=record['close'], high=record['high'], low=record['low'], + vol=record['volume'] * 100 if record['volume'] else 0, # 成交量,单位:股 + amount=record['amount'] if record['amount'] > 0 else 0, # 成交额,单位:元 + ) + bars.append(bar) + return bars + + +def get_local_kline(symbol, period, start_time, end_time, count=-1, dividend_type='none', data_dir=None, update=True): + """获取 QMT 本地K线数据 + + :param symbol: 股票代码 例如:'300001.SZ' + :param period: 周期 分笔"tick" 分钟线"1m"/"5m" 日线"1d" + :param start_time: 开始时间,格式YYYYMMDD/YYYYMMDDhhmmss/YYYYMMDDhhmmss.milli, + 例如:"20200427" "20200427093000" "20200427093000.000" + :param end_time: 结束时间 格式同上 + :param count: 数量 -1全部,n: 从结束时间向前数n个 + :param dividend_type: 除权类型"none" "front" "back" "front_ratio" "back_ratio" + :param data_dir: 下载QMT本地数据路径,如 D:/迅投极速策略交易系统交易终端/datadir + :param update: 更新QMT本地数据路径中的数据 + :return: df Dataframe格式的数据,样例如下 + time open high low close volume amount \ + 0 2022-12-01 10:15:00 13.22 13.22 13.16 13.18 20053 26432861.0 + 1 2022-12-01 10:20:00 13.18 13.19 13.15 13.15 32667 43002512.0 + 2 2022-12-01 10:25:00 13.16 13.18 13.13 13.16 32466 42708049.0 + 3 2022-12-01 10:30:00 13.16 13.19 13.13 13.18 15606 20540461.0 + 4 2022-12-01 10:35:00 13.20 13.25 13.19 13.22 29959 39626170.0 + symbol + 0 000001.SZ + 1 000001.SZ + 2 000001.SZ + 3 000001.SZ + 4 000001.SZ + """ + field_list = ['time', 'open', 'high', 'low', 'close', 'volume', 'amount'] + if update: + xtdata.download_history_data(symbol, period, start_time='20100101', end_time='21000101') + local_data = xtdata.get_local_data(field_list, [symbol], period, count=count, dividend_type=dividend_type, + start_time=start_time, end_time=end_time, data_dir=data_dir) + + df = pd.DataFrame({key: value.values[0] for key, value in local_data.items()}) + df['time'] = pd.to_datetime(df['time'], unit='ms') + pd.to_timedelta('8H') + df.reset_index(inplace=True, drop=True) + df['symbol'] = symbol + return df + + +def get_symbols(step): + """获取择时策略投研不同阶段对应的标的列表 + + :param step: 投研阶段 + :return: 标的列表 + """ + stocks = xtdata.get_stock_list_in_sector('沪深A股') + stocks_map = { + "index": ['000905.SH', '000016.SH', '000300.SH', '000001.SH', '000852.SH', + '399001.SZ', '399006.SZ', '399376.SZ', '399377.SZ', '399317.SZ', '399303.SZ'], + "stock": stocks.ts_code.to_list(), + "check": ['000001.SZ'], + "train": stocks[:200], + "valid": stocks[200:600], + "etfs": ['512880.SH', '518880.SH', '515880.SH', '513050.SH', '512690.SH', + '512660.SH', '512400.SH', '512010.SH', '512000.SH', '510900.SH', + '510300.SH', '510500.SH', '510050.SH', '159992.SZ', '159985.SZ', + '159981.SZ', '159949.SZ', '159915.SZ'], + } + return stocks_map[step] + + +def test_local_kline(): + # 获取所有板块 + slt = xtdata.get_sector_list() + stocks = xtdata.get_stock_list_in_sector('沪深A股') + + df = get_local_kline(symbol='000001.SZ', period='1m', count=1000, dividend_type='front', + data_dir=r"D:\迅投极速策略交易系统交易终端 华鑫证券QMT实盘\datadir", + start_time='20200427', end_time='20221231', update=True) + assert not df.empty + # df = get_local_kline(symbol='000001.SZ', period='5m', count=1000, dividend_type='front', + # data_dir=r"D:\迅投极速策略交易系统交易终端 华鑫证券QMT实盘\datadir", + # start_time='20200427', end_time='20221231', update=False) + # df = get_local_kline(symbol='000001.SZ', period='1d', count=1000, dividend_type='front', + # data_dir=r"D:\迅投极速策略交易系统交易终端 华鑫证券QMT实盘\datadir", + # start_time='20200427', end_time='20221231', update=False) diff --git a/czsc/data/ts.py b/czsc/data/ts.py index 2a1de9163..f2b493478 100644 --- a/czsc/data/ts.py +++ b/czsc/data/ts.py @@ -15,8 +15,7 @@ from functools import partial from loguru import logger -from ..analyze import RawBar -from ..enum import Freq +from czsc.objects import RawBar, Freq # 数据频度 :支持分钟(min)/日(D)/周(W)/月(M)K线,其中1min表示1分钟(类推1/5/15/30/60分钟)。 @@ -25,18 +24,6 @@ Freq.F60: "60min", Freq.D: 'D', Freq.W: "W", Freq.M: "M"} freq_cn_map = {"1分钟": Freq.F1, "5分钟": Freq.F5, "15分钟": Freq.F15, "30分钟": Freq.F30, "60分钟": Freq.F60, "日线": Freq.D} -exchanges = { - "SSE": "上交所", - "SZSE": "深交所", - "CFFEX": "中金所", - "SHFE": "上期所", - "CZCE": "郑商所", - "DCE": "大商所", - "INE": "能源", - "IB": "银行间", - "XHKG": "港交所" -} - dt_fmt = "%Y-%m-%d %H:%M:%S" date_fmt = "%Y%m%d" diff --git a/czsc/gm_utils.py b/czsc/gm_utils.py deleted file mode 100644 index 9e42ea71d..000000000 --- a/czsc/gm_utils.py +++ /dev/null @@ -1,906 +0,0 @@ -# -*- coding: utf-8 -*- -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2021/11/17 22:11 -describe: 配合 CzscAdvancedTrader 进行使用的掘金工具 -""" -import os -import dill -import inspect -import czsc -import traceback -import pandas as pd -from gm.api import * -from loguru import logger -from datetime import datetime, timedelta -from collections import OrderedDict -from typing import List, Callable -from czsc import CzscAdvancedTrader, create_advanced_trader -from czsc.data import freq_cn2gm -from czsc.utils import qywx as wx -from czsc.utils import x_round, BarGenerator, create_logger -from czsc.objects import RawBar, Event, Freq, Operate, PositionLong, PositionShort - -logger.warning("gm_utils.py 即将废弃,请使用 czsc.gms 模块") - -dt_fmt = "%Y-%m-%d %H:%M:%S" -date_fmt = "%Y-%m-%d" - -assert czsc.__version__ >= "0.8.27" - - -def set_gm_token(token): - with open(os.path.join(os.path.expanduser("~"), "gm_token.txt"), 'w', encoding='utf-8') as f: - f.write(token) - - -file_token = os.path.join(os.path.expanduser("~"), "gm_token.txt") -if not os.path.exists(file_token): - print("{} 文件不存在,请单独启动一个 python 终端,调用 set_gm_token 方法创建该文件,再重新执行。".format(file_token)) -else: - gm_token = open(file_token, encoding="utf-8").read() - set_token(gm_token) - -indices = { - "上证指数": 'SHSE.000001', - "上证50": 'SHSE.000016', - "沪深300": "SHSE.000300", - "中证1000": "SHSE.000852", - "中证500": "SHSE.000905", - - "深证成指": "SZSE.399001", - "创业板指数": 'SZSE.399006', - "深次新股": "SZSE.399678", - "中小板指": "SZSE.399005", - "国证2000": "SZSE.399303", - "小盘成长": "SZSE.399376", - "小盘价值": "SZSE.399377", -} - - -def is_trade_date(dt): - """判断 dt 时刻是不是交易日期""" - dt = pd.to_datetime(dt) - date_ = dt.strftime("%Y-%m-%d") - trade_dates = get_trading_dates(exchange='SZSE', start_date=date_, end_date=date_) - if trade_dates: - return True - else: - return False - - -def is_trade_time(dt): - """判断 dt 时刻是不是交易时间""" - dt = pd.to_datetime(dt) - date_ = dt.strftime("%Y-%m-%d") - trade_dates = get_trading_dates(exchange='SZSE', start_date=date_, end_date=date_) - if trade_dates and "15:00" > dt.strftime("%H:%M") > "09:30": - return True - else: - return False - - -def get_stocks(): - """获取股票市场标的列表,包括股票、指数等""" - df = get_instruments(exchanges='SZSE,SHSE', fields="symbol,sec_name", df=True) - shares = {row['symbol']: row['sec_name'] for _, row in df.iterrows()} - return shares - - -def get_index_shares(name, end_date=None): - """获取某一交易日的指数成分股列表 - - symbols = get_index_shares("上证50", "2019-01-01 09:30:00") - """ - if not end_date: - end_date = datetime.now().strftime(date_fmt) - else: - end_date = pd.to_datetime(end_date).strftime(date_fmt) - constituents = get_history_constituents(indices[name], end_date, end_date)[0] - symbol_list = [k for k, v in constituents['constituents'].items()] - return list(set(symbol_list)) - - -def format_kline(df, freq: Freq): - bars = [] - for i, row in df.iterrows(): - # amount 单位:元 - bar = RawBar(symbol=row['symbol'], id=i, freq=freq, dt=row['eob'], open=round(row['open'], 2), - close=round(row['close'], 2), high=round(row['high'], 2), - low=round(row['low'], 2), vol=row['volume'], amount=row['amount']) - bars.append(bar) - return bars - - -def get_kline(symbol, end_time, freq='60s', count=33000, adjust=ADJUST_PREV): - """获取K线数据 - - :param symbol: 标的代码 - :param end_time: 结束时间 - :param freq: K线周期 - :param count: K线数量 - :param adjust: 复权方式 - :return: - """ - if isinstance(end_time, datetime): - end_time = end_time.strftime(dt_fmt) - - exchange = symbol.split(".")[0] - freq_map_ = {'60s': Freq.F1, '300s': Freq.F5, '900s': Freq.F15, '1800s': Freq.F30, - '3600s': Freq.F60, '1d': Freq.D} - - if exchange in ["SZSE", "SHSE"]: - df = history_n(symbol=symbol, frequency=freq, end_time=end_time, adjust=adjust, - fields='symbol,eob,open,close,high,low,volume,amount', count=count, df=True) - else: - df = history_n(symbol=symbol, frequency=freq, end_time=end_time, adjust=adjust, - fields='symbol,eob,open,close,high,low,volume,amount,position', count=count, df=True) - return format_kline(df, freq_map_[freq]) - - -def get_init_bg(symbol: str, - end_dt: [str, datetime], - base_freq: str, - freqs: List[str], - max_count=1000, - adjust=ADJUST_PREV): - """获取 symbol 的初始化 bar generator""" - if isinstance(end_dt, str): - end_dt = pd.to_datetime(end_dt, utc=True) - end_dt = end_dt.tz_convert('dateutil/PRC') - # 时区转换之后,要减去8个小时才是设置的时间 - end_dt = end_dt - timedelta(hours=8) - else: - assert end_dt.tzinfo._filename == 'PRC' - - delta_days = 180 - last_day = (end_dt - timedelta(days=delta_days)).replace(hour=16, minute=0) - - bg = BarGenerator(base_freq, freqs, max_count) - if "周线" in freqs or "月线" in freqs: - d_bars = get_kline(symbol, last_day, freq_cn2gm["日线"], count=5000, adjust=adjust) - bgd = BarGenerator("日线", ['周线', '月线', '季线', '年线']) - for b in d_bars: - bgd.update(b) - else: - bgd = None - - for freq in bg.bars.keys(): - if freq in ['周线', '月线', '季线', '年线']: - bars_ = bgd.bars[freq] - else: - bars_ = get_kline(symbol, last_day, freq_cn2gm[freq], max_count, adjust) - bg.init_freq_bars(freq, bars_) - print(f"{symbol} - {freq} - {len(bg.bars[freq])} - last_dt: {bg.bars[freq][-1].dt} - last_day: {last_day}") - - bars2 = get_kline(symbol, end_dt, freq_cn2gm[base_freq], - count=int(240 / int(base_freq.strip('分钟')) * delta_days)) - data = [x for x in bars2 if x.dt > last_day] - assert len(data) > 0 - print(f"{symbol}: bar generator 最新时间 {bg.bars[base_freq][-1].dt.strftime(dt_fmt)},还有{len(data)}行数据需要update") - return bg, data - - -order_side_map = {OrderSide_Unknown: '其他', OrderSide_Buy: '买入', OrderSide_Sell: '卖出'} -order_status_map = { - OrderStatus_Unknown: "其他", - OrderStatus_New: "已报", - OrderStatus_PartiallyFilled: "部成", - OrderStatus_Filled: "已成", - OrderStatus_Canceled: "已撤", - OrderStatus_PendingCancel: "待撤", - OrderStatus_Rejected: "已拒绝", - OrderStatus_Suspended: "挂起(无效)", - OrderStatus_PendingNew: "待报", - OrderStatus_Expired: "已过期", -} -pos_side_map = {PositionSide_Unknown: '其他', PositionSide_Long: '多头', PositionSide_Short: '空头'} -pos_effect_map = { - PositionEffect_Unknown: '其他', - PositionEffect_Open: '开仓', - PositionEffect_Close: '平仓', - PositionEffect_CloseToday: '平今仓', - PositionEffect_CloseYesterday: '平昨仓', -} -exec_type_map = { - ExecType_Unknown: "其他", - ExecType_New: "已报", - ExecType_Canceled: "已撤销", - ExecType_PendingCancel: "待撤销", - ExecType_Rejected: "已拒绝", - ExecType_Suspended: "挂起", - ExecType_PendingNew: "待报", - ExecType_Expired: "过期", - ExecType_Trade: "成交(有效)", - ExecType_OrderStatus: "委托状态", - ExecType_CancelRejected: "撤单被拒绝(有效)", -} - - -def on_order_status(context, order): - """ - https://www.myquant.cn/docs/python/python_object_trade#007ae8f5c7ec5298 - - :param context: - :param order: - :return: - """ - if not is_trade_time(context.now): - return - - symbol = order.symbol - latest_dt = context.now.strftime("%Y-%m-%d %H:%M:%S") - logger = context.logger - - if symbol not in context.symbols_info.keys(): - msg = f"订单状态更新通知:\n{'*' * 31}\n" \ - f"更新时间:{latest_dt}\n" \ - f"标的名称:{symbol} {context.stocks.get(symbol, '无名')}\n" \ - f"操作类型:{order_side_map[order.side]}{pos_effect_map[order.position_effect]}\n" \ - f"操作描述:非机器交易标的\n" \ - f"下单价格:{round(order.price, 2)}\n" \ - f"最新状态:{order_status_map[order.status]}\n" \ - f"委托(股):{int(order.volume)}\n" \ - f"已成(股):{int(order.filled_volume)}\n" \ - f"均价(元):{round(order.filled_vwap, 2)}" - - else: - trader: CzscAdvancedTrader = context.symbols_info[symbol]['trader'] - if trader.long_pos.operates: - last_op_desc = trader.long_pos.operates[-1]['op_desc'] - else: - last_op_desc = "" - - msg = f"订单状态更新通知:\n{'*' * 31}\n" \ - f"更新时间:{latest_dt}\n" \ - f"标的名称:{symbol} {context.stocks.get(symbol, '无名')}\n" \ - f"操作类型:{order_side_map[order.side]}{pos_effect_map[order.position_effect]}\n" \ - f"操作描述:{last_op_desc}\n" \ - f"下单价格:{round(order.price, 2)}\n" \ - f"最新状态:{order_status_map[order.status]}\n" \ - f"委托(股):{int(order.volume)}\n" \ - f"已成(股):{int(order.filled_volume)}\n" \ - f"均价(元):{round(order.filled_vwap, 2)}" - - logger.info(msg.replace("\n", " - ").replace('*', "")) - if context.mode != MODE_BACKTEST and order.status in [1, 3, 5, 8, 9, 12]: - wx.push_text(content=str(msg), key=context.wx_key) - - -def on_execution_report(context, execrpt): - """响应委托被执行事件,委托成交或者撤单拒绝后被触发。 - - https://www.myquant.cn/docs/python/python_trade_event#on_execution_report%20-%20%E5%A7%94%E6%89%98%E6%89%A7%E8%A1%8C%E5%9B%9E%E6%8A%A5%E4%BA%8B%E4%BB%B6 - https://www.myquant.cn/docs/python/python_object_trade#ExecRpt%20-%20%E5%9B%9E%E6%8A%A5%E5%AF%B9%E8%B1%A1 - - :param context: - :param execrpt: - :return: - """ - if not is_trade_time(context.now): - return - - latest_dt = context.now.strftime(dt_fmt) - logger = context.logger - msg = f"委托订单被执行通知:\n{'*' * 31}\n" \ - f"时间:{latest_dt}\n" \ - f"标的:{execrpt.symbol}\n" \ - f"名称:{context.stocks.get(execrpt.symbol, '无名')}\n" \ - f"方向:{order_side_map[execrpt.side]}{pos_effect_map[execrpt.position_effect]}\n" \ - f"成交量:{int(execrpt.volume)}\n" \ - f"成交价:{round(execrpt.price, 2)}\n" \ - f"执行回报类型:{exec_type_map[execrpt.exec_type]}" - - logger.info(msg.replace("\n", " - ").replace('*', "")) - if context.mode != MODE_BACKTEST and execrpt.exec_type in [1, 5, 6, 8, 12, 19]: - wx.push_text(content=str(msg), key=context.wx_key) - - -def on_backtest_finished(context, indicator): - """回测结束回调函数 - - :param context: - :param indicator: - https://www.myquant.cn/docs/python/python_object_trade#bd7f5adf22081af5 - :return: - """ - wx_key = context.wx_key - symbols = context.symbols - data_path = context.data_path - logger = context.logger - - logger.info(str(indicator)) - logger.info("回测结束 ... ") - cash = context.account().cash - - for k, v in indicator.items(): - if isinstance(v, float): - indicator[k] = round(v, 4) - - row = OrderedDict({ - "标的数量": len(context.symbols_info.keys()), - "开始时间": context.backtest_start_time, - "结束时间": context.backtest_end_time, - "累计收益": indicator['pnl_ratio'], - "最大回撤": indicator['max_drawdown'], - "年化收益": indicator['pnl_ratio_annual'], - "夏普比率": indicator['sharp_ratio'], - "盈利次数": indicator['win_count'], - "亏损次数": indicator['lose_count'], - "交易胜率": indicator['win_ratio'], - "累计出入金": int(cash['cum_inout']), - "累计交易额": int(cash['cum_trade']), - "累计手续费": int(cash['cum_commission']), - "累计平仓收益": int(cash['cum_pnl']), - "净收益": int(cash['pnl']), - }) - sdt = pd.to_datetime(context.backtest_start_time).strftime('%Y%m%d') - edt = pd.to_datetime(context.backtest_end_time).strftime('%Y%m%d') - file_xlsx = os.path.join(data_path, f'{context.name}_{sdt}_{edt}.xlsx') - file = pd.ExcelWriter(file_xlsx, mode='w') - - dfe = pd.DataFrame({"指标": list(row.keys()), "值": list(row.values())}) - dfe.to_excel(file, sheet_name='回测表现', index=False) - - logger.info("回测结果:{}".format(row)) - content = "" - for k, v in row.items(): - content += "{}: {}\n".format(k, v) - wx.push_text(content=content, key=wx_key) - - trades = [] - operates = [] - performances = [] - for symbol in symbols: - trader: CzscAdvancedTrader = context.symbols_info[symbol]['trader'] - trades.extend(trader.long_pos.pairs) - operates.extend(trader.long_pos.operates) - performances.append(trader.long_pos.evaluate_operates()) - - df = pd.DataFrame(trades) - df['开仓时间'] = df['开仓时间'].apply(lambda x: x.strftime("%Y-%m-%d %H:%M")) - df['平仓时间'] = df['平仓时间'].apply(lambda x: x.strftime("%Y-%m-%d %H:%M")) - df.to_excel(file, sheet_name='交易汇总', index=False) - - dfo = pd.DataFrame(operates) - dfo['dt'] = dfo['dt'].apply(lambda x: x.strftime("%Y-%m-%d %H:%M")) - dfo.to_excel(file, sheet_name='操作汇总', index=False) - - dfp = pd.DataFrame(performances) - dfp.to_excel(file, sheet_name='表现汇总', index=False) - file.close() - - wx.push_file(file_xlsx, wx_key) - - -def on_error(context, code, info): - if not is_trade_time(context.now): - return - - logger = context.logger - msg = "{} - {}".format(code, info) - logger.warn(msg) - if context.mode != MODE_BACKTEST: - wx.push_text(content=msg, key=context.wx_key) - - -def on_account_status(context, account): - """响应交易账户状态更新事件,交易账户状态变化时被触发 - https://www.myquant.cn/docs/python/python_trade_event#4f07d24fc4314e3c - """ - status = account['status'] - if status['state'] == 3: - return - - if not is_trade_time(context.now): - return - - logger = context.logger - msg = f"{str(account)}" - logger.warn(msg) - if context.mode != MODE_BACKTEST: - wx.push_text(content=msg, key=context.wx_key) - - -def on_bar(context, bars): - """订阅K线回调函数""" - context.unfinished_orders = get_unfinished_orders() - cancel_timeout_orders(context, max_m=30) - - for bar in bars: - symbol = bar['symbol'] - trader: CzscAdvancedTrader = context.symbols_info[symbol]['trader'] - - # 确保数据更新到最新时刻 - base_freq = trader.base_freq - bars = context.data(symbol=symbol, frequency=freq_cn2gm[base_freq], count=100, - fields='symbol,eob,open,close,high,low,volume,amount') - bars = format_kline(bars, freq=trader.bg.freq_map[base_freq]) - bars_new = [x for x in bars if x.dt > trader.bg.bars[base_freq][-1].dt] - if bars_new: - for bar_ in bars_new: - trader.update(bar_) - - sync_long_position(context, trader) - - -def is_order_exist(context, symbol, side) -> bool: - """判断同方向订单是否已经存在 - - :param context: - :param symbol: 交易标的 - :param side: 交易方向 - :return: bool - """ - uo = context.unfinished_orders - if not uo: - return False - else: - for o in uo: - if o.symbol == symbol and o.side == side: - context.logger.info("同类型订单已存在:{} - {}".format(symbol, side)) - return True - return False - - -def cancel_timeout_orders(context, max_m=30): - """实盘仿真,撤销挂单时间超过 max_m 分钟的订单。 - - :param context: - :param max_m: 最大允许挂单分钟数 - :return: - """ - for u_order in context.unfinished_orders: - if context.now - u_order.created_at >= timedelta(minutes=max_m): - order_cancel(u_order) - - -def report_account_status(context): - """报告账户持仓状态""" - if context.now.isoweekday() > 5: - return - - logger = context.logger - latest_dt = context.now.strftime(dt_fmt) - account = context.account(account_id=context.account_id) - cash = account.cash - positions = account.positions() - - logger.info("=" * 30 + f" 账户状态【{latest_dt}】 " + "=" * 30) - cash_report = f"净值:{int(cash.nav)},可用资金:{int(cash.available)}," \ - f"浮动盈亏:{int(cash.fpnl)},标的数量:{len(positions)}" - logger.info(cash_report) - - for p in positions: - p_report = f"标的:{p.symbol},名称:{context.stocks.get(p.symbol, '无名')}," \ - f"数量:{p.volume},成本:{round(p.vwap, 2)},方向:{p.side}," \ - f"当前价:{round(p.price, 2)},成本市值:{int(p.volume * p.vwap)}," \ - f"建仓时间:{p.created_at.strftime(dt_fmt)}" - logger.info(p_report) - - # 实盘或仿真,推送账户信息到企业微信 - if context.mode != MODE_BACKTEST: - - msg = f"股票账户状态报告\n{'*' * 31}\n" - msg += f"账户净值:{int(cash.nav)}\n" \ - f"持仓市值:{int(cash.market_value)}\n" \ - f"可用资金:{int(cash.available)}\n" \ - f"浮动盈亏:{int(cash.fpnl)}\n" \ - f"标的数量:{len(positions)}\n" - wx.push_text(msg.strip("\n *"), key=context.wx_key) - - results = [] - for symbol, info in context.symbols_info.items(): - name = context.stocks.get(symbol, '无名') - trader: CzscAdvancedTrader = context.symbols_info[symbol]['trader'] - p = account.position(symbol=symbol, side=PositionSide_Long) - - row = {'交易标的': symbol, '标的名称': name, - '最新时间': trader.end_dt.strftime(dt_fmt), - '最新价格': trader.latest_price} - - if "日线" in trader.kas.keys(): - bar1, bar2 = trader.kas['日线'].bars_raw[-2:] - row.update({'昨日收盘': round(bar1.close, 2), - '今日涨幅': round(bar2.close / bar1.close - 1, 4)}) - - if trader.long_pos.pos > 0: - row.update({'多头持仓': trader.long_pos.pos, - '多头成本': trader.long_pos.long_cost, - '多头收益': round(trader.latest_price / trader.long_pos.long_cost - 1, 4), - '开多时间': trader.long_pos.operates[-1]['dt'].strftime(dt_fmt)}) - else: - row.update({'多头持仓': 0, '多头成本': 0, '多头收益': 0, '开多时间': None}) - - if p: - row.update({"实盘持仓数量": p.volume, - "实盘持仓成本": x_round(p.vwap, 2), - "实盘持仓市值": int(p.volume * p.vwap)}) - else: - row.update({"实盘持仓数量": 0, "实盘持仓成本": 0, "实盘持仓市值": 0}) - - results.append(row) - - df = pd.DataFrame(results) - df.sort_values(['多头持仓', '多头收益'], ascending=False, inplace=True, ignore_index=True) - file_xlsx = os.path.join(context.data_path, f"holds_{context.now.strftime('%Y%m%d_%H%M')}.xlsx") - df.to_excel(file_xlsx, index=False) - wx.push_file(file_xlsx, key=context.wx_key) - os.remove(file_xlsx) - - # 提示非策略交易标的持仓 - process_out_of_symbols(context) - - -def sync_long_position(context, trader: CzscAdvancedTrader): - """同步多头仓位到交易账户""" - if not trader.long_events: - return - - symbol = trader.symbol - name = context.stocks.get(symbol, "无名标的") - long_pos: PositionLong = trader.long_pos - max_sym_pos = context.symbols_info[symbol]['max_sym_pos'] # 最大标的仓位 - logger = context.logger - if context.mode == MODE_BACKTEST: - account = context.account() - else: - account = context.account(account_id=context.account_id) - cash = account.cash - - price = trader.latest_price - print(f"{trader.end_dt}: {name},多头:{long_pos.pos},成本:{long_pos.long_cost}," - f"现价:{price},操作次数:{len(long_pos.operates)}") - - algo_name = os.environ.get('algo_name', None) - if algo_name: - # 算法名称,TWAP、VWAP、ATS-SMART、ZC-POV - algo_name = algo_name.upper() - start_time = trader.end_dt.strftime("%H:%M:%S") - end_time = (trader.end_dt + timedelta(minutes=30)).strftime("%H:%M:%S") - end_time = min(end_time, '14:55:00') - - if algo_name == 'TWAP' or algo_name == 'VWAP' or algo_name == 'ZC-POV': - algo_param = { - "start_time": start_time, - "end_time": end_time, - "part_rate": 0.5, - "min_amount": 5000, - } - elif algo_name == 'ATS-SMART': - algo_param = { - 'start_time': start_time, - 'end_time_referred': end_time, - 'end_time': end_time, - 'end_time_valid': 1, - 'stop_sell_when_dl': 1, - 'cancel_when_pl': 0, - 'min_trade_amount': 5000 - } - else: - raise ValueError("算法单名称输入错误") - else: - algo_param = {} - - sym_position = account.position(symbol, PositionSide_Long) - if long_pos.pos == 0 and not sym_position: - # 如果多头仓位为0且掘金账户没有对应持仓,直接退出 - return - - if long_pos.pos == 0 and sym_position and sym_position.volume > 0: - # 如果多头仓位为0且掘金账户依然还有持仓,清掉仓位 - volume = sym_position.volume - if algo_name: - assert len(algo_param) > 0, f"error: {algo_name}, {algo_param}" - _ = algo_order(symbol=symbol, volume=volume, side=OrderSide_Sell, - order_type=OrderType_Limit, position_effect=PositionEffect_Close, - price=price, algo_name=algo_name, algo_param=algo_param, account=account.id) - else: - order_target_volume(symbol=symbol, volume=0, position_side=PositionSide_Long, - order_type=OrderType_Limit, price=price, account=account.id) - return - - if not long_pos.pos_changed: - return - - assert long_pos.pos > 0 - cash_left = cash.available - if long_pos.operates[-1]['op'] in [Operate.LO, Operate.LA1, Operate.LA2]: - change_amount = max_sym_pos * long_pos.operates[-1]['pos_change'] * cash.nav - if cash_left < change_amount: - logger.info(f"{context.now} {symbol} {name} 可用资金不足,无法开多仓;" - f"剩余资金{int(cash_left)}元,所需资金{int(change_amount)}元") - return - - if is_order_exist(context, symbol, PositionSide_Long): - logger.info(f"{context.now} {symbol} {name} 同方向订单已存在") - return - - percent = max_sym_pos * long_pos.pos - volume = int((cash.nav * percent / price // 100) * 100) # 单位:股 - if algo_name: - _ = algo_order(symbol=symbol, volume=volume, side=OrderSide_Buy, - order_type=OrderType_Limit, position_effect=PositionEffect_Open, - price=price, algo_name=algo_name, algo_param=algo_param, account=account.id) - else: - order_target_volume(symbol=symbol, volume=volume, position_side=PositionSide_Long, - order_type=OrderType_Limit, price=price, account=account.id) - - -def sync_short_position(trader: CzscAdvancedTrader, context): - """同步空头仓位到交易账户""" - if not trader.short_events: - return - - symbol = trader.symbol - name = context.stocks.get(symbol, "无名标的") - short_pos: PositionShort = trader.short_pos - max_sym_pos = context.symbols_info[symbol]['max_sym_pos'] # 最大标的仓位 - logger = context.logger - if context.mode == MODE_BACKTEST: - account = context.account() - else: - account = context.account(account_id=context.account_id) - cash = account.cash - - price = trader.latest_price - print(f"{trader.end_dt}: {name},空头:{short_pos.pos},成本:{short_pos.short_cost}," - f"现价:{price},操作次数:{len(short_pos.operates)}") - - sym_position = account.position(symbol, PositionSide_Short) - if short_pos.pos == 0 and sym_position and sym_position.volume > 0: - order_target_percent(symbol=symbol, percent=0, position_side=PositionSide_Short, - order_type=OrderType_Limit, price=price, account=account.id) - return - - if not short_pos.pos_changed: - return - - cash_left = cash.available - if short_pos.operates[-1]['op'] in [Operate.SO, Operate.SA1, Operate.SA2]: - change_amount = max_sym_pos * short_pos.operates[-1]['pos_change'] * cash.nav - if cash_left < change_amount: - logger.info(f"{context.now} {symbol} {name} 可用资金不足,无法开空仓;" - f"剩余资金{int(cash_left)}元,所需资金{int(change_amount)}元") - return - - if is_order_exist(context, symbol, PositionSide_Long): - logger.info(f"{context.now} {symbol} {name} 同方向订单已存在") - return - - percent = max_sym_pos * short_pos.pos - order_target_percent(symbol=symbol, percent=percent, position_side=PositionSide_Short, - order_type=OrderType_Limit, price=price, account=account.id) - - -def gm_take_snapshot(gm_symbol, end_dt=None, file_html=None, - freqs=('1分钟', '5分钟', '15分钟', '30分钟', '60分钟', '日线', '周线', '月线'), - adjust=ADJUST_PREV, max_count=1000): - """使用掘金的数据对任意标的、任意时刻的状态进行快照 - - :param gm_symbol: - :param end_dt: - :param file_html: - :param freqs: - :param adjust: - :param max_count: - :return: - """ - if not end_dt: - end_dt = datetime.now().strftime(dt_fmt) - - bg, data = get_init_bg(gm_symbol, end_dt, freqs[0], freqs[1:], max_count, adjust) - ct = CzscAdvancedTrader(bg) - for bar in data: - ct.update(bar) - - if file_html: - ct.take_snapshot(file_html) - print(f'saved into {file_html}') - else: - ct.open_in_browser() - return ct - - -def trader_tactic_snapshot(symbol, strategy: Callable, end_dt=None, file_html=None, adjust=ADJUST_PREV, max_count=1000): - """使用掘金的数据对任意标的、任意时刻的状态进行策略快照 - - :param symbol: 交易标的 - :param strategy: 择时交易策略 - :param end_dt: 结束时间,精确到分钟 - :param file_html: 结果文件 - :param adjust: 复权类型 - :param max_count: 最大K线数量 - :return: trader - """ - tactic = strategy(symbol) - base_freq = tactic['base_freq'] - freqs = tactic['freqs'] - bg, data = get_init_bg(symbol, end_dt, base_freq, freqs, max_count, adjust) - trader = create_advanced_trader(bg, data, strategy) - if file_html: - trader.take_snapshot(file_html) - print(f'saved into {file_html}') - else: - trader.open_in_browser() - return trader - - -def check_index_status(qywx_key): - """查看主要指数状态""" - from czsc.utils.cache import home_path - - wx.push_text(f"{datetime.now()} 开始获取主要指数行情快照", qywx_key) - for gm_symbol in indices.values(): - try: - file_html = os.path.join(home_path, f"{gm_symbol}_{datetime.now().strftime('%Y%m%d')}.html") - gm_take_snapshot(gm_symbol, file_html=file_html) - wx.push_file(file_html, qywx_key) - os.remove(file_html) - except: - traceback.print_exc() - wx.push_text(f"{datetime.now()} 获取主要指数行情快照获取结束,请仔细观察!!!", qywx_key) - - -def realtime_check_index_status(context): - """实盘:发送主要指数行情图表""" - if context.now.isoweekday() > 5: - print(f"realtime_check_index_status: {context.now} 不是交易时间") - return - - check_index_status(context.wx_key) - - -def process_out_of_symbols(context): - """实盘:处理不在交易列表的持仓股""" - if context.now.isoweekday() > 5: - print(f"process_out_of_symbols: {context.now} 不是交易时间") - return - - if context.mode == MODE_BACKTEST: - print(f"process_out_of_symbols: 回测模式下不需要执行") - return - - account = context.account(account_id=context.account_id) - positions = account.positions(symbol="", side=PositionSide_Long) - - oos = [] - for p in positions: - symbol = p.symbol - if p.volume > 0 and p.symbol not in context.symbols_info.keys(): - oos.append(symbol) - # order_target_volume(symbol=symbol, volume=0, position_side=PositionSide_Long, - # order_type=OrderType_Limit, price=p.price, account=account.id) - if oos: - wx.push_text(f"不在交易列表的持仓股:{', '.join(oos)}", context.wx_key) - - -def save_traders(context): - """实盘:保存交易员快照""" - if context.now.isoweekday() > 5: - print(f"save_traders: {context.now} 不是交易时间") - return - - for symbol in context.symbols_info.keys(): - trader: CzscAdvancedTrader = context.symbols_info[symbol]['trader'] - if context.mode != MODE_BACKTEST: - file_trader = os.path.join(context.data_path, f'traders/{symbol}.cat') - dill.dump(trader, open(file_trader, 'wb')) - - -def init_context_universal(context, name): - """通用 context 初始化:1、创建文件目录和日志记录 - - :param context: - :param name: 交易策略名称,建议使用英文 - """ - path_gm_logs = os.environ.get('path_gm_logs', None) - if context.mode == MODE_BACKTEST: - data_path = os.path.join(path_gm_logs, f"backtest/{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}") - else: - data_path = os.path.join(path_gm_logs, f"realtime/{name}") - os.makedirs(data_path, exist_ok=True) - - context.name = name - context.data_path = data_path - context.stocks = get_stocks() - context.logger = create_logger(os.path.join(data_path, "gm_trader.log"), cmd=True, name="gm") - - context.logger.info("运行配置:") - context.logger.info(f"data_path = {data_path}") - - if context.mode == MODE_BACKTEST: - context.logger.info("backtest_start_time = " + str(context.backtest_start_time)) - context.logger.info("backtest_end_time = " + str(context.backtest_end_time)) - - -def init_context_env(context): - """通用 context 初始化:2、读入环境变量 - - :param context: - """ - context.wx_key = os.environ['wx_key'] - context.account_id = os.environ.get('account_id', '') - if context.mode != MODE_BACKTEST: - assert len(context.account_id) > 10, "非回测模式,必须设置 account_id " - - # 单个标的仓位控制[0, 1],按资金百分比控制,1表示满仓,仅在开仓的时候控制 - context.max_sym_pos = float(os.environ['max_sym_pos']) - assert 0 <= context.max_sym_pos <= 1 - - logger = context.logger - logger.info(f"环境变量读取结果如下:") - logger.info(f"单标的控制:context.max_sym_pos = {context.max_sym_pos}") - - -def init_context_traders(context, symbols: List[str], strategy: Callable): - """通用 context 初始化:3、为每个标的创建 trader - - :param context: - :param symbols: 交易标的列表 - :param strategy: 交易策略 - :return: - """ - with open(os.path.join(context.data_path, f'{strategy.__name__}.txt'), mode='w') as f: - f.write(inspect.getsource(strategy)) - - tactic = strategy("000001") - base_freq, freqs = tactic['base_freq'], tactic['freqs'] - frequency = freq_cn2gm[base_freq] - unsubscribe(symbols='*', frequency=frequency) - - data_path = context.data_path - logger = context.logger - logger.info(f"输入交易标的数量:{len(symbols)}") - logger.info(f"交易员的周期列表:base_freq = {base_freq}; freqs = {freqs}") - - os.makedirs(os.path.join(data_path, 'traders'), exist_ok=True) - symbols_info = {symbol: dict() for symbol in symbols} - for symbol in symbols: - try: - symbols_info[symbol]['max_sym_pos'] = context.max_sym_pos - file_trader = os.path.join(data_path, f'traders/{symbol}.cat') - - if os.path.exists(file_trader) and context.mode != MODE_BACKTEST: - trader: CzscAdvancedTrader = dill.load(open(file_trader, 'rb')) - logger.info(f"{symbol} Loaded Trader from {file_trader}") - - else: - bg, data = get_init_bg(symbol, context.now, base_freq, freqs, 1000, ADJUST_PREV) - trader = create_advanced_trader(bg, data, strategy) - dill.dump(trader, open(file_trader, 'wb')) - - symbols_info[symbol]['trader'] = trader - logger.info("{} Trader 构建成功,最新时间:{},多仓:{}".format(symbol, trader.end_dt, trader.long_pos.pos)) - - except: - del symbols_info[symbol] - logger.info(f"{symbol} - {context.stocks.get(symbol, '无名')} 初始化失败,当前时间:{context.now}") - traceback.print_exc() - - subscribe(",".join(symbols_info.keys()), frequency=frequency, count=300, wait_group=False) - logger.info(f"订阅成功数量:{len(symbols_info)}") - logger.info(f"交易标的配置:{symbols_info}") - context.symbols_info = symbols_info - - -def init_context_schedule(context): - """通用 context 初始化:设置定时任务""" - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='09:31:00') - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='10:01:00') - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='10:31:00') - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='11:01:00') - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='11:31:00') - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='13:01:00') - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='13:31:00') - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='14:01:00') - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='14:31:00') - schedule(schedule_func=report_account_status, date_rule='1d', time_rule='15:01:00') - - # 以下是 实盘/仿真 模式下的定时任务 - if context.mode != MODE_BACKTEST: - schedule(schedule_func=save_traders, date_rule='1d', time_rule='11:40:00') - schedule(schedule_func=save_traders, date_rule='1d', time_rule='15:10:00') - # schedule(schedule_func=realtime_check_index_status, date_rule='1d', time_rule='17:30:00') - # schedule(schedule_func=process_out_of_symbols, date_rule='1d', time_rule='09:40:00') diff --git a/czsc/objects.py b/czsc/objects.py index d97fc48a8..9a5abb000 100644 --- a/czsc/objects.py +++ b/czsc/objects.py @@ -8,11 +8,11 @@ import math from dataclasses import dataclass from datetime import datetime -from typing import List +from loguru import logger +from typing import List, Callable from transitions import Machine from czsc.enum import Mark, Direction, Freq, Operate -from czsc.utils.ta import RSQ - +from czsc.utils.corr import single_linear long_operates = [Operate.HO, Operate.LO, Operate.LA1, Operate.LA2, Operate.LE, Operate.LR1, Operate.LR2] shor_operates = [Operate.HO, Operate.SO, Operate.SA1, Operate.SA2, Operate.SE, Operate.SR1, Operate.SR2] @@ -39,7 +39,7 @@ class RawBar: low: [float, int] vol: [float, int] amount: [float, int] = None - cache: dict = None # cache 用户缓存,一个最常见的场景是缓存技术指标计算结果 + cache: dict = None # cache 用户缓存,一个最常见的场景是缓存技术指标计算结果 @property def upper(self): @@ -70,8 +70,8 @@ class NewBar: low: [float, int] vol: [float, int] amount: [float, int] = None - elements: List = None # 存入具有包含关系的原始K线 - cache: dict = None # cache 用户缓存 + elements: List = None # 存入具有包含关系的原始K线 + cache: dict = None # cache 用户缓存 @property def raw_bars(self): @@ -87,7 +87,7 @@ class FX: low: [float, int] fx: [float, int] elements: List = None - cache: dict = None # cache 用户缓存 + cache: dict = None # cache 用户缓存 @property def new_bars(self): @@ -149,7 +149,7 @@ class FakeBI: high: [float, int] low: [float, int] power: [float, int] - cache: dict = None # cache 用户缓存 + cache: dict = None # cache 用户缓存 def create_fake_bis(fxs: List[FX]) -> List[FakeBI]: @@ -163,15 +163,15 @@ def create_fake_bis(fxs: List[FX]) -> List[FakeBI]: fake_bis = [] for i in range(1, len(fxs)): - fx1 = fxs[i-1] + fx1 = fxs[i - 1] fx2 = fxs[i] assert fx1.mark != fx2.mark if fx1.mark == Mark.D: fake_bi = FakeBI(symbol=fx1.symbol, sdt=fx1.dt, edt=fx2.dt, direction=Direction.Up, - high=fx2.high, low=fx1.low, power=round(fx2.high-fx1.low, 2)) + high=fx2.high, low=fx1.low, power=round(fx2.high - fx1.low, 2)) elif fx1.mark == Mark.G: fake_bi = FakeBI(symbol=fx1.symbol, sdt=fx1.dt, edt=fx2.dt, direction=Direction.Down, - high=fx1.high, low=fx2.low, power=round(fx1.high-fx2.low, 2)) + high=fx1.high, low=fx2.low, power=round(fx1.high - fx2.low, 2)) else: raise ValueError fake_bis.append(fake_bi) @@ -181,9 +181,9 @@ def create_fake_bis(fxs: List[FX]) -> List[FakeBI]: @dataclass class BI: symbol: str - fx_a: FX = None # 笔开始的分型 - fx_b: FX = None # 笔结束的分型 - fxs: List = None # 笔内部的分型列表 + fx_a: FX = None # 笔开始的分型 + fx_b: FX = None # 笔结束的分型 + fxs: List = None # 笔内部的分型列表 direction: Direction = None bars: List[NewBar] = None cache: dict = None # cache 用户缓存 @@ -196,15 +196,57 @@ def __repr__(self): return f"BI(symbol={self.symbol}, sdt={self.sdt}, edt={self.edt}, " \ f"direction={self.direction}, high={self.high}, low={self.low})" + def get_cache_with_default(self, key, default: Callable): + """带有默认值计算的缓存读取 + + :param key: 缓存 key + :param default: 如果没有缓存数据,用来计算默认值并更新缓存的函数 + :return: + """ + cache = self.cache if self.cache else {} + value = cache.get(key, None) + if not value: + value = default() + cache[key] = value + self.cache = cache + return value + + def get_price_linear(self, price_key="close"): + """计算 price 的单变量线性回归特征 + + :param price_key: 指定价格类型,可选值 open close high low + :return value: 单变量线性回归特征,样例如下 + {'slope': 1.565, 'intercept': 67.9783, 'r2': 0.9967} + + slope 标识斜率 + intercept 截距 + r2 拟合优度 + """ + cache = self.cache if self.cache else {} + key = f"{price_key}_linear_info" + value = cache.get(key, None) + + if not value: + value = single_linear([x.__dict__[price_key] for x in self.raw_bars]) + cache[key] = value + self.cache = cache + return value + # 定义一些附加属性,用的时候才会计算,提高效率 # ====================================================================== @property def fake_bis(self): - return create_fake_bis(self.fxs) + """笔的内部分型连接得到近似次级别笔列表""" + + def __default(): return create_fake_bis(self.fxs) + + return self.get_cache_with_default('fake_bis', __default) @property def high(self): - return max(self.fx_a.high, self.fx_b.high) + def __default(): return max(self.fx_a.high, self.fx_b.high) + + return self.get_cache_with_default('high', __default) @property def low(self): @@ -237,22 +279,26 @@ def length(self): @property def rsq(self): - """笔的斜率""" - close = [x.close for x in self.raw_bars] - return round(RSQ(close), 4) + """笔的原始K线 close 单变量线性回归 r2""" + value = self.get_price_linear('close') + return round(value['r2'], 4) @property def raw_bars(self): """构成笔的原始K线序列""" - x = [] - for bar in self.bars[1:-1]: - x.extend(bar.raw_bars) - return x + + def __default(): + value = [] + for bar in self.bars[1:-1]: + value.extend(bar.raw_bars) + return value + + return self.get_cache_with_default('raw_bars', __default) @property def hypotenuse(self): """笔的斜边长度""" - return pow(pow(self.power_price, 2) + pow(len(self.raw_bars), 2), 1/2) + return pow(pow(self.power_price, 2) + pow(len(self.raw_bars), 2), 1 / 2) @property def angle(self): @@ -265,7 +311,7 @@ class ZS: """中枢对象,主要用于辅助信号函数计算""" symbol: str bis: List[BI] - cache: dict = None # cache 用户缓存 + cache: dict = None # cache 用户缓存 @property def sdt(self): @@ -315,6 +361,7 @@ def __repr__(self): f"len_bis={len(self.bis)}, zg={self.zg}, zd={self.zd}, " \ f"gg={self.gg}, dd={self.dd}, zz={self.zz})" + @dataclass class Signal: signal: str = None @@ -413,6 +460,35 @@ def is_match(self, s: dict) -> bool: return True return False + def dump(self) -> dict: + """将 Factor 对象转存为 dict""" + raw = { + "name": self.name, + "signals_all": [x.signal for x in self.signals_all], + "signals_any": [] if not self.signals_any else [x.signal for x in self.signals_any], + "signals_not": [] if not self.signals_not else [x.signal for x in self.signals_not], + } + return raw + + @classmethod + def load(cls, raw: dict): + """从 dict 中创建 Factor + + :param raw: 样例如下 + {'name': '单测', + 'signals_all': ['15分钟_倒0笔_方向_向上_其他_其他_0', '15分钟_倒0笔_长度_大于5_其他_其他_0'], + 'signals_any': [], + 'signals_not': []} + + :return: + """ + fa = Factor(name=raw['name'], + signals_all=[Signal(x) for x in raw['signals_all']], + signals_any=[Signal(x) for x in raw['signals_any']] if raw['signals_any'] else None, + signals_not=[Signal(x) for x in raw['signals_not']] if raw['signals_not'] else None + ) + return fa + @dataclass class Event: @@ -465,6 +541,42 @@ def is_match(self, s: dict): return False, None + def dump(self) -> dict: + """将 Event 对象转存为 dict""" + raw = { + "name": self.name, + "operate": self.operate.value, + "signals_all": [] if not self.signals_all else [x.signal for x in self.signals_all], + "signals_any": [] if not self.signals_any else [x.signal for x in self.signals_any], + "signals_not": [] if not self.signals_not else [x.signal for x in self.signals_not], + "factors": [x.dump() for x in self.factors], + } + return raw + + @classmethod + def load(cls, raw: dict): + """从 dict 中创建 Event + + :param raw: 样例如下 + {'name': '单测', + 'operate': '开多', + 'factors': [{'name': '测试', + 'signals_all': ['15分钟_倒0笔_长度_大于5_其他_其他_0'], + 'signals_any': [], + 'signals_not': []}], + 'signals_all': ['15分钟_倒0笔_方向_向上_其他_其他_0'], + 'signals_any': [], + 'signals_not': []} + :return: + """ + e = Event(name=raw['name'], operate=Operate.__dict__["_value2member_map_"][raw['operate']], + factors=[Factor.load(x) for x in raw['factors']], + signals_all=[Signal(x) for x in raw['signals_all']] if raw['signals_all'] else None, + signals_any=[Signal(x) for x in raw['signals_any']] if raw['signals_any'] else None, + signals_not=[Signal(x) for x in raw['signals_not']] if raw['signals_not'] else None + ) + return e + def cal_break_even_point(seq: List[float]) -> float: """计算单笔收益序列的盈亏平衡点 @@ -577,9 +689,9 @@ def __init__(self, symbol: str, self.operates = [] self.last_pair_operates = [] self.pairs = [] - self.long_high = -1 # 持多仓期间出现的最高价 - self.long_cost = -1 # 最近一次加多仓的成本 - self.long_bid = -1 # 最近一次加多仓的1分钟Bar ID + self.long_high = -1 # 持多仓期间出现的最高价 + self.long_cost = -1 # 最近一次加多仓的成本 + self.long_bid = -1 # 最近一次加多仓的1分钟Bar ID self.today = None self.today_pos = 0 @@ -615,7 +727,7 @@ def operates_to_pair(self, operates): '持仓K线数': operates[-1]['bid'] - operates[0]['bid'], '事件序列': " > ".join([x['op_desc'] for x in operates]), } - pair['持仓天数'] = (pair['平仓时间'] - pair['开仓时间']).total_seconds() / (24*3600) + pair['持仓天数'] = (pair['平仓时间'] - pair['开仓时间']).total_seconds() / (24 * 3600) pair['盈亏金额'] = pair['累计平仓'] - pair['累计开仓'] # 注意:【交易盈亏】的计算是对交易进行的,不是对账户,所以不能用来统计账户的收益 pair['交易盈亏'] = int((pair['盈亏金额'] / pair['累计开仓']) * 10000) / 10000 @@ -756,9 +868,9 @@ def __init__(self, symbol: str, self.operates = [] self.last_pair_operates = [] self.pairs = [] - self.short_low = -1 # 持多仓期间出现的最低价 - self.short_cost = -1 # 最近一次加空仓的成本 - self.short_bid = -1 # 最近一次加空仓的1分钟Bar ID + self.short_low = -1 # 持多仓期间出现的最低价 + self.short_cost = -1 # 最近一次加空仓的成本 + self.short_bid = -1 # 最近一次加空仓的1分钟Bar ID self.today = None self.today_pos = 0 @@ -794,7 +906,7 @@ def operates_to_pair(self, operates): '持仓K线数': operates[-1]['bid'] - operates[0]['bid'], '事件序列': " > ".join([x['op_desc'] for x in operates]), } - pair['持仓天数'] = (pair['平仓时间'] - pair['开仓时间']).total_seconds() / (24*3600) + pair['持仓天数'] = (pair['平仓时间'] - pair['开仓时间']).total_seconds() / (24 * 3600) # 空头计算盈亏,需要取反 pair['盈亏金额'] = -(pair['累计平仓'] - pair['累计开仓']) # 注意:【交易盈亏】的计算是对交易进行的,不是对账户,所以不能用来统计账户的收益 @@ -895,39 +1007,69 @@ def update(self, dt: datetime, op: Operate, price: float, bid: int, op_desc: str class Position: - def __init__(self, symbol: str, - events: List[Event], - hold_a: float = 0.5, - hold_b: float = 0.8, - hold_c: float = 1.0, - min_interval: int = None, - cost: float = 0.003, - T0: bool = False): - """空头持仓对象 + def __init__(self, symbol: str, opens: List[Event], exits: List[Event] = None, interval: int = 0, + timeout: int = 1000, stop_loss=1000, T0: bool = False, name=None): + """简单持仓对象,仓位表达:1 持有多头,-1 持有空头,0 空仓 :param symbol: 标的代码 - :param hold_a: 首次开仓后的仓位 - :param hold_b: 第一次加仓后的仓位 - :param hold_c: 第二次加仓的仓位 - :param min_interval: 两次开空仓之间的最小时间间隔,单位:秒 - :param cost: 双边交易成本,默认为千分之三 + :param opens: 开仓交易事件列表 + :param exits: 平仓交易事件列表,允许为空 + :param interval: 同类型开仓间隔时间,单位:秒;默认值为 0,表示同类型开仓间隔没有约束 + 假设上次开仓为多头,那么下一次多头开仓时间必须大于 上次开仓时间 + interval;空头也是如此。 + :param timeout: 最大允许持仓K线数量限制为最近一个开仓事件触发后的 timeout 根基础周期K线 + :param stop_loss: 最大允许亏损比例,单位:BP, 1BP = 0.01%;成本的计算以最近一个开仓事件触发价格为准 :param T0: 是否允许T0交易,默认为 False 表示不允许T0交易 + :param name: 仓位名称,默认值为第一个开仓事件的名称 """ - assert 0 <= hold_a <= hold_b <= hold_c <= 1.0 - if events[0].operate in long_operates: - for event in events: - assert event.operate in long_operates - self._position = PositionLong(symbol, hold_a, hold_b, hold_c, min_interval, cost, T0) - else: - for event in events: - assert event.operate in shor_operates - self._position = PositionShort(symbol, hold_a, hold_b, hold_c, min_interval, cost, T0) - self.events = events + self.symbol = symbol + self.opens = opens + self.name = name if name else opens[0].name + self.exits = exits if exits else [] + self.events = self.opens + self.exits + for event in self.events: + assert event.operate in [Operate.LO, Operate.LE, Operate.SO, Operate.SE] + + self.interval = interval + self.timeout = timeout + self.stop_loss = stop_loss + self.T0 = T0 + + self.operates = [] # 事件触发的操作列表 + self.holds = [] # 持仓状态列表 + self.pos = 0 + + # 辅助判断的缓存数据 + self.last_event = {'dt': None, 'bid': None, 'price': None, "op": None, 'op_desc': None} + self.last_lo_dt = None # 最近一次开多交易的时间 + self.last_so_dt = None # 最近一次开空交易的时间 + self.end_dt = None # 最近一次信号传入的时间 + + def __two_operates_pair(self, op1, op2): + assert op1['op'] in [Operate.LO, Operate.SO] + pair = { + '标的代码': self.symbol, + '交易方向': "多头" if op1['op'] == Operate.LO else "空头", + '开仓时间': op1['dt'], + '平仓时间': op2['dt'], + '开仓价格': op1['price'], + '平仓价格': op2['price'], + '持仓K线数': op2['bid'] - op1['bid'], + '事件序列': f"{op1['op_desc']} -> {op2['op_desc']}", + '持仓天数': (op2['dt'] - op1['dt']).total_seconds() / (24 * 3600), + '盈亏比例': op2['price'] / op1['price'] - 1 if op1['op'] == Operate.LO else 1 - op2['price'] / op1['price'], + } + # 盈亏比例 转换成以 BP 为单位的收益,1BP = 0.0001 + pair['盈亏比例'] = round(pair['盈亏比例'] * 10000, 2) + return pair @property - def pos(self): - """返回状态对应的仓位""" - return self._position.pos + def pairs(self): + """开平交易列表""" + pairs = [] + for op1, op2 in zip(self.operates, self.operates[1:]): + if op1['op'] in [Operate.LO, Operate.SO]: + pairs.append(self.__two_operates_pair(op1, op2)) + return pairs def update(self, s: dict): """更新持仓状态 @@ -935,14 +1077,90 @@ def update(self, s: dict): :param s: 最新信号字典 :return: """ + if self.end_dt and s['dt'] <= self.end_dt: + logger.warning(f"请检查信号传入:最新信号时间{s['dt']}在上次信号时间{self.end_dt}之前") + return + op = Operate.HO op_desc = "" - for event in self.events: m, f = event.is_match(s) if m: op = event.operate op_desc = f"{event.name}@{f}" break - dt, price, bid = s['dt'], s['close'], s['bid'] - self._position.update(dt, op, price, bid, op_desc) + + symbol, dt, price, bid = s['symbol'], s['dt'], s['close'], s['id'] + self.end_dt = dt + + # 当有新的开仓 event 发生,更新 last_event + if op in [Operate.LO, Operate.SO]: + self.last_event = {'dt': dt, 'bid': bid, 'price': price, 'op': op, 'op_desc': op_desc} + + def __create_operate(_op, _op_desc): + return {'symbol': self.symbol, 'dt': dt, 'bid': bid, 'price': price, + 'op': _op, 'op_desc': _op_desc, 'pos': self.pos} + + # 更新仓位 + if op == Operate.LO: + if not self.last_lo_dt or (dt - self.last_lo_dt).total_seconds() > self.interval: + # 与前一次开多间隔时间大于 interval,直接开多 + self.pos = 1 + self.operates.append(__create_operate(Operate.LO, op_desc)) + self.last_lo_dt = dt + else: + # 与前一次开多间隔时间小于 interval,仅对空头平仓 + if self.pos == -1 and (self.T0 or dt.date() != self.last_lo_dt.date()): + self.pos = 0 + self.operates.append(__create_operate(Operate.SE, op_desc)) + + if op == Operate.SO: + if not self.last_so_dt or (dt - self.last_so_dt).total_seconds() > self.interval: + # 与前一次开空间隔时间大于 interval,直接开空 + self.pos = -1 + self.operates.append(__create_operate(Operate.SO, op_desc)) + self.last_so_dt = dt + else: + # 与前一次开空间隔时间小于 interval,仅对多头平仓 + if self.pos == 1 and (self.T0 or dt.date() != self.last_so_dt.date()): + self.pos = 0 + self.operates.append(__create_operate(Operate.LE, op_desc)) + + # 多头出场 + if self.pos == 1 and (self.T0 or dt.date() != self.last_lo_dt.date()): + assert self.last_event['dt'] >= self.last_lo_dt + + # 多头平仓 + if op == Operate.LE: + self.pos = 0 + self.operates.append(__create_operate(Operate.LE, op_desc)) + + # 多头止损 + if price / self.last_event['price'] - 1 < -self.stop_loss / 10000: + self.pos = 0 + self.operates.append(__create_operate(Operate.LE, f"平多@{self.stop_loss}BP止损")) + + # 多头超时 + if bid - self.last_event['bid'] > self.timeout: + self.pos = 0 + self.operates.append(__create_operate(Operate.LE, f"平多@{self.timeout}K超时")) + + # 空头出场 + if self.pos == -1 and (self.T0 or dt.date() != self.last_so_dt.date()): + assert self.last_event['dt'] >= self.last_so_dt + + # 空头平仓 + if op == Operate.SE: + self.pos = 0 + self.operates.append(__create_operate(Operate.SE, op_desc)) + + # 空头止损 + if 1 - price / self.last_event['price'] < -self.stop_loss / 10000: + self.pos = 0 + self.operates.append(__create_operate(Operate.SE, f"平空@{self.stop_loss}BP止损")) + + # 空头超时 + if bid - self.last_event['bid'] > self.timeout: + self.pos = 0 + self.operates.append(__create_operate(Operate.SE, f"平空@{self.timeout}K超时")) + diff --git a/czsc/sensors/__init__.py b/czsc/sensors/__init__.py index 7fe242aa8..ceaf9d614 100644 --- a/czsc/sensors/__init__.py +++ b/czsc/sensors/__init__.py @@ -12,7 +12,6 @@ generate_signals, generate_stocks_signals, generate_symbol_signals, - read_cached_signals, turn_over_rate, discretizer, compound_returns, diff --git a/czsc/sensors/utils.py b/czsc/sensors/utils.py index 8005ddbb0..173810797 100644 --- a/czsc/sensors/utils.py +++ b/czsc/sensors/utils.py @@ -255,6 +255,8 @@ def turn_over_rate(df_holds: pd.DataFrame) -> [pd.DataFrame, float]: def compound_returns(n1b: List): """复利收益计算 + 等价于: `np.cumprod(np.array(n1b) / 10000 + 1) * 10000 - 10000` + :param n1b: 逐个结算周期的收益列表,单位:BP,换算关系是 10000BP = 100% 如,n1b = [100.1, -90.5, 212.6],表示第一个结算周期收益为100.1BP,也就是1.001%,以此类推。 :return: 累计复利收益,逐个结算周期的复利收益 @@ -267,57 +269,6 @@ def compound_returns(n1b: List): return v-10000, detail -def read_cached_signals(file_output: str, path_pat=None, sdt=None, edt=None, keys=None) -> pd.DataFrame: - """读取缓存信号 - - :param file_output: 读取后保存结果 - :param path_pat: 缓存信号文件路径模板,用于glob获取文件列表 - :param keys: 需要读取的信号名称列表 - :param sdt: 开始时间 - :param edt: 结束时间 - :return: 信号 - """ - verbose = envs.get_verbose() - - if os.path.exists(file_output): - sf = pd.read_pickle(file_output) - if verbose: - print(f"read_cached_signals: read from {file_output}, 数据占用内存大小" - f":{int(sf.memory_usage(deep=True).sum() / (1024 * 1024))} MB") - return sf - - files = glob.glob(path_pat, recursive=False) - results = [] - for file in tqdm(files, desc="read_cached_signals"): - df = pd.read_pickle(file) - if not df.empty: - if keys: - base_cols = [x for x in df.columns if len(x.split("_")) != 3] - df = df[base_cols + keys] - if sdt: - df = df[df['dt'] >= pd.to_datetime(sdt)] - if edt: - df = df[df['dt'] <= pd.to_datetime(edt)] - results.append(df) - else: - print(f"read_cached_signals: {file} is empty") - - sf = pd.concat(results, ignore_index=True) - if verbose: - print(f"read_cached_signals: 原始数据占用内存大小:{int(sf.memory_usage(deep=True).sum() / (1024 * 1024))} MB") - - c_cols = [k for k, v in sf.dtypes.to_dict().items() if v.name.startswith('object')] - sf[c_cols] = sf[c_cols].astype('category') - - float_cols = [k for k, v in sf.dtypes.to_dict().items() if v.name.startswith('float')] - sf[float_cols] = sf[float_cols].astype('float32') - if verbose: - print(f"read_cached_signals: 转类型后占用内存大小:{int(sf.memory_usage(deep=True).sum() / (1024 * 1024))} MB") - - sf.to_pickle(file_output, protocol=4) - return sf - - def generate_symbol_signals(dc: TsDataCache, ts_code: str, asset: str, @@ -412,21 +363,18 @@ def generate_stocks_signals(dc: TsDataCache, class SignalsPerformance: """信号表现分析""" - def __init__(self, dfs: pd.DataFrame, keys: List[AnyStr], dc: TsDataCache = None, base_freq="日线"): + def __init__(self, dfs: pd.DataFrame, keys: List[AnyStr]): """ :param dfs: 信号表 :param keys: 信号列,支持一个或多个信号列组合分析 - :param dc: Tushare 数据缓存对象 - :param base_freq: 信号对应的K线基础周期 """ if 'year' not in dfs.columns: - dfs['year'] = dfs['dt'].apply(lambda x: x.year) + y = dfs['dt'].apply(lambda x: x.year) + dfs['year'] = y.values self.dfs = dfs self.keys = keys - self.dc = dc - self.base_freq = base_freq self.b_cols = [x for x in dfs.columns if x[0] == 'b' and x[-1] == 'b'] self.n_cols = [x for x in dfs.columns if x[0] == 'n' and x[-1] == 'b'] @@ -451,7 +399,7 @@ def __return_performance(self, dfs: pd.DataFrame, mode: str = '1b') -> pd.DataFr edt = dfs['dt'].max().strftime("%Y%m%d") def __static(_df, _name): - _res = {"name": _name, "sdt": sdt, "edt": edt, + _res = {"name": _name, "date_span": f"{sdt} ~ {edt}", "count": len(_df), "cover": round(len(_df) / len_dfs, 4)} if mode.startswith('0'): _r = _df.groupby('dt')[cols].mean().mean().to_dict() @@ -462,7 +410,7 @@ def __static(_df, _name): results = [__static(dfs, "基准")] - for values, dfg in dfs.groupby(keys): + for values, dfg in dfs.groupby(by=keys if len(keys) > 1 else keys[0]): if isinstance(values, str): values = [values] assert len(keys) == len(values) @@ -474,7 +422,7 @@ def __static(_df, _name): dfr[cols] = dfr[cols].round(2) return dfr - def analyze_return(self, mode='0b') -> pd.DataFrame: + def analyze(self, mode='0b') -> pd.DataFrame: """分析信号出现前后的收益情况 :param mode: 分析模式, @@ -492,138 +440,11 @@ def analyze_return(self, mode='0b') -> pd.DataFrame: dfr = pd.concat(results, ignore_index=True) return dfr - def __corr_index(self, dfs: pd.DataFrame, index: str): - """分析信号每天出现的次数与指数的相关性""" - dc = self.dc - base_freq = self.base_freq - keys = self.keys - n_cols = self.n_cols - freq = freq_cn2ts[base_freq] - sdt = dfs['dt'].min().strftime("%Y%m%d") - edt = dfs['dt'].max().strftime("%Y%m%d") - adj = 'hfq' - asset = "I" - - if "分钟" in base_freq: - dfi = dc.pro_bar_minutes(index, sdt, edt, freq, asset, adj, raw_bar=False) - dfi['dt'] = pd.to_datetime(dfi['trade_time']) - else: - dfi = dc.pro_bar(index, sdt, edt, freq, asset, adj, raw_bar=False) - dfi['dt'] = pd.to_datetime(dfi['trade_date']) - - results = [] - for values, dfg in dfs.groupby(keys): - if isinstance(values, str): - values = [values] - assert len(keys) == len(values) - name = "#".join([f"{key1}_{name1}" for key1, name1 in zip(keys, values)]) - c = dfg.groupby("dt")['symbol'].count() - c_col = f'{name}_count' - dfc = pd.DataFrame({'dt': c.index, c_col: c.values}) - df_ = dfi.merge(dfc, on=['dt'], how='left') - df_[c_col] = df_[c_col].fillna(0) - - res = {"name": name, 'sdt': sdt, 'edt': edt, 'index': index} - corr_ = df_[[c_col] + n_cols].corr(method='spearman').iloc[0][n_cols].round(4).to_dict() - res.update(corr_) - results.append(res) - df_corr = pd.DataFrame(results) - return df_corr - - def analyze_corr_index(self, index: str) -> pd.DataFrame: - """分析信号出现前后的收益情况 - - :param index: Tushare 指数代码,如 000905.SH 表示中证500 - :return: - """ - dfr = self.__corr_index(self.dfs, index) - results = [dfr] - for year, df_ in self.dfs.groupby('year'): - dfr_ = self.__corr_index(df_, index) - results.append(dfr_) - dfr = pd.concat(results, ignore_index=True) - return dfr - - def __ar_counts(self, dfs: pd.DataFrame): - """分析信号每天出现的次数与自身收益的相关性""" - keys = self.keys - n_cols = self.n_cols - sdt = dfs['dt'].min().strftime("%Y%m%d") - edt = dfs['dt'].max().strftime("%Y%m%d") - - results = [] - for values, dfg in dfs.groupby(keys): - if isinstance(values, str): - values = [values] - assert len(keys) == len(values) - name = "#".join([f"{key1}_{name1}" for key1, name1 in zip(keys, values)]) - c = dfg.groupby("dt")['symbol'].count() - n_bars = dfg.groupby("dt")[n_cols].mean() - n_bars['count'] = c - res_ = {"name": name, 'sdt': sdt, 'edt': edt} - corr_ = n_bars[['count'] + n_cols].corr(method='spearman').iloc[0][n_cols].round(4).to_dict() - res_.update(corr_) - results.append(res_) - dfr = pd.DataFrame(results) - return dfr - - def analyze_ar_counts(self) -> pd.DataFrame: - """分析信号每天出现的次数与自身收益的相关性""" - dfr = self.__ar_counts(self.dfs) - results = [dfr] - for year, df_ in self.dfs.groupby('year'): - dfr_ = self.__ar_counts(df_) - results.append(dfr_) - dfr = pd.concat(results, ignore_index=True) - return dfr - - def __b_bar(self, dfs: pd.DataFrame, b_col='b21b'): - """分析信号出现前的收益与出现后收益的相关性""" - keys = self.keys - n_cols = self.n_cols - sdt = dfs['dt'].min().strftime("%Y%m%d") - edt = dfs['dt'].max().strftime("%Y%m%d") - - results = [] - for values, dfg in dfs.groupby(keys): - if isinstance(values, str): - values = [values] - assert len(keys) == len(values) - name = "#".join([f"{key1}_{name1}" for key1, name1 in zip(keys, values)]) - n_bars = dfg.groupby("dt")[[b_col] + n_cols].mean() - res_ = {"name": name, 'sdt': sdt, 'edt': edt, 'b_col': b_col} - corr_ = n_bars[[b_col] + n_cols].corr(method='spearman').iloc[0][n_cols].round(4).to_dict() - res_.update(corr_) - results.append(res_) - dfr = pd.DataFrame(results) - return dfr - - def analyze_b_bar(self, b_col='b21b') -> pd.DataFrame: - """分析信号出现前的收益与出现后收益的相关性""" - dfr = self.__b_bar(self.dfs, b_col) - results = [dfr] - for year, df_ in self.dfs.groupby('year'): - dfr_ = self.__b_bar(df_, b_col) - results.append(dfr_) - dfr = pd.concat(results, ignore_index=True) - return dfr - def report(self, file_xlsx=None): res = { - '向前看截面': self.analyze_return('0b'), - '向后看截面': self.analyze_return('0n'), - '向前看时序': self.analyze_return('1b'), - '向后看时序': self.analyze_return('1n'), - - '信号数量与自身收益相关性': self.analyze_ar_counts(), + '向后看截面': self.analyze('0n'), + '向后看时序': self.analyze('1n'), } - - if self.dc: - res.update({ - '信号数量与上证50相关性': self.analyze_corr_index('000016.SH'), - '信号数量与中证500相关性': self.analyze_corr_index('000905.SH'), - '信号数量与沪深300相关性': self.analyze_corr_index('000300.SH'), - }) if file_xlsx: writer = pd.ExcelWriter(file_xlsx) for sn, df_ in res.items(): diff --git a/czsc/signals/__init__.py b/czsc/signals/__init__.py index 0b91b7d50..187964908 100644 --- a/czsc/signals/__init__.py +++ b/czsc/signals/__init__.py @@ -26,6 +26,8 @@ cxt_first_buy_V221126, cxt_first_sell_V221126, cxt_bi_break_V221126, + cxt_sub_b3_V221212, + cxt_zhong_shu_gong_zhen_V221221, ) diff --git a/czsc/signals/bar.py b/czsc/signals/bar.py index e7ace0d19..f5b603432 100644 --- a/czsc/signals/bar.py +++ b/czsc/signals/bar.py @@ -8,8 +8,9 @@ import numpy as np from datetime import datetime from typing import List +from loguru import logger from collections import OrderedDict -from czsc import CZSC, Signal, CzscAdvancedTrader +from czsc import envs, CZSC, Signal, CzscAdvancedTrader from czsc.objects import RawBar from czsc.utils import check_pressure_support, get_sub_elements @@ -251,15 +252,20 @@ def bar_mean_amount_V221112(c: CZSC, di: int = 1, n: int = 10, th1: int = 1, th2 """ k1, k2, k3 = str(c.freq.value), f"D{di}K{n}B均额", f"{th1}至{th2}千万" - if len(c.bars_raw) < di + n + 5: - v1 = "其他" - - else: - bars = get_sub_elements(c.bars_raw, di=di, n=n) - assert len(bars) == n - - m = sum([x.amount for x in bars]) / n - v1 = "是" if th2 >= m / 10000000 >= th1 else "否" + v1 = "其他" + if len(c.bars_raw) > di + n + 5: + try: + bars = get_sub_elements(c.bars_raw, di=di, n=n) + assert len(bars) == n + m = sum([x.amount for x in bars]) / n + v1 = "是" if th2 >= m / 10000000 >= th1 else "否" + + except Exception as e: + msg = f"{c.symbol} - {c.bars_raw[-1].dt} fail: {e}" + if envs.get_verbose(): + logger.exception(msg) + else: + logger.warning(msg) s = OrderedDict() signal = Signal(k1=k1, k2=k2, k3=k3, v1=v1) diff --git a/czsc/signals/cxt.py b/czsc/signals/cxt.py index 09a32b2d7..c9e5c1a92 100644 --- a/czsc/signals/cxt.py +++ b/czsc/signals/cxt.py @@ -297,3 +297,56 @@ def cxt_sub_b3_V221212(cat: CzscAdvancedTrader, freq='60分钟', sub_freq='15分 signal = Signal(k1=k1, k2=k2, k3=k3, v1=v1) s[signal.key] = signal.value return s + + +def cxt_zhong_shu_gong_zhen_V221221(cat: CzscAdvancedTrader, freq1='日线', freq2='60分钟') -> OrderedDict: + """大小级别中枢共振,类二买共振;贡献者:琅盎 + + **信号逻辑:** + + 1. 不区分上涨或下跌中枢 + 2. 次级别中枢 DD 大于本级别中枢中轴 + 3. 次级别向下笔出底分型开多;反之看空 + + **信号列表:** + + - Signal('日线_60分钟_中枢共振_看多_任意_任意_0') + - Signal('日线_60分钟_中枢共振_看空_任意_任意_0') + + :param cat: + :param freq1:大级别周期 + :param freq2: 小级别周期 + :return: 信号识别结果 + """ + k1, k2, k3 = f"{freq1}_{freq2}_中枢共振".split('_') + + max_freq: CZSC = cat.kas[freq1] + min_freq: CZSC = cat.kas[freq2] + symbol = cat.symbol + + def __is_zs(_bis): + _zs = ZS(symbol=symbol, bis=_bis) + if _zs.zd < _zs.zg: + return True + else: + return False + + v1 = "其他" + if len(max_freq.bi_list) >= 5 and __is_zs(max_freq.bi_list[-3:]) \ + and len(min_freq.bi_list) >= 5 and __is_zs(min_freq.bi_list[-3:]): + + big_zs = ZS(symbol=symbol, bis=max_freq.bi_list[-3:]) + small_zs = ZS(symbol=symbol, bis=min_freq.bi_list[-3:]) + + if small_zs.dd > big_zs.zz and min_freq.bi_list[-1].direction == Direction.Down: + v1 = "看多" + + if small_zs.gg < big_zs.zz and min_freq.bi_list[-1].direction == Direction.Up: + v1 = "看空" + + s = OrderedDict() + signal = Signal(k1=k1, k2=k2, k3=k3, v1=v1) + s[signal.key] = signal.value + return s + + diff --git a/czsc/signals/tas.py b/czsc/signals/tas.py index 5bcd50e48..a9e4e7ac3 100644 --- a/czsc/signals/tas.py +++ b/czsc/signals/tas.py @@ -8,6 +8,7 @@ tas = ta-lib signals 的缩写 """ from loguru import logger + try: import talib as ta except: @@ -20,7 +21,7 @@ from collections import OrderedDict -def update_ma_cache(c: CZSC, ma_type: str, timeperiod: int, **kwargs) -> None: +def update_ma_cache(c: CZSC, ma_type: str, timeperiod: int, **kwargs): """更新均线缓存 :param c: CZSC对象 @@ -36,30 +37,39 @@ def update_ma_cache(c: CZSC, ma_type: str, timeperiod: int, **kwargs) -> None: 'TEMA': ta.MA_Type.TEMA, 'DEMA': ta.MA_Type.DEMA, 'MAMA': ta.MA_Type.MAMA, - 'T3': ta.MA_Type.T3, 'TRIMA': ta.MA_Type.TRIMA, } - min_count = timeperiod + ma_type = ma_type.upper() + assert ma_type in ma_type_map.keys(), f"{ma_type} 不是支持的均线类型,可选值:{list(ma_type_map.keys())}" cache_key = f"{ma_type.upper()}{timeperiod}" + if c.bars_raw[-1].cache and c.bars_raw[-1].cache.get(cache_key, None): + # 如果最后一根K线已经有对应的缓存,不执行更新 + return cache_key + last_cache = dict(c.bars_raw[-2].cache) if c.bars_raw[-2].cache else dict() - if cache_key not in last_cache.keys() or len(c.bars_raw) < min_count + 5: + if cache_key not in last_cache.keys() or len(c.bars_raw) < timeperiod + 15: # 初始化缓存 close = np.array([x.close for x in c.bars_raw]) - min_count = 0 + ma = ta.MA(close, timeperiod=timeperiod, matype=ma_type_map[ma_type.upper()]) + assert len(ma) == len(close) + for i in range(len(close)): + _c = dict(c.bars_raw[i].cache) if c.bars_raw[i].cache else dict() + _c.update({cache_key: ma[i] if ma[i] else close[i]}) + c.bars_raw[i].cache = _c + else: - # 增量更新缓存 + # 增量更新最近5个K线缓存 close = np.array([x.close for x in c.bars_raw[-timeperiod - 10:]]) - - ma = ta.MA(close, timeperiod=timeperiod, matype=ma_type_map[ma_type.upper()]) - - for i in range(1, len(close) - min_count - 5): - _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() - _c.update({cache_key: ma[-i]}) - c.bars_raw[-i].cache = _c + ma = ta.MA(close, timeperiod=timeperiod, matype=ma_type_map[ma_type.upper()]) + for i in range(1, 6): + _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() + _c.update({cache_key: ma[-i]}) + c.bars_raw[-i].cache = _c + return cache_key -def update_macd_cache(c: CZSC, **kwargs) -> None: +def update_macd_cache(c: CZSC, **kwargs): """更新MACD缓存 :param c: CZSC对象 @@ -69,24 +79,38 @@ def update_macd_cache(c: CZSC, **kwargs) -> None: slowperiod = kwargs.get('slowperiod', 26) signalperiod = kwargs.get('signalperiod', 9) - min_count = fastperiod + slowperiod cache_key = f"MACD" + if c.bars_raw[-1].cache and c.bars_raw[-1].cache.get(cache_key, None): + # 如果最后一根K线已经有对应的缓存,不执行更新 + return cache_key + + min_count = signalperiod + slowperiod last_cache = dict(c.bars_raw[-2].cache) if c.bars_raw[-2].cache else dict() - if cache_key not in last_cache.keys() or len(c.bars_raw) < min_count + 30: + if cache_key not in last_cache.keys() or len(c.bars_raw) < min_count + 15: + # 初始化缓存 close = np.array([x.close for x in c.bars_raw]) - min_count = 0 - else: - close = np.array([x.close for x in c.bars_raw[-min_count-30:]]) + dif, dea, macd = ta.MACD(close, fastperiod=fastperiod, slowperiod=slowperiod, signalperiod=signalperiod) + for i in range(len(close)): + _c = dict(c.bars_raw[i].cache) if c.bars_raw[i].cache else dict() + dif_i = dif[i] if dif[i] else close[i] + dea_i = dea[i] if dea[i] else close[i] + macd_i = dif_i - dea_i + _c.update({cache_key: {'dif': dif_i, 'dea': dea_i, 'macd': macd_i}}) + c.bars_raw[i].cache = _c - dif, dea, macd = ta.MACD(close, fastperiod=fastperiod, slowperiod=slowperiod, signalperiod=signalperiod) + else: + # 增量更新最近5个K线缓存 + close = np.array([x.close for x in c.bars_raw[-min_count - 10:]]) + dif, dea, macd = ta.MACD(close, fastperiod=fastperiod, slowperiod=slowperiod, signalperiod=signalperiod) + for i in range(1, 6): + _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() + _c.update({cache_key: {'dif': dif[-i], 'dea': dea[-i], 'macd': macd[-i]}}) + c.bars_raw[-i].cache = _c - for i in range(1, len(close) - min_count - 10): - _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() - _c.update({cache_key: {'dif': dif[-i], 'dea': dea[-i], 'macd': macd[-i]}}) - c.bars_raw[-i].cache = _c + return cache_key -def update_boll_cache(c: CZSC, **kwargs) -> None: +def update_boll_cache(c: CZSC, **kwargs): """更新K线的BOLL缓存 :param c: 交易对象 @@ -96,30 +120,52 @@ def update_boll_cache(c: CZSC, **kwargs) -> None: timeperiod = kwargs.get('timeperiod', 20) dev_seq = kwargs.get('dev_seq', (1.382, 2, 2.764)) - min_count = timeperiod + if c.bars_raw[-1].cache and c.bars_raw[-1].cache.get(cache_key, None): + # 如果最后一根K线已经有对应的缓存,不执行更新 + return cache_key + last_cache = dict(c.bars_raw[-2].cache) if c.bars_raw[-2].cache else dict() - if cache_key not in last_cache.keys() or len(c.bars_raw) < min_count + 30: + if cache_key not in last_cache.keys() or len(c.bars_raw) < timeperiod + 15: + # 初始化缓存 close = np.array([x.close for x in c.bars_raw]) - min_count = 0 + u1, m, l1 = ta.BBANDS(close, timeperiod=timeperiod, nbdevup=dev_seq[0], nbdevdn=dev_seq[0], matype=0) + u2, m, l2 = ta.BBANDS(close, timeperiod=timeperiod, nbdevup=dev_seq[1], nbdevdn=dev_seq[1], matype=0) + u3, m, l3 = ta.BBANDS(close, timeperiod=timeperiod, nbdevup=dev_seq[2], nbdevdn=dev_seq[2], matype=0) + + for i in range(len(close)): + _c = dict(c.bars_raw[i].cache) if c.bars_raw[i].cache else dict() + if not m[i]: + _data = {"上轨3": close[i], "上轨2": close[i], "上轨1": close[i], + "中线": close[i], + "下轨1": close[i], "下轨2": close[i], "下轨3": close[i]} + else: + _data = {"上轨3": u3[i], "上轨2": u2[i], "上轨1": u1[i], + "中线": m[i], + "下轨1": l1[i], "下轨2": l2[i], "下轨3": l3[i]} + _c.update({cache_key: _data}) + c.bars_raw[i].cache = _c + else: - close = np.array([x.close for x in c.bars_raw[-min_count-30:]]) + # 增量更新最近5个K线缓存 + close = np.array([x.close for x in c.bars_raw[-timeperiod - 10:]]) + u1, m, l1 = ta.BBANDS(close, timeperiod=timeperiod, nbdevup=dev_seq[0], nbdevdn=dev_seq[0], matype=0) + u2, m, l2 = ta.BBANDS(close, timeperiod=timeperiod, nbdevup=dev_seq[1], nbdevdn=dev_seq[1], matype=0) + u3, m, l3 = ta.BBANDS(close, timeperiod=timeperiod, nbdevup=dev_seq[2], nbdevdn=dev_seq[2], matype=0) - u1, m, l1 = ta.BBANDS(close, timeperiod=timeperiod, nbdevup=dev_seq[0], nbdevdn=dev_seq[0], matype=ta.MA_Type.SMA) - u2, m, l2 = ta.BBANDS(close, timeperiod=timeperiod, nbdevup=dev_seq[1], nbdevdn=dev_seq[1], matype=ta.MA_Type.SMA) - u3, m, l3 = ta.BBANDS(close, timeperiod=timeperiod, nbdevup=dev_seq[2], nbdevdn=dev_seq[2], matype=ta.MA_Type.SMA) + for i in range(1, 6): + _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() + _c.update({cache_key: {"上轨3": u3[-i], "上轨2": u2[-i], "上轨1": u1[-i], + "中线": m[-i], + "下轨1": l1[-i], "下轨2": l2[-i], "下轨3": l3[-i]}}) + c.bars_raw[-i].cache = _c - for i in range(1, len(close) - min_count - 10): - _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() - _c.update({cache_key: {"上轨3": u3[-i], "上轨2": u2[-i], "上轨1": u1[-i], - "中线": m[-i], - "下轨1": l1[-i], "下轨2": l2[-i], "下轨3": l3[-i]}}) - c.bars_raw[-i].cache = _c + return cache_key # MACD信号计算函数 # ====================================================================================================================== -def tas_macd_base_V221028(c: CZSC, di: int = 1, key="macd") -> OrderedDict: +def tas_macd_base_V221028(c: CZSC, di: int = 1, key="macd", **kwargs) -> OrderedDict: """MACD|DIF|DEA 多空和方向信号 **信号逻辑:** @@ -139,10 +185,11 @@ def tas_macd_base_V221028(c: CZSC, di: int = 1, key="macd") -> OrderedDict: :param key: 指定使用哪个Key来计算,可选值 [macd, dif, dea] :return: """ + cache_key = update_macd_cache(c, **kwargs) assert key.lower() in ['macd', 'dif', 'dea'] k1, k2, k3 = f"{c.freq.value}_D{di}K_{key.upper()}".split('_') - macd = [x.cache['MACD'][key.lower()] for x in c.bars_raw[-5 - di:]] + macd = [x.cache[cache_key][key.lower()] for x in c.bars_raw[-5 - di:]] v1 = "多头" if macd[-di] >= 0 else "空头" v2 = "向上" if macd[-di] >= macd[-di - 1] else "向下" @@ -152,7 +199,7 @@ def tas_macd_base_V221028(c: CZSC, di: int = 1, key="macd") -> OrderedDict: return s -def tas_macd_direct_V221106(c: CZSC, di: int = 1) -> OrderedDict: +def tas_macd_direct_V221106(c: CZSC, di: int = 1, **kwargs) -> OrderedDict: """MACD方向;贡献者:马鸣 **信号逻辑:** 连续三根macd柱子值依次增大,向上;反之,向下 @@ -167,9 +214,10 @@ def tas_macd_direct_V221106(c: CZSC, di: int = 1) -> OrderedDict: :param di: 连续倒3根K线 :return: """ + cache_key = update_macd_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}K_MACD方向".split("_") bars = get_sub_elements(c.bars_raw, di=di, n=3) - macd = [x.cache['MACD']['macd'] for x in bars] + macd = [x.cache[cache_key]['macd'] for x in bars] if len(macd) != 3: v1 = "模糊" @@ -188,7 +236,7 @@ def tas_macd_direct_V221106(c: CZSC, di: int = 1) -> OrderedDict: return s -def tas_macd_power_V221108(c: CZSC, di: int = 1) -> OrderedDict: +def tas_macd_power_V221108(c: CZSC, di: int = 1, **kwargs) -> OrderedDict: """MACD强弱 **信号逻辑:** @@ -209,12 +257,13 @@ def tas_macd_power_V221108(c: CZSC, di: int = 1) -> OrderedDict: :param di: 信号产生在倒数第di根K线 :return: 信号识别结果 """ + cache_key = update_macd_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}K_MACD强弱".split("_") v1 = "其他" if len(c.bars_raw) > di + 10: bar = c.bars_raw[-di] - dif, dea = bar.cache['MACD']['dif'], bar.cache['MACD']['dea'] + dif, dea = bar.cache[cache_key]['dif'], bar.cache[cache_key]['dea'] if dif >= dea >= 0: v1 = "超强" @@ -231,7 +280,7 @@ def tas_macd_power_V221108(c: CZSC, di: int = 1) -> OrderedDict: return s -def tas_macd_first_bs_V221201(c: CZSC, di: int = 1): +def tas_macd_first_bs_V221201(c: CZSC, di: int = 1, **kwargs): """MACD金叉死叉判断第一买卖点 **信号逻辑:** @@ -247,14 +296,15 @@ def tas_macd_first_bs_V221201(c: CZSC, di: int = 1): :param di: 倒数第i根K线 :return: 信号识别结果 """ + cache_key = update_macd_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}MACD_BS1".split('_') - bars = get_sub_elements(c.bars_raw, di=di, n=350)[50:] + bars = get_sub_elements(c.bars_raw, di=di, n=300) v1 = "其他" if len(bars) >= 100: - dif = [x.cache['MACD']['dif'] for x in bars] - dea = [x.cache['MACD']['dea'] for x in bars] - macd = [x.cache['MACD']['macd'] for x in bars] + dif = [x.cache[cache_key]['dif'] for x in bars] + dea = [x.cache[cache_key]['dea'] for x in bars] + macd = [x.cache[cache_key]['macd'] for x in bars] cross = fast_slow_cross(dif, dea) up = [x for x in cross if x['类型'] == "金叉" and x['距离'] > 5] @@ -278,7 +328,7 @@ def tas_macd_first_bs_V221201(c: CZSC, di: int = 1): return s -def tas_macd_first_bs_V221216(c: CZSC, di: int = 1): +def tas_macd_first_bs_V221216(c: CZSC, di: int = 1, **kwargs): """MACD金叉死叉判断第一买卖点 **信号逻辑:** @@ -297,15 +347,16 @@ def tas_macd_first_bs_V221216(c: CZSC, di: int = 1): :param di: 倒数第i根K线 :return: 信号识别结果 """ + cache_key = update_macd_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}MACD_BS1A".split('_') - bars = get_sub_elements(c.bars_raw, di=di, n=350)[50:] + bars = get_sub_elements(c.bars_raw, di=di, n=300) v1 = "其他" v2 = "任意" if len(bars) >= 100: - dif = [x.cache['MACD']['dif'] for x in bars] - dea = [x.cache['MACD']['dea'] for x in bars] - macd = [x.cache['MACD']['macd'] for x in bars] + dif = [x.cache[cache_key]['dif'] for x in bars] + dea = [x.cache[cache_key]['dea'] for x in bars] + macd = [x.cache[cache_key]['macd'] for x in bars] n_bars = bars[-10:] m_bars = bars[-100: -10] high_n = max([x.high for x in n_bars]) @@ -339,7 +390,7 @@ def tas_macd_first_bs_V221216(c: CZSC, di: int = 1): return s -def tas_macd_second_bs_V221201(c: CZSC, di: int = 1): +def tas_macd_second_bs_V221201(c: CZSC, di: int = 1, **kwargs): """MACD金叉死叉判断第二买卖点 **信号逻辑:** @@ -358,15 +409,16 @@ def tas_macd_second_bs_V221201(c: CZSC, di: int = 1): :param di: 倒数第i根K线 :return: 信号识别结果 """ + cache_key = update_macd_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}MACD_BS2".split('_') bars = get_sub_elements(c.bars_raw, di=di, n=350)[50:] v1 = "其他" v2 = "任意" if len(bars) >= 100: - dif = [x.cache['MACD']['dif'] for x in bars] - dea = [x.cache['MACD']['dea'] for x in bars] - macd = [x.cache['MACD']['macd'] for x in bars] + dif = [x.cache[cache_key]['dif'] for x in bars] + dea = [x.cache[cache_key]['dea'] for x in bars] + macd = [x.cache[cache_key]['macd'] for x in bars] cross = fast_slow_cross(dif, dea) up = [x for x in cross if x['类型'] == "金叉" and x['距离'] > 5] @@ -394,7 +446,7 @@ def tas_macd_second_bs_V221201(c: CZSC, di: int = 1): return s -def tas_macd_xt_V221208(c: CZSC, di: int = 1): +def tas_macd_xt_V221208(c: CZSC, di: int = 1, **kwargs): """MACD形态信号 **信号逻辑:** @@ -414,9 +466,10 @@ def tas_macd_xt_V221208(c: CZSC, di: int = 1): :param di: 倒数第i根K线 :return: """ + cache_key = update_macd_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}K_MACD形态".split('_') bars = get_sub_elements(c.bars_raw, di=di, n=5) - macd = [x.cache['MACD']['macd'] for x in bars] + macd = [x.cache[cache_key]['macd'] for x in bars] v1 = "其他" if len(macd) == 5: @@ -439,7 +492,7 @@ def tas_macd_xt_V221208(c: CZSC, di: int = 1): return s -def tas_macd_bc_V221201(c: CZSC, di: int = 1, n: int = 3, m: int = 50): +def tas_macd_bc_V221201(c: CZSC, di: int = 1, n: int = 3, m: int = 50, **kwargs): """MACD背驰辅助 **信号逻辑:** @@ -461,8 +514,9 @@ def tas_macd_bc_V221201(c: CZSC, di: int = 1, n: int = 3, m: int = 50): :param m: 远期窗口大小 :return: 信号识别结果 """ + cache_key = update_macd_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}N{n}M{m}_MACD背驰".split('_') - bars = get_sub_elements(c.bars_raw, di=di, n=n+m) + bars = get_sub_elements(c.bars_raw, di=di, n=n + m) assert n >= 3, "近期窗口大小至少要大于3" v1 = "其他" @@ -472,9 +526,9 @@ def tas_macd_bc_V221201(c: CZSC, di: int = 1, n: int = 3, m: int = 50): m_bars = bars[:m] assert len(n_bars) == n and len(m_bars) == m n_close = [x.close for x in n_bars] - n_macd = [x.cache['MACD']['macd'] for x in n_bars] + n_macd = [x.cache[cache_key]['macd'] for x in n_bars] m_close = [x.close for x in m_bars] - m_macd = [x.cache['MACD']['macd'] for x in m_bars] + m_macd = [x.cache[cache_key]['macd'] for x in m_bars] if n_macd[-1] > n_macd[-2] and min(n_close) < min(m_close) and min(n_macd) > min(m_macd): v1 = '底部' @@ -489,10 +543,12 @@ def tas_macd_bc_V221201(c: CZSC, di: int = 1, n: int = 3, m: int = 50): return s -def tas_macd_change_V221105(c: CZSC, di: int = 1, n: int = 55) -> OrderedDict: +def tas_macd_change_V221105(c: CZSC, di: int = 1, n: int = 55, **kwargs) -> OrderedDict: """MACD颜色变化;贡献者:马鸣 - **信号逻辑:** 从dik往前数n根k线对应的macd红绿柱子变换次数 + **信号逻辑:** + + 从dik往前数n根k线对应的macd红绿柱子变换次数 **信号列表:** @@ -511,11 +567,12 @@ def tas_macd_change_V221105(c: CZSC, di: int = 1, n: int = 55) -> OrderedDict: :param n: 从dik往前数n根k线 :return: """ + cache_key = update_macd_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}K{n}_MACD变色次数".split('_') bars = get_sub_elements(c.bars_raw, di=di, n=n) - dif = [x.cache['MACD']['dif'] for x in bars] - dea = [x.cache['MACD']['dea'] for x in bars] + dif = [x.cache[cache_key]['dif'] for x in bars] + dea = [x.cache[cache_key]['dea'] for x in bars] cross = fast_slow_cross(dif, dea) # 过滤低级别信号抖动造成的金叉死叉(这个参数根据自身需要进行修改) @@ -553,7 +610,7 @@ def tas_macd_change_V221105(c: CZSC, di: int = 1, n: int = 55) -> OrderedDict: # MA信号计算函数 # ====================================================================================================================== -def tas_ma_base_V221101(c: CZSC, di: int = 1, key="SMA5") -> OrderedDict: +def tas_ma_base_V221101(c: CZSC, di: int = 1, ma_type='SMA', timeperiod=5) -> OrderedDict: """MA 多空和方向信号 **信号逻辑:** @@ -570,10 +627,12 @@ def tas_ma_base_V221101(c: CZSC, di: int = 1, key="SMA5") -> OrderedDict: :param c: CZSC对象 :param di: 信号计算截止倒数第i根K线 - :param key: 指定使用哪个Key来计算,必须是 `update_ma_cache` 中已经缓存的 key + :param ma_type: 均线类型,必须是 `ma_type_map` 中的 key + :param timeperiod: 均线计算周期 :return: """ - k1, k2, k3 = f"{c.freq.value}_D{di}K_{key.upper()}".split('_') + key = update_ma_cache(c, ma_type, timeperiod) + k1, k2, k3 = f"{c.freq.value}_D{di}K_{key}".split('_') bars = get_sub_elements(c.bars_raw, di=di, n=3) v1 = "多头" if bars[-1].close >= bars[-1].cache[key] else "空头" v2 = "向上" if bars[-1].cache[key] >= bars[-2].cache[key] else "向下" @@ -584,7 +643,7 @@ def tas_ma_base_V221101(c: CZSC, di: int = 1, key="SMA5") -> OrderedDict: return s -def tas_ma_base_V221203(c: CZSC, di: int = 1, key="SMA5", th=100) -> OrderedDict: +def tas_ma_base_V221203(c: CZSC, di: int = 1, ma_type='SMA', timeperiod=5, th=100) -> OrderedDict: """MA 多空和方向信号,加距离限制 **信号逻辑:** @@ -606,11 +665,13 @@ def tas_ma_base_V221203(c: CZSC, di: int = 1, key="SMA5", th=100) -> OrderedDict :param c: CZSC对象 :param di: 信号计算截止倒数第i根K线 - :param key: 指定使用哪个Key来计算,必须是 `update_ma_cache` 中已经缓存的 key + :param ma_type: 均线类型,必须是 `ma_type_map` 中的 key + :param timeperiod: 均线计算周期 :param th: 距离阈值,单位 BP :return: """ - k1, k2, k3 = f"{c.freq.value}_D{di}T{th}_{key.upper()}".split('_') + key = update_ma_cache(c, ma_type, timeperiod) + k1, k2, k3 = f"{c.freq.value}_D{di}T{th}_{key}".split('_') bars = get_sub_elements(c.bars_raw, di=di, n=3) c = bars[-1].close m = bars[-1].cache[key] @@ -624,7 +685,7 @@ def tas_ma_base_V221203(c: CZSC, di: int = 1, key="SMA5", th=100) -> OrderedDict return s -def tas_ma_round_V221206(c: CZSC, di: int = 1, key: str = "SMA60", th: int = 10) -> OrderedDict: +def tas_ma_round_V221206(c: CZSC, di: int = 1, ma_type='SMA', timeperiod=60, th: int = 10) -> OrderedDict: """笔端点在均线附近,贡献者:谌意勇 **信号逻辑:** @@ -638,10 +699,12 @@ def tas_ma_round_V221206(c: CZSC, di: int = 1, key: str = "SMA60", th: int = 10) :param c: CZSC对象 :param di: 指定倒数第几笔 - :param key: 指定均线名称 + :param ma_type: 均线类型,必须是 `ma_type_map` 中的 key + :param timeperiod: 均线计算周期 :param th: 笔的端点到均线的绝对价差 / 笔的价差 < th / 100 表示笔端点在均线附近 :return: 信号识别结果 """ + key = update_ma_cache(c, ma_type, timeperiod) k1, k2, k3 = f'{c.freq.value}_D{di}TH{th}_碰{key}'.split('_') v1 = "其他" @@ -661,7 +724,7 @@ def tas_ma_round_V221206(c: CZSC, di: int = 1, key: str = "SMA60", th: int = 10) return s -def tas_double_ma_V221203(c: CZSC, di: int = 1, ma1="SMA5", ma2='SMA10', th: int = 100) -> OrderedDict: +def tas_double_ma_V221203(c: CZSC, di: int = 1, ma_type='SMA', ma_seq=(5, 10), th: int = 100) -> OrderedDict: """双均线多空和强弱信号 **信号逻辑:** @@ -678,12 +741,16 @@ def tas_double_ma_V221203(c: CZSC, di: int = 1, ma1="SMA5", ma2='SMA10', th: int :param c: CZSC对象 :param di: 信号计算截止倒数第i根K线 - :param ma1: 指定短期均线,必须是 `update_ma_cache` 中已经缓存的 key - :param ma2: 指定长期均线,必须是 `update_ma_cache` 中已经缓存的 key + :param ma_type: 均线类型,必须是 `ma_type_map` 中的 key + :param ma_seq: 快慢均线计算周期,快线在前 :param th: ma1 相比 ma2 的距离阈值,单位 BP :return: 信号识别结果 """ - k1, k2, k3 = f"{c.freq.value}_D{di}T{th}_{ma1.upper()}{ma2.upper()}".split('_') + assert len(ma_seq) == 2 and ma_seq[1] > ma_seq[0] + ma1 = update_ma_cache(c, ma_type, ma_seq[0]) + ma2 = update_ma_cache(c, ma_type, ma_seq[1]) + + k1, k2, k3 = f"{c.freq.value}_D{di}T{th}_{ma1}{ma2}".split('_') bars = get_sub_elements(c.bars_raw, di=di, n=3) ma1v = bars[-1].cache[ma1] ma2v = bars[-1].cache[ma2] @@ -700,7 +767,7 @@ def tas_double_ma_V221203(c: CZSC, di: int = 1, ma1="SMA5", ma2='SMA10', th: int # ====================================================================================================================== -def tas_boll_power_V221112(c: CZSC, di: int = 1): +def tas_boll_power_V221112(c: CZSC, di: int = 1, **kwargs): """BOLL指标强弱 **信号逻辑:** @@ -723,6 +790,7 @@ def tas_boll_power_V221112(c: CZSC, di: int = 1): :param di: 信号计算截止倒数第i根K线 :return: s """ + cache_key = update_boll_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}K_BOLL强弱".split("_") if len(c.bars_raw) < di + 20: @@ -730,7 +798,7 @@ def tas_boll_power_V221112(c: CZSC, di: int = 1): else: last = c.bars_raw[-di] - cache = last.cache['boll'] + cache = last.cache[cache_key] latest_c = last.close m = cache['中线'] @@ -754,7 +822,7 @@ def tas_boll_power_V221112(c: CZSC, di: int = 1): return s -def tas_boll_bc_V221118(c: CZSC, di=1, n=3, m=10, line=3): +def tas_boll_bc_V221118(c: CZSC, di=1, n=3, m=10, line=3, **kwargs): """BOLL背驰辅助 **信号逻辑:** @@ -773,18 +841,19 @@ def tas_boll_bc_V221118(c: CZSC, di=1, n=3, m=10, line=3): :param line: 选第几个上下轨 :return: """ + cache_key = update_boll_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}N{n}M{m}L{line}_BOLL背驰".split('_') bn = get_sub_elements(c.bars_raw, di=di, n=n) bm = get_sub_elements(c.bars_raw, di=di, n=m) d_c1 = min([x.low for x in bn]) <= min([x.low for x in bm]) - d_c2 = sum([x.close < x.cache['boll'][f'下轨{line}'] for x in bm]) > 1 - d_c3 = sum([x.close < x.cache['boll'][f'下轨{line}'] for x in bn]) == 0 + d_c2 = sum([x.close < x.cache[cache_key][f'下轨{line}'] for x in bm]) > 1 + d_c3 = sum([x.close < x.cache[cache_key][f'下轨{line}'] for x in bn]) == 0 g_c1 = max([x.high for x in bn]) == max([x.high for x in bm]) - g_c2 = sum([x.close > x.cache['boll'][f'上轨{line}'] for x in bm]) > 1 - g_c3 = sum([x.close > x.cache['boll'][f'上轨{line}'] for x in bn]) == 0 + g_c2 = sum([x.close > x.cache[cache_key][f'上轨{line}'] for x in bm]) > 1 + g_c3 = sum([x.close > x.cache[cache_key][f'上轨{line}'] for x in bn]) == 0 if d_c1 and d_c2 and d_c3: v1 = "一买" @@ -801,7 +870,7 @@ def tas_boll_bc_V221118(c: CZSC, di=1, n=3, m=10, line=3): # KDJ信号计算函数 # ====================================================================================================================== -def update_kdj_cache(c: CZSC, **kwargs) -> None: +def update_kdj_cache(c: CZSC, **kwargs): """更新KDJ缓存 :param c: CZSC对象 @@ -810,30 +879,45 @@ def update_kdj_cache(c: CZSC, **kwargs) -> None: fastk_period = kwargs.get('fastk_period', 9) slowk_period = kwargs.get('slowk_period', 3) slowd_period = kwargs.get('slowd_period', 3) + cache_key = f"KDJ({fastk_period},{slowk_period},{slowd_period})" + + if c.bars_raw[-1].cache and c.bars_raw[-1].cache.get(cache_key, None): + # 如果最后一根K线已经有对应的缓存,不执行更新 + return cache_key min_count = fastk_period + slowk_period - cache_key = f"KDJ({fastk_period},{slowk_period},{slowd_period})" last_cache = dict(c.bars_raw[-2].cache) if c.bars_raw[-2].cache else dict() - if cache_key not in last_cache.keys() or len(c.bars_raw) < min_count + 30: + if cache_key not in last_cache.keys() or len(c.bars_raw) < min_count + 15: bars = c.bars_raw - min_count = 0 - else: - bars = c.bars_raw[-min_count-30:] + high = np.array([x.high for x in bars]) + low = np.array([x.low for x in bars]) + close = np.array([x.close for x in bars]) + + k, d = ta.STOCH(high, low, close, fastk_period=fastk_period, slowk_period=slowk_period, slowd_period=slowd_period) + j = list(map(lambda x, y: 3 * x - 2 * y, k, d)) - high = np.array([x.high for x in bars]) - low = np.array([x.low for x in bars]) - close = np.array([x.close for x in bars]) + for i in range(len(close)): + _c = dict(c.bars_raw[i].cache) if c.bars_raw[i].cache else dict() + _c.update({cache_key: {'k': k[i] if k[i] else 0, 'd': d[i] if d[i] else 0, 'j': j[i] if j[i] else 0}}) + c.bars_raw[i].cache = _c - k, d = ta.STOCH(high, low, close, fastk_period=fastk_period, slowk_period=slowk_period, slowd_period=slowd_period) - j = list(map(lambda x, y: 3*x - 2*y, k, d)) + else: + bars = c.bars_raw[-min_count - 10:] + high = np.array([x.high for x in bars]) + low = np.array([x.low for x in bars]) + close = np.array([x.close for x in bars]) + k, d = ta.STOCH(high, low, close, fastk_period=fastk_period, slowk_period=slowk_period, slowd_period=slowd_period) + j = list(map(lambda x, y: 3 * x - 2 * y, k, d)) + + for i in range(1, 6): + _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() + _c.update({cache_key: {'k': k[-i], 'd': d[-i], 'j': j[-i]}}) + c.bars_raw[-i].cache = _c - for i in range(1, len(close) - min_count - 10): - _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() - _c.update({cache_key: {'k': k[-i], 'd': d[-i], 'j': j[-i]}}) - c.bars_raw[-i].cache = _c + return cache_key -def tas_kdj_base_V221101(c: CZSC, di: int = 1, key="KDJ(9,3,3)") -> OrderedDict: +def tas_kdj_base_V221101(c: CZSC, di: int = 1, **kwargs) -> OrderedDict: """KDJ金叉死叉信号 **信号逻辑:** @@ -850,12 +934,12 @@ def tas_kdj_base_V221101(c: CZSC, di: int = 1, key="KDJ(9,3,3)") -> OrderedDict: :param c: CZSC对象 :param di: 信号计算截止倒数第i根K线 - :param key: 指定使用哪个Key来计算,必须是 `update_kdj_cache` 中已经缓存的 key :return: """ + cache_key = update_kdj_cache(c, **kwargs) k1, k2, k3 = f"{c.freq.value}_D{di}K_KDJ".split('_') bars = get_sub_elements(c.bars_raw, di=di, n=3) - kdj = bars[-1].cache[key] + kdj = bars[-1].cache[cache_key] if kdj['j'] > kdj['k'] > kdj['d']: v1 = "多头" @@ -864,7 +948,7 @@ def tas_kdj_base_V221101(c: CZSC, di: int = 1, key="KDJ(9,3,3)") -> OrderedDict: else: v1 = "其他" - v2 = "向上" if kdj['j'] >= bars[-2].cache[key]['j'] else "向下" + v2 = "向上" if kdj['j'] >= bars[-2].cache[cache_key]['j'] else "向下" s = OrderedDict() signal = Signal(k1=k1, k2=k2, k3=k3, v1=v1, v2=v2) @@ -874,10 +958,10 @@ def tas_kdj_base_V221101(c: CZSC, di: int = 1, key="KDJ(9,3,3)") -> OrderedDict: # RSI信号计算函数 # ====================================================================================================================== -def update_rsi_cache(c: CZSC, **kwargs) -> None: +def update_rsi_cache(c: CZSC, **kwargs): """更新RSI缓存 - 相对强弱指数(RSI)是通过比较一段时期内的平均收盘涨数和平均收盘跌数来分析市场买沽盘的意向和实力,从而作出未来市场的走势。 + 相对强弱指数(RSI)是通过比较一段时期内的平均收盘涨数和平均收盘跌数来分析市场买沽盘的意向和实力,从而判断未来市场的走势。 RSI在1978年6月由WellsWider创制。 RSI = 100 × RS / (1 + RS) 或者 RSI=100-100÷(1+RS) @@ -887,26 +971,36 @@ def update_rsi_cache(c: CZSC, **kwargs) -> None: :return: """ timeperiod = kwargs.get('timeperiod', 9) - - min_count = timeperiod + 5 cache_key = f"RSI{timeperiod}" + if c.bars_raw[-1].cache and c.bars_raw[-1].cache.get(cache_key, None): + # 如果最后一根K线已经有对应的缓存,不执行更新 + return cache_key + last_cache = dict(c.bars_raw[-2].cache) if c.bars_raw[-2].cache else dict() - if cache_key not in last_cache.keys() or len(c.bars_raw) < min_count + 30: - bars = c.bars_raw - min_count = 0 + if cache_key not in last_cache.keys() or len(c.bars_raw) < timeperiod + 15: + # 初始化缓存 + close = np.array([x.close for x in c.bars_raw]) + rsi = ta.RSI(close, timeperiod=timeperiod) + + for i in range(len(close)): + _c = dict(c.bars_raw[i].cache) if c.bars_raw[i].cache else dict() + _c.update({cache_key: rsi[i] if rsi[i] else 0}) + c.bars_raw[i].cache = _c + else: - bars = c.bars_raw[-min_count-30:] - close = np.array([x.close for x in bars]) + # 增量更新最近5个K线缓存 + close = np.array([x.close for x in c.bars_raw[-timeperiod - 10:]]) + rsi = ta.RSI(close, timeperiod=timeperiod) - rsi = ta.RSI(close, timeperiod=timeperiod) + for i in range(1, 6): + _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() + _c.update({cache_key: rsi[-i]}) + c.bars_raw[-i].cache = _c - for i in range(1, len(close) - min_count - 10): - _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() - _c.update({cache_key: rsi[-i]}) - c.bars_raw[-i].cache = _c + return cache_key -def tas_double_rsi_V221203(c: CZSC, di: int = 1, rsi1="RSI5", rsi2='RSI10') -> OrderedDict: +def tas_double_rsi_V221203(c: CZSC, di: int = 1, rsi_seq=(5, 10), **kwargs) -> OrderedDict: """两个周期的RSI多空信号 **信号逻辑:** @@ -915,16 +1009,20 @@ def tas_double_rsi_V221203(c: CZSC, di: int = 1, rsi1="RSI5", rsi2='RSI10') -> O **信号列表:** - - Signal('日线_D2K_RSI6RSI12_多头_任意_任意_0') - - Signal('日线_D2K_RSI6RSI12_空头_任意_任意_0') + - Signal('15分钟_D1K_RSI5#10_空头_任意_任意_0') + - Signal('15分钟_D1K_RSI5#10_多头_任意_任意_0') :param c: CZSC对象 :param di: 信号计算截止倒数第i根K线 - :param rsi1: 指定短期RSI,必须是 `update_rsi_cache` 中已经缓存的 key - :param rsi2: 指定长期RSI,必须是 `update_rsi_cache` 中已经缓存的 key + :param di: 信号计算截止倒数第i根K线 + :param rsi_seq: 指定短期RSI, 长期RSI 参数 :return: 信号识别结果 """ - k1, k2, k3 = f"{c.freq.value}_D{di}K_{rsi1.upper()}{rsi2.upper()}".split('_') + assert len(rsi_seq) == 2 and rsi_seq[1] > rsi_seq[0] + rsi1 = update_rsi_cache(c, timeperiod=rsi_seq[0]) + rsi2 = update_rsi_cache(c, timeperiod=rsi_seq[1]) + + k1, k2, k3 = f"{c.freq.value}_D{di}K_RSI{rsi_seq[0]}#{rsi_seq[1]}".split('_') bars = get_sub_elements(c.bars_raw, di=di, n=3) rsi1v = bars[-1].cache[rsi1] rsi2v = bars[-1].cache[rsi2] @@ -934,6 +1032,3 @@ def tas_double_rsi_V221203(c: CZSC, di: int = 1, rsi1="RSI5", rsi2='RSI10') -> O signal = Signal(k1=k1, k2=k2, k3=k3, v1=v1) s[signal.key] = signal.value return s - - - diff --git a/czsc/traders/__init__.py b/czsc/traders/__init__.py index 689a87bc2..2454386f2 100644 --- a/czsc/traders/__init__.py +++ b/czsc/traders/__init__.py @@ -5,9 +5,8 @@ create_dt: 2021/11/1 22:20 describe: 交易员(traders)的主要职能是依据感应系统(sensors)的输出来调整仓位,以此应对变幻无常的市场风险。 """ - -from czsc.traders.advanced import CzscAdvancedTrader, create_advanced_trader, CzscDummyTrader -from czsc.traders.ts_backtest import TsStocksBacktest +from czsc.traders.base import CzscSignals, CzscAdvancedTrader +from czsc.traders.advanced import create_advanced_trader, CzscDummyTrader from czsc.traders.performance import TradersPerformance, PairsPerformance from czsc.traders.ts_simulator import TradeSimulator from czsc.traders.utils import trader_fast_backtest, trade_replay diff --git a/czsc/traders/advanced.py b/czsc/traders/advanced.py index 1ecdda3a8..c07b46b74 100644 --- a/czsc/traders/advanced.py +++ b/czsc/traders/advanced.py @@ -8,6 +8,7 @@ import os import webbrowser import pandas as pd +from deprecated import deprecated from collections import OrderedDict from typing import Callable, List from pyecharts.charts import Tab @@ -19,6 +20,7 @@ from czsc.utils.cache import home_path +@deprecated(reason="已统一到 czsc.traders.base 中") class CzscAdvancedTrader: """缠中说禅技术分析理论之多级别联立交易决策类(支持分批开平仓 / 支持从任意周期开始交易)""" @@ -227,6 +229,7 @@ def create_advanced_trader(bg: BarGenerator, raw_bars: List[RawBar], strategy: C :param strategy: 择时交易策略 :return: trader """ + from czsc.traders.base import CzscAdvancedTrader trader = CzscAdvancedTrader(bg, strategy) for bar in raw_bars: trader.update(bar) diff --git a/czsc/traders/base.py b/czsc/traders/base.py new file mode 100644 index 000000000..3500585e0 --- /dev/null +++ b/czsc/traders/base.py @@ -0,0 +1,375 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2022/12/24 22:20 +describe: 简单的单仓位策略执行 +""" +import os +import webbrowser +import numpy as np +import pandas as pd +from collections import OrderedDict +from typing import Callable, List +from pyecharts.charts import Tab +from pyecharts.components import Table +from pyecharts.options import ComponentTitleOpts +from czsc.analyze import CZSC +from czsc.objects import Position, PositionLong, PositionShort, Operate, Event, RawBar +from czsc.utils import BarGenerator, x_round +from czsc.utils.cache import home_path + + +class CzscSignals: + """缠中说禅技术分析理论之多级别信号计算""" + + def __init__(self, bg: BarGenerator, get_signals: Callable = None): + """ + + :param bg: K线合成器 + :param get_signals: 信号计算函数 + """ + self.name = "CzscAdvancedTrader" + self.bg = bg + assert bg.symbol, "bg.symbol is None" + self.symbol = bg.symbol + self.base_freq = bg.base_freq + self.freqs = list(bg.bars.keys()) + self.get_signals: Callable = get_signals + self.kas = {freq: CZSC(b) for freq, b in bg.bars.items()} + + # cache 是信号计算过程的缓存容器,需要信号计算函数自行维护 + self.cache = OrderedDict() + + last_bar = self.kas[self.base_freq].bars_raw[-1] + self.end_dt, self.bid, self.latest_price = last_bar.dt, last_bar.id, last_bar.close + if self.get_signals: + self.s = self.get_signals(self) + self.s.update(last_bar.__dict__) + else: + self.s = OrderedDict() + + def __repr__(self): + return "<{} for {}>".format(self.name, self.symbol) + + def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "580px"): + """获取快照 + + :param file_html: 交易快照保存的 html 文件名 + :param width: 图表宽度 + :param height: 图表高度 + :return: + """ + tab = Tab(page_title="{}@{}".format(self.symbol, self.end_dt.strftime("%Y-%m-%d %H:%M"))) + for freq in self.freqs: + ka: CZSC = self.kas[freq] + chart = ka.to_echarts(width, height) + tab.add(chart, freq) + + signals = {k: v for k, v in self.s.items() if len(k.split("_")) == 3} + for freq in self.freqs: + # 按各周期K线分别加入信号表 + freq_signals = {k: signals[k] for k in signals.keys() if k.startswith("{}_".format(freq))} + for k in freq_signals.keys(): + signals.pop(k) + if len(freq_signals) <= 0: + continue + t1 = Table() + t1.add(["名称", "数据"], [[k, v] for k, v in freq_signals.items()]) + t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) + tab.add(t1, f"{freq}信号") + + if len(signals) > 0: + # 加入时间、持仓状态之类的其他信号 + t1 = Table() + t1.add(["名称", "数据"], [[k, v] for k, v in signals.items()]) + t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) + tab.add(t1, "其他信号") + + if file_html: + tab.render(file_html) + else: + return tab + + def open_in_browser(self, width="1400px", height="580px"): + """直接在浏览器中打开分析结果""" + file_html = os.path.join(home_path, "temp_czsc_advanced_trader.html") + self.take_snapshot(file_html, width, height) + webbrowser.open(file_html) + + def update_signals(self, bar: RawBar): + """输入基础周期已完成K线,更新信号,更新仓位""" + self.bg.update(bar) + for freq, b in self.bg.bars.items(): + self.kas[freq].update(b[-1]) + + self.symbol = bar.symbol + last_bar = self.kas[self.base_freq].bars_raw[-1] + self.end_dt, self.bid, self.latest_price = last_bar.dt, last_bar.id, last_bar.close + + if self.get_signals: + self.s = self.get_signals(self) + self.s.update(last_bar.__dict__) + + +class CzscTrader(CzscSignals): + """缠中说禅技术分析理论之多级别联立交易决策类(支持多策略独立执行)""" + + def __init__(self, bg: BarGenerator, get_signals: Callable = None, positions: List[Position] = None): + super().__init__(bg, get_signals=get_signals) + self.positions = positions + + def update(self, bar: RawBar): + """输入基础周期已完成K线,更新信号,更新仓位""" + self.update_signals(bar) + if self.positions: + for position in self.positions: + position.update(self.s) + + def get_ensemble_pos(self, method="mean"): + """获取多个仓位的集成仓位 + + :param method: 多个仓位集成一个仓位的方法,可选值 mean, vote, max + 假设有三个仓位对象,当前仓位分别是 1, 1, -1 + mean - 平均仓位,pos = np.mean([1, 1, -1]) = 0.33 + vote - 投票表决,pos = 1 + max - 取最大,pos = 1 + :return: pos, 集成仓位 + """ + if not self.positions: + return 0 + pos_seq = [x.pos for x in self.positions] + + if method.lower() == 'mean': + pos = np.mean(pos_seq) + elif method.lower() == 'vote': + _v = sum(pos_seq) + if _v > 0: + pos = 1 + elif _v < 0: + pos = -1 + else: + pos = 0 + elif method.lower() == 'max': + pos = max(pos_seq) + else: + raise ValueError + return pos + + def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "580px"): + """获取快照 + + :param file_html: 交易快照保存的 html 文件名 + :param width: 图表宽度 + :param height: 图表高度 + :return: + """ + tab = Tab(page_title="{}@{}".format(self.symbol, self.end_dt.strftime("%Y-%m-%d %H:%M"))) + for freq in self.freqs: + ka: CZSC = self.kas[freq] + bs = None + if freq == self.base_freq: + # 在基础周期K线上加入最近的操作记录 + bs = [] + for pos in self.positions: + for op in pos.operates: + if op['dt'] >= ka.bars_raw[0].dt: + bs.append(op) + + chart = ka.to_echarts(width, height, bs) + tab.add(chart, freq) + + signals = {k: v for k, v in self.s.items() if len(k.split("_")) == 3} + for freq in self.freqs: + # 按各周期K线分别加入信号表 + freq_signals = {k: signals[k] for k in signals.keys() if k.startswith("{}_".format(freq))} + for k in freq_signals.keys(): + signals.pop(k) + if len(freq_signals) <= 0: + continue + t1 = Table() + t1.add(["名称", "数据"], [[k, v] for k, v in freq_signals.items()]) + t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) + tab.add(t1, f"{freq}信号") + + if len(signals) > 0: + # 加入时间、持仓状态之类的其他信号 + t1 = Table() + t1.add(["名称", "数据"], [[k, v] for k, v in signals.items()]) + t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) + tab.add(t1, "其他信号") + + if file_html: + tab.render(file_html) + else: + return tab + + +class CzscAdvancedTrader(CzscSignals): + """缠中说禅技术分析理论之多级别联立交易决策类(支持分批开平仓 / 支持从任意周期开始交易)""" + + def __init__(self, bg: BarGenerator, strategy: Callable = None): + """ + + :param bg: K线合成器 + :param strategy: 择时策略描述函数 + 注意,strategy 函数必须是仅接受一个 symbol 参数的函数 + """ + self.name = "CzscAdvancedTrader" + self.strategy = strategy + tactic = self.strategy("") if strategy else {} + self.get_signals: Callable = tactic.get('get_signals') + self.tactic = tactic + self.long_events: List[Event] = tactic.get('long_events', None) + self.long_pos: PositionLong = tactic.get('long_pos', None) + self.long_holds = [] # 记录基础周期结束时间对应的多头仓位信息 + self.short_events: List[Event] = tactic.get('short_events', None) + self.short_pos: PositionShort = tactic.get('short_pos', None) + self.short_holds = [] # 记录基础周期结束时间对应的空头仓位信息 + super().__init__(bg, get_signals=self.get_signals) + + def __repr__(self): + return "<{} for {}>".format(self.name, self.symbol) + + def take_snapshot(self, file_html=None, width: str = "1400px", height: str = "580px"): + """获取快照 + + :param file_html: 交易快照保存的 html 文件名 + :param width: 图表宽度 + :param height: 图表高度 + :return: + """ + tab = Tab(page_title="{}@{}".format(self.symbol, self.end_dt.strftime("%Y-%m-%d %H:%M"))) + for freq in self.freqs: + ka: CZSC = self.kas[freq] + bs = None + if freq == self.base_freq: + # 在基础周期K线上加入最近的操作记录 + bs = [] + if self.long_pos: + for op in self.long_pos.operates[-10:]: + if op['dt'] >= ka.bars_raw[0].dt: + bs.append(op) + + if self.short_pos: + for op in self.short_pos.operates[-10:]: + if op['dt'] >= ka.bars_raw[0].dt: + bs.append(op) + + chart = ka.to_echarts(width, height, bs) + tab.add(chart, freq) + + signals = {k: v for k, v in self.s.items() if len(k.split("_")) == 3} + for freq in self.freqs: + # 按各周期K线分别加入信号表 + freq_signals = {k: signals[k] for k in signals.keys() if k.startswith("{}_".format(freq))} + for k in freq_signals.keys(): + signals.pop(k) + if len(freq_signals) <= 0: + continue + t1 = Table() + t1.add(["名称", "数据"], [[k, v] for k, v in freq_signals.items()]) + t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) + tab.add(t1, f"{freq}信号") + + if len(signals) > 0: + # 加入时间、持仓状态之类的其他信号 + t1 = Table() + t1.add(["名称", "数据"], [[k, v] for k, v in signals.items()]) + t1.set_global_opts(title_opts=ComponentTitleOpts(title="缠中说禅信号表", subtitle="")) + tab.add(t1, "其他信号") + + if file_html: + tab.render(file_html) + else: + return tab + + def update(self, bar: RawBar): + """输入基础周期已完成K线,更新信号,更新仓位""" + self.update_signals(bar) + last_bar = self.kas[self.base_freq].bars_raw[-1] + dt, bid, price, symbol = self.end_dt, self.bid, self.latest_price, self.symbol + assert last_bar.dt == dt and last_bar.id == bid and last_bar.close == price + + last_n1b = last_bar.close / self.kas[self.base_freq].bars_raw[-2].close - 1 + # 遍历 long_events,更新 long_pos + if self.long_events: + assert isinstance(self.long_pos, PositionLong), "long_events 必须配合 PositionLong 使用" + + op = Operate.HO + op_desc = "" + + for event in self.long_events: + m, f = event.is_match(self.s) + if m: + op = event.operate + op_desc = f"{event.name}@{f}" + break + + self.long_pos.update(dt, op, price, bid, op_desc) + if self.long_holds: + self.long_holds[-1]['n1b'] = last_n1b + self.long_holds.append({'dt': dt, 'symbol': symbol, 'long_pos': self.long_pos.pos, 'n1b': 0}) + + # 遍历 short_events,更新 short_pos + if self.short_events: + assert isinstance(self.short_pos, PositionShort), "short_events 必须配合 PositionShort 使用" + + op = Operate.HO + op_desc = "" + + for event in self.short_events: + m, f = event.is_match(self.s) + if m: + op = event.operate + op_desc = f"{event.name}@{f}" + break + + self.short_pos.update(dt, op, price, bid, op_desc) + if self.short_holds: + self.short_holds[-1]['n1b'] = -last_n1b + self.short_holds.append({'dt': dt, 'symbol': symbol, 'short_pos': self.short_pos.pos, 'n1b': 0}) + + @property + def results(self): + """汇集回测相关结果""" + res = {} + ct = self + dt_fmt = "%Y-%m-%d %H:%M" + if ct.long_pos: + df_holds = pd.DataFrame(ct.long_holds) + + p = {"开始时间": df_holds['dt'].min().strftime(dt_fmt), + "结束时间": df_holds['dt'].max().strftime(dt_fmt), + "基准收益": x_round(df_holds['n1b'].sum(), 4), + "覆盖率": x_round(df_holds['long_pos'].mean(), 4)} + + df_holds['持仓收益'] = df_holds['long_pos'] * df_holds['n1b'] + df_holds['累计基准'] = df_holds['n1b'].cumsum() + df_holds['累计收益'] = df_holds['持仓收益'].cumsum() + + res['long_holds'] = df_holds + res['long_operates'] = ct.long_pos.operates + res['long_pairs'] = ct.long_pos.pairs + res['long_performance'] = ct.long_pos.evaluate_operates() + res['long_performance'].update(dict(p)) + + if ct.short_pos: + df_holds = pd.DataFrame(ct.short_holds) + + p = {"开始时间": df_holds['dt'].min().strftime(dt_fmt), + "结束时间": df_holds['dt'].max().strftime(dt_fmt), + "基准收益": x_round(df_holds['n1b'].sum(), 4), + "覆盖率": x_round(df_holds['short_pos'].mean(), 4)} + + df_holds['持仓收益'] = df_holds['short_pos'] * df_holds['n1b'] + df_holds['累计基准'] = df_holds['n1b'].cumsum() + df_holds['累计收益'] = df_holds['持仓收益'].cumsum() + + res['short_holds'] = df_holds + res['short_operates'] = ct.short_pos.operates + res['short_pairs'] = ct.short_pos.pairs + res['short_performance'] = ct.short_pos.evaluate_operates() + res['short_performance'].update(dict(p)) + + return res diff --git a/czsc/traders/dummy.py b/czsc/traders/dummy.py new file mode 100644 index 000000000..2bc2e738d --- /dev/null +++ b/czsc/traders/dummy.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2022/12/21 20:04 +describe: +""" +import os +import glob +import json +import shutil +import pandas as pd +from loguru import logger +from tqdm import tqdm +from datetime import datetime +from czsc.traders.utils import trade_replay +from czsc.traders.advanced import CzscDummyTrader +from czsc.sensors.utils import generate_signals +from czsc.traders.performance import PairsPerformance +from czsc.utils import BarGenerator, get_py_namespace, dill_dump, dill_load, WordWriter + + +class DummyBacktest: + def __init__(self, file_strategy): + """ + + :param file_strategy: 策略定义文件,必须是 .py 结尾 + """ + res_path = get_py_namespace(file_strategy)['results_path'] + os.makedirs(res_path, exist_ok=True) + self.signals_path = os.path.join(res_path, "signals") + self.results_path = os.path.join(res_path, f"DEXP_{datetime.now().strftime('%Y%m%d_%H%M%S')}") + os.makedirs(self.signals_path, exist_ok=True) + os.makedirs(self.results_path, exist_ok=True) + + # 创建 CzscDummyTrader 缓存路径 + self.cdt_path = os.path.join(self.results_path, 'cache') + os.makedirs(self.cdt_path, exist_ok=True) + + self.strategy_file = os.path.join(self.results_path, os.path.basename(file_strategy)) + shutil.copy(file_strategy, self.strategy_file) + + self.__debug = get_py_namespace(self.strategy_file).get('debug', False) + logger.add(os.path.join(self.results_path, 'dummy.log')) + + def replay(self): + """执行策略回放""" + py = get_py_namespace(self.strategy_file) + strategy = py['trader_strategy'] + replay_params = py.get('replay_params', {}) + + # 获取单个品种的基础周期K线 + tactic = strategy("000001.SZ") + symbol = replay_params.get('symbol', py['dummy_params']['symbols'][0]) + sdt = pd.to_datetime(replay_params.get('sdt', '20170101')) + mdt = pd.to_datetime(replay_params.get('mdt', '20200101')) + edt = pd.to_datetime(replay_params.get('edt', '20220101')) + bars = py['read_bars'](symbol, sdt, edt) + logger.info(f"交易回放参数 | {symbol} | 时间区间:{sdt} ~ {edt}") + + # 设置回放快照文件保存目录 + res_path = os.path.join(self.results_path, f"replay_{symbol}") + os.makedirs(res_path, exist_ok=True) + + # 拆分基础周期K线,一部分用来初始化BarGenerator,随后的K线是回放区间 + bg = BarGenerator(tactic['base_freq'], freqs=tactic['freqs']) + bars1 = [x for x in bars if x.dt <= mdt] + bars2 = [x for x in bars if x.dt > mdt] + for bar in bars1: + bg.update(bar) + + trade_replay(bg, bars2, strategy, res_path) + + def generate_symbol_signals(self, symbol): + """生成指定品种的交易信号 + + :param symbol: + :return: + """ + py = get_py_namespace(self.strategy_file) + sdt, mdt, edt = py['dummy_params']['sdt'], py['dummy_params']['mdt'], py['dummy_params']['edt'] + + bars = py['read_bars'](symbol, sdt, edt) + signals = generate_signals(bars, sdt=mdt, strategy=py['trader_strategy']) + + df = pd.DataFrame(signals) + if 'cache' in df.columns: + del df['cache'] + + c_cols = [k for k, v in df.dtypes.to_dict().items() if v.name.startswith('object')] + df[c_cols] = df[c_cols].astype('category') + + float_cols = [k for k, v in df.dtypes.to_dict().items() if v.name.startswith('float')] + df[float_cols] = df[float_cols].astype('float32') + return df + + def execute(self): + """执行策略文件中定义的内容""" + signals_path = self.signals_path + py = get_py_namespace(self.strategy_file) + + strategy = py['trader_strategy'] + symbols = py['dummy_params']['symbols'] + + for symbol in symbols: + file_dfs = os.path.join(signals_path, f"{symbol}_signals.pkl") + + try: + # 可以直接生成信号,也可以直接读取信号 + if os.path.exists(file_dfs): + dfs = pd.read_pickle(file_dfs) + else: + dfs = self.generate_symbol_signals(symbol) + dfs.to_pickle(file_dfs) + + if len(dfs) == 0: + continue + + cdt = CzscDummyTrader(dfs, strategy) + dill_dump(cdt, os.path.join(self.cdt_path, f"{symbol}.cdt")) + + res = cdt.results + if "long_performance" in res.keys(): + logger.info(f"{res['long_performance']}") + + if "short_performance" in res.keys(): + logger.info(f"{res['short_performance']}") + except Exception as e: + msg = f"fail on {symbol}: {e}" + if self.__debug: + logger.exception(msg) + else: + logger.warning(msg) + + def collect(self): + """汇集回测结果""" + res = {'long_pairs': [], 'lpf': [], 'short_pairs': [], 'spf': []} + files = glob.glob(f"{self.cdt_path}/*.cdt") + for file in tqdm(files, desc="DummyBacktest::collect"): + cdt = dill_load(file) + + if cdt.results.get("long_pairs", None): + res['lpf'].append(cdt.results['long_performance']) + res['long_pairs'].append(pd.DataFrame(cdt.results['long_pairs'])) + + if cdt.results.get("short_pairs", None): + res['spf'].append(cdt.results['short_performance']) + res['short_pairs'].append(pd.DataFrame(cdt.results['short_pairs'])) + + if res['long_pairs'] and res['lpf']: + long_ppf = PairsPerformance(pd.concat(res['long_pairs'])) + res['long_ppf_basic'] = long_ppf.basic_info + res['long_ppf_year'] = long_ppf.agg_statistics('平仓年') + long_ppf.agg_to_excel(os.path.join(self.results_path, 'long_ppf.xlsx')) + + if res['short_pairs'] and res['spf']: + short_ppf = PairsPerformance(pd.concat(res['short_pairs'])) + res['short_ppf_basic'] = short_ppf.basic_info + res['short_ppf_year'] = short_ppf.agg_statistics('平仓年') + short_ppf.agg_to_excel(os.path.join(self.results_path, 'short_ppf.xlsx')) + + return res + + def report(self): + py = get_py_namespace(self.strategy_file) + strategy = py['trader_strategy']('symbol') + + res = self.collect() + file_word = os.path.join(self.results_path, "report.docx") + if os.path.exists(file_word): + os.remove(file_word) + writer = WordWriter(file_word) + + writer.add_title("策略Dummy回测分析报告") + writer.add_heading("一、基础信息", level=1) + if strategy.get('long_events', None): + writer.add_heading("多头事件定义", level=2) + writer.add_paragraph(json.dumps([x.dump() for x in strategy['long_events']], + ensure_ascii=False, indent=4), first_line_indent=0) + + if strategy.get('short_events', None): + writer.add_heading("空头事件定义", level=2) + writer.add_paragraph(json.dumps([x.dump() for x in strategy['short_events']], ensure_ascii=False, indent=4)) + writer.add_paragraph('\n') + + writer.add_heading("二、回测分析", level=1) + if res.get("long_ppf_basic", None): + writer.add_heading("多头表现", level=2) + lpb = pd.DataFrame([res['long_ppf_basic']]).T.reset_index() + lpb.columns = ['名称', '取值'] + writer.add_df_table(lpb) + writer.add_paragraph('\n') + + lpy = res['long_ppf_year'].T.reset_index() + lpy.columns = lpy.iloc[0] + lpy = lpy.iloc[1:] + writer.add_df_table(lpy) + writer.add_paragraph('\n') + + if res.get("short_ppf_basic", None): + writer.add_heading("空头表现", level=2) + pb = pd.DataFrame([res['short_ppf_basic']]).T.reset_index() + pb.columns = ['名称', '取值'] + writer.add_df_table(pb) + writer.add_paragraph('\n') + + py = res['short_ppf_year'].T.reset_index() + py.columns = py.iloc[0] + py = py.iloc[1:] + writer.add_df_table(py) + writer.add_paragraph('\n') + + writer.save() + + + diff --git a/czsc/traders/ts_backtest.py b/czsc/traders/ts_backtest.py deleted file mode 100644 index 7af6963d1..000000000 --- a/czsc/traders/ts_backtest.py +++ /dev/null @@ -1,254 +0,0 @@ -# -*- coding: utf-8 -*- -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2022/2/14 17:25 -describe: 基于 Tushare 分钟数据的择时策略快速回测 -""" - -import os -import inspect -import pandas as pd -from tqdm import tqdm -from loguru import logger -from typing import Callable -from czsc import envs -from czsc.data import TsDataCache, freq_cn2ts -from czsc.traders.utils import trader_fast_backtest -from czsc.traders.performance import PairsPerformance - - -def read_raw_results(raw_path, trade_dir="long"): - """读入指定路径下的回测原始结果 - - :param raw_path: 原始结果路径 - :param trade_dir: 交易方向 - :return: - """ - assert trade_dir in ['long', 'short'] - - pairs, p = [], [] - for file in tqdm(os.listdir(raw_path)): - if len(file) != 14: - continue - file = os.path.join(raw_path, file) - try: - pairs.append(pd.read_excel(file, sheet_name=f'{trade_dir}_pairs')) - p.append(pd.read_excel(file, sheet_name=f'{trade_dir}_performance')) - except: - logger.exception(f"fail on {file}") - - df_pairs = pd.concat(pairs, ignore_index=True) - df_p = pd.concat(p, ignore_index=True) - return df_pairs, df_p - - -class TsStocksBacktest: - """基于 Tushare 数据的择时回测系统(股票市场)""" - - def __init__(self, - dc: TsDataCache, - strategy: Callable, - init_n: int, - sdt: str, - edt: str, - ): - """ - - :param dc: Tushare 数据缓存对象 - :param strategy: 股票择时策略 - :param init_n: 初始化 Trader 需要的最少基础K线数量 - :param sdt: 开始回测时间 - :param edt: 结束回测时间 - """ - self.name = self.__class__.__name__ - self.strategy = strategy - self.init_n = init_n - self.data_path = dc.data_path - self.res_path = os.path.join(self.data_path, f"{strategy.__name__}_mbl{envs.get_min_bi_len()}") - os.makedirs(self.res_path, exist_ok=True) - - file_strategy = os.path.join(self.res_path, f'{strategy.__name__}_strategy.txt') - with open(file_strategy, 'w', encoding='utf-8') as f: - f.write(inspect.getsource(strategy)) - logger.info(f"strategy saved into {file_strategy}") - - self.dc, self.sdt, self.edt = dc, sdt, edt - stocks = self.dc.stock_basic() - stocks_ = stocks[stocks['list_date'] < '2010-01-01'].ts_code.to_list() - self.stocks_map = { - "index": ['000905.SH', '000016.SH', '000300.SH', '000001.SH', '000852.SH', - '399001.SZ', '399006.SZ', '399376.SZ', '399377.SZ', '399317.SZ', '399303.SZ'], - "stock": stocks.ts_code.to_list(), - "check": ['000001.SZ'], - "train": stocks_[:200], - "valid": stocks_[200:600], - "etfs": ['512880.SH', '518880.SH', '515880.SH', '513050.SH', '512690.SH', - '512660.SH', '512400.SH', '512010.SH', '512000.SH', '510900.SH', - '510300.SH', '510500.SH', '510050.SH', '159992.SZ', '159985.SZ', - '159981.SZ', '159949.SZ', '159915.SZ'], - } - - self.asset_map = { - "index": "I", - "stock": "E", - "check": "E", - "train": "E", - "valid": "E", - "etfs": "FD" - } - - def analyze_results(self, step, trade_dir="long"): - res_path = self.res_path - raw_path = os.path.join(res_path, f'raw_{step}') - if not os.path.exists(raw_path): - return - - df_pairs, df_p = read_raw_results(raw_path, trade_dir) - s_name = self.strategy.__name__ - - df_pairs.to_excel(os.path.join(res_path, f"{s_name}_{step}_{trade_dir}_pairs.xlsx"), index=False) - f = pd.ExcelWriter(os.path.join(res_path, f"{s_name}_{step}_{trade_dir}_performance.xlsx")) - df_p.to_excel(f, sheet_name="评估", index=False) - tp = PairsPerformance(df_pairs) - for col in tp.agg_columns: - df_ = tp.agg_statistics(col) - df_.to_excel(f, sheet_name=f"{col}聚合", index=False) - f.close() - logger.info(f"{s_name} - {step} - {trade_dir}: \n{tp.basic_info}") - - def update_step(self, step: str, ts_codes: list): - """更新指定阶段的批量回测标的 - - :param step: 阶段名称 - :param ts_codes: 标的列表 - :return: - """ - self.stocks_map[step] += ts_codes - - def batch_backtest(self, step): - """批量回测 - - :param step: 择时策略研究阶段 - check 在给定的股票上观察策略交易的准确性,输出交易快照 - index 在股票指数上评估策略表现 - train 在训练集上评估策略表现 - valid 在验证集上评估策略表现 - stock 用全市场所有股票评估策略表现 - :return: - """ - assert step in self.stocks_map.keys(), f"step 参数输入错误,可选值:{list(self.stocks_map.keys())}" - - init_n = self.init_n - save_html = True if step == 'check' else False - ts_codes = self.stocks_map[step] - dc, sdt, edt = self.dc, self.sdt, self.edt - res_path = self.res_path - strategy = self.strategy - raw_path = os.path.join(res_path, f"raw_{step}") - os.makedirs(raw_path, exist_ok=True) - asset = self.asset_map[step] - - for ts_code in ts_codes: - tactic = strategy(ts_code) - base_freq = tactic['base_freq'] - if save_html: - html_path = os.path.join(res_path, f"raw_{step}/{ts_code}") - os.makedirs(html_path, exist_ok=True) - else: - html_path = None - - try: - file_res = os.path.join(res_path, f"raw_{step}/{ts_code}.xlsx") - file_signals = os.path.join(res_path, f"raw_{step}/{ts_code}_signals.pkl") - if os.path.exists(file_res) and os.path.exists(file_signals): - logger.info(f"exits: {file_res}") - continue - - if "分钟" in base_freq: - bars = dc.pro_bar_minutes(ts_code, sdt, edt, freq=freq_cn2ts[base_freq], - asset=asset, adj='hfq', raw_bar=True) - else: - bars = dc.pro_bar(ts_code, sdt, edt, freq=freq_cn2ts[base_freq], - asset=asset, adj='hfq', raw_bar=True) - res = trader_fast_backtest(bars, init_n, strategy, html_path) - - # 保存信号结果 - dfs = pd.DataFrame(res['signals']) - c_cols = [k for k, v in dfs.dtypes.to_dict().items() if v.name.startswith('object')] - dfs[c_cols] = dfs[c_cols].astype('category') - float_cols = [k for k, v in dfs.dtypes.to_dict().items() if v.name.startswith('float')] - dfs[float_cols] = dfs[float_cols].astype('float32') - dfs.to_pickle(file_signals, protocol=4) - - f = pd.ExcelWriter(file_res) - if res.get('long_performance', None): - logger.info(f"{strategy.__name__} long_performance: \n{res['long_performance']}") - pd.DataFrame(res['long_holds']).to_excel(f, sheet_name="long_holds", index=False) - pd.DataFrame(res['long_operates']).to_excel(f, sheet_name="long_operates", index=False) - pd.DataFrame(res['long_pairs']).to_excel(f, sheet_name="long_pairs", index=False) - pd.DataFrame([res['long_performance']]).to_excel(f, sheet_name="long_performance", index=False) - - if res.get('short_performance', None): - logger.info(f"{strategy.__name__} short_performance: \n{res['short_performance']}") - pd.DataFrame(res['short_holds']).to_excel(f, sheet_name="short_holds", index=False) - pd.DataFrame(res['short_operates']).to_excel(f, sheet_name="short_operates", index=False) - pd.DataFrame(res['short_pairs']).to_excel(f, sheet_name="short_pairs", index=False) - pd.DataFrame([res['short_performance']]).to_excel(f, sheet_name="short_performance", index=False) - - f.close() - except: - logger.exception(f"fail on {ts_code}") - - # self.analyze_results(step, 'long') - # self.analyze_results(step, 'short') - # print(f"results saved into {self.res_path}") - - def analyze_signals(self, step: str): - """分析策略中信号的基础表现 - - :param step: - :return: - """ - dc = self.dc - raw_path = os.path.join(self.res_path, f"raw_{step}") - file_dfs = os.path.join(self.res_path, f"{step}_dfs.pkl") - signals_pat = fr"{raw_path}\*_signals.pkl" - freq = freq_cn2ts[self.strategy()['base_freq']] - - # 由于python存在循环导入的问题,只能把两个导入放到这里 - from ..sensors.utils import read_cached_signals, SignalsPerformance - - if not os.path.exists(file_dfs): - dfs = read_cached_signals(file_dfs, signals_pat) - asset = "I" if step == 'index' else "E" - results = [] - for symbol, dfg in tqdm(dfs.groupby('symbol'), desc='add nbar'): - dfk = dc.pro_bar_minutes(symbol, sdt=dfg['dt'].min(), edt=dfg['dt'].max(), - freq=freq, asset=asset, adj='hfq', raw_bar=False) - dfk_cols = ['dt'] + [x for x in dfk.columns if x not in dfs.columns] - dfk = dfk[dfk_cols] - dfs_ = dfg.merge(dfk, on='dt', how='left') - results.append(dfs_) - - dfs = pd.concat(results, ignore_index=True) - c_cols = [k for k, v in dfs.dtypes.to_dict().items() if v.name.startswith('object')] - dfs[c_cols] = dfs[c_cols].astype('category') - float_cols = [k for k, v in dfs.dtypes.to_dict().items() if v.name.startswith('float')] - dfs[float_cols] = dfs[float_cols].astype('float32') - dfs.to_pickle(file_dfs, protocol=4) - else: - dfs = pd.read_pickle(file_dfs) - - results_path = os.path.join(raw_path, 'signals_performance') - if os.path.exists(results_path): - return - - os.makedirs(results_path, exist_ok=True) - signal_cols = [x for x in dfs.columns if len(x.split("_")) == 3] - for key in signal_cols: - file_xlsx = os.path.join(results_path, f"{key.replace(':', '')}.xlsx") - sp = SignalsPerformance(dfs, keys=[key], dc=dc) - sp.report(file_xlsx) - logger.info(f"{key} performance saved into {file_xlsx}") - diff --git a/czsc/utils/word_writer.py b/czsc/utils/word_writer.py index 532f83a15..aa20e9343 100644 --- a/czsc/utils/word_writer.py +++ b/czsc/utils/word_writer.py @@ -56,10 +56,18 @@ def add_heading(self, text, level=1): title_run.element.rPr.rFonts.set(qn('w:eastAsia'), '微软雅黑') title_run.font.color.rgb = RGBColor(0, 0, 0) - def add_paragraph(self, text, style=None, bold=False): + def add_paragraph(self, text, style=None, bold=False, first_line_indent=0.74): + """新增段落 + + :param text: 文本 + :param style: 段落样式 + :param bold: 是否加粗 + :param first_line_indent: 首行缩进,0.74 表示两个空格 + :return: + """ p = self.document.add_paragraph(style=style) p.paragraph_format.left_indent = Cm(0) - p.paragraph_format.first_line_indent = Cm(0.74) + p.paragraph_format.first_line_indent = Cm(first_line_indent) p.paragraph_format.line_spacing = 1.25 p.paragraph_format.space_before = Pt(8) p.paragraph_format.space_after = Pt(8) diff --git a/examples/__init__.py b/examples/__init__.py index bd620fe96..e51f1a93c 100644 --- a/examples/__init__.py +++ b/examples/__init__.py @@ -4,3 +4,19 @@ email: zeng_bin8888@163.com create_dt: 2022/2/15 16:09 """ +import os +import pandas as pd +from czsc.data import TsDataCache + +os.environ['czsc_verbose'] = "0" # 是否输出详细执行信息,包括一些用于debug的信息,0 不输出,1 输出 +os.environ['czsc_min_bi_len'] = "6" # 通过环境变量设定最小笔长度,6 对应新笔定义,7 对应老笔定义 + +pd.set_option('mode.chained_assignment', None) +pd.set_option('display.max_rows', 1000) +pd.set_option('display.max_columns', 20) + + +# data_path 是 TS_CACHE 缓存数据文件夹所在目录 +dc = TsDataCache(data_path=r"C:\ts_data_czsc", refresh=False, sdt="20120101", edt="20221001") + + diff --git a/examples/gm_backtest.py b/examples/gm_backtest.py index a927f5bac..18ac6abf2 100644 --- a/examples/gm_backtest.py +++ b/examples/gm_backtest.py @@ -32,7 +32,18 @@ os.environ['backtest_slippage_ratio'] = '0.0005' """ from czsc.gms.gm_stocks import * -from czsc.strategies import trader_strategy_a as strategy +from examples.strategies.cat_sma import trader_strategy + +os.environ['strategy_id'] = 'b24661f5-838d-11ed-882c-988fe0675a5b' +os.environ['wx_key'] = '2****96b-****-4f83-818b-2952fe2731c0' +os.environ['max_sym_pos'] = '0.5' +os.environ['path_gm_logs'] = 'C:/gm_logs' +os.environ['backtest_start_time'] = '2020-01-01 14:30:00' +os.environ['backtest_end_time'] = '2020-12-31 15:30:00' +os.environ['backtest_initial_cash'] = '100000000' +os.environ['backtest_transaction_ratio'] = '1' +os.environ['backtest_commission_ratio'] = '0.001' +os.environ['backtest_slippage_ratio'] = '0.0005' def init(context): @@ -50,7 +61,8 @@ def init(context): 'SHSE.600010', 'SHSE.600011' ] - name = f"{strategy.__name__}" + name = "stocks_sma5" + strategy = trader_strategy init_context_universal(context, name) init_context_env(context) init_context_traders(context, symbols, strategy) diff --git a/examples/gm_check_point.py b/examples/gm_check_point.py index 4f096aff0..ced59682e 100644 --- a/examples/gm_check_point.py +++ b/examples/gm_check_point.py @@ -5,15 +5,17 @@ create_dt: 2021/11/17 22:26 describe: 使用掘金数据验证买卖点 """ -from czsc.gm_utils import trader_tactic_snapshot, gm_take_snapshot +from czsc.gms.gm_stocks import strategy_snapshot, gm_take_snapshot from czsc.strategies import trader_strategy_a as strategy if __name__ == '__main__': - # _symbol = "SZSE.300669" - # ct = trader_tactic_snapshot(_symbol, end_dt="2022-03-18 13:15", strategy=strategy) _symbol = "SHSE.000016" - ct = trader_tactic_snapshot(_symbol, end_dt="2022-07-27 13:15", strategy=strategy) - # ct = gm_take_snapshot('SHSE.000001') + # 查看含策略交易信号的快照 + ct = strategy_snapshot(_symbol, end_dt="2022-07-27 13:15", strategy=strategy) + + # 仅查看分型、笔的程序化识别 + cts = gm_take_snapshot(_symbol, end_dt="2022-07-27 13:15", max_count=1000) + cts.open_in_browser() diff --git a/examples/gm_realtime.py b/examples/gm_realtime.py index f498c94cc..fa5052c60 100644 --- a/examples/gm_realtime.py +++ b/examples/gm_realtime.py @@ -21,7 +21,14 @@ os.environ['path_gm_logs'] = 'C:/gm_logs' """ from czsc.gms.gm_stocks import * -from czsc.strategies import trader_strategy_a as strategy +from examples.strategies.cat_sma import trader_strategy + + +os.environ['strategy_id'] = 'c7991760-****-11eb-b66a-00163e0c87d1' +os.environ['account_id'] = 'c7991760-****-11eb-b66a-00163e0c87d1' +os.environ['wx_key'] = '2daec96b-****-4f83-818b-2952fe2731c0' +os.environ['max_sym_pos'] = '0.5' +os.environ['path_gm_logs'] = 'C:/gm_logs' def init(context): @@ -39,7 +46,8 @@ def init(context): 'SHSE.600010', 'SHSE.600011' ] - name = f"{strategy.__name__}" + name = "stocks_sma5" + strategy = trader_strategy init_context_universal(context, name) init_context_env(context) init_context_traders(context, symbols, strategy) diff --git a/examples/quick_start.py b/examples/quick_start.py index 97be1db9c..de653c5aa 100644 --- a/examples/quick_start.py +++ b/examples/quick_start.py @@ -11,7 +11,7 @@ from czsc import CZSC, CzscAdvancedTrader, Freq from czsc.utils import BarGenerator from czsc import signals -from czsc.traders.ts_backtest import TsDataCache +from czsc.data import TsDataCache os.environ['czsc_verbose'] = "1" # 是否输出详细执行信息,0 不输出,1 输出 os.environ['czsc_min_bi_len'] = "6" # 通过环境变量设定最小笔长度,6 对应新笔定义,7 对应老笔定义 diff --git a/examples/strategies/cat_sma.py b/examples/strategies/cat_sma.py index 40868c649..1b3d055f5 100644 --- a/examples/strategies/cat_sma.py +++ b/examples/strategies/cat_sma.py @@ -3,8 +3,9 @@ author: zengbin93 email: zeng_bin8888@163.com create_dt: 2022/10/21 19:56 -describe: 择时交易策略 +describe: 择时交易策略样例 """ +from loguru import logger from collections import OrderedDict from czsc import signals from czsc.data import TsDataCache, get_symbols @@ -12,64 +13,65 @@ from czsc.traders import CzscAdvancedTrader from czsc.objects import PositionLong, PositionShort, RawBar +logger.disable('czsc.signals.cxt') + # 定义择时交易策略,策略函数名称必须是 trader_strategy # ---------------------------------------------------------------------------------------------------------------------- - def trader_strategy(symbol): - """择时策略""" + """5日线""" + def get_signals(cat: CzscAdvancedTrader) -> OrderedDict: s = OrderedDict({"symbol": cat.symbol, "dt": cat.end_dt, "close": cat.latest_price}) - s.update(signals.pos.get_s_long01(cat, th=100)) - s.update(signals.pos.get_s_long02(cat, th=100)) - s.update(signals.pos.get_s_long05(cat, span='月', th=500)) - for _, c in cat.kas.items(): - if c.freq in [Freq.F15]: - s.update(signals.bxt.get_s_d0_bi(c)) - s.update(signals.other.get_s_zdt(c, di=1)) - s.update(signals.other.get_s_op_time_span(c, op='开多', time_span=('13:00', '14:50'))) - s.update(signals.other.get_s_op_time_span(c, op='平多', time_span=('09:35', '14:50'))) - - if c.freq in [Freq.F60, Freq.D, Freq.W]: - s.update(signals.ta.get_s_macd(c, di=1)) + s.update(signals.bar_operate_span_V221111(cat.kas['15分钟'], k1='全天', span=('0935', '1450'))) + s.update(signals.bar_operate_span_V221111(cat.kas['15分钟'], k1='临收盘', span=('1410', '1450'))) + s.update(signals.bar_operate_span_V221111(cat.kas['15分钟'], k1='下午', span=('1300', '1450'))) + s.update(signals.bar_operate_span_V221111(cat.kas['15分钟'], k1='上午', span=('0935', '1130'))) + s.update(signals.bar_zdt_V221111(cat, '15分钟', di=1)) + s.update(signals.bar_mean_amount_V221112(cat.kas['日线'], di=2, n=20, th1=2, th2=1000)) + + signals.update_ma_cache(cat.kas["日线"], ma_type='SMA', timeperiod=5) + s.update(signals.tas_ma_base_V221101(cat.kas["日线"], di=1, key='SMA5')) + s.update(signals.tas_ma_base_V221101(cat.kas["日线"], di=2, key='SMA5')) + s.update(signals.tas_ma_base_V221101(cat.kas["日线"], di=5, key='SMA5')) + c = cat.kas['30分钟'] + s.update(signals.cxt_first_buy_V221126(c, di=1)) return s # 定义多头持仓对象和交易事件 long_pos = PositionLong(symbol, hold_long_a=1, hold_long_b=1, hold_long_c=1, - T0=False, long_min_interval=3600*4) + T0=False, long_min_interval=3600 * 4) long_events = [ - Event(name="开多", operate=Operate.LO, factors=[ - Factor(name="低吸", signals_all=[ - Signal("开多时间范围_13:00_14:50_是_任意_任意_0"), - Signal("15分钟_倒1K_ZDT_非涨跌停_任意_任意_0"), - Signal("60分钟_倒1K_MACD多空_多头_任意_任意_0"), - Signal("15分钟_倒0笔_方向_向上_任意_任意_0"), - Signal("15分钟_倒0笔_长度_5根K线以下_任意_任意_0"), - ]), - ]), - - Event(name="平多", operate=Operate.LE, factors=[ - Factor(name="持有资金", signals_all=[ - Signal("平多时间范围_09:35_14:50_是_任意_任意_0"), - Signal("15分钟_倒1K_ZDT_非涨跌停_任意_任意_0"), - ], signals_not=[ - Signal("15分钟_倒0笔_方向_向上_任意_任意_0"), - Signal("60分钟_倒1K_MACD多空_多头_任意_任意_0"), - ]), - ]), + Event(name="开多", operate=Operate.LO, + signals_not=[Signal('15分钟_D1K_涨跌停_涨停_任意_任意_0')], + signals_all=[Signal("日线_D2K20B均额_2至1000千万_是_任意_任意_0")], + factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("全天_0935_1450_是_任意_任意_0"), + Signal("日线_D1K_SMA5_多头_任意_任意_0"), + Signal("日线_D5K_SMA5_空头_向下_任意_0"), + Signal('30分钟_D1B_BUY1_一买_任意_任意_0'), + ]), + ]), + + Event(name="平多", operate=Operate.LE, + signals_not=[Signal('15分钟_D1K_涨跌停_跌停_任意_任意_0')], + factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("下午_1300_1450_是_任意_任意_0"), + Signal("日线_D1K_SMA5_空头_任意_任意_0"), + Signal("日线_D2K_SMA5_多头_任意_任意_0"), + ]), + ]), ] tactic = { "base_freq": '15分钟', - "freqs": ['60分钟', '日线'], + "freqs": ['30分钟', '日线'], "get_signals": get_signals, "long_pos": long_pos, "long_events": long_events, - - # 空头策略不进行定义,也就是不做空头交易 - "short_pos": None, - "short_events": None, } return tactic @@ -78,21 +80,49 @@ def get_signals(cat: CzscAdvancedTrader) -> OrderedDict: # 定义命令行接口的特定参数 # ---------------------------------------------------------------------------------------------------------------------- -# 初始化 Tushare 数据缓存 -dc = TsDataCache(r"C:\ts_data_czsc") +# 【必须】执行结果路径 +results_path = r"D:\ts_data\TS_SMA5" -# 定义回测使用的标的列表 -symbols = get_symbols(dc, 'train')[:3] +# 初始化 Tushare 数据缓存 +dc = TsDataCache(r"D:\ts_data") + +# 【必须】策略回测参数设置 +dummy_params = { + "symbols": get_symbols(dc, 'train'), # 回测使用的标的列表 + "sdt": "20150101", # K线数据开始时间 + "mdt": "20200101", # 策略回测开始时间 + "edt": "20220101", # 策略回测结束时间 +} -# 执行结果路径 -results_path = r"C:\ts_data_czsc\cat_sma" -# 策略回放参数设置【可选】 +# 【可选】策略回放参数设置 replay_params = { - "symbol": "000001.SZ#E", # 回放交易品种 - "sdt": "20150101", # K线数据开始时间 - "mdt": "20180101", # 策略回放开始时间 - "edt": "20220101", # 策略回放结束时间 + "symbol": "000002.SZ#E", # 回放交易品种 + "sdt": "20150101", # K线数据开始时间 + "mdt": "20200101", # 策略回放开始时间 + "edt": "20220101", # 策略回放结束时间 } +# 【必须】定义K线数据读取函数,这里是为了方便接入任意数据源的K线行情 +# ---------------------------------------------------------------------------------------------------------------------- + +def read_bars(symbol, sdt, edt): + """自定义K线数据读取函数,便于接入任意来源的行情数据进行回测一类的分析 + + :param symbol: 标的名称 + :param sdt: 行情开始时间 + :param edt: 行情介绍时间 + :return: list of RawBar + """ + adj = 'hfq' + freq = '15min' + ts_code, asset = symbol.split("#") + + if "min" in freq: + bars = dc.pro_bar_minutes(ts_code, sdt, edt, freq=freq, asset=asset, adj=adj, raw_bar=True) + else: + bars = dc.pro_bar(ts_code, sdt, edt, freq=freq, asset=asset, adj=adj, raw_bar=True) + return bars + + diff --git a/examples/strategies/check_signal.py b/examples/strategies/check_signal.py index 0d21a0d3e..2fd30da98 100644 --- a/examples/strategies/check_signal.py +++ b/examples/strategies/check_signal.py @@ -18,46 +18,8 @@ from czsc.traders import CzscAdvancedTrader -# 定义信号函数 +# 【必须】定义信号函数 # ---------------------------------------------------------------------------------------------------------------------- -def macd_bs2_v2(cat: CzscAdvancedTrader, freq: str): - """MACD金叉死叉判断第二买卖点 - - 原理:最近一次交叉为死叉,DEA大于0,且前面三次死叉都在零轴下方,那么二买即将出现;二卖反之。 - - 完全分类: - Signal('15分钟_MACD_BS2V2_二卖_任意_任意_0'), - Signal('15分钟_MACD_BS2V2_二买_任意_任意_0') - :return: - """ - s = OrderedDict() - cache_key = f"{freq}MACD" - cache = cat.cache[cache_key] - assert cache and cache['update_dt'] == cat.end_dt - cross = cache['cross'] - macd = cache['macd'] - up = [x for x in cross if x['类型'] == "金叉"] - dn = [x for x in cross if x['类型'] == "死叉"] - - v1 = "其他" - - b2_con1 = len(cross) > 3 and cross[-1]['类型'] == '死叉' and cross[-1]['慢线'] > 0 - b2_con2 = len(dn) > 3 and dn[-3]['慢线'] < 0 and dn[-2]['慢线'] < 0 and dn[-3]['慢线'] < 0 - b2_con3 = len(macd) > 10 and macd[-1] > macd[-2] - if b2_con1 and b2_con2 and b2_con3: - v1 = "二买" - - s2_con1 = len(cross) > 3 and cross[-1]['类型'] == '金叉' and cross[-1]['慢线'] < 0 - s2_con2 = len(up) > 3 and up[-3]['慢线'] > 0 and up[-2]['慢线'] > 0 and up[-3]['慢线'] > 0 - s2_con3 = len(macd) > 10 and macd[-1] < macd[-2] - if s2_con1 and s2_con2 and s2_con3: - v1 = "二卖" - - signal = Signal(k1=freq, k2="MACD", k3="BS2V2", v1=v1) - s[signal.key] = signal.value - return s - - def tas_macd_first_bs_V221216(c: CZSC, di: int = 1): """MACD金叉死叉判断第一买卖点 @@ -142,19 +104,33 @@ def get_signals(cat: CzscAdvancedTrader) -> OrderedDict: # 定义命令行接口【信号检查】的特定参数 # ---------------------------------------------------------------------------------------------------------------------- -# 初始化 Tushare 数据缓存 -dc = TsDataCache(r"D:\ts_data") - # 信号检查参数设置【可选】 -# check_params = { -# "symbol": "000001.SZ#E", # 交易品种,格式为 {ts_code}#{asset} -# "sdt": "20180101", # 开始时间 -# "edt": "20220101", # 结束时间 -# } - - check_params = { - "symbol": "300001.SZ#E", # 交易品种,格式为 {ts_code}#{asset} - "sdt": "20150101", # 开始时间 - "edt": "20220101", # 结束时间 + "symbol": "000001.SZ#E", # 交易品种,格式为 {ts_code}#{asset} + "sdt": "20180101", # 开始时间 + "edt": "20220101", # 结束时间 } + + +# 【必须】定义K线数据读取函数,这里是为了方便接入任意数据源的K线行情 +# ---------------------------------------------------------------------------------------------------------------------- +def read_bars(symbol, sdt='20170101', edt='20221001'): + """自定义K线数据读取函数,便于接入任意来源的行情数据进行回测一类的分析 + + :param symbol: 标的名称 + :param sdt: 行情开始时间 + :param edt: 行情介绍时间 + :return: list of RawBar + """ + adj = 'hfq' + freq = '15min' + ts_code, asset = symbol.split("#") + # 初始化 Tushare 数据缓存 + dc = TsDataCache(r"D:\ts_data") + + if "min" in freq: + bars = dc.pro_bar_minutes(ts_code, sdt, edt, freq=freq, asset=asset, adj=adj, raw_bar=True) + else: + bars = dc.pro_bar(ts_code, sdt, edt, freq=freq, asset=asset, adj=adj, raw_bar=True) + return bars + diff --git a/examples/strategies/qmt_cat_sma.py b/examples/strategies/qmt_cat_sma.py new file mode 100644 index 000000000..11380759d --- /dev/null +++ b/examples/strategies/qmt_cat_sma.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2022/10/21 19:56 +describe: 择时交易策略样例 +""" +from loguru import logger +from collections import OrderedDict +from czsc import signals +from czsc.traders import CzscAdvancedTrader +from czsc.objects import Freq, Operate, Signal, Factor, Event, RawBar, PositionLong, PositionShort + +logger.disable('czsc.signals.cxt') + +# QMT 数据相关函数 +# ---------------------------------------------------------------------------------------------------------------------- +import pandas as pd +from typing import List +from xtquant import xtdata + + +def format_stock_kline(kline: pd.DataFrame, freq: Freq) -> List[RawBar]: + """QMT A股市场K线数据转换 + + :param kline: QMT 数据接口返回的K线数据 + time open high low close volume amount \ + 0 2022-12-01 10:15:00 13.22 13.22 13.16 13.18 20053 26432861.0 + 1 2022-12-01 10:20:00 13.18 13.19 13.15 13.15 32667 43002512.0 + 2 2022-12-01 10:25:00 13.16 13.18 13.13 13.16 32466 42708049.0 + 3 2022-12-01 10:30:00 13.16 13.19 13.13 13.18 15606 20540461.0 + 4 2022-12-01 10:35:00 13.20 13.25 13.19 13.22 29959 39626170.0 + symbol + 0 000001.SZ + 1 000001.SZ + 2 000001.SZ + 3 000001.SZ + 4 000001.SZ + :param freq: K线周期 + :return: 转换好的K线数据 + """ + bars = [] + dt_key = 'time' + kline = kline.sort_values(dt_key, ascending=True, ignore_index=True) + records = kline.to_dict('records') + + for i, record in enumerate(records): + # 将每一根K线转换成 RawBar 对象 + bar = RawBar(symbol=record['symbol'], dt=pd.to_datetime(record[dt_key]), id=i, freq=freq, + open=record['open'], close=record['close'], high=record['high'], low=record['low'], + vol=record['volume'] * 100 if record['volume'] else 0, # 成交量,单位:股 + amount=record['amount'] if record['amount'] > 0 else 0, # 成交额,单位:元 + ) + bars.append(bar) + return bars + + +def get_local_kline(symbol, period, start_time, end_time, count=-1, dividend_type='none', data_dir=None, update=True): + """获取 QMT 本地K线数据 + + :param symbol: 股票代码 例如:'300001.SZ' + :param period: 周期 分笔"tick" 分钟线"1m"/"5m" 日线"1d" + :param start_time: 开始时间,格式YYYYMMDD/YYYYMMDDhhmmss/YYYYMMDDhhmmss.milli, + 例如:"20200427" "20200427093000" "20200427093000.000" + :param end_time: 结束时间 格式同上 + :param count: 数量 -1全部,n: 从结束时间向前数n个 + :param dividend_type: 除权类型"none" "front" "back" "front_ratio" "back_ratio" + :param data_dir: 下载QMT本地数据路径,如 D:/迅投极速策略交易系统交易终端/datadir + :param update: 更新QMT本地数据路径中的数据 + :return: df Dataframe格式的数据,样例如下 + time open high low close volume amount \ + 0 2022-12-01 10:15:00 13.22 13.22 13.16 13.18 20053 26432861.0 + 1 2022-12-01 10:20:00 13.18 13.19 13.15 13.15 32667 43002512.0 + 2 2022-12-01 10:25:00 13.16 13.18 13.13 13.16 32466 42708049.0 + 3 2022-12-01 10:30:00 13.16 13.19 13.13 13.18 15606 20540461.0 + 4 2022-12-01 10:35:00 13.20 13.25 13.19 13.22 29959 39626170.0 + symbol + 0 000001.SZ + 1 000001.SZ + 2 000001.SZ + 3 000001.SZ + 4 000001.SZ + """ + field_list = ['time', 'open', 'high', 'low', 'close', 'volume', 'amount'] + if update: + xtdata.download_history_data(symbol, period, start_time='20100101', end_time='21000101') + local_data = xtdata.get_local_data(field_list, [symbol], period, count=count, dividend_type=dividend_type, + start_time=start_time, end_time=end_time, data_dir=data_dir) + + df = pd.DataFrame({key: value.values[0] for key, value in local_data.items()}) + df['time'] = pd.to_datetime(df['time'], unit='ms') + pd.to_timedelta('8H') + df.reset_index(inplace=True, drop=True) + df['symbol'] = symbol + return df + + +def get_symbols(step): + """获取择时策略投研不同阶段对应的标的列表 + + :param step: 投研阶段 + :return: 标的列表 + """ + stocks = xtdata.get_stock_list_in_sector('沪深A股') + stocks_map = { + "index": ['000905.SH', '000016.SH', '000300.SH', '000001.SH', '000852.SH', + '399001.SZ', '399006.SZ', '399376.SZ', '399377.SZ', '399317.SZ', '399303.SZ'], + "stock": stocks, + "check": ['000001.SZ'], + "train": stocks[:200], + "valid": stocks[200:600], + "etfs": ['512880.SH', '518880.SH', '515880.SH', '513050.SH', '512690.SH', + '512660.SH', '512400.SH', '512010.SH', '512000.SH', '510900.SH', + '510300.SH', '510500.SH', '510050.SH', '159992.SZ', '159985.SZ', + '159981.SZ', '159949.SZ', '159915.SZ'], + } + return stocks_map[step] + + +# 定义择时交易策略,策略函数名称必须是 trader_strategy +# ---------------------------------------------------------------------------------------------------------------------- +def trader_strategy(symbol): + """5日线""" + + def get_signals(cat: CzscAdvancedTrader) -> OrderedDict: + s = OrderedDict({"symbol": cat.symbol, "dt": cat.end_dt, "close": cat.latest_price}) + s.update(signals.bar_operate_span_V221111(cat.kas['15分钟'], k1='全天', span=('0935', '1450'))) + s.update(signals.bar_operate_span_V221111(cat.kas['15分钟'], k1='临收盘', span=('1410', '1450'))) + s.update(signals.bar_operate_span_V221111(cat.kas['15分钟'], k1='下午', span=('1300', '1450'))) + s.update(signals.bar_operate_span_V221111(cat.kas['15分钟'], k1='上午', span=('0935', '1130'))) + s.update(signals.bar_zdt_V221111(cat, '15分钟', di=1)) + s.update(signals.bar_mean_amount_V221112(cat.kas['日线'], di=2, n=20, th1=2, th2=1000)) + + signals.update_ma_cache(cat.kas["日线"], ma_type='SMA', timeperiod=5) + s.update(signals.tas_ma_base_V221101(cat.kas["日线"], di=1, key='SMA5')) + s.update(signals.tas_ma_base_V221101(cat.kas["日线"], di=2, key='SMA5')) + s.update(signals.tas_ma_base_V221101(cat.kas["日线"], di=5, key='SMA5')) + c = cat.kas['30分钟'] + s.update(signals.cxt_first_buy_V221126(c, di=1)) + return s + + # 定义多头持仓对象和交易事件 + long_pos = PositionLong(symbol, hold_long_a=1, hold_long_b=1, hold_long_c=1, + T0=False, long_min_interval=3600 * 4) + long_events = [ + Event(name="开多", operate=Operate.LO, + signals_not=[Signal('15分钟_D1K_涨跌停_涨停_任意_任意_0')], + signals_all=[Signal("日线_D2K20B均额_2至1000千万_是_任意_任意_0")], + factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("全天_0935_1450_是_任意_任意_0"), + Signal("日线_D1K_SMA5_多头_任意_任意_0"), + Signal("日线_D5K_SMA5_空头_向下_任意_0"), + Signal('30分钟_D1B_BUY1_一买_任意_任意_0'), + ]), + ]), + + Event(name="平多", operate=Operate.LE, + signals_not=[Signal('15分钟_D1K_涨跌停_跌停_任意_任意_0')], + factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("下午_1300_1450_是_任意_任意_0"), + Signal("日线_D1K_SMA5_空头_任意_任意_0"), + Signal("日线_D2K_SMA5_多头_任意_任意_0"), + ]), + ]), + ] + + tactic = { + "base_freq": '5分钟', + "freqs": ['15分钟', '30分钟', '日线'], + "get_signals": get_signals, + + "long_pos": long_pos, + "long_events": long_events, + } + + return tactic + + +# 定义命令行接口的特定参数 +# ---------------------------------------------------------------------------------------------------------------------- +# 【必须】执行结果路径 +results_path = r"D:\ts_data\SMA5" + + +# 【必须】策略回测参数设置 +dummy_params = { + "symbols": get_symbols('train'), # 回测使用的标的列表 + "sdt": "20150101", # K线数据开始时间 + "mdt": "20200101", # 策略回测开始时间 + "edt": "20220101", # 策略回测结束时间 +} + + +# 【可选】策略回放参数设置 +replay_params = { + "symbol": get_symbols('check')[0], # 回放交易品种 + "sdt": "20150101", # K线数据开始时间 + "mdt": "20200101", # 策略回放开始时间 + "edt": "20220101", # 策略回放结束时间 +} + +# 【可选】是否使用 debug 模式输出更多信息 +debug = True + + +# 【必须】定义K线数据读取函数,这里是为了方便接入任意数据源的K线行情 +# ---------------------------------------------------------------------------------------------------------------------- + +def read_bars(symbol, sdt, edt): + """自定义K线数据读取函数,便于接入任意来源的行情数据进行回测一类的分析 + + :param symbol: 标的名称 + :param sdt: 行情开始时间 + :param edt: 行情介绍时间 + :return: list of RawBar + """ + sdt = pd.to_datetime(sdt).strftime("%Y%m%d") + edt = pd.to_datetime(edt).strftime("%Y%m%d") + df = get_local_kline(symbol, period='5m', start_time=sdt, end_time=edt, dividend_type='back', + data_dir=r"D:\迅投极速策略交易系统交易终端 华鑫证券QMT实盘\datadir", update=True) + bars = format_stock_kline(df, Freq.F5) + return bars + + diff --git a/examples/ts_check_signal_acc.py b/examples/ts_check_signal_acc.py index 58624f047..3b3a23a68 100644 --- a/examples/ts_check_signal_acc.py +++ b/examples/ts_check_signal_acc.py @@ -29,14 +29,19 @@ def get_signals(cat: CzscAdvancedTrader) -> OrderedDict: s = OrderedDict({"symbol": cat.symbol, "dt": cat.end_dt, "close": cat.latest_price}) - # signals.update_ma_cache(cat.kas['15分钟'], ma_type='SMA', timeperiod=5) - # signals.update_ma_cache(cat.kas['15分钟'], ma_type='SMA', timeperiod=10) - s.update(signals.bar_mean_amount_V221112(cat.kas['15分钟'], di=2, n=20)) - # s.update(signals.bar_zdt_V221111(cat, '15分钟', di=2)) - # # 使用缓存来更新信号的方法 # signals.update_macd_cache(cat.kas['15分钟']) # s.update(signals.tas_macd_direct_V221106(cat.kas['15分钟'], di=1)) + # s.update(signals.tas_macd_base_V221028(cat.kas['15分钟'], di=1)) + # s.update(signals.tas_macd_first_bs_V221201(cat.kas['15分钟'], di=1)) + # s.update(signals.tas_macd_second_bs_V221201(cat.kas['15分钟'], di=1)) + # s.update(signals.tas_macd_xt_V221208(cat.kas['15分钟'], di=1)) + # s.update(signals.tas_macd_bc_V221201(cat.kas['15分钟'], di=1)) + # s.update(signals.tas_macd_change_V221105(cat.kas['15分钟'], di=1)) + # s.update(signals.tas_boll_bc_V221118(cat.kas['15分钟'], di=1)) + # s.update(signals.tas_boll_power_V221112(cat.kas['15分钟'], di=1)) + # s.update(signals.tas_kdj_base_V221101(cat.kas['15分钟'], di=1)) + s.update(signals.tas_double_rsi_V221203(cat.kas['15分钟'], di=1)) # signals.update_boll_cache(cat.kas['15分钟']) # s.update(signals.tas_boll_power_V221112(cat.kas['15分钟'], di=1)) return s diff --git a/examples/ts_continue_simulator.py b/examples/ts_continue_simulator.py index 9057ad035..911768593 100644 --- a/examples/ts_continue_simulator.py +++ b/examples/ts_continue_simulator.py @@ -5,7 +5,7 @@ create_dt: 2022/5/6 15:54 describe: 使用 Tushare 数据对交易策略进行持续仿真研究 """ -from ts_fast_backtest import TsDataCache +from czsc.data.ts_cache import TsDataCache from czsc.traders.ts_simulator import TradeSimulator from czsc.strategies import trader_strategy_a diff --git a/examples/ts_dummy_trader.py b/examples/ts_dummy_trader.py index 2097708fd..9b4000b92 100644 --- a/examples/ts_dummy_trader.py +++ b/examples/ts_dummy_trader.py @@ -10,7 +10,7 @@ from czsc.strategies import trader_strategy_a as strategy from czsc.traders.advanced import CzscDummyTrader from czsc.sensors.utils import generate_symbol_signals -from examples.ts_fast_backtest import dc +from examples import dc # 可以直接生成信号,也可以直接读取信号 diff --git a/examples/ts_fast_backtest.py b/examples/ts_fast_backtest.py deleted file mode 100644 index f5a93767a..000000000 --- a/examples/ts_fast_backtest.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2021/12/12 22:00 -""" -import os -import pandas as pd -from czsc.traders.ts_backtest import TsDataCache, TsStocksBacktest, freq_cn2ts -from czsc.strategies import trader_strategy_a as strategy - -os.environ['czsc_verbose'] = "0" # 是否输出详细执行信息,包括一些用于debug的信息,0 不输出,1 输出 -os.environ['czsc_min_bi_len'] = "6" # 通过环境变量设定最小笔长度,6 对应新笔定义,7 对应老笔定义 - -pd.set_option('mode.chained_assignment', None) -pd.set_option('display.max_rows', 1000) -pd.set_option('display.max_columns', 20) - - -# data_path 是 TS_CACHE 缓存数据文件夹所在目录 -dc = TsDataCache(data_path=r"C:\ts_data_czsc", refresh=False, sdt="20120101", edt="20221001") - -# 获取策略的基础K线周期,回测开始时间 sdt,回测结束时间 edt,初始化K线数量init_n -freq = freq_cn2ts[strategy('000001.SH')['base_freq']] -sdt = '20140101' -edt = "20211216" -init_n = 1000*4 - - -def run_backtest(step_seq=('check', 'index', 'etfs', 'train', 'valid', 'stock')): - """ - - :param step_seq: 回测执行顺序 - :return: - """ - tsb = TsStocksBacktest(dc, strategy, init_n, sdt, edt) - for step in step_seq: - tsb.batch_backtest(step.lower()) - tsb.analyze_results(step, 'long') - # tsb.analyze_results(step, 'short') - print(f"results saved into {tsb.res_path}") - - -def run_more_backtest(step, ts_codes): - """指定在某个阶段多回测一些标的,最常见的需求是在 check 阶段多检查几个标的 - - :param step: 阶段名称 - :param ts_codes: 新增回测标的列表 - :return: - """ - tsb = TsStocksBacktest(dc, strategy, init_n, sdt, edt) - tsb.update_step(step, ts_codes) - tsb.batch_backtest(step.lower()) - tsb.analyze_results(step, 'long') - # tsb.analyze_results(step, 'short') - print(f"results saved into {tsb.res_path}") - - -if __name__ == '__main__': - # run_more_backtest(step='check', ts_codes=['000002.SZ']) - # run_backtest(step_seq=('index',)) - run_backtest(step_seq=('train',)) - # run_backtest(step_seq=('etfs',)) - # run_backtest(step_seq=('index', 'train')) - # run_backtest(step_seq=('check', 'index', 'train')) - # run_backtest(step_seq=('check', 'index', 'train', 'valid')) - diff --git a/examples/ts_signals_analyze.py b/examples/ts_signals_analyze.py deleted file mode 100644 index a9059d844..000000000 --- a/examples/ts_signals_analyze.py +++ /dev/null @@ -1,74 +0,0 @@ -# -*- coding: utf-8 -*- -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2022/3/24 16:33 -describe: 使用 Tushare 数据分析信号表现 -""" -import os -import traceback -import pandas as pd -from czsc import CZSC, Freq, CzscAdvancedTrader -from collections import OrderedDict -from czsc.data.ts_cache import TsDataCache -from czsc.sensors.utils import read_cached_signals, generate_stocks_signals, SignalsPerformance -from czsc import signals - - -pd.set_option('display.max_rows', 1000) -pd.set_option('display.max_columns', 20) - - -def generate_all_signals(data_path=r"C:\ts_data", name="signals_b"): - """给出信号定义函数,计算全市场股票的所有信号""" - def get_v1_signals(cat: CzscAdvancedTrader): - s = OrderedDict() - for freq, c in cat.kas.items(): - if c.freq == Freq.D: - s.update(signals.bxt.get_s_three_bi(c, di=1)) - s.update(signals.vol.get_s_vol_single_sma(c, di=1, t_seq=(10, 20))) - return s - - def __strategy(symbol): - return { - "symbol": symbol, - "get_signals": get_v1_signals, - "base_freq": '日线', - "freqs": ['周线', '月线'], - } - - # tushare 研究数据缓存,一次缓存,可以重复使用 - dc = TsDataCache(data_path, sdt='2000-01-01', edt='2022-02-18') - - signals_path = os.path.join(data_path, name) - generate_stocks_signals(dc, signals_path, sdt='20100101', edt='20220101', strategy=__strategy) - - -def analyze_signals(): - results_path = r"C:\ts_data\signals_b_analyze" - signals_path = r"C:\ts_data\signals_b" - path_pat = f"{signals_path}\*_signals.pkl" - sdt = "20150101" - edt = "20220101" - file_output = os.path.join(signals_path, f"{sdt}_{edt}_merged.pkl") - dfs = read_cached_signals(file_output, path_pat, sdt, edt) - # 为了方便逐年查看信号表现,新增 year - dfs['year'] = dfs['dt'].apply(lambda x: x.year) - - os.makedirs(results_path, exist_ok=True) - - for col in [x for x in dfs.columns if len(x.split("_")) == 3]: - try: - sp = SignalsPerformance(dfs, keys=[col]) - file_res = os.path.join(results_path, f"{col}_{sdt}_{edt}.xlsx") - sp.report(file_res) - print(f"signal results saved into {file_res}") - except: - print(f"signal analyze failed: {col}") - traceback.print_exc() - - -if __name__ == '__main__': - generate_all_signals(data_path=r"C:\ts_data_czsc", name="signals_b") - analyze_signals() - diff --git a/examples/ts_stocks_sensors.py b/examples/ts_stocks_sensors.py index 2662bf0b6..350c290ba 100644 --- a/examples/ts_stocks_sensors.py +++ b/examples/ts_stocks_sensors.py @@ -60,18 +60,20 @@ def get_event(): row = {"index_code": None, "fc_top_n": None, 'fc_min_n': None, "min_total_mv": None, "max_count": None, 'window_size': 1} sss.write_validate_report("原始选股结果", row) - row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 3, "min_total_mv": 1e6, "max_count": 20, 'window_size': 1} - sss.write_validate_report("1日聚合测试20", row) - row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 2, "min_total_mv": 1e6, "max_count": 50, 'window_size': 1} - sss.write_validate_report("1日聚合测试50", row) - row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 2, "min_total_mv": 1e6, "max_count": 100, 'window_size': 1} - sss.write_validate_report("1日聚合测试100", row) - row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 2, "min_total_mv": 1e6, "max_count": 50, 'window_size': 8} - sss.write_validate_report("8日聚合测试", row) - row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 3, "min_total_mv": 1e6, "max_count": 50, 'window_size': 1} - sss.write_validate_report("1日聚合测试", row) - row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 3, "min_total_mv": 1e6, "max_count": 50, 'window_size': 8} - sss.write_validate_report("8日聚合测试", row) + # row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 3, "min_total_mv": 1e6, "max_count": 20, 'window_size': 1} + # sss.write_validate_report("1日聚合测试20", row) + # row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 2, "min_total_mv": 1e6, "max_count": 50, 'window_size': 1} + # sss.write_validate_report("1日聚合测试50", row) + # row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 2, "min_total_mv": 1e6, "max_count": 100, 'window_size': 1} + # sss.write_validate_report("1日聚合测试100", row) + # row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 2, "min_total_mv": 1e6, "max_count": 50, 'window_size': 8} + # sss.write_validate_report("8日聚合测试", row) + # row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 3, "min_total_mv": 1e6, "max_count": 50, 'window_size': 1} + # sss.write_validate_report("1日聚合测试", row) + # row = {"index_code": None, "fc_top_n": 10, 'fc_min_n': 3, "min_total_mv": 1e6, "max_count": 50, 'window_size': 8} + # sss.write_validate_report("8日聚合测试", row) # 给定参数获取最新的强势股列表 df = sss.get_latest_selected(fc_top_n=None, fc_min_n=None, min_total_mv=None, max_count=None, window_size=1) + print(df) + diff --git a/test/test_advanced_trader.py b/test/test_advanced_trader.py index f94c6d28d..06e00f203 100644 --- a/test/test_advanced_trader.py +++ b/test/test_advanced_trader.py @@ -8,7 +8,8 @@ from tqdm import tqdm from collections import OrderedDict from czsc import signals -from czsc.traders.advanced import CzscAdvancedTrader, BarGenerator +from czsc.utils import BarGenerator +from czsc.traders import CzscAdvancedTrader from czsc.objects import Signal, Factor, Event, Operate, PositionLong, PositionShort from test.test_analyze import read_1min, read_daily diff --git a/test/test_objects.py b/test/test_objects.py index 550e3cf08..bc5d704cf 100644 --- a/test/test_objects.py +++ b/test/test_objects.py @@ -85,6 +85,10 @@ def test_factor(): ) assert factor.is_match(s) + factor_raw = factor.dump() + new_factor = Factor.load(factor_raw) + assert new_factor.is_match(s) + factor = Factor( name="单测", signals_all=[ @@ -148,6 +152,11 @@ def test_event(): m, f = event.is_match(s) assert m and f + raw = event.dump() + new_event = Event.load(raw) + m, f = new_event.is_match(s) + assert m and f + event = Event(name="单测", operate=Operate.LO, factors=[ Factor(name="测试", signals_all=[ diff --git a/test/test_trader_base.py b/test/test_trader_base.py new file mode 100644 index 000000000..a35d77f3f --- /dev/null +++ b/test/test_trader_base.py @@ -0,0 +1,362 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2021/11/7 21:07 +""" +import pandas as pd +from tqdm import tqdm +from loguru import logger +from collections import OrderedDict +from czsc import signals +from czsc.traders.base import CzscSignals, CzscAdvancedTrader, BarGenerator, CzscTrader +from czsc.objects import Signal, Factor, Event, Operate, PositionLong, PositionShort, Position +from test.test_analyze import read_1min, read_daily + + +def test_object_position(): + bars = read_daily() + bg = BarGenerator(base_freq='日线', freqs=['周线', '月线']) + for bar in bars[:1000]: + bg.update(bar) + + def __get_signals(cat) -> OrderedDict: + s = OrderedDict({"symbol": cat.symbol, "dt": cat.end_dt, "close": cat.latest_price}) + s.update(signals.bxt.get_s_three_bi(cat.kas['日线'], di=1)) + s.update(signals.cxt_first_buy_V221126(cat.kas['日线'], di=1)) + s.update(signals.cxt_first_buy_V221126(cat.kas['日线'], di=2)) + s.update(signals.cxt_first_sell_V221126(cat.kas['日线'], di=1)) + s.update(signals.cxt_first_sell_V221126(cat.kas['日线'], di=2)) + return s + + opens = [ + Event(name='开多', operate=Operate.LO, factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("日线_D1B_BUY1_一买_任意_任意_0"), + ]) + ]), + Event(name='开空', operate=Operate.SO, factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("日线_D1B_BUY1_一卖_任意_任意_0"), + ]) + ]), + ] + + exits = [ + Event(name='平多', operate=Operate.LE, factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向上收敛_任意_任意_0"), + ]) + ]), + Event(name='平空', operate=Operate.SE, factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向下收敛_任意_任意_0"), + ]) + ]), + ] + + pos = Position(symbol=bg.symbol, opens=opens, exits=exits, interval=0, timeout=20, stop_loss=100) + + cs = CzscSignals(bg, get_signals=__get_signals) + for bar in bars[1000:]: + cs.update_signals(bar) + pos.update(cs.s) + + df = pd.DataFrame(pos.pairs) + assert df.shape == (244, 10) + assert len(cs.s) == 16 + + +def test_czsc_trader(): + bars = read_daily() + bg = BarGenerator(base_freq='日线', freqs=['周线', '月线']) + for bar in bars[:1000]: + bg.update(bar) + + def __get_signals(cat) -> OrderedDict: + s = OrderedDict({"symbol": cat.symbol, "dt": cat.end_dt, "close": cat.latest_price}) + s.update(signals.bxt.get_s_three_bi(cat.kas['日线'], di=1)) + s.update(signals.cxt_first_buy_V221126(cat.kas['日线'], di=1)) + s.update(signals.cxt_first_buy_V221126(cat.kas['日线'], di=2)) + s.update(signals.cxt_first_sell_V221126(cat.kas['日线'], di=1)) + s.update(signals.cxt_first_sell_V221126(cat.kas['日线'], di=2)) + return s + + def __create_sma5_pos(): + opens = [ + Event(name='开多', operate=Operate.LO, factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("日线_D1B_BUY1_一买_任意_任意_0"), + ]) + ]), + Event(name='开空', operate=Operate.SO, factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("日线_D1B_BUY1_一卖_任意_任意_0"), + ]) + ]), + ] + + exits = [ + Event(name='平多', operate=Operate.LE, factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向上收敛_任意_任意_0"), + ]) + ]), + Event(name='平空', operate=Operate.SE, factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向下收敛_任意_任意_0"), + ]) + ]), + ] + + pos = Position(symbol=bg.symbol, opens=opens, exits=exits, interval=0, timeout=20, stop_loss=100) + + return pos + + def __create_sma10_pos(): + opens = [ + Event(name='开多', operate=Operate.LO, factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向下无背_任意_任意_0"), + ]) + ]), + Event(name='开空', operate=Operate.SO, factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向上无背_任意_任意_0"), + ]) + ]), + ] + + exits = [ + Event(name='平多', operate=Operate.LE, factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向上收敛_任意_任意_0"), + ]) + ]), + Event(name='平空', operate=Operate.SE, factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向下收敛_任意_任意_0"), + ]) + ]), + ] + + pos = Position(symbol=bg.symbol, opens=opens, exits=exits, interval=0, timeout=20, stop_loss=100) + return pos + + def __create_sma20_pos(): + opens = [ + Event(name='开多', operate=Operate.LO, factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("日线_D2B_BUY1_一买_任意_任意_0"), + ]) + ]), + Event(name='开空', operate=Operate.SO, factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("日线_D2B_BUY1_一卖_任意_任意_0"), + ]) + ]), + ] + + exits = [ + Event(name='平多', operate=Operate.LE, factors=[ + Factor(name="跌破SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向上收敛_任意_任意_0"), + ]) + ]), + Event(name='平空', operate=Operate.SE, factors=[ + Factor(name="站上SMA5", signals_all=[ + Signal("日线_倒1笔_三笔形态_向下收敛_任意_任意_0"), + ]) + ]), + ] + + pos = Position(symbol=bg.symbol, opens=opens, exits=exits, interval=0, timeout=20, stop_loss=100) + return pos + + ct = CzscTrader(bg, get_signals=__get_signals, + positions=[__create_sma5_pos(), __create_sma10_pos(), __create_sma20_pos()]) + for bar in bars[1000:]: + ct.update(bar) + print(f"{bar.dt}: pos_seq = {[x.pos for x in ct.positions]}mean_pos = {ct.get_ensemble_pos('mean')}; vote_pos = {ct.get_ensemble_pos('vote')}; max_pos = {ct.get_ensemble_pos('max')}") + + assert [x.pos for x in ct.positions] == [0, -1, 0] + + +def get_signals(cat) -> OrderedDict: + s = OrderedDict({"symbol": cat.symbol, "dt": cat.end_dt, "close": cat.latest_price}) + for _, c in cat.kas.items(): + s.update(signals.bxt.get_s_like_bs(c, di=1)) + + if isinstance(cat, CzscAdvancedTrader) and cat.long_pos: + s.update(signals.cat.get_s_position(cat, cat.long_pos)) + if isinstance(cat, CzscAdvancedTrader) and cat.short_pos: + s.update(signals.cat.get_s_position(cat, cat.short_pos)) + return s + + +def test_czsc_signals(): + bars = read_daily() + bg = BarGenerator(base_freq='日线', freqs=['周线', '月线']) + for bar in bars[:1000]: + bg.update(bar) + + cs = CzscSignals(bg, get_signals=get_signals) + for bar in bars[1000:]: + cs.update_signals(bar) + assert len(cs.s) == 14 + + +def trader_strategy_test(symbol, T0=False): + """A股市场择时策略样例,支持按交易标的独立设置参数 + + :param symbol: + :param T0: 是否允许T0交易 + :return: + """ + long_events = [ + Event(name="开多", operate=Operate.LO, factors=[ + Factor(name="5分钟一买", signals_all=[Signal("5分钟_倒1笔_类买卖点_类一买_任意_任意_0")]), + Factor(name="1分钟一买", signals_all=[Signal("1分钟_倒1笔_类买卖点_类一买_任意_任意_0")]), + ]), + + Event(name="加多1", operate=Operate.LA1, factors=[ + Factor(name="5分钟二买", signals_all=[Signal("5分钟_倒1笔_类买卖点_类二买_任意_任意_0")]), + Factor(name="1分钟二买", signals_all=[Signal("1分钟_倒1笔_类买卖点_类二买_任意_任意_0")]), + ]), + + Event(name="加多2", operate=Operate.LA1, factors=[ + Factor(name="5分钟三买", signals_all=[Signal("5分钟_倒1笔_类买卖点_类三买_任意_任意_0")]), + Factor(name="1分钟三买", signals_all=[Signal("1分钟_倒1笔_类买卖点_类三买_任意_任意_0")]), + ]), + + Event(name="平多", operate=Operate.LE, factors=[ + Factor(name="5分钟二卖", signals_all=[Signal("5分钟_倒1笔_类买卖点_类二卖_任意_任意_0")]), + Factor(name="5分钟三卖", signals_all=[Signal("5分钟_倒1笔_类买卖点_类三卖_任意_任意_0")]) + ]), + ] + long_pos = PositionLong(symbol=symbol, hold_long_a=0.5, hold_long_b=0.8, hold_long_c=1, T0=T0) + short_events = [ + Event(name="开空", operate=Operate.SO, factors=[ + Factor(name="5分钟一买", signals_all=[Signal("5分钟_倒1笔_类买卖点_类一买_任意_任意_0")]), + Factor(name="1分钟一买", signals_all=[Signal("1分钟_倒1笔_类买卖点_类一买_任意_任意_0")]), + ]), + + Event(name="加空1", operate=Operate.SA1, factors=[ + Factor(name="5分钟二买", signals_all=[Signal("5分钟_倒1笔_类买卖点_类二买_任意_任意_0")]), + Factor(name="1分钟二买", signals_all=[Signal("1分钟_倒1笔_类买卖点_类二买_任意_任意_0")]), + ]), + + Event(name="加空2", operate=Operate.SA1, factors=[ + Factor(name="5分钟三买", signals_all=[Signal("5分钟_倒1笔_类买卖点_类三买_任意_任意_0")]), + Factor(name="1分钟三买", signals_all=[Signal("1分钟_倒1笔_类买卖点_类三买_任意_任意_0")]), + ]), + + Event(name="平空", operate=Operate.SE, factors=[ + Factor(name="5分钟二卖", signals_all=[Signal("5分钟_倒1笔_类买卖点_类二卖_任意_任意_0")]), + Factor(name="5分钟三卖", signals_all=[Signal("5分钟_倒1笔_类买卖点_类三卖_任意_任意_0")]) + ]), + ] + short_pos = PositionShort(symbol=symbol, hold_short_a=0.5, hold_short_b=0.8, hold_short_c=1, T0=T0) + + tactic = { + "base_freq": '1分钟', + "freqs": ['5分钟', '15分钟', '30分钟', '60分钟', '日线'], + "get_signals": get_signals, + "signals_n": 0, + + "long_pos": long_pos, + "long_events": long_events, + + # 空头策略不进行定义,也就是不做空头交易 + "short_pos": short_pos, + "short_events": short_events, + } + + return tactic + + +def test_daily_trader(): + bars = read_daily() + kg = BarGenerator(base_freq='日线', freqs=['周线', '月线']) + for bar in bars[:1000]: + kg.update(bar) + + def __trader_strategy(symbol): + tactic = { + "base_freq": '1分钟', + "freqs": ['5分钟', '15分钟', '30分钟', '60分钟', '日线'], + "get_signals": get_signals, + + "long_pos": None, + "long_events": None, + + # 空头策略不进行定义,也就是不做空头交易 + "short_pos": None, + "short_events": None, + } + + return tactic + ct = CzscAdvancedTrader(kg, __trader_strategy) + + signals_ = [] + for bar in bars[1000:]: + ct.update(bar) + signals_.append(dict(ct.s)) + + assert len(signals_) == 2332 + + # 测试传入空策略 + ct = CzscAdvancedTrader(kg) + assert len(ct.s) == 0 and len(ct.kas) == 3 + + +def run_advanced_trader(T0=True): + bars = read_1min() + kg = BarGenerator(base_freq='1分钟', freqs=['5分钟', '15分钟', '30分钟', '60分钟', '日线'], max_count=3000) + for row in tqdm(bars[:150000], desc='init kg'): + kg.update(row) + + def _strategy(symbol): + return trader_strategy_test(symbol, T0=T0) + ct = CzscAdvancedTrader(kg, _strategy) + + assert len(ct.s) == 29 + for row in tqdm(bars[150000:], desc="trade"): + ct.update(row) + # if long_pos.pos_changed: + # print(" : long op : ", long_pos.operates[-1]) + # if short_pos.pos_changed: + # print(" : short op : ", short_pos.operates[-1]) + + if ct.long_pos.pos > 0: + assert ct.long_pos.long_high > 0 + assert ct.long_pos.long_cost > 0 + assert ct.long_pos.long_bid > 0 + + if ct.short_pos.pos > 0: + assert ct.short_pos.short_low > 0 + assert ct.short_pos.short_cost > 0 + assert ct.short_pos.short_bid > 0 + + long_yk = pd.DataFrame(ct.long_pos.pairs)['盈亏比例'].sum() + short_yk = pd.DataFrame(ct.short_pos.pairs)['盈亏比例'].sum() + assert abs(long_yk) == abs(short_yk) + print(f"\nT0={T0}: 多头累计盈亏比例:{long_yk};空头累计盈亏比例:{short_yk}") + if not T0: + assert ct.s['多头_最大_盈利'] == '超过800BP_任意_任意_0' + assert ct.s['多头_累计_盈亏'] == '盈利_超过800BP_任意_0' + assert ct.s['空头_最大_回撤'] == '超过800BP_任意_任意_0' + assert ct.s['空头_累计_盈亏'] == '亏损_超过800BP_任意_0' + + holds_long = pd.DataFrame(ct.long_holds) + assert round(holds_long['long_pos'].mean(), 4) == 0.7376 + + holds_short = pd.DataFrame(ct.short_holds) + assert round(holds_short['short_pos'].mean(), 4) == 0.7376 + + +def test_advanced_trader(): + run_advanced_trader(T0=False) + run_advanced_trader(T0=True) + +