From 89e468ad573cc7be16f2adcd11e209dc9c3d051d Mon Sep 17 00:00:00 2001 From: Arvin Kushwaha Date: Fri, 26 Jul 2024 16:12:41 +0200 Subject: [PATCH] Modified `init_state` to return a partial initialization --- src/jaximal/nn.py | 64 ++++++++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 26 deletions(-) diff --git a/src/jaximal/nn.py b/src/jaximal/nn.py index 1f09c0c..4facaf3 100644 --- a/src/jaximal/nn.py +++ b/src/jaximal/nn.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum, auto -from typing import Any, Callable +from typing import Any, Callable, Self import jax @@ -68,25 +68,23 @@ def init( class JaximalModule(Jaximal, ABC): @abstractmethod - @staticmethod - def init_state(key: PRNGKeyArray, *args: Any, **kwargs: Any) -> 'JaximalModule': ... + @classmethod + def init_state( + cls, *args: Any, **kwargs: Any + ) -> Callable[[PRNGKeyArray], Self]: ... @abstractmethod def __call__(self, data: PyTree) -> PyTree: ... - @classmethod - def partial_init( - cls, *args: Any, **kwargs: Any - ) -> Callable[[PRNGKeyArray], 'JaximalModule']: - return lambda key: cls.init_state(key, *args, **kwargs) - class Activation(JaximalModule): func: Static[Callable[[Array], Array]] - @staticmethod - def init_state(key: PRNGKeyArray, func: Callable[[Array], Array]) -> 'Activation': - return Activation(func) + @classmethod + def init_state( + cls, func: Callable[[Array], Array] + ) -> Callable[[PRNGKeyArray], Self]: + return lambda key: cls(func) def __call__(self, data: PyTree) -> PyTree: return jax.tree.map(self.func, data) @@ -99,19 +97,24 @@ class Linear(JaximalModule): weights: Float[Array, 'in_dim out_dim'] biases: Float[Array, 'out_dim'] - @staticmethod + @classmethod def init_state( - key: PRNGKeyArray, + cls, in_dim: int, out_dim: int, weight_initialization: WeightInitialization = WeightInitialization.GlorotUniform, bias_initialization: WeightInitialization = WeightInitialization.Zero, - ) -> 'Linear': - w_key, b_key = jax.random.split(key) - weights = weight_initialization.init(w_key, (in_dim, out_dim), in_dim, out_dim) - biases = weight_initialization.init(b_key, (out_dim,), 1, out_dim) + ) -> Callable[[PRNGKeyArray], Self]: + def init(key: PRNGKeyArray) -> Self: + w_key, b_key = jax.random.split(key) + weights = weight_initialization.init( + w_key, (in_dim, out_dim), in_dim, out_dim + ) + biases = weight_initialization.init(b_key, (out_dim,), 1, out_dim) - return Linear(in_dim, out_dim, weights, biases) + return cls(in_dim, out_dim, weights, biases) + + return init def __call__(self, data: PyTree) -> PyTree: return data @ self.weights + self.biases @@ -120,14 +123,17 @@ def __call__(self, data: PyTree) -> PyTree: class Sequential(JaximalModule): modules: list[JaximalModule] - @staticmethod + @classmethod def init_state( - key: PRNGKeyArray, partials: list[Callable[[PRNGKeyArray], JaximalModule]] - ) -> 'Sequential': - keys = jax.random.split(key, len(partials)) + cls, partials: list[Callable[[PRNGKeyArray], JaximalModule]] + ) -> Callable[[PRNGKeyArray], Self]: + def init(key: PRNGKeyArray) -> Self: + keys = jax.random.split(key, len(partials)) + + modules = list(partial(key) for key, partial in zip(keys, partials)) + return cls(modules) - modules = list(partial(key) for key, partial in zip(keys, partials)) - return Sequential(modules) + return init def __call__(self, data: PyTree, *args: dict[str, Any]) -> PyTree: assert len(args) == len(self.modules), ( @@ -140,4 +146,10 @@ def __call__(self, data: PyTree, *args: dict[str, Any]) -> PyTree: return data -__all__ = ['JaximalModule', 'WeightInitialization', 'Linear', 'Sequential'] +__all__ = [ + 'JaximalModule', + 'WeightInitialization', + 'Linear', + 'Sequential', + 'Activation', +]