Skip to content

Commit

Permalink
Handle classes with __new__ method (#10)
Browse files Browse the repository at this point in the history
* handle classes with __new__ method

* fix release action
  • Loading branch information
cgarciae authored Mar 19, 2023
1 parent 4190cf2 commit a4bc7a8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
21 changes: 12 additions & 9 deletions .github/workflows/publish-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@ jobs:
run: |
poetry install --without dev
- name: Build Docs 🔨
run: |
cp README.md docs/index.md
poetry run mkdocs build
# ----------------------------------------
# No docs for now
# ----------------------------------------
# - name: Build Docs 🔨
# run: |
# cp README.md docs/index.md
# poetry run mkdocs build

- name: Deploy Page 🚀
uses: JamesIves/[email protected]
with:
branch: gh-pages
folder: site
# - name: Deploy Page 🚀
# uses: JamesIves/[email protected]
# with:
# branch: gh-pages
# folder: site

- name: Publish to PyPI
run: |
Expand Down
8 changes: 4 additions & 4 deletions simple_pytree/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def static_field(


class PytreeMeta(ABCMeta):
def __call__(self: tp.Type[P], *args: tp.Any, **kwds: tp.Any) -> P:
obj: P = self.__new__(self)
def __call__(self: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P:
obj: P = self.__new__(self, *args, **kwargs)
obj.__dict__["_pytree__initializing"] = True
try:
obj.__init__(*args, **kwds)
obj.__init__(*args, **kwargs)
finally:
del obj.__dict__["_pytree__initializing"]
return obj
Expand Down Expand Up @@ -172,7 +172,7 @@ def _pytree__unflatten(
) -> P:
node_names, static_fields = metadata
node_fields = dict(zip(node_names, node_values))
pytree = cls.__new__(cls)
pytree = object.__new__(cls)
pytree.__dict__.update(node_fields, **dict(static_fields))
return pytree

Expand Down
12 changes: 12 additions & 0 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ class B(A):
leaves = jax.tree_util.tree_leaves(pytree)
assert leaves == [1, 3]

def test_pytree_with_new(self):
class A(Pytree):
def __init__(self, a):
self.a = a

def __new__(cls, a):
return super().__new__(cls)

pytree = A(a=1)

pytree = jax.tree_map(lambda x: x * 2, pytree)


class TestMutablePytree:
def test_pytree(self):
Expand Down

0 comments on commit a4bc7a8

Please sign in to comment.