From 1e7e3fc6f297f39bd379b6ca9814807f7f700345 Mon Sep 17 00:00:00 2001 From: Fengyuan Hu <127644049+HuFY-dev@users.noreply.github.com> Date: Thu, 28 Mar 2024 22:49:15 -0400 Subject: [PATCH] Fixed import & style --- sparse_autoencoder/optimizer/adam_with_reset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sparse_autoencoder/optimizer/adam_with_reset.py b/sparse_autoencoder/optimizer/adam_with_reset.py index 5620d110..390f47ec 100644 --- a/sparse_autoencoder/optimizer/adam_with_reset.py +++ b/sparse_autoencoder/optimizer/adam_with_reset.py @@ -12,10 +12,11 @@ from sparse_autoencoder.tensor_types import Axis # params_t was renamed to ParamsT in PyTorch 2.2, which caused import errors -# Copied from PyTorch 2.1 -from typing import Union, Iterable, Dict, Any +# Copied from PyTorch 2.1 with modifications for better style +from collections.abc import Iterable +from typing import Any from typing_extensions import TypeAlias -params_t: TypeAlias = Union[Iterable[Tensor], Iterable[Dict[str, Any]]] +params_t: TypeAlias = Iterable[Tensor] | Iterable[dict[str, Any]] class AdamWithReset(Adam): """Adam Optimizer with a reset method.