Skip to content

Commit

Permalink
e3nn.utils.vmap
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Jun 22, 2023
1 parent cbc80a2 commit f409705
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 16 deletions.
3 changes: 3 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `e3nn.utils.vmap`

### Changed
- Simplify the tetris examples

Expand Down
2 changes: 2 additions & 0 deletions docs/api/utils.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Utils
=====

.. autofunction:: e3nn_jax.utils.vmap

.. autofunction:: e3nn_jax.utils.equivariance_test

.. autofunction:: e3nn_jax.utils.assert_equivariant
Expand Down
15 changes: 5 additions & 10 deletions e3nn_jax/_src/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,10 @@ def linear_vanilla(
for ins in linear.instructions
]
f = lambda x: linear(w, x)
input0 = input
for _ in range(input.ndim - 1):
f = jax.vmap(f)
input0 = input0[0]
f = e3nn.utils.vmap(f)

# vmap will drop the zero_flags, so we need to recompute them:
output0 = linear(w, input0)
output = f(input)
return IrrepsArray(output.irreps, output.array, zero_flags=output0.zero_flags)
return f(input)


def linear_indexed(
Expand Down Expand Up @@ -267,7 +262,7 @@ def linear_indexed(

f = lin
for _ in range(input.ndim - 1):
f = jax.vmap(f)
f = e3nn.utils.vmap(f)
return f(w, input)


Expand Down Expand Up @@ -307,7 +302,7 @@ def linear_mixed(

f = lin
for _ in range(input.ndim - 1):
f = jax.vmap(f)
f = e3nn.utils.vmap(f)
return f(w, input) # (..., irreps)


Expand Down Expand Up @@ -348,5 +343,5 @@ def linear_mixed_per_channel(

f = lin
for _ in range(input.ndim - 1):
f = jax.vmap(f)
f = e3nn.utils.vmap(f)
return f(w, input) # (..., num_channels, irreps)
2 changes: 1 addition & 1 deletion e3nn_jax/_src/tensor_product_with_spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def tensor_product_with_spherical_harmonics(

f = impl
for _ in range(len(leading_shape)):
f = jax.vmap(f, in_axes=(0, 0, None), out_axes=0)
f = e3nn.utils.vmap(f, in_axes=(0, 0, None), out_axes=0)

return f(input, vector, degree)

Expand Down
3 changes: 1 addition & 2 deletions e3nn_jax/_src/tensor_products.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from functools import partial
from typing import List, Optional

import jax
import jax.numpy as jnp

import e3nn_jax as e3nn
Expand All @@ -16,7 +15,7 @@ def wrapper(*args):
args = [arg.broadcast_to(leading_shape + (-1,)) for arg in args]
f = func
for _ in range(len(leading_shape)):
f = jax.vmap(f)
f = e3nn.utils.vmap(f)
return f(*args)

return wrapper
Expand Down
73 changes: 73 additions & 0 deletions e3nn_jax/_src/utils/vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Callable, Sequence, Union

import jax
from attr import attrib, attrs

import e3nn_jax as e3nn


def vmap(
fun: Callable[..., Any],
in_axes: Union[int, None, Sequence[Any]] = 0,
out_axes: Any = 0,
):
r"""Wrapper around :func:`jax.vmap` that handles :class:`e3nn_jax.IrrepsArray` objects.
Args:
fun: Function to be mapped.
in_axes: Specifies which axes to map over for the input arguments. See :func:`jax.vmap` for details.
out_axes: Specifies which axes to map over for the output arguments. See :func:`jax.vmap` for details.
Returns:
Batched/vectorized version of ``fun``.
Example:
>>> x = e3nn.from_chunks("0e + 0e", [jnp.ones((100, 1, 1)), None], (100,))
>>> x.zero_flags
(False, True)
>>> y = vmap(e3nn.scalar_activation)(x)
>>> y.zero_flags
(False, True)
"""
if in_axes == -1:
raise ValueError("in_axes=-1 is not supported for e3nn.vmap")
if out_axes == -1:
raise ValueError("out_axes=-1 is not supported for e3nn.vmap")

def to_via(x):
return _VIA(x) if isinstance(x, e3nn.IrrepsArray) else x

def from_via(x):
return x.a if isinstance(x, _VIA) else x

def inside_fun(*args, **kwargs):
args, kwargs = jax.tree_util.tree_map(
from_via, (args, kwargs), is_leaf=lambda x: isinstance(x, _VIA)
)
out = fun(*args, **kwargs)
return jax.tree_util.tree_map(
to_via, out, is_leaf=lambda x: isinstance(x, e3nn.IrrepsArray)
)

def outside_fun(*args, **kwargs):
args, kwargs = jax.tree_util.tree_map(
to_via, (args, kwargs), is_leaf=lambda x: isinstance(x, e3nn.IrrepsArray)
)
out = jax.vmap(inside_fun, in_axes, out_axes)(*args, **kwargs)
return jax.tree_util.tree_map(
from_via, out, is_leaf=lambda x: isinstance(x, _VIA)
)

return outside_fun


@attrs(frozen=True)
class _VIA:
a: e3nn.IrrepsArray = attrib()


jax.tree_util.register_pytree_node(
_VIA,
lambda x: ((x.a.array,), (x.a.irreps, x.a.zero_flags)),
lambda attrs, data: _VIA(e3nn.IrrepsArray(attrs[0], data[0], zero_flags=attrs[1])),
)
3 changes: 1 addition & 2 deletions e3nn_jax/experimental/voxel_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import flax
import haiku as hk
import jax
import jax.numpy as jnp
from jax import lax

Expand Down Expand Up @@ -218,7 +217,7 @@ def _kernel(

tp_right = tp.right
for _ in range(3):
tp_right = jax.vmap(tp_right, (0, 0), 0)
tp_right = e3nn.utils.vmap(tp_right, (0, 0), 0)
k = tp_right(ws, sh) # [x, y, z, irreps_in.dim, irreps_out.dim]

# self-connection, center of the kernel
Expand Down
2 changes: 2 additions & 0 deletions e3nn_jax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
equivariance_test,
assert_output_dtype_matches_input_dtype,
)
from e3nn_jax._src.utils.vmap import vmap

__all__ = [
"assert_equivariant",
"equivariance_test",
"assert_output_dtype_matches_input_dtype",
"vmap",
]
2 changes: 1 addition & 1 deletion examples/tetris_voxel.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __call__(self, x):

g = e3nn.gate
for _ in range(1 + 3):
g = jax.vmap(g)
g = e3nn.utils.vmap(g)

# Shallower and wider convolutions also works

Expand Down

0 comments on commit f409705

Please sign in to comment.