-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathpredict.py
84 lines (73 loc) · 2.84 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
predict trades for one specific day
"""
import functools
import logging
from datetime import datetime, timedelta
import pandas as pd
import tensorflow as tf
from absl import app
from absl import logging as absl_logging
from tf_agents.system import system_multiprocessing as multiprocessing
from config import settings
from envirement.trading_py_env import TradingPyEnv
from models.model_ppo import TradeDRLAgent
from preprocess_data import preprocess_data
def main(_):
PREDICT_DAY = "2020-06-01"
days_to_subtract = 60
ticker_list = settings.DOW_30_TICKER
data_columns = settings.DATA_COLUMNS
# Preprocess data
df_trade = preprocess_data.preprocess_data(
tic_list=ticker_list,
start_date=str(datetime.strptime(PREDICT_DAY, '%Y-%m-%d')
- timedelta(days=days_to_subtract)),
end_date=PREDICT_DAY,
field_mappings=settings.CSV_FIELD_MAPPINGS,
baseline_filed_mappings=settings.BASELINE_FIELD_MAPPINGS,
csv_file_info=settings.CSV_FILE_SETTINGS,
user_columns=settings.USER_DEFINED_FEATURES
)
information_cols = []
unavailable_cols = []
for col in data_columns:
if col in df_trade.columns:
information_cols.append(col)
else:
unavailable_cols.append(col)
if not information_cols:
logging.error('No column to train')
raise ValueError
else:
logging.info(f'Columns used to train:\n{information_cols} ✅')
if unavailable_cols:
logging.info(f'Unavailable columns:\n{unavailable_cols} ❌')
# df_trade[information_cols].to_csv("temp.csv", index=1, encoding="utf-8")
logging.info(f'TensorFlow v{tf.version.VERSION}')
logging.info(
f"Available [GPU] devices:\n{tf.config.list_physical_devices('GPU')}")
# Predict
test_py_env = TradingPyEnv(
df=df_trade,
daily_information_cols=information_cols,
)
model = TradeDRLAgent()
_, df_actions = model.test_trade(env=test_py_env)
assert len(df_trade.tic.unique()) == len(
df_actions.tail(1).transactions.values[0])
pred_inf_df = pd.DataFrame(
{'ticker': df_trade.tic.unique()}
)
pred_inf_df['trade'] = pd.Series(df_actions.tail(1).transactions.values[0])
last_day = pd.to_datetime(str(df_actions.tail(1).date.values[0]))
last_day_str = last_day.strftime("%B %d, %Y")
logging.info(f'\nPredicted trades for {last_day_str}:\n{pred_inf_df}')
if __name__ == '__main__':
# FMT = '[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s'
FMT = '[%(levelname)s] %(message)s'
formatter = logging.Formatter(FMT)
absl_logging.get_absl_handler().setFormatter(formatter)
absl_logging.set_verbosity('info')
# logging.basicConfig(format='%(message)s', level=logging.INFO)
multiprocessing.handle_main(functools.partial(app.run, main))