Skip to content

Commit

Permalink
model logging and registry for foundation models
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed Jun 3, 2024
1 parent 8d2d082 commit d4f2178
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
9 changes: 4 additions & 5 deletions mmf_sa/models/moiraiforecast/MoiraiPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,12 @@ def __init__(self, repository, prediction_length, patch_size, num_samples):
self.patch_size = patch_size
self.num_samples = num_samples
self.module = MoiraiModule.from_pretrained(self.repository)
self.pipeline = None

def predict(self, context, input_data, params=None):
from einops import rearrange
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
model = MoiraiForecast(
self.pipeline = MoiraiForecast(
module=self.module,
prediction_length=self.prediction_length,
context_length=len(input_data),
Expand All @@ -225,14 +226,12 @@ def predict(self, context, input_data, params=None):
past_feat_dynamic_real_dim=0,
)
# Time series values. Shape: (batch, time, variate)
past_target = rearrange(
torch.as_tensor(input_data, dtype=torch.float32), "t -> 1 t 1"
)
past_target = rearrange(torch.as_tensor(input_data, dtype=torch.float32), "t -> 1 t 1")
# 1s if the value is observed, 0s otherwise. Shape: (batch, time, variate)
past_observed_target = torch.ones_like(past_target, dtype=torch.bool)
# 1s if the value is padding, 0s otherwise. Shape: (batch, time)
past_is_pad = torch.zeros_like(past_target, dtype=torch.bool).squeeze(-1)
forecast = model(
forecast = self.pipeline(
past_target=past_target,
past_observed_target=past_observed_target,
past_is_pad=past_is_pad,
Expand Down
2 changes: 1 addition & 1 deletion mmf_sa/models/momentforecast/MomentPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def predict(self, context, input_data, params=None):
series = series[-512:]
input_mask = torch.reshape(torch.tensor(input_mask),(1, 512)).to(self.device)
series = torch.reshape(torch.tensor(series),(1, 1, 512)).to(dtype=torch.float32).to(self.device)
output = self.model(series, input_mask=input_mask)
output = self.pipeline(series, input_mask=input_mask)
forecast = output.forecast.squeeze().tolist()
return forecast

0 comments on commit d4f2178

Please sign in to comment.