Skip to content

Commit

Permalink
Utils (#33)
Browse files Browse the repository at this point in the history
* rename util into utils

* add doc

* add docstring

* fix

* changelog
  • Loading branch information
mariogeiger authored Jun 19, 2023
1 parent 0fc08ec commit 44a2203
Show file tree
Hide file tree
Showing 36 changed files with 87 additions and 27 deletions.
1 change: 1 addition & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `IrrepsArray.__radd__` and `IrrepsArray.__rsub__` to support `scalar + IrrepsArray` and `scalar - IrrepsArray`
- `0 + IrrepsArray` and `0 - IrrepsArray` are now always accepted as special cases.
- Support for `IrrepsArray / array`
- Add `utils` as a submodule

### Fixed
- `e3nn.scatter` operation handle indices with `ndim > 1`
Expand Down
3 changes: 2 additions & 1 deletion docs/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ API
haiku
flax
s2
extra
extra
utils
11 changes: 11 additions & 0 deletions docs/api/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Utils
=====

.. autoclass:: e3nn_jax.utils.equivariance_test
:members:

.. autoclass:: e3nn_jax.utils.assert_equivariant
:members:

.. autoclass:: e3nn_jax.utils.assert_output_dtype_matches_input_dtype
:members:
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ It can for instance get rid of the dead code:
.. jupyter-execute::
:hide-code:

from e3nn_jax._src.util.jit import jit_code
from e3nn_jax._src.utils.jit import jit_code

.. jupyter-execute::

Expand Down
2 changes: 2 additions & 0 deletions e3nn_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@

# make submodules flax and haiku available
from e3nn_jax import flax, haiku
from e3nn_jax import utils

__all__ = [
"config", # not in docs
Expand Down Expand Up @@ -201,4 +202,5 @@
"tensor_product_with_spherical_harmonics",
"flax",
"haiku",
"utils",
]
2 changes: 1 addition & 1 deletion e3nn_jax/_src/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.scipy.special as jsp

import e3nn_jax as e3nn
from e3nn_jax._src.util.decorators import overload_for_irreps_without_array
from e3nn_jax._src.utils.decorators import overload_for_irreps_without_array


def soft_odd(x):
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/batchnorm_haiku_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

import e3nn_jax as e3nn
from e3nn_jax.util import assert_equivariant
from e3nn_jax.utils import assert_equivariant


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/core_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from e3nn_jax import Instruction, Irreps, IrrepsArray, clebsch_gordan, config
from e3nn_jax._src.einsum import einsum as opt_einsum
from e3nn_jax._src.util.dtype import get_pytree_dtype
from e3nn_jax._src.utils.dtype import get_pytree_dtype


class FunctionalTensorProduct:
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/dropout_haiku_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import haiku as hk

import e3nn_jax as e3nn
from e3nn_jax.util import assert_equivariant
from e3nn_jax.utils import assert_equivariant


def test_dropout(keys):
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import e3nn_jax as e3nn
from e3nn_jax import IrrepsArray, scalar_activation
from e3nn_jax._src.util.decorators import overload_for_irreps_without_array
from e3nn_jax._src.utils.decorators import overload_for_irreps_without_array


@partial(jax.jit, static_argnums=(1, 2, 3, 4, 5))
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax
import jax.numpy as jnp
import pytest
from e3nn_jax.util import assert_equivariant
from e3nn_jax.utils import assert_equivariant

gate = jax.jit(jax.vmap(e3nn.gate))

Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/grad_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import e3nn_jax as e3nn
import numpy as np
from e3nn_jax.util import assert_equivariant
from e3nn_jax.utils import assert_equivariant
from jax import random


Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from e3nn_jax import Irreps, IrrepsArray, config
from e3nn_jax._src.core_tensor_product import _sum_tensors
from e3nn_jax._src.util.dtype import get_pytree_dtype
from e3nn_jax._src.utils.dtype import get_pytree_dtype


class Instruction(NamedTuple):
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/linear_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp

import e3nn_jax as e3nn
from e3nn_jax._src.util.dtype import get_pytree_dtype
from e3nn_jax._src.utils.dtype import get_pytree_dtype

from .linear import (
FunctionalLinear,
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/linear_flax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest

import e3nn_jax as e3nn
from e3nn_jax.util import assert_output_dtype_matches_input_dtype
from e3nn_jax.utils import assert_output_dtype_matches_input_dtype


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/linear_haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp

import e3nn_jax as e3nn
from e3nn_jax._src.util.dtype import get_pytree_dtype
from e3nn_jax._src.utils.dtype import get_pytree_dtype

from .linear import (
FunctionalLinear,
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/radial_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pytest
from e3nn_jax._src.radial import u
from e3nn_jax._src.util.test import assert_output_dtype_matches_input_dtype
from e3nn_jax._src.utils.test import assert_output_dtype_matches_input_dtype


def test_sus():
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/reduced_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import e3nn_jax as e3nn
from e3nn_jax import perm
from e3nn_jax._src.util.math_numpy import basis_intersection, round_to_sqrt_rational
from e3nn_jax._src.utils.math_numpy import basis_intersection, round_to_sqrt_rational


def reduced_tensor_product_basis(
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/s2grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import e3nn_jax as e3nn
from e3nn_jax._src.s2grid import _irfft, _rfft, _spherical_harmonics_s2grid
from e3nn_jax.util import assert_output_dtype_matches_input_dtype
from e3nn_jax.utils import assert_output_dtype_matches_input_dtype


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
import sympy
from e3nn_jax import Irreps, IrrepsArray, clebsch_gordan, config
from e3nn_jax._src.util.sympy import sqrtQarray_to_sympy
from e3nn_jax._src.utils.sympy import sqrtQarray_to_sympy


def sh(
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/_src/tensor_products.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax
import jax.numpy as jnp
from e3nn_jax import FunctionalTensorProduct, Irrep, Irreps, IrrepsArray, config
from e3nn_jax._src.util.decorators import overload_for_irreps_without_array
from e3nn_jax._src.utils.decorators import overload_for_irreps_without_array
from e3nn_jax._src.basic import _align_two_irreps_arrays


Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from e3nn_jax._src.util.jit import jit_code
from e3nn_jax._src.utils.jit import jit_code


def test_jit_code():
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
49 changes: 47 additions & 2 deletions e3nn_jax/_src/util/test.py → e3nn_jax/_src/utils/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,31 @@
import numpy as np

import e3nn_jax as e3nn
from e3nn_jax._src.util.dtype import get_pytree_dtype
from e3nn_jax._src.utils.dtype import get_pytree_dtype


def equivariance_test(
fun: Callable[[e3nn.IrrepsArray], e3nn.IrrepsArray],
rng_key: jnp.ndarray,
*args,
):
r"""Test equivariance of a function.
Args:
fun: function to test
rng_key: random number generator key
*args: arguments to pass to fun
Returns:
out1, out2: outputs of R fun(args) and fun(R args) where R is a random rotation and inversion
Example:
>>> fun = e3nn.norm
>>> rng = jax.random.PRNGKey(0)
>>> x = e3nn.IrrepsArray("1e", jnp.array([0.0, 4.0, 3.0]))
>>> equivariance_test(fun, rng, x)
(1x0e [5.], 1x0e [5.])
"""
assert all(isinstance(arg, e3nn.IrrepsArray) for arg in args)
dtype = get_pytree_dtype(args)
if dtype.kind == "i":
Expand All @@ -35,6 +52,25 @@ def assert_equivariant(
atol: float = 1e-6,
rtol: float = 1e-6,
):
r"""Assert that a function is equivariant.
Args:
fun: function to test
rng_key: random number generator key
args_in (optional): inputs to pass to fun, irreps_in must be None
irreps_in (optional): irreps of inputs to pass to fun, args_in must be None
atol: absolute tolerance
rtol: relative tolerance
Examples:
>>> fun = e3nn.norm
>>> rng = jax.random.PRNGKey(0)
>>> x = e3nn.IrrepsArray("1e", jnp.array([0.0, 4.0, 3.0]))
>>> assert_equivariant(fun, rng, args_in=(x,))
We can also pass the irreps of the inputs instead of the inputs themselves:
>>> assert_equivariant(fun, rng, irreps_in=("1e",))
"""
if args_in is None and irreps_in is None:
raise ValueError("Either args_in or irreps_in must be provided")

Expand All @@ -50,7 +86,16 @@ def assert_(x, y):


def assert_output_dtype_matches_input_dtype(fun: Callable, *args, **kwargs):
"""Checks that the dtype of fun(*args, **kwargs) matches that of the input (*args, **kwargs)."""
"""Checks that the dtype of fun(*args, **kwargs) matches that of the input (*args, **kwargs).
Args:
fun: function to test
*args: arguments to pass to fun
**kwargs: keyword arguments to pass to fun
Raises:
AssertionError: if the dtype of fun(*args, **kwargs) does not match that of the input (*args, **kwargs).
"""
if not jax.config.read("jax_enable_x64"):
raise ValueError("This test requires jax_enable_x64=True")

Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/experimental/linear_shtp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import e3nn_jax as e3nn
from e3nn_jax.experimental.linear_shtp import LinearSHTP, shtp
from e3nn_jax.util import equivariance_test
from e3nn_jax.utils import equivariance_test


@pytest.mark.parametrize("mix", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/experimental/point_convolution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
MessagePassingConvolutionHaiku,
radial_basis,
)
from e3nn_jax.util import assert_equivariant, assert_output_dtype_matches_input_dtype
from e3nn_jax.utils import assert_equivariant, assert_output_dtype_matches_input_dtype


def test_point_convolution(keys):
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/experimental/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax
import jax.numpy as jnp
from e3nn_jax.experimental.transformer import Transformer
from e3nn_jax.util import assert_equivariant
from e3nn_jax.utils import assert_equivariant


def test_transformer(keys):
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/experimental/voxel_convolution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import e3nn_jax as e3nn
from e3nn_jax.experimental.voxel_convolution import ConvolutionFlax
from e3nn_jax.util import assert_output_dtype_matches_input_dtype
from e3nn_jax.utils import assert_output_dtype_matches_input_dtype


def test_convolution(keys):
Expand Down
2 changes: 1 addition & 1 deletion e3nn_jax/util.py → e3nn_jax/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from e3nn_jax._src.util.test import (
from e3nn_jax._src.utils.test import (
assert_equivariant,
equivariance_test,
assert_output_dtype_matches_input_dtype,
Expand Down
2 changes: 1 addition & 1 deletion examples/tensor_product_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jaxlib

import e3nn_jax as e3nn
from e3nn_jax._src.util.jit import jit_code
from e3nn_jax._src.utils.jit import jit_code


# https://stackoverflow.com/a/15008806/1008938
Expand Down

0 comments on commit 44a2203

Please sign in to comment.