Skip to content

Commit

Permalink
Merge pull request #76 from databricks-industry-solutions/update-moirai
Browse files Browse the repository at this point in the history
integrate-moirai-moe-models
  • Loading branch information
ryuta-yoshimatsu authored Jan 15, 2025
2 parents f8b0790 + 1be2769 commit 6d4b537
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 36 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Get started now!

- Jan 2025: [TimesFM](https://github.com/google-research/timesfm) is available for univariate forecasting. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/foundation_daily.py).
- Jan 2025: [Chronos Bolt](https://github.com/amazon-science/chronos-forecasting) models are available for univariate forecasting. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/foundation_daily.py).
- Jan 2025: [Moirai MoE](https://github.com/SalesforceAIResearch/uni2ts) models are available for univariate forecasting. Try the [notebook](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/foundation_daily.py).

## Getting started

Expand Down Expand Up @@ -211,6 +212,9 @@ active_models = [
"MoiraiSmall",
"MoiraiBase",
"MoiraiLarge",
"MoiraiMoESmall",
"MoiraiMoEBase",
"MoiraiMoELarge",
"TimesFM_1_0_200m",
"TimesFM_2_0_500m",
]
Expand Down
3 changes: 3 additions & 0 deletions examples/foundation_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def transform_group(df):
"MoiraiSmall",
"MoiraiBase",
"MoiraiLarge",
"MoiraiMoESmall",
"MoiraiMoEBase",
"MoiraiMoELarge",
"TimesFM_1_0_200m",
"TimesFM_2_0_500m",
]
Expand Down
3 changes: 3 additions & 0 deletions examples/foundation_monthly.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def transform_group(df):
"MoiraiSmall",
"MoiraiBase",
"MoiraiLarge",
"MoiraiMoESmall",
"MoiraiMoEBase",
"MoiraiMoELarge",
"TimesFM_1_0_200m",
"TimesFM_2_0_500m",
]
Expand Down
3 changes: 3 additions & 0 deletions examples/m5-examples/foundation_daily_m5.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
"MoiraiSmall",
"MoiraiBase",
"MoiraiLarge",
"MoiraiMoESmall",
"MoiraiMoEBase",
"MoiraiMoELarge",
"TimesFM_1_0_200m",
"TimesFM_2_0_500m",
]
Expand Down
31 changes: 29 additions & 2 deletions mmf_sa/models/models_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,15 @@ models:
num_samples: 10
batch_size: 16

MoiraiSmall:
module: mmf_sa.models.moiraiforecast.MoiraiPipeline
model_class: MoiraiSmall
framework: Moirai
model_type: foundation
num_samples: 10
patch_size: 32
batch_size: 10

MoiraiBase:
module: mmf_sa.models.moiraiforecast.MoiraiPipeline
model_class: MoiraiBase
Expand All @@ -426,9 +435,27 @@ models:
patch_size: 32
batch_size: 10

MoiraiSmall:
MoiraiMoESmall:
module: mmf_sa.models.moiraiforecast.MoiraiPipeline
model_class: MoiraiSmall
model_class: MoiraiMoESmall
framework: Moirai
model_type: foundation
num_samples: 10
patch_size: 32
batch_size: 10

MoiraiMoEBase:
module: mmf_sa.models.moiraiforecast.MoiraiPipeline
model_class: MoiraiMoEBase
framework: Moirai
model_type: foundation
num_samples: 10
patch_size: 32
batch_size: 10

MoiraiMoELarge:
module: mmf_sa.models.moiraiforecast.MoiraiPipeline
model_class: MoiraiMoELarge
framework: Moirai
model_type: foundation
num_samples: 10
Expand Down
110 changes: 76 additions & 34 deletions mmf_sa/models/moiraiforecast/MoiraiPipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from abc import ABC
import subprocess
import sys
import pandas as pd
import numpy as np
import torch
Expand All @@ -24,11 +21,6 @@ def __init__(self, params):
self.params = params
self.device = None
self.model = None
self.install("git+https://github.com/SalesforceAIResearch/uni2ts.git")

@staticmethod
def install(package: str):
subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])

