diff --git a/mmf_sa/models/moiraiforecast/MoiraiPipeline.py b/mmf_sa/models/moiraiforecast/MoiraiPipeline.py index 5839926..b770427 100644 --- a/mmf_sa/models/moiraiforecast/MoiraiPipeline.py +++ b/mmf_sa/models/moiraiforecast/MoiraiPipeline.py @@ -210,11 +210,12 @@ def __init__(self, repository, prediction_length, patch_size, num_samples): self.patch_size = patch_size self.num_samples = num_samples self.module = MoiraiModule.from_pretrained(self.repository) + self.pipeline = None def predict(self, context, input_data, params=None): from einops import rearrange from uni2ts.model.moirai import MoiraiForecast, MoiraiModule - model = MoiraiForecast( + self.pipeline = MoiraiForecast( module=self.module, prediction_length=self.prediction_length, context_length=len(input_data), @@ -225,14 +226,12 @@ def predict(self, context, input_data, params=None): past_feat_dynamic_real_dim=0, ) # Time series values. Shape: (batch, time, variate) - past_target = rearrange( - torch.as_tensor(input_data, dtype=torch.float32), "t -> 1 t 1" - ) + past_target = rearrange(torch.as_tensor(input_data, dtype=torch.float32), "t -> 1 t 1") # 1s if the value is observed, 0s otherwise. Shape: (batch, time, variate) past_observed_target = torch.ones_like(past_target, dtype=torch.bool) # 1s if the value is padding, 0s otherwise. Shape: (batch, time) past_is_pad = torch.zeros_like(past_target, dtype=torch.bool).squeeze(-1) - forecast = model( + forecast = self.pipeline( past_target=past_target, past_observed_target=past_observed_target, past_is_pad=past_is_pad, diff --git a/mmf_sa/models/momentforecast/MomentPipeline.py b/mmf_sa/models/momentforecast/MomentPipeline.py index bd27e34..b421089 100644 --- a/mmf_sa/models/momentforecast/MomentPipeline.py +++ b/mmf_sa/models/momentforecast/MomentPipeline.py @@ -210,7 +210,7 @@ def predict(self, context, input_data, params=None): series = series[-512:] input_mask = torch.reshape(torch.tensor(input_mask),(1, 512)).to(self.device) series = torch.reshape(torch.tensor(series),(1, 1, 512)).to(dtype=torch.float32).to(self.device) - output = self.model(series, input_mask=input_mask) + output = self.pipeline(series, input_mask=input_mask) forecast = output.forecast.squeeze().tolist() return forecast