From f5eb4789a441e68d4258cbf7a7c8f336ef78054e Mon Sep 17 00:00:00 2001 From: LINGLONGQIAN <15869023990@163.com> Date: Tue, 28 May 2024 16:18:27 +0100 Subject: [PATCH 1/5] fix mismatch issue --- pypots/imputation/etsformer/model.py | 16 ++++++++-------- pypots/nn/modules/etsformer/layers.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pypots/imputation/etsformer/model.py b/pypots/imputation/etsformer/model.py index 6dbb2fbc..a46e45b6 100644 --- a/pypots/imputation/etsformer/model.py +++ b/pypots/imputation/etsformer/model.py @@ -103,14 +103,14 @@ class ETSformer(BaseNNImputer): def __init__( self, - n_steps, - n_features, - n_e_layers, - n_d_layers, - d_model, - n_heads, - d_ffn, - top_k, + n_steps: int, + n_features: int, + n_e_layers: int, + n_d_layers: int, + d_model: int, + n_heads: int, + d_ffn: int, + top_k: int, dropout: float = 0, ORT_weight: float = 1, MIT_weight: float = 1, diff --git a/pypots/nn/modules/etsformer/layers.py b/pypots/nn/modules/etsformer/layers.py index 60e44798..7fe446cf 100644 --- a/pypots/nn/modules/etsformer/layers.py +++ b/pypots/nn/modules/etsformer/layers.py @@ -160,7 +160,7 @@ def forward(self, x): f = fft.rfftfreq(t)[self.low_freq :] x_freq, index_tuple = self.topk_freq(x_freq) - f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)) + f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(x_freq.device) f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device) return self.extrapolate(x_freq, f, t) From ab63adbc7fd9b2a6c8e6fcebb0050266ff6583d6 Mon Sep 17 00:00:00 2001 From: LINGLONGQIAN <15869023990@163.com> Date: Tue, 28 May 2024 16:45:55 +0100 Subject: [PATCH 2/5] update --- pypots/imputation/fedformer/model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pypots/imputation/fedformer/model.py b/pypots/imputation/fedformer/model.py index 2d8ca073..5b0a8491 100644 --- a/pypots/imputation/fedformer/model.py +++ b/pypots/imputation/fedformer/model.py @@ -111,13 +111,13 @@ class FEDformer(BaseNNImputer): def __init__( self, - n_steps, - n_features, - n_layers, - d_model, - n_heads, - d_ffn, - moving_avg_window_size, + n_steps: int, + n_features: int, + n_layers: int, + d_model: int, + n_heads: int, + d_ffn: int, + moving_avg_window_size: int, dropout: float = 0, version="Fourier", modes=32, From 283b42cb830199266f666640512ab2c037361fdf Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 29 May 2024 00:33:52 +0800 Subject: [PATCH 3/5] feat: expose both imputation and classification GRUD; --- pypots/cli/tuning.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pypots/cli/tuning.py b/pypots/cli/tuning.py index c99eca9d..b12baeed 100644 --- a/pypots/cli/tuning.py +++ b/pypots/cli/tuning.py @@ -14,6 +14,7 @@ from .base import BaseCommand from .utils import load_package_from_path from ..classification import BRITS as BRITS_classification +from ..classification import GRUD as GRUD_classification from ..classification import Raindrop from ..clustering import CRLI, VaDER from ..data.saving.h5 import load_dict_from_h5 @@ -80,8 +81,9 @@ "pypots.imputation.GPVAE": GPVAE, "pypots.imputation.BRITS": BRITS, "pypots.imputation.MRNN": MRNN, + "pypots.imputation.GRUD": GRUD, # classification models - "pypots.classification.GRUD": GRUD, + "pypots.classification.GRUD": GRUD_classification, "pypots.classification.BRITS": BRITS_classification, "pypots.classification.Raindrop": Raindrop, # clustering models @@ -248,7 +250,7 @@ def run(self): ) raise RuntimeError( f"Hyperparameters do not match. Mismatched hyperparameters " - f"(in the tuning configuration but not in the given model's arguments): {list(mismatched)}" + f"(in the tuning configuration but not in {model_class.__name__}'s arguments): {list(mismatched)}" ) # initializing optimizer and model From bb80e474cec9a738c8725ae8d167eb9f0fa61edb Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 29 May 2024 00:34:10 +0800 Subject: [PATCH 4/5] feat: release v0.6; --- pypots/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypots/__init__.py b/pypots/__init__.py index 62da4b21..022de605 100644 --- a/pypots/__init__.py +++ b/pypots/__init__.py @@ -22,7 +22,7 @@ # # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' -__version__ = "0.5" +__version__ = "0.6" from . import imputation, classification, clustering, forecasting, optim, data, utils From c39beb2b9a15ef3d708f7d8ddb75398981441ae4 Mon Sep 17 00:00:00 2001 From: Wenjie Du Date: Wed, 29 May 2024 00:40:44 +0800 Subject: [PATCH 5/5] fix: release as v0.6 RC1 instead of v0.6; --- pypots/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pypots/__init__.py b/pypots/__init__.py index 022de605..5ab736d3 100644 --- a/pypots/__init__.py +++ b/pypots/__init__.py @@ -22,7 +22,7 @@ # # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' -__version__ = "0.6" +__version__ = "0.6rc1" from . import imputation, classification, clustering, forecasting, optim, data, utils