diff --git a/mmf_sa/Forecaster.py b/mmf_sa/Forecaster.py index 957b9d5..f2ff733 100644 --- a/mmf_sa/Forecaster.py +++ b/mmf_sa/Forecaster.py @@ -355,8 +355,8 @@ def evaluate_foundation_model(self, model_conf): with mlflow.start_run(experiment_id=self.experiment_id) as run: model_name = model_conf["name"] model = self.model_registry.get_model(model_name) - # For now, only support registering chronos and moirai models - if model_conf["framework"] in ["Chronos", "Moirai"]: + # For now, only support registering chronos, moirai and moment models + if model_conf["framework"] in ["Chronos", "Moirai", "Moment"]: model.register( registered_model_name=f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}" ) diff --git a/mmf_sa/models/momentforecast/MomentPipeline.py b/mmf_sa/models/momentforecast/MomentPipeline.py index ff2dbfe..bd27e34 100644 --- a/mmf_sa/models/momentforecast/MomentPipeline.py +++ b/mmf_sa/models/momentforecast/MomentPipeline.py @@ -4,6 +4,9 @@ import pandas as pd import numpy as np import torch +import mlflow +from mlflow.types import Schema, TensorSpec +from mlflow.models.signature import ModelSignature from sktime.performance_metrics.forecasting import mean_absolute_percentage_error from typing import Iterator from pyspark.sql.functions import collect_list, pandas_udf @@ -19,9 +22,32 @@ def __init__(self, params): self.model = None self.install("git+https://github.com/moment-timeseries-foundation-model/moment.git") - def install(self, package: str): + @staticmethod + def install(package: str): subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"]) + def register(self, registered_model_name: str): + pipeline = MomentModel( + self.repo, + self.params["prediction_length"], + ) + input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))]) + output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))]) + signature = ModelSignature(inputs=input_schema, outputs=output_schema) + input_example = np.random.rand(52) + mlflow.pyfunc.log_model( + "model", + python_model=pipeline, + registered_model_name=registered_model_name, + signature=signature, + input_example=input_example, + pip_requirements=[ + "git+https://github.com/moment-timeseries-foundation-model/moment.git", + "git+https://github.com/databricks-industry-solutions/many-model-forecasting.git", + "pyspark==3.5.0", + ], + ) + def create_horizon_timestamps_udf(self): @pandas_udf('array') def horizon_timestamps_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: @@ -157,3 +183,34 @@ def __init__(self, params): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.repo = "AutonLab/MOMENT-1-large" + +class MomentModel(mlflow.pyfunc.PythonModel): + def __init__(self, repository, prediction_length): + from momentfm import MOMENTPipeline + self.repository = repository + self.prediction_length = prediction_length + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.pipeline = MOMENTPipeline.from_pretrained( + self.repository, + device_map=self.device, + model_kwargs={ + "task_name": "forecasting", + "forecast_horizon": self.prediction_length}, + ) + self.pipeline.init() + self.pipeline = self.pipeline.to(self.device) + + def predict(self, context, input_data, params=None): + series = list(input_data) + if len(series) < 512: + input_mask = [1] * len(series) + [0] * (512 - len(series)) + series = series + [0] * (512 - len(series)) + else: + input_mask = [1] * 512 + series = series[-512:] + input_mask = torch.reshape(torch.tensor(input_mask),(1, 512)).to(self.device) + series = torch.reshape(torch.tensor(series),(1, 1, 512)).to(dtype=torch.float32).to(self.device) + output = self.model(series, input_mask=input_mask) + forecast = output.forecast.squeeze().tolist() + return forecast +