Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: TimeXer #1267

Merged
merged 10 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 88 additions & 14 deletions nbs/common.modules.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"source": [
"#| export\n",
"import math\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
Expand Down Expand Up @@ -409,40 +410,93 @@
"source": [
"#| export\n",
"class AttentionLayer(nn.Module):\n",
" def __init__(self, attention, hidden_size, n_head, d_keys=None,\n",
" def __init__(self, attention, hidden_size, n_heads, d_keys=None,\n",
" d_values=None):\n",
" super(AttentionLayer, self).__init__()\n",
"\n",
" d_keys = d_keys or (hidden_size // n_head)\n",
" d_values = d_values or (hidden_size // n_head)\n",
" d_keys = d_keys or (hidden_size // n_heads)\n",
" d_values = d_values or (hidden_size // n_heads)\n",
"\n",
" self.inner_attention = attention\n",
" self.query_projection = nn.Linear(hidden_size, d_keys * n_head)\n",
" self.key_projection = nn.Linear(hidden_size, d_keys * n_head)\n",
" self.value_projection = nn.Linear(hidden_size, d_values * n_head)\n",
" self.out_projection = nn.Linear(d_values * n_head, hidden_size)\n",
" self.n_head = n_head\n",
" self.query_projection = nn.Linear(hidden_size, d_keys * n_heads)\n",
" self.key_projection = nn.Linear(hidden_size, d_keys * n_heads)\n",
" self.value_projection = nn.Linear(hidden_size, d_values * n_heads)\n",
" self.out_projection = nn.Linear(d_values * n_heads, hidden_size)\n",
" self.n_heads = n_heads\n",
"\n",
" def forward(self, queries, keys, values, attn_mask):\n",
" def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n",
" B, L, _ = queries.shape\n",
" _, S, _ = keys.shape\n",
" H = self.n_head\n",
" H = self.n_heads\n",
"\n",
" queries = self.query_projection(queries).view(B, L, H, -1)\n",
" keys = self.key_projection(keys).view(B, S, H, -1)\n",
" values = self.value_projection(values).view(B, S, H, -1)\n",
"\n",
" out, attn = self.inner_attention(\n",
" queries,\n",
" keys,\n",
" values,\n",
" attn_mask\n",
" queries=queries,\n",
" keys=keys,\n",
" values=values,\n",
" attn_mask=attn_mask,\n",
" tau=tau,\n",
" delta=delta\n",
" )\n",
" out = out.view(B, L, -1)\n",
"\n",
" return self.out_projection(out), attn"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"\n",
"class TriangularCausalMask():\n",
" \"\"\"\n",
" TriangularCausalMask\n",
" \"\"\" \n",
" def __init__(self, B, L, device=\"cpu\"):\n",
" mask_shape = [B, 1, L, L]\n",
" with torch.no_grad():\n",
" self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)\n",
"\n",
" @property\n",
" def mask(self):\n",
" return self._mask\n",
"\n",
"class FullAttention(nn.Module):\n",
" def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):\n",
" super(FullAttention, self).__init__()\n",
" self.scale = scale\n",
" self.mask_flag = mask_flag\n",
" self.output_attention = output_attention\n",
" self.dropout = nn.Dropout(attention_dropout)\n",
"\n",
" def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n",
" B, L, H, E = queries.shape\n",
" _, S, _, D = values.shape\n",
" scale = self.scale or 1. / math.sqrt(E)\n",
"\n",
" scores = torch.einsum(\"blhe,bshe->bhls\", queries, keys)\n",
"\n",
" if self.mask_flag:\n",
" if attn_mask is None:\n",
" attn_mask = TriangularCausalMask(B, L, device=queries.device)\n",
"\n",
" scores.masked_fill_(attn_mask.mask, -np.inf)\n",
"\n",
" A = self.dropout(torch.softmax(scale * scores, dim=-1))\n",
" V = torch.einsum(\"bhls,bshd->blhd\", A, values)\n",
"\n",
" if self.output_attention:\n",
" return V.contiguous(), A\n",
" else:\n",
" return V.contiguous(), None "
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -570,6 +624,26 @@
" if self.temporal_embedding is not None:\n",
" x = x + self.temporal_embedding(x_mark) \n",
"\n",
" return self.dropout(x)\n",
"\n",
"class DataEmbedding_inverted(nn.Module):\n",
" \"\"\"\n",
" DataEmbedding_inverted\n",
" \"\"\" \n",
" def __init__(self, c_in, hidden_size, dropout=0.1):\n",
" super(DataEmbedding_inverted, self).__init__()\n",
" self.value_embedding = nn.Linear(c_in, hidden_size)\n",
" self.dropout = nn.Dropout(p=dropout)\n",
"\n",
" def forward(self, x, x_mark):\n",
" x = x.permute(0, 2, 1)\n",
" # x: [Batch Variate Time]\n",
" if x_mark is None:\n",
" x = self.value_embedding(x)\n",
" else:\n",
" # the potential to take covariates (e.g. timestamps) as tokens\n",
" x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) \n",
" # x: [Batch Variate hidden_size]\n",
" return self.dropout(x)"
]
},
Expand Down
5 changes: 3 additions & 2 deletions nbs/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
" StemGNN, PatchTST, TimesNet, TimeLLM, TSMixer, TSMixerx,\n",
" MLPMultivariate, iTransformer,\n",
" BiTCN, TiDE, DeepNPTS, SOFTS,\n",
" TimeMixer, KAN, RMoK\n",
" TimeMixer, KAN, RMoK, TimeXer\n",
")\n",
"from neuralforecast.common._base_auto import BaseAuto, MockTrial\n",
"from neuralforecast.utils import PredictionIntervals, get_prediction_interval_method"
Expand Down Expand Up @@ -247,7 +247,8 @@
" 'softs': SOFTS, 'autosofts': SOFTS,\n",
" 'timemixer': TimeMixer, 'autotimemixer': TimeMixer,\n",
" 'kan': KAN, 'autokan': KAN,\n",
" 'rmok': RMoK, 'autormok': RMoK\n",
" 'rmok': RMoK, 'autormok': RMoK,\n",
" 'timexer': TimeXer, 'autotimexer': TimeXer\n",
"}"
]
},
Expand Down
3 changes: 2 additions & 1 deletion nbs/docs/capabilities/01_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
"|`TiDE` | `AutoTiDE` | MLP | Univariate | Direct | F/H/S | \n",
"|`TimeMixer` | `AutoTimeMixer` | MLP | Multivariate | Direct | - | \n",
"|`TimeLLM` | - | LLM | Univariate | Direct | - | \n",
"|`TimesNet` | `AutoTimesNet` | CNN | Univariate | Direct | F | \n",
"|`TimesNet` | `AutoTimesNet` | CNN | Univariate | Direct | F |\n",
"|`TimeXer` | `AutoTimeXer` | Transformer | Multivariate | Direct | F | \n",
"|`TSMixer` | `AutoTSMixer` | MLP | Multivariate | Direct | - | \n",
"|`TSMixerx` | `AutoTSMixerx` | MLP | Multivariate | Direct | F/H/S | \n",
"|`VanillaTransformer` | `AutoVanillaTransformer` | Transformer | Univariate | Direct | F | \n",
Expand Down
Binary file added nbs/imgs_models/timexer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions nbs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
"models.timellm.html",
"models.timemixer.html",
"models.timesnet.html",
"models.timexer.html",
"models.tsmixer.html",
"models.tsmixerx.html",
"models.vanillatransformer.html"
Expand Down
2 changes: 1 addition & 1 deletion nbs/models.informer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
" else:\n",
" return (context_in, None)\n",
"\n",
" def forward(self, queries, keys, values, attn_mask):\n",
" def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):\n",
" B, L_Q, H, D = queries.shape\n",
" _, L_K, _, _ = keys.shape\n",
"\n",
Expand Down
152 changes: 152 additions & 0 deletions nbs/models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"from neuralforecast.models.patchtst import PatchTST\n",
"from neuralforecast.models.timesnet import TimesNet\n",
"from neuralforecast.models.itransformer import iTransformer\n",
"from neuralforecast.models.timexer import TimeXer\n",
"\n",
"from neuralforecast.models.kan import KAN\n",
"from neuralforecast.models.rmok import RMoK\n",
Expand Down Expand Up @@ -3430,6 +3431,157 @@
"model.fit(dataset=dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34660732",
"metadata": {},
"outputs": [],
"source": [
"#| export\n",
"class AutoTimeXer(BaseAuto):\n",
"\n",
" default_config = {\n",
" \"input_size_multiplier\": [1, 2, 3, 4, 5],\n",
" \"h\": None,\n",
" \"n_series\": None,\n",
" \"hidden_size\": tune.choice([128, 256, 512]),\n",
" \"n_heads\": tune.choice([4, 8]),\n",
" \"learning_rate\": tune.loguniform(1e-4, 1e-1),\n",
" \"scaler_type\": tune.choice([None, 'robust', 'standard']),\n",
" \"max_steps\": tune.choice([500, 1000, 2000]),\n",
" \"batch_size\": tune.choice([32, 64, 128, 256]),\n",
" \"loss\": None,\n",
" \"random_seed\": tune.randint(1, 20),\n",
" }\n",
"\n",
" def __init__(self,\n",
" h,\n",
" n_series,\n",
" loss=MAE(),\n",
" valid_loss=None,\n",
" config=None, \n",
" search_alg=BasicVariantGenerator(random_state=1),\n",
" num_samples=10,\n",
" refit_with_val=False,\n",
" cpus=cpu_count(),\n",
" gpus=torch.cuda.device_count(),\n",
" verbose=False,\n",
" alias=None,\n",
" backend='ray',\n",
" callbacks=None):\n",
" \n",
" # Define search space, input/output sizes\n",
" if config is None:\n",
" config = self.get_default_config(h=h, backend=backend, n_series=n_series) \n",
"\n",
" # Always use n_series from parameters, raise exception with Optuna because we can't enforce it\n",
" if backend == 'ray':\n",
" config['n_series'] = n_series\n",
" elif backend == 'optuna':\n",
" mock_trial = MockTrial()\n",
" if ('n_series' in config(mock_trial) and config(mock_trial)['n_series'] != n_series) or ('n_series' not in config(mock_trial)):\n",
" raise Exception(f\"config needs 'n_series': {n_series}\") \n",
"\n",
" super(AutoTimeXer, self).__init__(\n",
" cls_model=TimeXer, \n",
" h=h,\n",
" loss=loss,\n",
" valid_loss=valid_loss,\n",
" config=config,\n",
" search_alg=search_alg,\n",
" num_samples=num_samples, \n",
" refit_with_val=refit_with_val,\n",
" cpus=cpus,\n",
" gpus=gpus,\n",
" verbose=verbose,\n",
" alias=alias,\n",
" backend=backend,\n",
" callbacks=callbacks, \n",
" )\n",
"\n",
" @classmethod\n",
" def get_default_config(cls, h, backend, n_series):\n",
" config = cls.default_config.copy() \n",
" config['input_size'] = tune.choice([h * x \\\n",
" for x in config[\"input_size_multiplier\"]])\n",
"\n",
" # Rolling windows with step_size=1 or step_size=h\n",
" # See `BaseWindows` and `BaseRNN`'s create_windows\n",
" config['step_size'] = tune.choice([1, h])\n",
" del config[\"input_size_multiplier\"]\n",
" if backend == 'optuna':\n",
" # Always use n_series from parameters\n",
" config['n_series'] = n_series\n",
" config = cls._ray_config_to_optuna(config) \n",
"\n",
" return config "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de761efc",
"metadata": {},
"outputs": [],
"source": [
"show_doc(AutoTimeXer, title_level=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f08f23a2",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"# Use your own config or AutoTimeXer.default_config\n",
"config = dict(max_steps=1, val_check_steps=1, input_size=12, patch_len=12)\n",
"model = AutoTimeXer(h=12, n_series=1, config=config, num_samples=1, cpus=1)\n",
"\n",
"# Fit and predict\n",
"model.fit(dataset=dataset)\n",
"y_hat = model.predict(dataset=dataset)\n",
"\n",
"# Optuna\n",
"model = AutoTimeXer(h=12, n_series=1, config=None, backend='optuna')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8488c991",
"metadata": {},
"outputs": [],
"source": [
"#| hide\n",
"# Check Optuna\n",
"assert model.config(MockTrial())['h'] == 12\n",
"\n",
"# Unit test to test that Auto* model contains all required arguments from BaseAuto\n",
"test_args(AutoTimeXer, exclude_args=['cls_model']) \n",
"\n",
"# Unit test for situation: Optuna with updated default config\n",
"my_config = AutoTimeXer.get_default_config(h=12, n_series=1, backend='optuna')\n",
"def my_config_new(trial):\n",
" config = {**my_config(trial)}\n",
" config.update({'max_steps': 1, 'val_check_steps': 1, 'input_size': 12, 'patch_len': 12})\n",
" return config\n",
"\n",
"model = AutoTimeXer(h=12, n_series=1, config=my_config_new, backend='optuna', num_samples=1, cpus=1)\n",
"model.fit(dataset=dataset)\n",
"\n",
"# Unit test for situation: Ray with updated default config\n",
"my_config = AutoTimeXer.get_default_config(h=12, n_series=1, backend='ray')\n",
"my_config['max_steps'] = 1\n",
"my_config['val_check_steps'] = 1\n",
"my_config['input_size'] = 12\n",
"my_config['patch_len'] = 12\n",
"model = AutoTimeXer(h=12, n_series=1, config=my_config, backend='ray', num_samples=1, cpus=1)\n",
"model.fit(dataset=dataset)"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down
Loading