diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 37e947695..f5da163ab 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -5,7 +5,7 @@ name: Python package on: push: - branches: [ master, V0.9.16 ] + branches: [ master, V0.9.17 ] pull_request: branches: [ master ] @@ -15,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.7, 3.8, 3.9, 3.10.10, 3.11.2] + python-version: [3.7, 3.8, 3.9, 3.10.11, 3.11.3] steps: - uses: actions/checkout@v2 diff --git a/czsc/__init__.py b/czsc/__init__.py index 965fcedb8..d31146e3a 100644 --- a/czsc/__init__.py +++ b/czsc/__init__.py @@ -20,14 +20,14 @@ from czsc.strategies import CzscStrategyBase from czsc.utils import KlineChart, BarGenerator, resample_bars, dill_dump, dill_load, read_json, save_json from czsc.utils import get_sub_elements, get_py_namespace, freqs_sorted, x_round, import_by_name, create_grid_params -from czsc.utils import cal_trade_price, cross_sectional_ic +from czsc.utils import cal_trade_price, cross_sectional_ic, update_bbars, update_tbars, update_nbars from czsc.sensors import holds_concepts_effect, StocksDaySensor, ThsConceptsSensor, SignalsPerformance -__version__ = "0.9.16" +__version__ = "0.9.17" __author__ = "zengbin93" __email__ = "zeng_bin8888@163.com" -__date__ = "20230404" +__date__ = "20230415" def welcome(): diff --git a/czsc/analyze.py b/czsc/analyze.py index 742d791d6..ee3ef0c3c 100644 --- a/czsc/analyze.py +++ b/czsc/analyze.py @@ -153,8 +153,8 @@ def check_bi(bars: List[NewBar], benchmark: float = None): fxs_ = [x for x in fxs if fx_a.elements[0].dt <= x.dt <= fx_b.elements[2].dt] bi = BI(symbol=fx_a.symbol, fx_a=fx_a, fx_b=fx_b, fxs=fxs_, direction=direction, bars=bars_a) - low_ubi = min([x.low for x in bars_b]) - high_ubi = max([x.high for x in bars_b]) + low_ubi = min([x.low for y in bars_b for x in y.raw_bars]) + high_ubi = max([x.high for y in bars_b for x in y.raw_bars]) if (bi.direction == Direction.Up and high_ubi > bi.high) \ or (bi.direction == Direction.Down and low_ubi < bi.low): return None, bars @@ -379,6 +379,32 @@ def ubi_fxs(self) -> List[FX]: else: return check_fxs(self.bars_ubi) + @property + def ubi(self): + """Unfinished Bi,未完成的笔""" + if not self.bars_ubi or not self.bi_list: + return None + + bars_raw = [y for x in self.bars_ubi for y in x.raw_bars] + # 获取最高点和最低点,以及对应的时间 + high_bar = max(bars_raw, key=lambda x: x.high) + low_bar = min(bars_raw, key=lambda x: x.low) + direction = Direction.Up if self.bi_list[-1].direction == Direction.Down else Direction.Down + + bi = { + "symbol": self.symbol, + "direction": direction, + "high": high_bar.high, + "low": low_bar.low, + "high_bar": high_bar, + "low_bar": low_bar, + "bars": self.bars_ubi, + "raw_bars": bars_raw, + "fxs": self.ubi_fxs, + "fx_a": self.ubi_fxs[0], + } + return bi + @property def fx_list(self) -> List[FX]: """分型列表,包括 bars_ubi 中的分型""" @@ -387,6 +413,6 @@ def fx_list(self) -> List[FX]: fxs.extend(bi_.fxs[1:]) ubi = self.ubi_fxs for x in ubi: - if not fxs or x.dt > fxs[-1].dt: + if not fxs or x.dt > fxs[-1].raw_bars[-1].dt: fxs.append(x) return fxs diff --git a/czsc/sensors/utils.py b/czsc/sensors/utils.py index f87fce474..f842a2436 100644 --- a/czsc/sensors/utils.py +++ b/czsc/sensors/utils.py @@ -175,6 +175,9 @@ def __init__(self, dfs: pd.DataFrame, keys: List[AnyStr]): :param dfs: 信号表 :param keys: 信号列,支持一个或多个信号列组合分析 """ + base_cols = [x for x in dfs.columns if len(x.split("_")) != 3] + dfs = dfs[base_cols + keys].copy() + if 'year' not in dfs.columns: y = dfs['dt'].apply(lambda x: x.year) dfs['year'] = y.values diff --git a/czsc/signals/__init__.py b/czsc/signals/__init__.py index 7fe028904..500c50727 100644 --- a/czsc/signals/__init__.py +++ b/czsc/signals/__init__.py @@ -78,6 +78,7 @@ bar_time_V230327, bar_weekday_V230328, bar_r_breaker_V230326, + bar_dual_thrust_V230403, ) from czsc.signals.jcc import ( @@ -110,6 +111,7 @@ update_kdj_cache, update_boll_cache, update_rsi_cache, + update_cci_cache, tas_macd_base_V221028, tas_macd_change_V221105, @@ -146,6 +148,10 @@ tas_second_bs_V230303, tas_hlma_V230301, + tas_cci_base_V230402, + tas_kdj_evc_V230401 ) - +from czsc.signals.pos import ( + pos_fx_stop_V230414, +) diff --git a/czsc/signals/bar.py b/czsc/signals/bar.py index 1b2312161..80f25d89f 100644 --- a/czsc/signals/bar.py +++ b/czsc/signals/bar.py @@ -939,3 +939,62 @@ def bar_r_breaker_V230326(c: CZSC, **kwargs): v2 = '其他' return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1, v2=v2) + + +def bar_dual_thrust_V230403(c: CZSC, **kwargs): + """Dual Thrust 通道突破 + + 参数模板:"{freq}_D{di}通道突破#{N}#{K1}#{K2}_BS辅助V230403" + + **信号逻辑:** + + 参见:https://www.myquant.cn/docs/python_strategyies/424 + + 其核心思想是定义一个区间,区间的上界和下界分别为支撑线和阻力线。当价格超过上界时,看多,跌破下界,看空。 + + **信号列表:** + + - Signal('日线_D1通道突破#5#20#20_BS辅助V230403_看空_任意_任意_0') + - Signal('日线_D1通道突破#5#20#20_BS辅助V230403_看多_任意_任意_0') + + :param c: 基础周期的 CZSC 对象 + :param kwargs: 其他参数 + - di: 倒数第 di 根 K 线 + - N: 前N天的数据 + - K1: 参数,根据经验优化 + - K2: 参数,根据经验优化 + :return: 信号字典 + """ + di = int(kwargs.get('di', 1)) + N = int(kwargs.get('N', 5)) + K1 = int(kwargs.get('K1', 20)) + K2 = int(kwargs.get('K2', 20)) + + freq = c.freq.value + k1, k2, k3 = f"{freq}_D{di}通道突破#{N}#{K1}#{K2}_BS辅助V230403".split('_') + if len(c.bars_raw) < 3: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1='其他') + + bars = get_sub_elements(c.bars_raw, di=di+1, n=N+1) + HH = max([i.high for i in bars]) + HC = max([i.close for i in bars]) + LC = min([i.close for i in bars]) + LL = min([i.low for i in bars]) + Range = max(HH - LC, HC - LL) + + current_bar = c.bars_raw[-di] + buy_line = current_bar.open + Range * K1 / 100 # 上轨 + sell_line = current_bar.open - Range * K2 / 100 # 下轨 + + # 根据价格位置判断信号 + if current_bar.close > buy_line: + v1 = '看多' + elif current_bar.close < sell_line: + v1 = '看空' + else: + v1 = '其他' + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + + diff --git a/czsc/signals/byi.py b/czsc/signals/byi.py index 4d7c1e0da..0bc4cda4f 100644 --- a/czsc/signals/byi.py +++ b/czsc/signals/byi.py @@ -165,7 +165,7 @@ def byi_bi_end_V230107(c: CZSC, **kwargs) -> OrderedDict: return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) -def byi_second_bs_V230324(c: CZSC, di=1, **kwargs) -> OrderedDict: +def byi_second_bs_V230324(c: CZSC, **kwargs) -> OrderedDict: """白仪二类买卖点辅助V230324 参数模板:"{freq}_D{di}MACD{fastperiod}#{slowperiod}#{signalperiod}回抽零轴_BS2辅助V230324" diff --git a/czsc/signals/pos.py b/czsc/signals/pos.py new file mode 100644 index 000000000..999a5f2c1 --- /dev/null +++ b/czsc/signals/pos.py @@ -0,0 +1,67 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2023/4/14 19:27 +describe: +""" +from collections import OrderedDict +from czsc.traders.base import CzscTrader +from czsc.utils import create_single_signal +from czsc.objects import Operate, Direction, Mark + + +def pos_fx_stop_V230414(cat: CzscTrader, **kwargs) -> OrderedDict: + """按照开仓点附近的分型止损 + + 参数模板:"{freq1}_{pos_name}N{n}_止损V230414" + + **信号逻辑:** + + 多头止损逻辑如下,反之为空头止损逻辑: + + 1. 从多头开仓点开始,在给定对的K线周期 freq1 上向前找 N 个底分型,记为 F1 + 2. 将这 N 个底分型的最低点,记为 L1,如果 L1 的价格低于开仓点的价格,则止损 + + **信号列表:** + + - Signal('日线_日线三买多头N1_止损V230414_多头止损_任意_任意_0') + - Signal('日线_日线三买多头N1_止损V230414_空头止损_任意_任意_0') + + :param cat: CzscTrader对象 + :param kwargs: 参数字典 + - pos_name: str,开仓信号的名称 + - freq1: str,给定的K线周期 + - n: int,向前找的分型个数,默认为 3 + :return: + """ + pos_name = kwargs["pos_name"] + freq1 = kwargs["freq1"] + n = int(kwargs.get('n', 3)) + k1, k2, k3 = f"{freq1}_{pos_name}N{n}_止损V230414".split("_") + v1 = '其他' + + # 如果没有持仓策略,则不产生信号 + if not hasattr(cat, "positions"): + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + pos = [x for x in cat.positions if x.name == pos_name][0] + if len(pos.operates) == 0 or pos.operates[-1]['op'] in [Operate.SE, Operate.LE]: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + c = cat.kas[freq1] + op = pos.operates[-1] + + # 多头止损逻辑 + if op['op'] == Operate.LO: + fxs = [x for x in c.fx_list if x.mark == Mark.D and x.dt < op['dt']][-n:] + if cat.latest_price < min([x.low for x in fxs]): + v1 = '多头止损' + + # 空头止损逻辑 + if op['op'] == Operate.SO: + fxs = [x for x in c.fx_list if x.mark == Mark.G and x.dt < op['dt']][-n:] + if cat.latest_price > max([x.high for x in fxs]): + v1 = '空头止损' + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) \ No newline at end of file diff --git a/czsc/signals/tas.py b/czsc/signals/tas.py index e6a12ca85..94b7e0001 100644 --- a/czsc/signals/tas.py +++ b/czsc/signals/tas.py @@ -81,7 +81,7 @@ def update_macd_cache(c: CZSC, **kwargs): # 如果最后一根K线已经有对应的缓存,不执行更新 return cache_key - min_count = signalperiod + slowperiod + min_count = signalperiod + slowperiod + 168 last_cache = dict(c.bars_raw[-2].cache) if c.bars_raw[-2].cache else dict() if cache_key not in last_cache.keys() or len(c.bars_raw) < min_count + 15: # 初始化缓存 @@ -103,7 +103,6 @@ def update_macd_cache(c: CZSC, **kwargs): _c = dict(c.bars_raw[-i].cache) if c.bars_raw[-i].cache else dict() _c.update({cache_key: {'dif': dif[-i], 'dea': dea[-i], 'macd': macd[-i]}}) c.bars_raw[-i].cache = _c - return cache_key @@ -1787,3 +1786,151 @@ def tas_macd_base_V230320(c: CZSC, **kwargs) -> OrderedDict: return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1, v2=v2) +def update_cci_cache(c: CZSC, **kwargs): + """更新CCI缓存 + + CCI = (TP - MA) / MD / 0.015; 其中, + + - TP=(最高价+最低价+收盘价)÷3; + - MA=最近N日收盘价的累计之和÷N; + - MD=最近N日(MA-收盘价)的累计之和÷N; + - 0.015为计算系数,N为计算周期 + + :param c: CZSC对象 + :return: + """ + timeperiod = int(kwargs.get('timeperiod', 14)) + cache_key = f"CCI{timeperiod}" + if c.bars_raw[-1].cache and c.bars_raw[-1].cache.get(cache_key, None): + # 如果最后一根K线已经有对应的缓存,不执行更新 + return cache_key + + last_cache = dict(c.bars_raw[-2].cache) if c.bars_raw[-2].cache else dict() + if cache_key not in last_cache.keys() or len(c.bars_raw) < timeperiod + 15: + # 初始化缓存 + bars = c.bars_raw + else: + # 增量更新最近5个K线缓存 + bars = c.bars_raw[-timeperiod - 10:] + + high = np.array([x.high for x in bars]) + low = np.array([x.low for x in bars]) + close = np.array([x.close for x in bars]) + cci = ta.CCI(high, low, close, timeperiod=timeperiod) + + for i in range(len(bars)): + _c = dict(bars[i].cache) if bars[i].cache else dict() + if cache_key not in _c.keys(): + _c.update({cache_key: cci[i] if cci[i] else 0}) + bars[i].cache = _c + + return cache_key + + +def tas_cci_base_V230402(c: CZSC, **kwargs) -> OrderedDict: + """CCI基础信号 + + 参数模板:"{freq}_D{di}CCI{timeperiod}#{min_count}#{max_count}_BS辅助V230402" + + **信号逻辑:** + + 1. CCI连续大于100的次数大于 min_count, 小于max_count,看空;反之,看多。 + + **信号列表:** + + - Signal('60分钟_D1CCI14#3#6_BS辅助V230402_空头_任意_任意_0') + - Signal('60分钟_D1CCI14#3#6_BS辅助V230402_多头_任意_任意_0') + + :param c: CZSC对象 + :param kwargs: 参数字典 + - di: int, 默认1,倒数第几根K线 + - timeperiod: int, 默认14,计算CCI的周期 + - min_count: int, 默认3,CCI连续大于100的次数 + - max_count: int, 默认min_count+3,CCI连续大于100的次数 + :return: 信号识别结果 + """ + di = int(kwargs.get('di', 1)) + timeperiod = int(kwargs.get('timeperiod', 14)) + min_count = int(kwargs.get('min_count', 3)) + max_count = int(kwargs.get('max_count', min_count + 3)) + assert min_count < max_count, "min_count 必须小于 max_count" + freq = c.freq.value + k1, k2, k3 = f"{freq}_D{di}CCI{timeperiod}#{min_count}#{max_count}_BS辅助V230402".split("_") + v1 = "其他" + + cache_key = update_cci_cache(c, timeperiod=timeperiod) + bars = get_sub_elements(c.bars_raw, di=di, n=max_count + 1) + if len(bars) != max_count + 1: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + cci = [x.cache[cache_key] for x in bars] + + long = [x > 100 for x in cci] + short = [x < -100 for x in cci] + lc = count_last_same(long) if long[-1] else 0 + sc = count_last_same(short) if short[-1] else 0 + + if max_count > lc >= min_count: + v1 = "多头" + if max_count > sc >= min_count: + v1 = "空头" + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + +def tas_kdj_evc_V230401(c: CZSC, **kwargs) -> OrderedDict: + """KDJ极值计数信号, evc 是 extreme value counts 的首字母缩写 + + 参数模板:"{freq}_D{di}T{th}KDJ{fastk_period}#{slowk_period}#{slowd_period}#{key}值突破{min_count}#{max_count}_BS辅助V230401" + + **信号逻辑:** + + 1. K < th,记录一次多头信号,连续出现信号次数在 count_range 范围,则认为是有效多头信号; + 2. K > 100 - th, 记录一次空头信号,连续出现信号次数在 count_range 范围,则认为是有效空头信号 + + **信号列表:** + + - Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_BS辅助V230401_空头_任意_任意_0') + - Signal('60分钟_D1T10KDJ9#3#3#K值突破5#8_BS辅助V230401_多头_任意_任意_0') + + :param c: CZSC对象 + :param kwargs: 参数字典 + - di: 信号计算截止倒数第i根K线 + - key: KDJ 值的名称,可以是 K, D, J + - th: 信号计算截止倒数第i根K线 + - min_count: 连续出现信号次数的最小值 + - max_count: 连续出现信号次数的最大值 + :return: + """ + di = int(kwargs.get("di", 1)) + key = kwargs.get("key", "K") + th = int(kwargs.get("th", 10)) + min_count = int(kwargs.get("min_count", 5)) + max_count = int(kwargs.get("max_count", min_count + 3)) + freq = c.freq.value + key = key.upper() + assert min_count < max_count, "min_count 必须小于 max_count" + assert key in ['K', 'D', 'J'], "key 必须是 K, D, J 中的一个" + assert 0 < th < 100, "th 必须在 0 到 100 之间" + cache_key = update_kdj_cache(c, **kwargs) + + k1, k2, k3 = f"{freq}_D{di}T{th}{cache_key}#{key}值突破{min_count}#{max_count}_BS辅助V230401".split( + '_') + v1 = "其他" + if len(c.bars_raw) < di + max_count + 2: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + bars = get_sub_elements(c.bars_raw, di=di, n=3 + max_count) + key = key.lower() + long = [x.cache[cache_key][key] < th for x in bars] + short = [x.cache[cache_key][key] > 100 - th for x in bars] + lc = count_last_same(long) if long[-1] else 0 + sc = count_last_same(short) if short[-1] else 0 + + if max_count > lc >= min_count: + v1 = "多头" + + if max_count > sc >= min_count: + assert v1 == '其他' + v1 = "空头" + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) diff --git a/czsc/strategies.py b/czsc/strategies.py index 1ab7483bb..6ddae9841 100644 --- a/czsc/strategies.py +++ b/czsc/strategies.py @@ -13,6 +13,7 @@ import pandas as pd from tqdm import tqdm from copy import deepcopy +from datetime import timedelta from abc import ABC, abstractmethod from loguru import logger from czsc import signals @@ -204,6 +205,76 @@ def replay(self, bars: List[RawBar], res_path, **kwargs): logger.error(f"交易对象保存失败:{e};通常的原因是交易对象中包含了不支持序列化的对象,比如函数") return trader + def check(self, bars: List[RawBar], res_path, **kwargs): + """检查交易策略中的信号是否正确 + + :param bars: 基础周期K线 + :param res_path: 结果目录 + :param kwargs: + bg 已经初始化好的BarGenerator对象,如果传入了bg,则忽略sdt和n参数 + sdt 初始化开始日期 + n 初始化最小K线数量 + :return: + """ + if kwargs.get('refresh', False): + shutil.rmtree(res_path, ignore_errors=True) + + exist_ok = kwargs.get("exist_ok", False) + if os.path.exists(res_path) and not exist_ok: + logger.warning(f"结果文件夹存在且不允许覆盖:{res_path},如需执行,请先删除文件夹") + return + os.makedirs(res_path, exist_ok=exist_ok) + + # 第一遍执行,获取信号 + bg, bars2 = self.init_bar_generator(bars, **kwargs) + trader = CzscTrader(bg=bg, positions=deepcopy(self.positions), + signals_config=deepcopy(self.signals_config), **kwargs) + + _signals = [] + for bar in bars2: + trader.on_bar(bar) + _signals.append(trader.s) + + for position in trader.positions: + print(f"{position.name}: {position.evaluate()}") + + df = pd.DataFrame(_signals) + df.to_excel(os.path.join(res_path, "signals.xlsx"), index=False) + unique_signals = {} + for col in [x for x in df.columns if len(x.split("_")) == 3]: + unique_signals[col] = [Signal(f"{col}_{v}") for v in df[col].unique() if "其他" not in v] + + print('\n', "+" * 100) + for key, values in unique_signals.items(): + print(f"\n{key}:") + for value in values: + print(f"- {value}") + print('\n', "+" * 100) + + # 第二遍执行,检查信号,生成html + bg, bars2 = self.init_bar_generator(bars, **kwargs) + trader = CzscTrader(bg=bg, positions=deepcopy(self.positions), + signals_config=deepcopy(self.signals_config), **kwargs) + + # 记录每个信号最后一次出现的时间 + last_sig_dt = {y.key: trader.end_dt for x in unique_signals.values() for y in x} + delta_days = kwargs.get("delta_days", 1) + + for bar in bars2: + trader.on_bar(bar) + + for key, values in unique_signals.items(): + html_path = os.path.join(res_path, key) + os.makedirs(html_path, exist_ok=True) + + for signal in values: + if bar.dt - last_sig_dt[signal.key] > timedelta(days=delta_days) and signal.is_match(trader.s): + file_html = f"{bar.dt.strftime('%Y%m%d_%H%M')}_{signal.signal}.html" + file_html = os.path.join(html_path, file_html) + print(file_html) + trader.take_snapshot(file_html, height=kwargs.get("height", "680px")) + last_sig_dt[signal.key] = bar.dt + class CzscStrategyExample2(CzscStrategyBase): """仅传入Positions就完成策略创建""" diff --git a/czsc/traders/base.py b/czsc/traders/base.py index 7b7298ba3..a9c781539 100644 --- a/czsc/traders/base.py +++ b/czsc/traders/base.py @@ -50,9 +50,9 @@ def __init__(self, bg: BarGenerator = None, **kwargs): last_bar = self.kas[self.base_freq].bars_raw[-1] self.end_dt, self.bid, self.latest_price = last_bar.dt, last_bar.id, last_bar.close - self.s = OrderedDict(last_bar.__dict__) + self.s = OrderedDict() self.s.update(self.get_signals_by_conf()) - + self.s.update(last_bar.__dict__) else: self.bg = None self.symbol = None @@ -147,8 +147,9 @@ def update_signals(self, bar: RawBar): self.symbol = bar.symbol last_bar = self.kas[self.base_freq].bars_raw[-1] self.end_dt, self.bid, self.latest_price = last_bar.dt, last_bar.id, last_bar.close - self.s = OrderedDict(last_bar.__dict__) + self.s = OrderedDict() self.s.update(self.get_signals_by_conf()) + self.s.update(last_bar.__dict__) @deprecated(version="0.9.16", reason="请使用 CzscSignals 类") @@ -271,10 +272,10 @@ def check_signals_acc(bars: List[RawBar], signals_config: List[dict], delta_days html_path = os.path.join(home_path, signal.key) os.makedirs(html_path, exist_ok=True) if bar.dt - last_dt[signal.key] > timedelta(days=delta_days) and signal.is_match(ct.s): - file_html = f"{bar.symbol}_{signal.key}_{ct.s[signal.key]}_{bar.dt.strftime('%Y%m%d_%H%M')}.html" + file_html = f"{bar.dt.strftime('%Y%m%d_%H%M')}_{signal.key}_{ct.s[signal.key]}.html" file_html = os.path.join(html_path, file_html) print(file_html) - ct.take_snapshot(file_html) + ct.take_snapshot(file_html, height=kwargs.get("height", "680px")) last_dt[signal.key] = bar.dt @@ -286,7 +287,6 @@ def get_unique_signals(bars: List[RawBar], signals_config: List[dict], **kwargs) :param kwargs: :return: """ - base_freq = str(bars[-1].freq.value) assert bars[2].dt > bars[1].dt > bars[0].dt and bars[2].id > bars[1].id, "bars 中的K线元素必须按时间升序" if len(bars) < 600: return [] diff --git a/czsc/utils/__init__.py b/czsc/utils/__init__.py index 2e29a0feb..cbd89de9d 100644 --- a/czsc/utils/__init__.py +++ b/czsc/utils/__init__.py @@ -14,7 +14,7 @@ from .sig import check_pressure_support, check_gap_info, is_bis_down, is_bis_up, get_sub_elements from .sig import same_dir_counts, fast_slow_cross, count_last_same, create_single_signal from .plotly_plot import KlineChart -from .trade import cal_trade_price +from .trade import cal_trade_price, update_nbars, update_bbars, update_tbars sorted_freqs = ['Tick', '1分钟', '5分钟', '15分钟', '30分钟', '60分钟', '日线', '周线', '月线', '季线', '年线'] diff --git a/czsc/utils/echarts_plot.py b/czsc/utils/echarts_plot.py index 2d09418e0..ab9d8097e 100644 --- a/czsc/utils/echarts_plot.py +++ b/czsc/utils/echarts_plot.py @@ -120,12 +120,14 @@ def kline_pro(kline: List[dict], dz_slider = opts.DataZoomOpts(True, "slider", xaxis_index=[0, 1, 2], pos_top="96%", pos_bottom="0%", range_start=80, range_end=100) - yaxis_opts = opts.AxisOpts(is_scale=True, + yaxis_opts = opts.AxisOpts(is_scale=True, min_="dataMin", max_="dataMax", + splitline_opts=opts.SplitLineOpts(is_show=False), axislabel_opts=opts.LabelOpts(color="#c7c7c7", font_size=8, position="inside")) grid0_xaxis_opts = opts.AxisOpts(type_="category", grid_index=0, axislabel_opts=label_not_show_opts, split_number=20, min_="dataMin", max_="dataMax", is_scale=True, boundary_gap=False, + splitline_opts=opts.SplitLineOpts(is_show=False), axisline_opts=opts.AxisLineOpts(is_on_zero=False)) tool_tip_opts = opts.TooltipOpts( @@ -281,16 +283,15 @@ def kline_pro(kline: List[dict], chart_ma = Line() chart_ma.add_xaxis(xaxis_data=dts) if not t_seq: - t_seq = [5, 13, 21, 34, 55, 89, 144, 233] + t_seq = [5, 13, 21] ma_keys = dict() for t in t_seq: ma_keys[f"MA{t}"] = SMA(close, timeperiod=t) for i, (name, ma) in enumerate(ma_keys.items()): - is_selected = True if i < 4 else False chart_ma.add_yaxis(series_name=name, y_axis=ma, is_smooth=True, - is_selected=is_selected, symbol_size=0, label_opts=label_not_show_opts, + symbol_size=0, label_opts=label_not_show_opts, linestyle_opts=opts.LineStyleOpts(opacity=0.8, width=1)) chart_ma.set_global_opts(xaxis_opts=grid0_xaxis_opts, legend_opts=legend_not_show_opts) @@ -303,7 +304,7 @@ def kline_pro(kline: List[dict], fx_val = [round(x['fx'], 2) for x in fx] chart_fx = Line() chart_fx.add_xaxis(fx_dts) - chart_fx.add_yaxis(series_name="FX", y_axis=fx_val, is_selected=False, + chart_fx.add_yaxis(series_name="FX", y_axis=fx_val, symbol="circle", symbol_size=6, label_opts=label_show_opts, itemstyle_opts=opts.ItemStyleOpts(color="rgba(152, 147, 193, 1.0)", )) @@ -315,7 +316,7 @@ def kline_pro(kline: List[dict], bi_val = [round(x['bi'], 2) for x in bi] chart_bi = Line() chart_bi.add_xaxis(bi_dts) - chart_bi.add_yaxis(series_name="BI", y_axis=bi_val, is_selected=True, + chart_bi.add_yaxis(series_name="BI", y_axis=bi_val, symbol="diamond", symbol_size=10, label_opts=label_show_opts, itemstyle_opts=opts.ItemStyleOpts(color="rgba(184, 117, 225, 1.0)", ), linestyle_opts=opts.LineStyleOpts(width=1.5)) @@ -328,7 +329,8 @@ def kline_pro(kline: List[dict], xd_val = [x['xd'] for x in xd] chart_xd = Line() chart_xd.add_xaxis(xd_dts) - chart_xd.add_yaxis(series_name="XD", y_axis=xd_val, is_selected=True, symbol="triangle", symbol_size=10, + chart_xd.add_yaxis(series_name="XD", y_axis=xd_val, + symbol="triangle", symbol_size=10, itemstyle_opts=opts.ItemStyleOpts(color="rgba(37, 141, 54, 1.0)", )) chart_xd.set_global_opts(xaxis_opts=grid0_xaxis_opts, legend_opts=legend_not_show_opts) @@ -343,6 +345,7 @@ def kline_pro(kline: List[dict], xaxis_opts=opts.AxisOpts( type_="category", grid_index=1, + boundary_gap=False, axislabel_opts=opts.LabelOpts(is_show=True, font_size=8, color="#9b9da9"), ), yaxis_opts=yaxis_opts, legend_opts=legend_not_show_opts, @@ -358,6 +361,7 @@ def kline_pro(kline: List[dict], type_="category", grid_index=2, axislabel_opts=opts.LabelOpts(is_show=False), + splitline_opts=opts.SplitLineOpts(is_show=False), ), yaxis_opts=opts.AxisOpts( grid_index=2, diff --git a/czsc/utils/plotly_plot.py b/czsc/utils/plotly_plot.py index f252cf5b5..c9f4ecf85 100644 --- a/czsc/utils/plotly_plot.py +++ b/czsc/utils/plotly_plot.py @@ -25,23 +25,23 @@ class KlineChart: def __init__(self, n_rows=3, **kwargs): # 子图数量 self.n_rows = n_rows - if self.n_rows == 3: - row_heights = [0.6, 0.2, 0.2] - elif self.n_rows == 4: - row_heights = [0.55, 0.15, 0.15, 0.15] - elif self.n_rows == 5: - row_heights = [0.4, 0.15, 0.15, 0.15, 0.15] - else: - raise ValueError("n_rows 只能是 3, 4, 5") + row_heights = kwargs.get("row_heights", None) + if not row_heights: + heights_map = {3: [0.6, 0.2, 0.2], 4: [0.55, 0.15, 0.15, 0.15], 5: [0.4, 0.15, 0.15, 0.15, 0.15]} + assert self.n_rows in heights_map.keys(), "使用内置高度配置,n_rows 只能是 3, 4, 5" + row_heights = heights_map[self.n_rows] self.color_red = 'rgba(249,41,62,0.7)' self.color_green = 'rgba(0,170,59,0.7)' fig = make_subplots(rows=self.n_rows, cols=1, shared_xaxes=True, row_heights=row_heights, horizontal_spacing=0, vertical_spacing=0) + fig = fig.update_yaxes(showgrid=True, zeroline=False, automargin=True, - fixedrange=kwargs.get('y_fixed_range', True)) + fixedrange=kwargs.get('y_fixed_range', True), + showspikes=True, spikemode='across', spikesnap='cursor', showline=False, spikedash='dot') fig = fig.update_xaxes(type='category', rangeslider_visible=False, showgrid=False, automargin=True, - showticklabels=False) + showticklabels=False, showspikes=True, spikemode='across', spikesnap='cursor', + showline=False, spikedash='dot') # https://plotly.com/python/reference/layout/ fig.update_layout( @@ -56,11 +56,12 @@ def __init__(self, n_rows=3, **kwargs): legend=dict(orientation='h', yanchor="top", y=1.05, xanchor="left", x=0, bgcolor='rgba(0,0,0,0)'), template="plotly_dark", hovermode="x unified", - hoverlabel=dict(bgcolor='rgba(255,255,255,0.1)'), # 透明,更容易看清后面k线 + hoverlabel=dict(bgcolor='rgba(255,255,255,0.1)', font=dict(size=20)), # 透明,更容易看清后面k线 dragmode='pan', legend_title_font_color="red", height=kwargs.get('height', 300), ) + self.fig = fig def add_kline(self, kline: pd.DataFrame, name: str = "K线", **kwargs): diff --git a/examples/signals_dev/check_macd.py b/examples/signals_dev/check_macd.py new file mode 100644 index 000000000..b522f852b --- /dev/null +++ b/examples/signals_dev/check_macd.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2023/4/15 12:54 +describe: 检查MACD增量更新带来的影响 +""" +import sys +sys.path.insert(0, '../..') +import czsc +czsc.welcome() +import talib as ta +from test.test_analyze import read_1min + +bars = read_1min() +signals_config = [{'name': "czsc.signals.tas_macd_base_V230320", 'freq': '1分钟', 'di': 1}] +df = czsc.generate_czsc_signals(bars, signals_config=signals_config, signals_module_name='czsc.signals', df=True) +df['dif'], df['dea'], df['macd'] = ta.MACD(df['close'], fastperiod=12, slowperiod=26, signalperiod=9) +# parse cache +df['cache_macd'] = df['cache'].apply(lambda x: x['MACD12#26#9']['macd']) +df['cache_dif'] = df['cache'].apply(lambda x: x['MACD12#26#9']['dif']) +df['cache_dea'] = df['cache'].apply(lambda x: x['MACD12#26#9']['dea']) + +df = df.tail(10000) +print('macd 差异', (df['macd'] - df['cache_macd']).abs().sum()) +print('dif 差异', (df['dif'] - df['cache_dif']).abs().sum()) +print('dea 差异', (df['dea'] - df['cache_dea']).abs().sum()) + diff --git a/examples/signals_dev/pos_fx_stop_V230414.py b/examples/signals_dev/pos_fx_stop_V230414.py new file mode 100644 index 000000000..8bac38f82 --- /dev/null +++ b/examples/signals_dev/pos_fx_stop_V230414.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2023/4/14 17:48 +describe: +""" +import os +import sys +sys.path.insert(0, r'D:\ZB\git_repo\waditu\czsc\examples\signals_dev') +os.environ['czsc_verbose'] = '1' +import pandas as pd +from typing import List +from loguru import logger +from czsc import CzscStrategyBase, Position +from czsc.connectors import research +logger.enable('czsc.analyze') + +pd.set_option('expand_frame_repr', False) +pd.set_option('display.max_rows', 1000) +pd.set_option('display.max_columns', 1000) +pd.set_option('display.width', 1000) + + +class MyStrategy(CzscStrategyBase): + + def create_pos(self, freq='60分钟', freq1='15分钟'): + _pos_name = f'{freq}通道突破' + _pos = {'symbol': self.symbol, + 'name': _pos_name, + 'opens': [{'operate': '开多', + 'factors': [ + {'name': f'{freq}看多', + 'signals_all': [f'{freq}_D1通道突破#5#30#30_BS辅助V230403_看多_任意_任意_0']}, + ]}, + + {'operate': '开空', + 'factors': [ + {'name': f'{freq}看空', + 'signals_all': [f'{freq}_D1通道突破#5#30#30_BS辅助V230403_看空_任意_任意_0']}, + ]} + ], + 'exits': [ + {'operate': '平多', + 'factors': [ + {'name': f'{freq1}_{_pos_name}_止损V230414', + 'signals_all': [f'{freq1}_{_pos_name}N1_止损V230414_多头止损_任意_任意_0']}, + ]}, + + {'operate': '平空', + 'factors': [ + {'name': f'{freq1}_{_pos_name}_止损V230414', + 'signals_all': [f'{freq1}_{_pos_name}N1_止损V230414_空头止损_任意_任意_0']}, + ]}, + ], + 'interval': 7200, + 'timeout': 100, + 'stop_loss': 500, + 'T0': True} + + return Position.load(_pos) + + @property + def positions(self) -> List[Position]: + _pos_list = [self.create_pos(freq='日线', freq1='60分钟')] + return _pos_list + + +def check(): + from czsc.connectors import research + + symbols = research.get_symbols('A股主要指数') + tactic = MyStrategy(symbol=symbols[0], signals_module_name='pos_signals') + bars = research.get_raw_bars(symbols[0], tactic.base_freq, '20151101', '20210101', fq='前复权') + + tactic.check(bars, res_path=r'C:\Users\zengb\.czsc\策略信号验证', refresh=True) + + +if __name__ == '__main__': + check() + + + diff --git a/examples/signals_dev/pos_signals.py b/examples/signals_dev/pos_signals.py new file mode 100644 index 000000000..c77b41d17 --- /dev/null +++ b/examples/signals_dev/pos_signals.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +""" +author: zengbin93 +email: zeng_bin8888@163.com +create_dt: 2023/4/14 18:41 +describe: +""" +from collections import OrderedDict +from czsc.signals import * +from czsc.traders.base import CzscTrader +from czsc.utils import create_single_signal +from czsc.objects import Operate, Direction, Mark + + +def pos_fx_stop_V230414(cat: CzscTrader, **kwargs) -> OrderedDict: + """按照开仓点附近的分型止损 + + 参数模板:"{freq1}_{pos_name}N{n}_止损V230414" + + **信号逻辑:** + + 多头止损逻辑如下,反之为空头止损逻辑: + + 1. 从多头开仓点开始,在给定对的K线周期 freq1 上向前找 N 个底分型,记为 F1 + 2. 将这 N 个底分型的最低点,记为 L1,如果 L1 的价格低于开仓点的价格,则止损 + + **信号列表:** + + - Signal('日线_日线三买多头N1_止损V230414_多头止损_任意_任意_0') + - Signal('日线_日线三买多头N1_止损V230414_空头止损_任意_任意_0') + + :param cat: CzscTrader对象 + :param kwargs: 参数字典 + - pos_name: str,开仓信号的名称 + - freq1: str,给定的K线周期 + - n: int,向前找的分型个数,默认为 3 + :return: + """ + pos_name = kwargs["pos_name"] + freq1 = kwargs["freq1"] + n = int(kwargs.get('n', 3)) + k1, k2, k3 = f"{freq1}_{pos_name}N{n}_止损V230414".split("_") + v1 = '其他' + + pos = [x for x in cat.positions if x.name == pos_name][0] + if len(pos.operates) == 0 or pos.operates[-1]['op'] in [Operate.SE, Operate.LE]: + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) + + c = cat.kas[freq1] + op = pos.operates[-1] + + # 多头止损逻辑 + if op['op'] == Operate.LO: + fxs = [x for x in c.fx_list if x.mark == Mark.D and x.dt < op['dt']][-n:] + if cat.latest_price < min([x.low for x in fxs]): + v1 = '多头止损' + + # 空头止损逻辑 + if op['op'] == Operate.SO: + fxs = [x for x in c.fx_list if x.mark == Mark.G and x.dt < op['dt']][-n:] + if cat.latest_price > max([x.high for x in fxs]): + v1 = '空头止损' + + return create_single_signal(k1=k1, k2=k2, k3=k3, v1=v1) diff --git a/examples/signals_dev/sig_parse_backup.py b/examples/signals_dev/sig_parse_backup.py deleted file mode 100644 index f598d0766..000000000 --- a/examples/signals_dev/sig_parse_backup.py +++ /dev/null @@ -1,290 +0,0 @@ -# -*- coding: utf-8 -*- -""" -author: zengbin93 -email: zeng_bin8888@163.com -create_dt: 2023/3/29 10:04 -describe: -""" -import re -from loguru import logger -from parse import parse -from difflib import SequenceMatcher -from czsc.objects import Signal -from czsc.utils import import_by_name - - -class SignalsParserBackup: - """解析一串信号,生成信号函数配置""" - - def __init__(self, signals_module='czsc.signals', **kwargs): - """ - - :param signals_module: 指定信号函数所在模块 - :param kwargs: - usr_parse_map: 用户自定义信号函数解析方法,字典类型,key 为信号函数名,value 为解析方法 - """ - self.signals_module = signals_module - sig_name_map = {} - sig_pats_map = {} - - signals_module = import_by_name(signals_module) - for name in dir(signals_module): - if "_" not in name: - continue - - try: - doc = getattr(signals_module, name).__doc__ - # 解析信号函数参数 - pats = re.findall(r"参数模板:\"(.*)\"", doc) - if pats: - sig_pats_map[name] = pats[0] - - # 解析信号列表 - sigs = re.findall(r"Signal\('(.*)'\)", doc) - if sigs: - sig_name_map[name] = [Signal(x) for x in sigs] - - except Exception as e: - logger.error(f"解析信号函数 {name} 出错:{e}") - - self.sig_name_map = sig_name_map - self.sig_pats_map = sig_pats_map - - # 自动获取解析函数 - self._parse_map = {k: getattr(self, f"_SignalsParser__parse_{k}") for k in self.sig_name_map.keys() - if getattr(self, f"_SignalsParser__parse_{k}", None)} - - # 用户自定义信号函数解析方法传入 - if kwargs.get("usr_parse_map", None): - self._parse_map.update(kwargs.get("usr_parse_map")) - - def parse_params(self, name, signal): - """获取信号函数参数 - - :param name: 信号函数名称 - :param signal: 需要解析的信号 - :return: - """ - key = Signal(signal).key - pats = self.sig_pats_map.get(name, None) - if not pats: - return None - - try: - params = parse(pats, key).named - if 'di' in params: - params['di'] = int(params['di']) - - params['name'] = f"{self.signals_module}.{name}" - return params - except Exception as e: - logger.error(f"解析信号 {signal} - {name} - {pats} 出错:{e}") - return None - - def get_function_name(self, signal): - """获取信号函数名称""" - sig_name_map = self.sig_name_map - _signal = Signal(signal) - _k3_match = list({k for k, v in sig_name_map.items() if v[0].k3 == _signal.k3}) - # 优先匹配 k3,满足条件直接返回 - if len(_k3_match) == 1: - return _k3_match[0] - - if len(_k3_match) > 1: - logger.warning(f"信号 {signal} 有多个匹配函数:{_k3_match},请手动解析信号") - return None - - _signal_k2_k3 = _signal.k2 + _signal.k3 - scores = {} - for k, v in sig_name_map.items(): - # 计算 k2, k3 的相似度 - _vs = [SequenceMatcher(None, s.k2 + s.k3, _signal_k2_k3).ratio() for s in v] - if max(_vs) >= 0.8: - scores[k] = max(_vs) - - if not scores: - return None - - return max(scores, key=scores.get) - - def parse(self, signal_seq): - """解析信号序列""" - res = [] - for signal in signal_seq: - name = self.get_function_name(signal) - - # 首先使用参数模板进行解析 - if name in self.sig_pats_map: - row = self.parse_params(name, signal) - if row and row not in res: - res.append(row) - - # 其次使用信号函数名称对应的解析方法进行解析 - elif name in self._parse_map: - row = self._parse_map[name](signal) - row['name'] = f"{self.signals_module}.{name}" - if row not in res: - res.append(row) - - else: - logger.warning(f"未找到解析函数:{name},请手动解析信号:{signal}") - - return res - - def parse_with_name(self, signal_map): - """解析信号字典,信号字典的 key 为信号函数名,value 为信号序列""" - res = [] - for name, signal_seq in signal_map.items(): - if name in self._parse_map: - for signal in signal_seq: - row = self._parse_map[name](signal) - row['name'] = f"czsc.signals.{name}" - if row not in res: - res.append(row) - else: - logger.warning(f"未找到解析函数:{name},请手动解析信号:{signal_seq}") - return res - - @staticmethod - def __remove_duplicates(_res): - # 去除重复的信号配置 - _res = [dict(t) for t in {tuple(d.items()) for d in _res}] - return _res - - @staticmethod - def __parse_bar_single_V230214(signal): - # https://czsc.readthedocs.io/en/0.9.13/api/czsc.signals.bar_single_V230214.html - pats = re.findall(r"(.*?)_D(\d+)T(\d+)_", signal)[0] - _row = {"freq": pats[0], "di": int(pats[1]), 't': int(pats[2]) / 10} - return _row - - @staticmethod - def __parse_cxt_third_bs_V230319(signal): - pats = re.findall(r"(.*?)_D(\d+)(\D+)(\d+)_", signal)[0] - _row = {"freq": pats[0], "di": int(pats[1]), 'ma_type': pats[2], 'timeperiod': int(pats[3])} - return _row - - @staticmethod - def __parse_byi_bi_end_V230107(signal): - pats = re.findall(r"(.*?)_", signal) - _row = {"freq": pats[0]} - return _row - - @staticmethod - def __parse_byi_bi_end_V230106(signal): - pats = re.findall(r"(.*?)_", signal) - _row = {"freq": pats[0]} - return _row - - @staticmethod - def __parse_bar_accelerate_V221110(signal): - # https://czsc.readthedocs.io/en/0.9.13/api/czsc.signals.bar_accelerate_V221110.html - pats = re.findall(r"(.*?)_D(\d+)W(\d+)_", signal)[0] - _row = {"freq": pats[0], "di": int(pats[1]), 'window': int(pats[2])} - return _row - - @staticmethod - def __parse_bar_accelerate_V221118(signal): - # https://czsc.readthedocs.io/en/0.9.13/api/czsc.signals.bar_accelerate_V221118.html - pats = re.findall(r"(.*?)_D(\d+)W(\d+)(\D+)(\d+)_", signal)[0] - _row = { - "freq": pats[0], - "di": int(pats[1]), - "window ": int(pats[2]), - "ma_type": pats[3], - "timeperiod ": int(pats[4]), - } - return _row - - @staticmethod - def __parse_bar_bpm_V230227(signal): - # https://czsc.readthedocs.io/en/0.9.13/api/czsc.signals.bar_bpm_V230227.html - pats = re.findall(r"(.*?)_D(\d+)N(\d+)T(\d+)_", signal)[0] - _row = {"freq": pats[0], "di": int(pats[1]), "n ": int(pats[2]), "th ": int(pats[3])} - return _row - - -class SignalsParser: - """解析一串信号,生成信号函数配置""" - - def __init__(self, signals_module='czsc.signals', **kwargs): - """ - - :param signals_module: 指定信号函数所在模块 - :param kwargs: - usr_parse_map: 用户自定义信号函数解析方法,字典类型,key 为信号函数名,value 为解析方法 - """ - self.signals_module = signals_module - sig_name_map = {} - sig_pats_map = {} - - signals_module = import_by_name(signals_module) - for name in dir(signals_module): - if "_" not in name: - continue - - try: - doc = getattr(signals_module, name).__doc__ - # 解析信号函数参数 - pats = re.findall(r"参数模板:\"(.*)\"", doc) - if pats: - sig_pats_map[name] = pats[0] - - # 解析信号列表 - sigs = re.findall(r"Signal\('(.*)'\)", doc) - if sigs: - sig_name_map[name] = [Signal(x) for x in sigs] - - except Exception as e: - logger.error(f"解析信号函数 {name} 出错:{e}") - - self.sig_name_map = sig_name_map - self.sig_pats_map = sig_pats_map - - def parse_params(self, name, signal): - """获取信号函数参数 - - :param name: 信号函数名称 - :param signal: 需要解析的信号 - :return: - """ - key = Signal(signal).key - pats = self.sig_pats_map.get(name, None) - if not pats: - return None - - try: - params = parse(pats, key).named - if 'di' in params: - params['di'] = int(params['di']) - - params['name'] = f"{self.signals_module}.{name}" - return params - except Exception as e: - logger.error(f"解析信号 {signal} - {name} - {pats} 出错:{e}") - return None - - def get_function_name(self, signal): - """获取信号函数名称""" - sig_name_map = self.sig_name_map - _signal = Signal(signal) - _k3_match = list({k for k, v in sig_name_map.items() if v[0].k3 == _signal.k3}) - # 优先匹配 k3,满足条件直接返回 - if len(_k3_match) == 1: - return _k3_match[0] - else: - logger.error(f"信号 {signal} 有多个匹配函数:{_k3_match},请手动解析信号") - return None - - def parse(self, signal_seq): - """解析信号序列""" - res = [] - for signal in signal_seq: - name = self.get_function_name(signal) - if name in self.sig_pats_map: - row = self.parse_params(name, signal) - if row and row not in res: - res.append(row) - else: - logger.warning(f"未找到解析函数:{name},请手动解析信号:{signal}") - return res diff --git a/requirements.txt b/requirements.txt index d87a3c0f3..e6e070334 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ requests>=2.24.0 -pyecharts==1.9.1 +pyecharts>=1.9.1 tqdm pandas>=1.1.0 numpy>=1.16.5 diff --git a/test/test_analyze.py b/test/test_analyze.py index 14225f8ad..bda9a5842 100644 --- a/test/test_analyze.py +++ b/test/test_analyze.py @@ -90,6 +90,10 @@ def test_czsc_update(): c = CZSC(bars) assert not c.signals + # 测试 ubi 属性 + ubi = c.ubi + assert ubi['direction'] == Direction.Down + assert ubi['high_bar'].dt < ubi['low_bar'].dt # 测试自定义信号 c = CZSC(bars, get_signals=get_user_signals) assert len(c.signals) == 7 diff --git a/test/test_strategy.py b/test/test_strategy.py index e0aee09ee..9bc66bf6b 100644 --- a/test/test_strategy.py +++ b/test/test_strategy.py @@ -42,3 +42,10 @@ def test_czsc_strategy_example2(): assert len(os.listdir("trade_replay_test")) == 3 shutil.rmtree("trade_replay_test") + # 验证信号计算的准确性 + strategy.check(bars, res_path="trade_check_test", sdt='20190101', exist_ok=False) + assert len(os.listdir("trade_check_test")) == 2 + assert os.path.exists(os.path.join("trade_check_test", "signals.xlsx")) + assert os.path.exists(os.path.join("trade_check_test", "15分钟_D0停顿分型_BE辅助V230106")) + shutil.rmtree("trade_check_test") +