From 245e17eb23deaccad9f2c9cfd40fe40515e3c074 Mon Sep 17 00:00:00 2001 From: Ameya Daigavane Date: Sat, 28 Sep 2024 17:53:25 -0400 Subject: [PATCH] Fix. --- e3nn_jax/_src/linear_flax.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/e3nn_jax/_src/linear_flax.py b/e3nn_jax/_src/linear_flax.py index dce4be6..5e282c4 100644 --- a/e3nn_jax/_src/linear_flax.py +++ b/e3nn_jax/_src/linear_flax.py @@ -74,9 +74,9 @@ class Linear(flax.linen.Module): gradient_normalization: Optional[Union[float, str]] = None path_normalization: Optional[Union[float, str]] = None biases: bool = False - parameter_initializer: Optional[ - Callable[[], jax.nn.initializers.Initializer] - ] = None + parameter_initializer: Optional[Callable[[], jax.nn.initializers.Initializer]] = ( + None + ) instructions: Optional[List[Tuple[int, int]]] = None num_indexed_weights: Optional[int] = None weights_per_channel: bool = False