-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
Copy pathexample.py
38 lines (30 loc) · 1.18 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
import qlib
from ruamel.yaml import YAML
from qlib.utils import init_instance_by_config
def main(seed, config_file="configs/config_alstm.yaml"):
# set random seed
with open(config_file) as f:
yaml = YAML(typ="safe", pure=True)
config = yaml.load(f)
# seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}"
seed_suffix = ""
config["task"]["model"]["kwargs"].update(
{"seed": seed, "logdir": config["task"]["model"]["kwargs"]["logdir"] + seed_suffix}
)
# initialize workflow
qlib.init(
provider_uri=config["qlib_init"]["provider_uri"],
region=config["qlib_init"]["region"],
)
dataset = init_instance_by_config(config["task"]["dataset"])
model = init_instance_by_config(config["task"]["model"])
# train model
model.fit(dataset)
if __name__ == "__main__":
# set params from cmd
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--seed", type=int, default=1000, help="random seed")
parser.add_argument("--config_file", type=str, default="configs/config_alstm.yaml", help="config file")
args = parser.parse_args()
main(**vars(args))