From 76d3ef64c46ccf8acfa7af90db54c65fad289d8c Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Thu, 14 Nov 2024 16:15:10 +0000 Subject: [PATCH] fix: Allow updates to Scalars, without exclusion and more(#137) * fix: Allow updates to scalars * Add 'add' & 'update' * Add without & without_by_dim * Rework loss functions to use without - Allow limiting of scalars rather than turning off * Rework feature_indices to scalar_indices * Remove without in validation_metrics --- CHANGELOG.md | 2 + src/anemoi/training/losses/huber.py | 17 +- src/anemoi/training/losses/logcosh.py | 17 +- src/anemoi/training/losses/mae.py | 18 +- src/anemoi/training/losses/mse.py | 17 +- src/anemoi/training/losses/rmse.py | 17 +- src/anemoi/training/losses/utils.py | 164 +++++++++++++++--- src/anemoi/training/losses/weightedloss.py | 80 ++++++--- src/anemoi/training/train/forecaster.py | 3 +- .../train/{test_scaler.py => test_scalar.py} | 65 +++++++ 10 files changed, 310 insertions(+), 90 deletions(-) rename tests/train/{test_scaler.py => test_scalar.py} (66%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 20ec42a8..21ef6ff3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ Keep it human-readable, your future self will thank you! ### Added - Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) +- Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137) + - Add without subsetting in ScaleTensor - Sub-hour datasets [#63](https://github.com/ecmwf/anemoi-training/pull/63) - Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) - Feat: Anemoi Profiler compatible with mlflow and using Pytorch (Kineto) Profiler for memory report [38](https://github.com/ecmwf/anemoi-training/pull/38/) diff --git a/src/anemoi/training/losses/huber.py b/src/anemoi/training/losses/huber.py index e42105c7..ed5b8d25 100644 --- a/src/anemoi/training/losses/huber.py +++ b/src/anemoi/training/losses/huber.py @@ -73,8 +73,8 @@ def forward( pred: torch.Tensor, target: torch.Tensor, squash: bool = True, - feature_indices: torch.Tensor | None = None, - feature_scale: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: """Calculates the lat-weighted Huber loss. @@ -86,10 +86,11 @@ def forward( Target tensor, shape (bs, ensemble, lat*lon, n_outputs) squash : bool, optional Average last dimension, by default True - feature_indices: - feature indices (relative to full model output) of the features passed in pred and target - feature_scale: - If True, scale the loss by the feature_weights + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None Returns ------- @@ -98,6 +99,6 @@ def forward( """ out = self.huber(pred, target) - if feature_scale: - out = self.scale_by_variable_scaling(out, feature_indices) + out = self.scale(out, scalar_indices, without_scalars=without_scalars) + return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/losses/logcosh.py b/src/anemoi/training/losses/logcosh.py index 6112d472..6f916177 100644 --- a/src/anemoi/training/losses/logcosh.py +++ b/src/anemoi/training/losses/logcosh.py @@ -67,8 +67,8 @@ def forward( pred: torch.Tensor, target: torch.Tensor, squash: bool = True, - feature_indices: torch.Tensor | None = None, - feature_scale: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: """Calculates the lat-weighted LogCosh loss. @@ -80,10 +80,11 @@ def forward( Target tensor, shape (bs, ensemble, lat*lon, n_outputs) squash : bool, optional Average last dimension, by default True - feature_indices: - feature indices (relative to full model output) of the features passed in pred and target - feature_scale: - If True, scale the loss by the feature_weights + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None Returns ------- @@ -92,7 +93,5 @@ def forward( """ out = LogCosh.apply(pred - target) - - if feature_scale: - out = self.scale(out, feature_indices) + out = self.scale(out, scalar_indices, without_scalars=without_scalars) return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/losses/mae.py b/src/anemoi/training/losses/mae.py index bea16ac2..b2112d98 100644 --- a/src/anemoi/training/losses/mae.py +++ b/src/anemoi/training/losses/mae.py @@ -53,8 +53,8 @@ def forward( pred: torch.Tensor, target: torch.Tensor, squash: bool = True, - feature_indices: torch.Tensor | None = None, - feature_scale: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: """Calculates the lat-weighted MAE loss. @@ -66,10 +66,12 @@ def forward( Target tensor, shape (bs, ensemble, lat*lon, n_outputs) squash : bool, optional Average last dimension, by default True - feature_indices: - feature indices (relative to full model output) of the features passed in pred and target - feature_scale: - If True, scale the loss by the feature_weights + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None + Returns ------- @@ -77,7 +79,5 @@ def forward( Weighted MAE loss """ out = torch.abs(pred - target) - - if feature_scale: - out = self.scale(out, feature_indices) + out = self.scale(out, scalar_indices, without_scalars=without_scalars) return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/losses/mse.py b/src/anemoi/training/losses/mse.py index 87365f8c..c30f8b9d 100644 --- a/src/anemoi/training/losses/mse.py +++ b/src/anemoi/training/losses/mse.py @@ -51,8 +51,8 @@ def forward( pred: torch.Tensor, target: torch.Tensor, squash: bool = True, - feature_indices: torch.Tensor | None = None, - feature_scale: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: """Calculates the lat-weighted MSE loss. @@ -64,10 +64,11 @@ def forward( Target tensor, shape (bs, ensemble, lat*lon, n_outputs) squash : bool, optional Average last dimension, by default True - feature_indices: - feature indices (relative to full model output) of the features passed in pred and target - feature_scale: - If True, scale the loss by the feature_weights + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None Returns ------- @@ -75,7 +76,5 @@ def forward( Weighted MSE loss """ out = torch.square(pred - target) - - if feature_scale: - out = self.scale(out, feature_indices) + out = self.scale(out, scalar_indices, without_scalars=without_scalars) return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/losses/rmse.py b/src/anemoi/training/losses/rmse.py index 34b913a9..6c97344a 100644 --- a/src/anemoi/training/losses/rmse.py +++ b/src/anemoi/training/losses/rmse.py @@ -50,8 +50,8 @@ def forward( pred: torch.Tensor, target: torch.Tensor, squash: bool = True, - feature_indices: torch.Tensor | None = None, - feature_scale: bool = True, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: """Calculates the lat-weighted RMSE loss. @@ -63,10 +63,11 @@ def forward( Target tensor, shape (bs, ensemble, lat*lon, n_outputs) squash : bool, optional Average last dimension, by default True - feature_indices: - feature indices (relative to full model output) of the features passed in pred and target - feature_scale: - If True, scale the loss by the feature_weights + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None Returns ------- @@ -77,7 +78,7 @@ def forward( pred=pred, target=target, squash=squash, - feature_indices=feature_indices, - feature_scale=feature_scale, + scalar_indices=scalar_indices, + without_scalars=without_scalars, ) return torch.sqrt(mse) diff --git a/src/anemoi/training/losses/utils.py b/src/anemoi/training/losses/utils.py index 5e7c98a8..e98e0bfe 100644 --- a/src/anemoi/training/losses/utils.py +++ b/src/anemoi/training/losses/utils.py @@ -74,6 +74,7 @@ def __getitem__(self, dimension: int) -> int: return self.func(dimension) +# TODO(Harrison Cook): Consider moving this to subclass from a pytorch object and allow for device moving completely class ScaleTensor: """Dynamically resolved tensor scaling class. @@ -99,7 +100,7 @@ class ScaleTensor: """ tensors: dict[str, TENSOR_SPEC] - _specified_dimensions: list[tuple[int]] + _specified_dimensions: dict[str, tuple[int]] def __init__( self, @@ -120,13 +121,10 @@ def __init__( Kwargs form of {name: (dimension, tensor)} to add to the scalars """ self.tensors = {} - self._specified_dimensions = [] + self._specified_dimensions = {} - scalars = scalars or {} - scalars.update(named_tensors) - - for name, tensor_spec in scalars.items(): - self.add_scalar(*tensor_spec, name=name) + named_tensors.update(scalars or {}) + self.add(named_tensors) for tensor_spec in tensors: self.add_scalar(*tensor_spec) @@ -144,8 +142,10 @@ def get_dim_shape(dimension: int) -> int: if isinstance(dim_assign, tuple) and dimension in dim_assign: return tensor.shape[list(dim_assign).index(dimension)] + unique_dims = {dim for dim_assign in self._specified_dimensions.values() for dim in dim_assign} error_msg = ( - f"Could not find shape of dimension {dimension} with tensors in dims {list(self.tensors.keys())}" + f"Could not find shape of dimension {dimension}. " + f"Tensors are only specified for dimensions {list(unique_dims)}." ) raise IndexError(error_msg) @@ -175,8 +175,8 @@ def validate_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor) -> if self.shape[dim] != scalar.shape[scalar_dim]: error_msg = ( - f"Scalar shape {scalar.shape} at dimension {scalar_dim}" - f"does not match shape of scalar at dimension {dim}. Expected {self.shape[dim]}", + f"Incoming scalar shape {scalar.shape} at dimension {scalar_dim} " + f"does not match shape of saved scalar. Expected {self.shape[dim]}" ) raise ValueError(error_msg) @@ -190,7 +190,7 @@ def add_scalar( """Add new scalar to be applied along `dimension`. Dimension can be a single int even for a multi-dimensional scalar, - in this case the dimensions are assigned as a range from the given int. + in this case the dimensions are assigned as a range starting from the given int. Negative indexes are also valid, and will be resolved against the tensor's ndim. Parameters @@ -210,6 +210,15 @@ def add_scalar( dimension = (dimension,) else: dimension = tuple(dimension + i for i in range(len(scalar.shape))) + else: + dimension = tuple(dimension) + + if name is None: + name = str(uuid.uuid4()) + + if name in self.tensors: + msg = f"Scalar {name!r} already exists in scalars." + raise ValueError(msg) try: self.validate_scalar(dimension, scalar) @@ -217,15 +226,79 @@ def add_scalar( error_msg = f"Validating tensor {name!r} raised an error." raise ValueError(error_msg) from e - if name is None: - name = str(uuid.uuid4()) + self.tensors[name] = (dimension, scalar) + self._specified_dimensions[name] = dimension - if name in self.tensors: - self._specified_dimensions.remove(self.tensors[name][0]) - self.tensors[name] = (dimension, self.tensors[name][1] * scalar) + def update_scalar(self, name: str, scalar: torch.Tensor, *, override: bool = False) -> None: + """Update an existing scalar maintaining original dimensions. + + If `override` is False, the scalar must be valid against the original dimensions. + If `override` is True, the scalar will be updated regardless of validity against original scalar. + + Parameters + ---------- + name : str + Name of the scalar to update + scalar : torch.Tensor + New scalar tensor + override : bool, optional + Whether to override the scalar ignoring dimension compatibility, by default False + """ + if name not in self.tensors: + msg = f"Scalar {name!r} not found in scalars." + raise ValueError(msg) + + dimension = self.tensors[name][0] + + if not override: + self.validate_scalar(dimension, scalar) + + original_scalar = self.tensors.pop(name) + original_dimension = self._specified_dimensions.pop(name) + + try: + self.add_scalar(dimension, scalar, name=name) + except ValueError: + self.tensors[name] = original_scalar + self._specified_dimensions[name] = original_dimension + raise + + def add(self, new_scalars: dict[str, TENSOR_SPEC] | list[TENSOR_SPEC] | None = None, **kwargs) -> None: + """Add multiple scalars to the existing scalars. + + Parameters + ---------- + new_scalars : dict[str, TENSOR_SPEC] | list[TENSOR_SPEC] | None, optional + Scalars to add, see `add_scalar` for more info, by default None + **kwargs: + Kwargs form of {name: (dimension, tensor)} to add to the scalars + """ + if isinstance(new_scalars, list): + for tensor_spec in new_scalars: + self.add_scalar(*tensor_spec) else: - self.tensors[name] = (dimension, scalar) - self._specified_dimensions.append(dimension) + kwargs.update(new_scalars or {}) + for name, tensor_spec in kwargs.items(): + self.add_scalar(*tensor_spec, name=name) + + def update(self, updated_scalars: dict[str, torch.Tensor] | None = None, override: bool = False, **kwargs) -> None: + """Update multiple scalars in the existing scalars. + + If `override` is False, the scalar must be valid against the original dimensions. + If `override` is True, the scalar will be updated regardless of shape. + + Parameters + ---------- + updated_scalars : dict[str, torch.Tensor] | None, optional + Scalars to update, referenced by name, by default None + override : bool, optional + Whether to override the scalar ignoring dimension compatibility, by default False + **kwargs: + Kwargs form of {name: tensor} to update in the scalars + """ + kwargs.update(updated_scalars or {}) + for name, tensor in kwargs.items(): + self.update_scalar(name, tensor, override=override) def subset(self, scalars: str | Sequence[str]) -> ScaleTensor: """Get subset of the scalars, filtering by name. @@ -246,6 +319,23 @@ def subset(self, scalars: str | Sequence[str]) -> ScaleTensor: scalars = [scalars] return ScaleTensor(**{name: self.tensors[name] for name in scalars}) + def without(self, scalars: str | Sequence[str]) -> ScaleTensor: + """Get subset of the scalars, filtering out by name. + + Parameters + ---------- + scalars : str | Sequence[str] + Name/s of the scalars to exclude + + Returns + ------- + ScaleTensor + Subset of self + """ + if isinstance(scalars, str): + scalars = [scalars] + return ScaleTensor(**{name: tensor for name, tensor in self.tensors.items() if name not in scalars}) + def subset_by_dim(self, dimensions: int | Sequence[int]) -> ScaleTensor: """Get subset of the scalars, filtering by dimension. @@ -274,6 +364,32 @@ def subset_by_dim(self, dimensions: int | Sequence[int]) -> ScaleTensor: return ScaleTensor(**subset_scalars) + def without_by_dim(self, dimensions: int | Sequence[int]) -> ScaleTensor: + """Get subset of the scalars, filtering out by dimension. + + Parameters + ---------- + dimensions : int | Sequence[int] + Dimensions to exclude scalars of + + Returns + ------- + ScaleTensor + Subset of self + """ + subset_scalars: dict[str, TENSOR_SPEC] = {} + + if isinstance(dimensions, int): + dimensions = (dimensions,) + + for name, (dim, scalar) in self.tensors.items(): + if isinstance(dim, int): + dim = (dim,) + if len(set(dimensions).intersection(dim)) == 0: + subset_scalars[name] = (dim, scalar) + + return ScaleTensor(**subset_scalars) + def resolve(self, ndim: int) -> ScaleTensor: """Resolve relative indexes in scalars by associating against ndim. @@ -313,7 +429,7 @@ def scale(self, tensor: torch.Tensor) -> torch.Tensor: torch.Tensor Scaled tensor """ - return tensor * self.get_scalar(tensor.ndim) + return tensor * self.get_scalar(tensor.ndim, device=tensor.device) def get_scalar(self, ndim: int, device: str | None = None) -> torch.Tensor: """Get completely resolved scalar tensor. @@ -364,13 +480,19 @@ def to(self, *args, **kwargs) -> None: def __mul__(self, tensor: torch.Tensor) -> torch.Tensor: return self.scale(tensor) + def __rmul__(self, tensor: torch.Tensor) -> torch.Tensor: + return self.scale(tensor) + def __repr__(self): - return f"ScalarTensor:\n - With {list(self.tensors.keys())}\n - With dims: {self._specified_dimensions}" + return ( + f"ScalarTensor:\n - With tensors : {list(self.tensors.keys())}\n" + f" - In dimensions : {list(self._specified_dimensions.values())}" + ) def __contains__(self, dimension: int | tuple[int] | str) -> bool: """Check if either scalar by name or dimension by int/tuple is being scaled.""" if isinstance(dimension, tuple): - return dimension in self._specified_dimensions + return dimension in self._specified_dimensions.values() if isinstance(dimension, str): return dimension in self.tensors diff --git a/src/anemoi/training/losses/weightedloss.py b/src/anemoi/training/losses/weightedloss.py index 81829b66..0deccc9d 100644 --- a/src/anemoi/training/losses/weightedloss.py +++ b/src/anemoi/training/losses/weightedloss.py @@ -42,6 +42,7 @@ def __init__( Registers: - self.node_weights: torch.Tensor of shape (N, ) + - self.scalar: ScaleTensor modified with `add_scalar` and `update_scalar` Parameters ---------- @@ -64,10 +65,16 @@ def __init__( def add_scalar(self, dimension: int | tuple[int], scalar: torch.Tensor, *, name: str | None = None) -> None: self.scalar.add_scalar(dimension=dimension, scalar=scalar, name=name) + @functools.wraps(ScaleTensor.update_scalar, assigned=("__doc__", "__annotations__")) + def update_scalar(self, name: str, scalar: torch.Tensor, *, override: bool = False) -> None: + self.scalar.update_scalar(name=name, scalar=scalar, override=override) + def scale( self, x: torch.Tensor, - feature_indices: torch.Tensor | None = None, + scalar_indices: tuple[int, ...] | None = None, + *, + without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: """Scale a tensor by the variable_scaling. @@ -75,23 +82,32 @@ def scale( ---------- x : torch.Tensor Tensor to be scaled, shape (bs, ensemble, lat*lon, n_outputs) - feature_indices: - feature indices (relative to full model output) of the features passed in pred and target + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None. + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None Returns ------- torch.Tensor Scaled error tensor """ - # Use feature_weights if available if len(self.scalar) == 0: return x - scalar = self.scalar.get_scalar(x.ndim).to(x) + scale_tensor = self.scalar + if without_scalars is not None and len(without_scalars) > 0: + if isinstance(without_scalars[0], str): + scale_tensor = self.scalar.without(without_scalars) + else: + scale_tensor = self.scalar.without_by_dim(without_scalars) + + scalar = scale_tensor.get_scalar(x.ndim).to(x) - if feature_indices is None: + if scalar_indices is None: return x * scalar - return x * scalar[..., feature_indices] + return x * scalar[scalar_indices] def scale_by_node_weights(self, x: torch.Tensor, squash: bool = True) -> torch.Tensor: """Scale a tensor by the node_weights. @@ -132,8 +148,9 @@ def forward( pred: torch.Tensor, target: torch.Tensor, squash: bool = True, - feature_indices: torch.Tensor | None = None, - feature_scale: bool = True, + *, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: """Calculates the lat-weighted scaled loss. @@ -145,10 +162,11 @@ def forward( Target tensor, shape (bs, ensemble, lat*lon, n_outputs) squash : bool, optional Average last dimension, by default True - feature_indices: - feature indices (relative to full model output) of the features passed in pred and target - feature_scale: - If True, scale the loss by the feature_weights + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None Returns ------- @@ -157,8 +175,8 @@ def forward( """ out = pred - target - if feature_scale: - out = self.scale(out, feature_indices) + out = self.scale(out, scalar_indices, without_scalars=without_scalars) + return self.scale_by_node_weights(out, squash) @property @@ -168,7 +186,19 @@ def name(self) -> str: class FunctionalWeightedLoss(BaseWeightedLoss): - """WeightedLoss which a user can subclass and provide `calculate_difference`.""" + """WeightedLoss which a user can subclass and provide `calculate_difference`. + + `calculate_difference` should calculate the difference between the prediction and target. + All scaling and weighting is handled by the parent class. + + Example: + -------- + ```python + class MyLoss(FunctionalWeightedLoss): + def calculate_difference(self, pred, target): + return pred - target + ``` + """ def __init__( self, @@ -186,8 +216,9 @@ def forward( pred: torch.Tensor, target: torch.Tensor, squash: bool = True, - feature_indices: torch.Tensor | None = None, - feature_scale: bool = True, + *, + scalar_indices: tuple[int, ...] | None = None, + without_scalars: list[str] | list[int] | None = None, ) -> torch.Tensor: """Calculates the lat-weighted scaled loss. @@ -199,10 +230,12 @@ def forward( Target tensor, shape (bs, ensemble, lat*lon, n_outputs) squash : bool, optional Average last dimension, by default True - feature_indices: - feature indices (relative to full model output) of the features passed in pred and target - feature_scale: - If True, scale the loss by the feature_weights + scalar_indices: tuple[int,...], optional + Indices to subset the calculated scalar with, by default None + without_scalars: list[str] | list[int] | None, optional + list of scalars to exclude from scaling. Can be list of names or dimensions to exclude. + By default None + Returns ------- @@ -211,6 +244,5 @@ def forward( """ out = self.calculate_difference(pred, target) - if feature_scale: - out = self.scale(out, feature_indices) + out = self.scale(out, scalar_indices, without_scalars=without_scalars) return self.scale_by_node_weights(out, squash) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 12a8f9fb..80459b8f 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -458,8 +458,7 @@ def calculate_val_metrics( metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( y_pred_postprocessed[..., indices], y_postprocessed[..., indices], - feature_indices=indices, - feature_scale=mkey == "all", + scalar_indices=[..., indices], ) return metrics diff --git a/tests/train/test_scaler.py b/tests/train/test_scalar.py similarity index 66% rename from tests/train/test_scaler.py rename to tests/train/test_scalar.py index f3bad0d2..9a37e353 100644 --- a/tests/train/test_scaler.py +++ b/tests/train/test_scalar.py @@ -49,6 +49,39 @@ def test_scale_contains_subset_by_dim_indexing() -> None: assert "test" not in scale +def test_add_existing_scalar() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0]))) + with pytest.raises(ValueError, match=r".*already exists.*"): + scale.add_scalar(0, torch.tensor(3.0), name="test") + + +def test_update_scalar() -> None: + scale = ScaleTensor(test=(0, torch.ones(2))) + scale.update_scalar("test", torch.tensor([3.0])) + torch.testing.assert_close(scale.tensors["test"][1], torch.tensor([3.0])) + + +def test_update_missing_scalar() -> None: + scale = ScaleTensor(test=(0, torch.ones(2))) + with pytest.raises(ValueError, match=r".*not found in scalars.*"): + scale.update_scalar("test_missing", torch.tensor([3.0])) + assert "test" in scale + assert (0,) in scale + + +def test_update_scalar_wrong_dim() -> None: + scale = ScaleTensor(test=(0, torch.ones((2, 3)))) + with pytest.raises(ValueError, match=r".*does not match shape of saved scalar.*"): + scale.update_scalar("test", torch.ones((2, 2))) + assert "test" in scale + assert 0 in scale + + +def test_update_scalar_wrong_dim_allow_override() -> None: + scale = ScaleTensor(test=(0, torch.ones((2, 3)))) + assert scale.update_scalar("test", torch.ones((2, 2)), override=True) is None + + @pytest.mark.parametrize( ("scalars", "input_tensor", "output"), [ @@ -133,3 +166,35 @@ def test_scale_tensor_two_dim( output = torch.tensor(output, dtype=torch.float32) torch.testing.assert_close(scale.scale(input_tensor), output) + + +def test_scalar_subset() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(0, torch.tensor([3.0]))) + subset = scale.subset("test") + assert "test" in subset + assert "wow" not in subset + assert 0 in subset + + +def test_scalar_subset_without() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(0, torch.tensor([3.0]))) + subset = scale.without("test") + assert "test" not in subset + assert "wow" in subset + assert 0 in subset + + +def test_scalar_subset_by_dim() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) + subset = scale.subset_by_dim(0) + assert "test" in subset + assert "wow" not in subset + assert 0 in subset + + +def test_scalar_subset_by_dim_without() -> None: + scale = ScaleTensor(test=(0, torch.tensor([2.0])), wow=(1, torch.tensor([3.0]))) + subset = scale.without_by_dim(0) + assert "test" not in subset + assert "wow" in subset + assert 0 not in subset