Skip to content

Commit

Permalink
Modified init_state to return a partial initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
ArvinSKushwaha committed Jul 26, 2024
1 parent 5d220e8 commit 89e468a
Showing 1 changed file with 38 additions and 26 deletions.
64 changes: 38 additions & 26 deletions src/jaximal/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Any, Callable
from typing import Any, Callable, Self

import jax

Expand Down Expand Up @@ -68,25 +68,23 @@ def init(

class JaximalModule(Jaximal, ABC):
@abstractmethod
@staticmethod
def init_state(key: PRNGKeyArray, *args: Any, **kwargs: Any) -> 'JaximalModule': ...
@classmethod
def init_state(
cls, *args: Any, **kwargs: Any
) -> Callable[[PRNGKeyArray], Self]: ...

@abstractmethod
def __call__(self, data: PyTree) -> PyTree: ...

@classmethod
def partial_init(
cls, *args: Any, **kwargs: Any
) -> Callable[[PRNGKeyArray], 'JaximalModule']:
return lambda key: cls.init_state(key, *args, **kwargs)


class Activation(JaximalModule):
func: Static[Callable[[Array], Array]]

@staticmethod
def init_state(key: PRNGKeyArray, func: Callable[[Array], Array]) -> 'Activation':
return Activation(func)
@classmethod
def init_state(
cls, func: Callable[[Array], Array]
) -> Callable[[PRNGKeyArray], Self]:
return lambda key: cls(func)

def __call__(self, data: PyTree) -> PyTree:
return jax.tree.map(self.func, data)
Expand All @@ -99,19 +97,24 @@ class Linear(JaximalModule):
weights: Float[Array, 'in_dim out_dim']
biases: Float[Array, 'out_dim']

@staticmethod
@classmethod
def init_state(
key: PRNGKeyArray,
cls,
in_dim: int,
out_dim: int,
weight_initialization: WeightInitialization = WeightInitialization.GlorotUniform,
bias_initialization: WeightInitialization = WeightInitialization.Zero,
) -> 'Linear':
w_key, b_key = jax.random.split(key)
weights = weight_initialization.init(w_key, (in_dim, out_dim), in_dim, out_dim)
biases = weight_initialization.init(b_key, (out_dim,), 1, out_dim)
) -> Callable[[PRNGKeyArray], Self]:
def init(key: PRNGKeyArray) -> Self:
w_key, b_key = jax.random.split(key)
weights = weight_initialization.init(
w_key, (in_dim, out_dim), in_dim, out_dim
)
biases = weight_initialization.init(b_key, (out_dim,), 1, out_dim)

return Linear(in_dim, out_dim, weights, biases)
return cls(in_dim, out_dim, weights, biases)

return init

def __call__(self, data: PyTree) -> PyTree:
return data @ self.weights + self.biases
Expand All @@ -120,14 +123,17 @@ def __call__(self, data: PyTree) -> PyTree:
class Sequential(JaximalModule):
modules: list[JaximalModule]

@staticmethod
@classmethod
def init_state(
key: PRNGKeyArray, partials: list[Callable[[PRNGKeyArray], JaximalModule]]
) -> 'Sequential':
keys = jax.random.split(key, len(partials))
cls, partials: list[Callable[[PRNGKeyArray], JaximalModule]]
) -> Callable[[PRNGKeyArray], Self]:
def init(key: PRNGKeyArray) -> Self:
keys = jax.random.split(key, len(partials))

modules = list(partial(key) for key, partial in zip(keys, partials))
return cls(modules)

modules = list(partial(key) for key, partial in zip(keys, partials))
return Sequential(modules)
return init

def __call__(self, data: PyTree, *args: dict[str, Any]) -> PyTree:
assert len(args) == len(self.modules), (
Expand All @@ -140,4 +146,10 @@ def __call__(self, data: PyTree, *args: dict[str, Any]) -> PyTree:
return data


__all__ = ['JaximalModule', 'WeightInitialization', 'Linear', 'Sequential']
__all__ = [
'JaximalModule',
'WeightInitialization',
'Linear',
'Sequential',
'Activation',
]

0 comments on commit 89e468a

Please sign in to comment.