Skip to content

Commit

Permalink
Simplified Jaximal behavior.
Browse files Browse the repository at this point in the history
  • Loading branch information
ArvinSKushwaha committed Jun 5, 2024
1 parent 50bdfe1 commit 70d09ac
Showing 1 changed file with 20 additions and 19 deletions.
39 changes: 20 additions & 19 deletions src/jaximal/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import typing

from dataclasses import make_dataclass
from dataclasses import dataclass, fields
from itertools import chain
from json import dumps, loads
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Iterable,
Mapping,
Self,
Expand All @@ -20,6 +20,9 @@

from jaxtyping import AbstractArray, Array

if TYPE_CHECKING:
from _typeshed import DataclassInstance

type Static[T] = Annotated[T, 'jaximal::meta']


Expand Down Expand Up @@ -73,25 +76,21 @@ def forward(
"""

def __init_subclass__(cls) -> None:
cls2 = make_dataclass(
cls.__name__, list(cls.__annotations__), slots=True, frozen=True
)
# The `dataclass` decorator modifies `cls` itself, so we don't need to
# worry in the later steps. Note: This only holds true when
# `slots=False` in the `dataclass` decorator.
dataclass(frozen=True, eq=False)(cls)

setattr(cls, '__init__', cls2.__init__)
setattr(cls, '__repr__', cls2.__repr__)
setattr(cls, '__slots__', cls2.__slots__)
setattr(cls, '__setattr__', cls2.__setattr__)
setattr(cls, '__delattr__', cls2.__delattr__)
setattr(cls, '__getattribute__', cls2.__getattribute__)
cls_fields = fields(cast('DataclassInstance', cls))

data_fields = [
key for key, typ in cls.__annotations__.items() if get_origin(typ) != Static
]
meta_fields = [
key for key, typ in cls.__annotations__.items() if get_origin(typ) == Static
]
data_fields = []
meta_fields = []

jax.tree_util.register_dataclass(cls, data_fields, meta_fields)
for field in cls_fields:
if get_origin(field.type) != Static:
data_fields.append(field.name)
if get_origin(field.type) == Static:
meta_fields.append(field.name)

def cls_eq(self: Self, other: object) -> bool:
if type(other) != type(self):
Expand All @@ -117,7 +116,9 @@ def cls_eq(self: Self, other: object) -> bool:

return equal

cls.__eq__: Callable[[Self, object], bool] = cls_eq
setattr(cls, '__eq__', cls_eq)

jax.tree_util.register_dataclass(cls, data_fields, meta_fields)


def dictify(
Expand Down

0 comments on commit 70d09ac

Please sign in to comment.