Skip to content

Commit

Permalink
model logging and registry for foundation models
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed Jun 3, 2024
1 parent d087995 commit 340b2f4
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 5 deletions.
8 changes: 5 additions & 3 deletions mmf_sa/Forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,11 @@ def evaluate_foundation_model(self, model_conf):
with mlflow.start_run(experiment_id=self.experiment_id) as run:
model_name = model_conf["name"]
model = self.model_registry.get_model(model_name)
model.register(
registered_model_name=f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}"
)
# For now, only support registering chronos and moirai models
if model_conf["framework"] in ["Chronos", "Moirai"]:
model.register(
registered_model_name=f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}"
)
hist_df, removed = self.prepare_data_for_global_model("evaluating") # Reuse the same as global
train_df, val_df = self.split_df_train_val(hist_df)
model_uri = f"runs:/{run.info.run_id}/model"
Expand Down
2 changes: 1 addition & 1 deletion mmf_sa/models/models_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ models:
framework: Chronos
model_type: foundation
num_samples: 10
batch_size: 4
batch_size: 2

MoiraiBase:
module: mmf_sa.models.moiraiforecast.MoiraiPipeline
Expand Down
68 changes: 67 additions & 1 deletion mmf_sa/models/moiraiforecast/MoiraiPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pandas as pd
import numpy as np
import torch
import mlflow
from mlflow.types import Schema, TensorSpec
from mlflow.models.signature import ModelSignature
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error
from typing import Iterator
from pyspark.sql.functions import collect_list, pandas_udf
Expand All @@ -19,9 +22,34 @@ def __init__(self, params):
self.model = None
self.install("git+https://github.com/SalesforceAIResearch/uni2ts.git")

def install(self, package: str):
@staticmethod
def install(package: str):
subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])

def register(self, registered_model_name: str):
pipeline = MoiraiModel(
self.repo,
self.params["prediction_length"],
self.params["patch_size"],
self.params["num_samples"],
)
input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))])
output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
input_example = np.random.rand(52)
mlflow.pyfunc.log_model(
"model",
python_model=pipeline,
registered_model_name=registered_model_name,
signature=signature,
input_example=input_example,
pip_requirements=[
"git+https://github.com/SalesforceAIResearch/uni2ts.git",
"git+https://github.com/databricks-industry-solutions/many-model-forecasting.git",
"pyspark==3.5.0",
],
)

def create_horizon_timestamps_udf(self):
@pandas_udf('array<timestamp>')
def horizon_timestamps_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
Expand Down Expand Up @@ -172,3 +200,41 @@ def __init__(self, params):
super().__init__(params)
self.params = params
self.repo = "Salesforce/moirai-1.0-R-base"


class MoiraiModel(mlflow.pyfunc.PythonModel):
def __init__(self, repository, prediction_length, patch_size, num_samples):
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
self.repository = repository
self.prediction_length = prediction_length
self.patch_size = patch_size
self.num_samples = num_samples
self.module = MoiraiModule.from_pretrained(self.repository)

def predict(self, context, input_data, params=None):
from einops import rearrange
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
model = MoiraiForecast(
module=self.module,
prediction_length=self.prediction_length,
context_length=len(input_data),
patch_size=self.patch_size,
num_samples=self.num_samples,
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)
# Time series values. Shape: (batch, time, variate)
past_target = rearrange(
torch.as_tensor(input_data, dtype=torch.float32), "t -> 1 t 1"
)
# 1s if the value is observed, 0s otherwise. Shape: (batch, time, variate)
past_observed_target = torch.ones_like(past_target, dtype=torch.bool)
# 1s if the value is padding, 0s otherwise. Shape: (batch, time)
past_is_pad = torch.zeros_like(past_target, dtype=torch.bool).squeeze(-1)
forecast = model(
past_target=past_target,
past_observed_target=past_observed_target,
past_is_pad=past_is_pad,
)
return np.median(forecast[0], axis=0)

0 comments on commit 340b2f4

Please sign in to comment.