diff --git a/examples/foundation-model-examples/chronos/02_chronos_fine_tune.py b/examples/foundation-model-examples/chronos/02_chronos_fine_tune.py index 79d8fad..d4cbe75 100644 --- a/examples/foundation-model-examples/chronos/02_chronos_fine_tune.py +++ b/examples/foundation-model-examples/chronos/02_chronos_fine_tune.py @@ -22,7 +22,7 @@ catalog = "mmf" # Name of the catalog we use to manage our assets db = "m4" # Name of the schema we use to manage our assets (e.g. datasets) volume = "chronos_fine_tune" # Name of the volume we store the data and the weigts -chronos_model = "chronos-t5-tiny" # Chronos model to finetune. Alternatives: -mini, -small, -base, -large +model = "chronos-t5-tiny" # Chronos model to finetune. Alternatives: -mini, -small, -base, -large n = 1000 # Number of time series to sample # COMMAND ---------- @@ -89,7 +89,7 @@ def convert_to_arrow( start_times = list(df["ds"].apply(lambda x: x.min().to_numpy())) # Make sure that the volume exists. We stored the fine-tuned weights here. -_ = spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{db}.chronos_fine_tune") +_ = spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{db}.{volume}") # Convert to GluonTS arrow format and save it in UC Volume convert_to_arrow( @@ -105,7 +105,7 @@ def convert_to_arrow( # MAGIC # MAGIC In this example, we wil fine-tune `amazon/chronos-t5-tiny` for 1000 steps with initial learning rate of 1e-3. # MAGIC -# MAGIC Make sure that you have the configuration yaml files placed inside the `configs` folder and the `train.py` script in the same directory. These two assets are taken directly from [chronos-forecasting/scripts/training](https://github.com/amazon-science/chronos-forecasting/tree/main/scripts/training). They are subject to change as the Chronos' team develops the framework further. Keep your eyes on the latest changes (we will try too) and use the latest versions if needed. We have made a small change to our `train.py` script and set the frequency of the time series to daily ("D"). +# MAGIC Make sure that you have the configuration yaml files placed inside the `configs` folder and the `train.py` script in the same directory. These two assets are taken directly from [chronos-forecasting/scripts/training](https://github.com/amazon-science/chronos-forecasting/tree/main/scripts/training). They are subject to change as the Chronos' team develops the framework further. Keep your eyes on the latest changes (we will try too) and use the latest versions as needed. We have made a small change to our `train.py` script and set the frequency of the time series to daily ("D"). # MAGIC # MAGIC Inside the configuration yaml (for this example, `configs/chronos-t5-tiny.yaml`), make sure to set the parameters: # MAGIC - `training_data_paths` to `/Volumes/mmf/m4/chronos_fine_tune/data.arrow`, where your arrow converted file is stored @@ -168,7 +168,7 @@ def predict(self, context, input_data, params=None): files = os.listdir(f"/Volumes/{catalog}/{db}/{volume}/") runs = [int(file[4:]) for file in files if "run-" in file] latest_run = max(runs) -registered_model_name=f"{catalog}.{db}.{chronos_model}_finetuned" +registered_model_name=f"{catalog}.{db}.{model}_finetuned" weights = f"/Volumes/{catalog}/{db}/{volume}/run-{latest_run}/checkpoint-final/" # Get the model signature for registry @@ -195,7 +195,7 @@ def predict(self, context, input_data, params=None): # MAGIC %md # MAGIC ##Reload Model -# MAGIC We reload the model from the registry and perform forecasting on the in-training time series (for testing purpose). You can also go ahead and deploy this model behind a Model Serving's real-time endpoint. See the previous notebook: `01_chronos_load_inference` for more information. +# MAGIC We reload the model from the registry and perform forecasting on the in-training time series (for testing purpose). You can also go ahead and deploy this model behind a Model Serving's real-time endpoint. See the previous notebook: [`01_chronos_load_inference`](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/foundation-model-examples/chronos/01_chronos_load_inference.py) for more information. # COMMAND ---------- diff --git a/examples/foundation-model-examples/moirai/01_moirai_load_inference.py b/examples/foundation-model-examples/moirai/01_moirai_load_inference.py index 01a7e78..51b2967 100644 --- a/examples/foundation-model-examples/moirai/01_moirai_load_inference.py +++ b/examples/foundation-model-examples/moirai/01_moirai_load_inference.py @@ -117,7 +117,7 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: # COMMAND ---------- -moirai_model = "moirai-1.0-R-small" # Alternatibely moirai-1.0-R-base, moirai-1.0-R-large +model = "moirai-1.0-R-small" # Alternatibely moirai-1.0-R-base, moirai-1.0-R-large prediction_length = 10 # Time horizon for forecasting num_samples = 10 # Number of forecast to generate. We will take median as our final forecast. patch_size = 32 # Patch size: choose from {"auto", 8, 16, 32, 64, 128} @@ -129,7 +129,7 @@ def forecast_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: get_horizon_timestamps = create_get_horizon_timestamps(freq=freq, prediction_length=prediction_length) forecast_udf = create_forecast_udf( - repository=f"Salesforce/{moirai_model}", + repository=f"Salesforce/{model}", prediction_length=prediction_length, patch_size=patch_size, num_samples=num_samples, @@ -192,7 +192,7 @@ def predict(self, context, input_data, params=None): ) return np.median(forecast[0], axis=0) -pipeline = MoiraiModel(f"Salesforce/{moirai_model}") +pipeline = MoiraiModel(f"Salesforce/{model}") input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))]) output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))]) signature = ModelSignature(inputs=input_schema, outputs=output_schema) diff --git a/examples/foundation-model-examples/moirai/02_moirai_fine_tune.py b/examples/foundation-model-examples/moirai/02_moirai_fine_tune.py new file mode 100644 index 0000000..9e106b5 --- /dev/null +++ b/examples/foundation-model-examples/moirai/02_moirai_fine_tune.py @@ -0,0 +1,224 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC This is an example notebook that shows how to use [Moirai](https://github.com/SalesforceAIResearch/uni2ts) models on Databricks. +# MAGIC +# MAGIC The notebook loads, fine-tunes, and registers the model. + +# COMMAND ---------- + +# MAGIC %pip install git+https://github.com/SalesforceAIResearch/uni2ts.git --quiet +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Prepare Data +# MAGIC Make sure that the catalog and the schema already exist. + +# COMMAND ---------- + +catalog = "mmf" # Name of the catalog we use to manage our assets +db = "random" # Name of the schema we use to manage our assets (e.g. datasets) +volume = "moirai_fine_tune" # Name of the volume we store the data and the weigts +model = "moirai-1.0-R-small" # Alternatibely: moirai-1.0-R-base, moirai-1.0-R-large +n = 100 # Number of time series to sample + +# COMMAND ---------- + +# Make sure that the database exists. +_ = spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{db}") + +# Make sure that the volume exists. We stored the fine-tuned weights here. +_ = spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{db}.{volume}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We synthesize `n` number of time series (randomly sampled) of daily resolution and store it as a csv file in UC Volume. + +# COMMAND ---------- + +import pandas as pd +import numpy as np + +df_dict = {} + +for i in range(n): + + # Create a date range for the index + date_range = pd.date_range(start='2021-01-01', end='2023-12-31', freq='D') + + # Create a DataFrame with a date range index and two columns: 'item_id' and 'target' + df = pd.DataFrame({ + 'item_id': str(f"item_{i}"), + 'target': np.random.randn(len(date_range)) + }, index=date_range) + + # Set 'item_id' as the second level of the MultiIndex + df.set_index('item_id', append=True, inplace=True) + + # Sort the index + df.sort_index(inplace=True) + + df_dict[i] = df + + +pdf = pd.concat([df_dict[i] for i in range(n)]) +pdf.to_csv(f"/Volumes/{catalog}/{db}/{volume}/random.csv", index=True) +pdf + +# COMMAND ---------- + +# MAGIC %md +# MAGIC This dotenv file is needed to use the `uni2ts.data.builder.simple` function from the `uni2ts` library to build a dataset. + +# COMMAND ---------- + +import os +import site + +uni2ts = os.path.join(site.getsitepackages()[0], "uni2ts") +dotenv = os.path.join(uni2ts, ".env") +os.environ['DOTENV'] = dotenv +os.environ['CUSTOM_DATA_PATH'] = f"/Volumes/{catalog}/{db}/{volume}" + +# COMMAND ---------- + +# MAGIC %sh +# MAGIC rm -f $DOTENV +# MAGIC touch $DOTENV +# MAGIC echo "CUSTOM_DATA_PATH=$CUSTOM_DATA_PATH" >> $DOTENV + +# COMMAND ---------- + +# MAGIC %md +# MAGIC We convert the dataset into the Uni2TS format. `random` is the name of the dataset to use, which we load from our volume's location. See the [README](https://github.com/SalesforceAIResearch/uni2ts/tree/main?tab=readme-ov-file#fine-tuning) of Uni2TS for more information on the parameters. + +# COMMAND ---------- + +# MAGIC %sh python -m uni2ts.data.builder.simple random /Volumes/mmf/random/moirai_fine_tune/random.csv \ +# MAGIC --dataset_type long \ +# MAGIC --offset 640 + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ##Run Fine-tuning +# MAGIC +# MAGIC In this example, we wil fine-tune `moirai-1.0-R-small` for max 100 epochs with early stopping (can be specified here: `examples/foundation-model-examples/moirai/conf/finetune/default.yaml`). The learning rate is set to 1e-3, which you can modify in `examples/foundation-model-examples/moirai/conf/finetune/default.yaml`. +# MAGIC +# MAGIC Make sure that you have the configuration yaml files placed inside the `conf` folder and the `train.py` script in the same directory. These two assets are taken directly from and [cli/conf](https://github.com/SalesforceAIResearch/uni2ts/tree/main/cli/conf) and [cli/train.py](https://github.com/SalesforceAIResearch/uni2ts/blob/main/cli/train.py). They are subject to change as the Moirai' team develops the framework further. Keep your eyes on the latest changes (we will try too) and use the latest versions as needed. +# MAGIC +# MAGIC The key configuration files to be customized for you use case are `examples/foundation-model-examples/moirai/conf/finetune/default.yaml`, `examples/foundation-model-examples/moirai/conf/finetune/data/random.yaml` and `examples/foundation-model-examples/moirai/conf/finetune/val_data/random.yaml`. Refer to the Moirai [documentation](https://github.com/SalesforceAIResearch/uni2ts) for more detail. + +# COMMAND ---------- + +# MAGIC %sh python train.py \ +# MAGIC -cp conf/finetune \ +# MAGIC run_name=random_run \ +# MAGIC model=moirai_1.0_R_small \ +# MAGIC data=random \ +# MAGIC val_data=random + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ##Register Model +# MAGIC We get the fine-tuned weights from the run from the UC volume, wrap the pipeline with `mlflow.pyfunc.PythonModel` and register this on Unity Catalog. + +# COMMAND ---------- + +import mlflow +import torch +import numpy as np +from mlflow.models.signature import ModelSignature +from mlflow.types import DataType, Schema, TensorSpec +mlflow.set_registry_uri("databricks-uc") + + +class FineTunedMoiraiModel(mlflow.pyfunc.PythonModel): + def predict(self, context, input_data, params=None): + from einops import rearrange + from uni2ts.model.moirai import MoiraiForecast, MoiraiModule + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = MoiraiForecast.load_from_checkpoint( + prediction_length=10, + context_length=len(input_data), + patch_size=32, + num_samples=10, + target_dim=1, + feat_dynamic_real_dim=0, + past_feat_dynamic_real_dim=0, + checkpoint_path=context.artifacts["weights"], + ).to(device) + + # Time series values. Shape: (batch, time, variate) + past_target = rearrange( + torch.as_tensor(input_data, dtype=torch.float32), "t -> 1 t 1" + ) + # 1s if the value is observed, 0s otherwise. Shape: (batch, time, variate) + past_observed_target = torch.ones_like(past_target, dtype=torch.bool) + # 1s if the value is padding, 0s otherwise. Shape: (batch, time) + past_is_pad = torch.zeros_like(past_target, dtype=torch.bool).squeeze(-1) + forecast = model( + past_target=past_target.to(device), + past_observed_target=past_observed_target.to(device), + past_is_pad=past_is_pad.to(device), + ) + return np.median(forecast.cpu()[0], axis=0) + +input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))]) +output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))]) +signature = ModelSignature(inputs=input_schema, outputs=output_schema) +input_example = np.random.rand(52) +registered_model_name=f"{catalog}.{db}.moirai-1-r-small_finetuned" +weights = f"/Volumes/{catalog}/{db}/{volume}/outputs/moirai_1.0_R_small/random/random_run/checkpoints/epoch=0-step=100.ckpt" + + +with mlflow.start_run() as run: + mlflow.pyfunc.log_model( + "model", + python_model=FineTunedMoiraiModel(), + registered_model_name=registered_model_name, + artifacts={"weights": weights}, + signature=signature, + input_example=input_example, + pip_requirements=[ + "git+https://github.com/SalesforceAIResearch/uni2ts.git", + ], + ) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ##Reload Model +# MAGIC We reload the model from the registry and perform forecasting on a randomly generated time series (for testing purpose). You can also go ahead and deploy this model behind a Model Serving's real-time endpoint. See the previous notebook: [`01_moirai_load_inference`](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/examples/foundation-model-examples/chronos/02_moirai_load_inference.py) for more information. + +# COMMAND ---------- + +from mlflow import MlflowClient +client = MlflowClient() + +def get_latest_model_version(client, registered_model_name): + latest_version = 1 + for mv in client.search_model_versions(f"name='{registered_model_name}'"): + version_int = int(mv.version) + if version_int > latest_version: + latest_version = version_int + return latest_version + +model_version = get_latest_model_version(client, registered_model_name) +logged_model = f"models:/{registered_model_name}/{model_version}" + +# Load model as a PyFuncModel +loaded_model = mlflow.pyfunc.load_model(logged_model) + +# Create input data +input_data = np.random.rand(52) + +# Generate forecasts +loaded_model.predict(input_data) + +# COMMAND ---------- + + diff --git a/examples/foundation-model-examples/moirai/conf/finetune/data/etth1.yaml b/examples/foundation-model-examples/moirai/conf/finetune/data/etth1.yaml new file mode 100644 index 0000000..a5de611 --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/data/etth1.yaml @@ -0,0 +1,3 @@ +_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +dataset: ETTh1 +weight: 1000 \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/data/random.yaml b/examples/foundation-model-examples/moirai/conf/finetune/data/random.yaml new file mode 100644 index 0000000..ab8153e --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/data/random.yaml @@ -0,0 +1,3 @@ +_target_: uni2ts.data.builder.simple.SimpleDatasetBuilder +dataset: random +weight: 1000 \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/default.yaml b/examples/foundation-model-examples/moirai/conf/finetune/default.yaml new file mode 100644 index 0000000..d8b1abd --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/default.yaml @@ -0,0 +1,83 @@ +hydra: + run: + dir: /Volumes/mmf/random/moirai_fine_tune/outputs/${hydra:runtime.choices.model}/${hydra:runtime.choices.data}/${run_name} +defaults: + - model: ??? + - data: ??? + - val_data: null + - _self_ +run_name: ??? +seed: 0 +tf32: true +compile: false # set to mode: default, reduce-overhead, max-autotune +trainer: + _target_: lightning.Trainer + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 32 + logger: + _target_: lightning.pytorch.loggers.TensorBoardLogger + save_dir: ${hydra:runtime.output_dir} + name: logs + callbacks: + - _target_: lightning.pytorch.callbacks.LearningRateMonitor + logging_interval: epoch + - _target_: lightning.pytorch.callbacks.ModelCheckpoint + dirpath: ${hydra:runtime.output_dir}/checkpoints + monitor: val/PackedNLLLoss + save_weights_only: true + mode: min + save_top_k: 1 + every_n_epochs: 1 + - _target_: lightning.pytorch.callbacks.EarlyStopping + monitor: val/PackedNLLLoss + min_delta: 0.0 + patience: 3 + mode: min + strict: false + verbose: true + max_epochs: 100 + enable_progress_bar: true + accumulate_grad_batches: 1 + gradient_clip_val: 1.0 + gradient_clip_algorithm: norm +train_dataloader: + _target_: uni2ts.data.loader.DataLoader + batch_size: 128 + batch_size_factor: 2.0 + cycle: true + num_batches_per_epoch: 100 + shuffle: true + num_workers: 11 + collate_fn: + _target_: uni2ts.data.loader.PackCollate + max_length: ${model.module_kwargs.max_seq_len} + seq_fields: ${cls_getattr:${model._target_},seq_fields} + pad_func_map: ${cls_getattr:${model._target_},pad_func_map} + pin_memory: true + drop_last: false + fill_last: false + worker_init_fn: null + prefetch_factor: 2 + persistent_workers: true +val_dataloader: + _target_: uni2ts.data.loader.DataLoader + batch_size: 128 + batch_size_factor: 2.0 + cycle: false + num_batches_per_epoch: null + shuffle: false + num_workers: 11 + collate_fn: + _target_: uni2ts.data.loader.PackCollate + max_length: ${model.module_kwargs.max_seq_len} + seq_fields: ${cls_getattr:${model._target_},seq_fields} + pad_func_map: ${cls_getattr:${model._target_},pad_func_map} + pin_memory: false + drop_last: false + fill_last: true + worker_init_fn: null + prefetch_factor: 2 + persistent_workers: true \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_1.0_R_base.yaml b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_1.0_R_base.yaml new file mode 100644 index 0000000..96e5c4e --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_1.0_R_base.yaml @@ -0,0 +1,33 @@ +# load a pretrained checkpoint from huggingface hub +_target_: uni2ts.model.moirai.MoiraiFinetune +module: + _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained + pretrained_model_name_or_path: Salesforce/moirai-1.0-R-base +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 768 + num_layers: 12 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +lr: 1e-3 +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} +num_warmup_steps: 0 \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_1.0_R_large.yaml b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_1.0_R_large.yaml new file mode 100644 index 0000000..991ba8d --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_1.0_R_large.yaml @@ -0,0 +1,33 @@ +# load a pretrained checkpoint from huggingface hub +_target_: uni2ts.model.moirai.MoiraiFinetune +module: + _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained + pretrained_model_name_or_path: Salesforce/moirai-1.0-R-large +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 1024 + num_layers: 24 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +lr: 1e-3 +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} +num_warmup_steps: 0 \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_1.0_R_small.yaml b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_1.0_R_small.yaml new file mode 100644 index 0000000..9f799d7 --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_1.0_R_small.yaml @@ -0,0 +1,37 @@ +# load a pretrained checkpoint from huggingface hub +_target_: uni2ts.model.moirai.MoiraiFinetune +module: + _target_: uni2ts.model.moirai.MoiraiModule.from_pretrained + pretrained_model_name_or_path: Salesforce/moirai-1.0-R-small +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 384 + num_layers: 6 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +val_metric: + - _target_: uni2ts.loss.packed.PackedMSELoss + - _target_: uni2ts.loss.packed.PackedNRMSELoss + normalize: absolute_target_squared +lr: 1e-3 +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} +num_warmup_steps: 0 \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_base.yaml b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_base.yaml new file mode 100644 index 0000000..8962bb3 --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_base.yaml @@ -0,0 +1,31 @@ +# load a pytorch lightning checkpoint +_target_: uni2ts.model.moirai.MoiraiFinetune.load_from_checkpoint +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 768 + num_layers: 12 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +lr: 1e-3 +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} +num_warmup_steps: 0 +checkpoint_path: ... \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_large.yaml b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_large.yaml new file mode 100644 index 0000000..d52f67f --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_large.yaml @@ -0,0 +1,31 @@ +# load a pytorch lightning checkpoint +_target_: uni2ts.model.moirai.MoiraiFinetune.load_from_checkpoint +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 1024 + num_layers: 24 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +lr: 1e-3 +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} +num_warmup_steps: 0 +checkpoint_path: ... \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_small.yaml b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_small.yaml new file mode 100644 index 0000000..741e08f --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/model/moirai_small.yaml @@ -0,0 +1,35 @@ +# load a pytorch lightning checkpoint +_target_: uni2ts.model.moirai.MoiraiFinetune.load_from_checkpoint +module_kwargs: + _target_: builtins.dict + distr_output: + _target_: uni2ts.distribution.MixtureOutput + components: + - _target_: uni2ts.distribution.StudentTOutput + - _target_: uni2ts.distribution.NormalFixedScaleOutput + - _target_: uni2ts.distribution.NegativeBinomialOutput + - _target_: uni2ts.distribution.LogNormalOutput + d_model: 384 + num_layers: 6 + patch_sizes: ${as_tuple:[8, 16, 32, 64, 128]} + max_seq_len: 512 + attn_dropout_p: 0.0 + dropout_p: 0.0 + scaling: true +min_patches: 2 +min_mask_ratio: 0.15 +max_mask_ratio: 0.5 +max_dim: 128 +loss_func: + _target_: uni2ts.loss.packed.PackedNLLLoss +val_metric: + - _target_: uni2ts.loss.packed.PackedMSELoss + - _target_: uni2ts.loss.packed.PackedNRMSELoss + normalize: absolute_target_squared +lr: 1e-3 +weight_decay: 1e-1 +beta1: 0.9 +beta2: 0.98 +num_training_steps: ${mul:${trainer.max_epochs},${train_dataloader.num_batches_per_epoch}} +num_warmup_steps: 0 +checkpoint_path: ... \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/val_data/etth1.yaml b/examples/foundation-model-examples/moirai/conf/finetune/val_data/etth1.yaml new file mode 100644 index 0000000..00c462a --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/val_data/etth1.yaml @@ -0,0 +1,9 @@ +_target_: uni2ts.data.builder.ConcatDatasetBuilder +_args_: + _target_: uni2ts.data.builder.simple.generate_eval_builders + dataset: ETTh1_eval + offset: 11520 + eval_length: 2880 + prediction_lengths: [96, 192, 336, 720] + context_lengths: [1000, 2000, 3000, 4000, 5000] + patch_sizes: [32, 64] \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/val_data/etth1_multi.yaml b/examples/foundation-model-examples/moirai/conf/finetune/val_data/etth1_multi.yaml new file mode 100644 index 0000000..56e19ae --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/val_data/etth1_multi.yaml @@ -0,0 +1,16 @@ +- _target_: uni2ts.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh1_eval + offset: 11520 + windows: 10 + distance: 96 + prediction_length: 96 + context_length: 1000 + patch_size: 32 +- _target_: uni2ts.data.builder.simple.SimpleEvalDatasetBuilder + dataset: ETTh1_eval + offset: 11520 + windows: 10 + distance: 192 + prediction_length: 192 + context_length: 1000 + patch_size: 32 \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/conf/finetune/val_data/random.yaml b/examples/foundation-model-examples/moirai/conf/finetune/val_data/random.yaml new file mode 100644 index 0000000..261af17 --- /dev/null +++ b/examples/foundation-model-examples/moirai/conf/finetune/val_data/random.yaml @@ -0,0 +1,9 @@ +_target_: uni2ts.data.builder.ConcatDatasetBuilder +_args_: + _target_: uni2ts.data.builder.simple.generate_eval_builders + dataset: random_eval + offset: 273 + eval_length: 10 + prediction_lengths: [10] + context_lengths: [270] + patch_sizes: [32, 64] \ No newline at end of file diff --git a/examples/foundation-model-examples/moirai/train.py b/examples/foundation-model-examples/moirai/train.py new file mode 100644 index 0000000..ee2aa3b --- /dev/null +++ b/examples/foundation-model-examples/moirai/train.py @@ -0,0 +1,149 @@ +# Copyright (c) 2024, Salesforce, Inc. +# SPDX-License-Identifier: Apache-2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Callable, Optional + +import hydra +import lightning as L +import torch +from hydra.utils import instantiate +from omegaconf import DictConfig +from torch.utils._pytree import tree_map +from torch.utils.data import Dataset, DistributedSampler + +from uni2ts.common import hydra_util # noqa: hydra resolvers +from uni2ts.data.loader import DataLoader + + +class DataModule(L.LightningDataModule): + def __init__( + self, + cfg: DictConfig, + train_dataset: Dataset, + val_dataset: Optional[Dataset | list[Dataset]], + ): + super().__init__() + self.cfg = cfg + self.train_dataset = train_dataset + + if val_dataset is not None: + self.val_dataset = val_dataset + self.val_dataloader = self._val_dataloader + + @staticmethod + def get_dataloader( + dataset: Dataset, + dataloader_func: Callable[..., DataLoader], + shuffle: bool, + world_size: int, + batch_size: int, + num_batches_per_epoch: Optional[int] = None, + ) -> DataLoader: + sampler = ( + DistributedSampler( + dataset, + num_replicas=None, + rank=None, + shuffle=shuffle, + seed=0, + drop_last=False, + ) + if world_size > 1 + else None + ) + return dataloader_func( + dataset=dataset, + shuffle=shuffle if sampler is None else None, + sampler=sampler, + batch_size=batch_size, + num_batches_per_epoch=num_batches_per_epoch, + ) + + def train_dataloader(self) -> DataLoader: + return self.get_dataloader( + self.train_dataset, + instantiate(self.cfg.train_dataloader, _partial_=True), + self.cfg.train_dataloader.shuffle, + self.trainer.world_size, + self.train_batch_size, + num_batches_per_epoch=self.train_num_batches_per_epoch, + ) + + def _val_dataloader(self) -> DataLoader | list[DataLoader]: + return tree_map( + partial( + self.get_dataloader, + dataloader_func=instantiate(self.cfg.val_dataloader, _partial_=True), + shuffle=self.cfg.val_dataloader.shuffle, + world_size=self.trainer.world_size, + batch_size=self.val_batch_size, + num_batches_per_epoch=None, + ), + self.val_dataset, + ) + + @property + def train_batch_size(self) -> int: + return self.cfg.train_dataloader.batch_size // ( + self.trainer.world_size * self.trainer.accumulate_grad_batches + ) + + @property + def val_batch_size(self) -> int: + return self.cfg.val_dataloader.batch_size // ( + self.trainer.world_size * self.trainer.accumulate_grad_batches + ) + + @property + def train_num_batches_per_epoch(self) -> int: + return ( + self.cfg.train_dataloader.num_batches_per_epoch + * self.trainer.accumulate_grad_batches + ) + + +@hydra.main(version_base="1.3", config_name="default.yaml") +def main(cfg: DictConfig): + if cfg.tf32: + assert cfg.trainer.precision == 32 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + model: L.LightningModule = instantiate(cfg.model, _convert_="all") + + if cfg.compile: + model.module.compile(mode=cfg.compile) + trainer: L.Trainer = instantiate(cfg.trainer) + train_dataset: Dataset = instantiate(cfg.data).load_dataset( + model.train_transform_map + ) + val_dataset: Optional[Dataset | list[Dataset]] = ( + tree_map( + lambda ds: ds.load_dataset(model.val_transform_map), + instantiate(cfg.val_data, _convert_="all"), + ) + if "val_data" in cfg + else None + ) + L.seed_everything(cfg.seed + trainer.logger.version, workers=True) + trainer.fit( + model, + datamodule=DataModule(cfg, train_dataset, val_dataset), + ) + + +if __name__ == "__main__": + main()