From 22509d0060e07c94f0009b16471c3b5158f8fc4b Mon Sep 17 00:00:00 2001 From: Ryuta Yoshimatsu Date: Mon, 27 Jan 2025 13:55:18 +0100 Subject: [PATCH] added covariate support for timesfm models --- mmf_sa/Forecaster.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmf_sa/Forecaster.py b/mmf_sa/Forecaster.py index 5d5bdc4..eb37abd 100644 --- a/mmf_sa/Forecaster.py +++ b/mmf_sa/Forecaster.py @@ -589,8 +589,8 @@ def score_foundation_model(self, model_conf): model_name = model_conf["name"] _, model_uri = self.get_model_for_scoring(model_conf) model = self.model_registry.get_model(model_name) - hist_df, removed = self.prepare_data_for_global_model() - prediction_df, model_pretrained = model.forecast(hist_df, spark=self.spark) + score_df, removed = self.prepare_data_for_global_model("scoring") + prediction_df, model_pretrained = model.forecast(score_df, spark=self.spark) sdf = self.spark.createDataFrame(prediction_df).drop('index') ( sdf.withColumn(self.conf["group_id"], col(self.conf["group_id"]).cast(StringType()))