Skip to content

Commit

Permalink
added covariate support for timesfm models
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed Jan 27, 2025
1 parent 4c64793 commit 22509d0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions mmf_sa/Forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,8 +589,8 @@ def score_foundation_model(self, model_conf):
model_name = model_conf["name"]
_, model_uri = self.get_model_for_scoring(model_conf)
model = self.model_registry.get_model(model_name)
hist_df, removed = self.prepare_data_for_global_model()
prediction_df, model_pretrained = model.forecast(hist_df, spark=self.spark)
score_df, removed = self.prepare_data_for_global_model("scoring")
prediction_df, model_pretrained = model.forecast(score_df, spark=self.spark)
sdf = self.spark.createDataFrame(prediction_df).drop('index')
(
sdf.withColumn(self.conf["group_id"], col(self.conf["group_id"]).cast(StringType()))
Expand Down

0 comments on commit 22509d0

Please sign in to comment.