def register(self, registered_model_name: str):
pipeline = MoiraiModel(
Expand All @@ -48,7 +40,7 @@ def register(self, registered_model_name: str):
signature=signature,
input_example=input_example,
pip_requirements=[
"git+https://github.com/SalesforceAIResearch/uni2ts.git",
"uni2ts",
"git+https://github.com/databricks-industry-solutions/many-model-forecasting.git",
"pyspark==3.5.0",
],
Expand Down Expand Up @@ -171,21 +163,37 @@ def predict_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
import pandas as pd
from einops import rearrange
from uni2ts.model.moirai import MoiraiModule, MoiraiForecast
module = MoiraiModule.from_pretrained(self.repo)
from uni2ts.model.moirai_moe import MoiraiMoEForecast, MoiraiMoEModule
if 'moe' in self.repo:
module = MoiraiMoEModule.from_pretrained(self.repo)
else:
module = MoiraiModule.from_pretrained(self.repo)
# inference
for batch in batch_iterator:
median = []
for series in batch:
model = MoiraiForecast(
module=module,
prediction_length=self.params["prediction_length"],
context_length=len(series),
patch_size=self.params["patch_size"],
num_samples=self.params["num_samples"],
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)
if 'moe' in self.repo:
model = MoiraiMoEForecast(
module=module,
prediction_length=self.params["prediction_length"],
context_length=len(series),
patch_size=16,
num_samples=self.params["num_samples"],
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)
else:
model = MoiraiForecast(
module=module,
prediction_length=self.params["prediction_length"],
context_length=len(series),
patch_size=self.params["patch_size"],
num_samples=self.params["num_samples"],
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)

# Time series values. Shape: (batch, time, variate)
past_target = rearrange(
Expand All @@ -210,46 +218,80 @@ class MoiraiSmall(MoiraiForecaster):
def __init__(self, params):
super().__init__(params)
self.params = params
self.repo = "Salesforce/moirai-1.0-R-small"
self.repo = "Salesforce/moirai-1.1-R-small"


class MoiraiBase(MoiraiForecaster):
def __init__(self, params):
super().__init__(params)
self.params = params
self.repo = "Salesforce/moirai-1.0-R-base"
self.repo = "Salesforce/moirai-1.1-R-base"


class MoiraiLarge(MoiraiForecaster):
def __init__(self, params):
super().__init__(params)
self.params = params
self.repo = "Salesforce/moirai-1.0-R-base"
self.repo = "Salesforce/moirai-1.1-R-large"

class MoiraiMoESmall(MoiraiForecaster):
def __init__(self, params):
super().__init__(params)
self.params = params
self.repo = "Salesforce/moirai-moe-1.0-R-small"

class MoiraiMoEBase(MoiraiForecaster):
def __init__(self, params):
super().__init__(params)
self.params = params
self.repo = "Salesforce/moirai-moe-1.0-R-base"

class MoiraiMoELarge(MoiraiForecaster):
def __init__(self, params):
super().__init__(params)
self.params = params
self.repo = "Salesforce/moirai-moe-1.0-R-large"

class MoiraiModel(mlflow.pyfunc.PythonModel):
def __init__(self, repository, prediction_length, patch_size, num_samples):
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
from uni2ts.model.moirai_moe import MoiraiMoEForecast, MoiraiMoEModule
self.repository = repository
self.prediction_length = prediction_length
self.patch_size = patch_size
self.num_samples = num_samples
self.module = MoiraiModule.from_pretrained(self.repository)
if 'moe' in self.repository:
self.module = MoiraiMoEModule.from_pretrained(self.repository)
else:
self.module = MoiraiModule.from_pretrained(self.repository)
self.pipeline = None

def predict(self, context, input_data, params=None):
from einops import rearrange
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
self.pipeline = MoiraiForecast(
module=self.module,
prediction_length=self.prediction_length,
context_length=len(input_data),
patch_size=self.patch_size,
num_samples=self.num_samples,
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)
from uni2ts.model.moirai_moe import MoiraiMoEForecast, MoiraiMoEModule
if 'moe' in self.repository:
self.pipeline = MoiraiMoEForecast(
module=self.module,
prediction_length=self.prediction_length,
context_length=len(input_data),
patch_size=self.patch_size,
num_samples=self.num_samples,
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)
else:
self.pipeline = MoiraiForecast(
module=self.module,
prediction_length=self.prediction_length,
context_length=len(input_data),
patch_size=16,
num_samples=self.num_samples,
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)
# Time series values. Shape: (batch, time, variate)
past_target = rearrange(torch.as_tensor(input_data, dtype=torch.float32), "t -> 1 t 1")
# 1s if the value is observed, 0s otherwise. Shape: (batch, time, variate)
Expand Down

0 comments on commit 6d4b537

Please sign in to comment.