Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Nov 6, 2023
1 parent e366b3c commit 3e12290
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
11 changes: 6 additions & 5 deletions e3nn_jax/_src/irreps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,11 +1058,11 @@ def rot_y(phi):

if b is not None:
if l < len(Jd):
J = Jd[l]
J = Jd[l].astype(b.dtype)
R += [J @ rot_y(b) @ J]
else:
X = generators(l)
R += [jax.scipy.linalg.expm(b * X[0])]
R += [jax.scipy.linalg.expm(b.astype(X.dtype) * X[0]).astype(b.dtype)]

if c is not None:
R += [rot_y(c)]
Expand Down Expand Up @@ -1094,11 +1094,12 @@ def _wigner_D_from_log_coordinates(l: int, log_coordinates: jnp.ndarray) -> jnp.
"""
X = generators(l)

def func(log_coordinates):
return jax.scipy.linalg.expm(jnp.einsum("a,aij->ij", log_coordinates, X))
def func(log):
log = log.astype(X.dtype)
return jax.scipy.linalg.expm(jnp.einsum("a,aij->ij", log, X))

f = func
for _ in range(log_coordinates.ndim - 1):
f = jax.vmap(f)

return f(log_coordinates)
return f(log_coordinates).astype(log_coordinates.dtype)
28 changes: 25 additions & 3 deletions e3nn_jax/_src/irreps_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import numpy as np
import pytest

import jax.numpy as jnp
import e3nn_jax as e3nn


Expand Down Expand Up @@ -141,9 +141,31 @@ def test_D(keys, ir):
jax.config.update("jax_enable_x64", True)

ir = e3nn.Irrep(ir)
angles = e3nn.rand_angles(keys[0])
angles = e3nn.rand_angles(keys[0], dtype=np.float64)
Da = ir.D_from_angles(*angles)
w = e3nn.angles_to_log_coordinates(*angles)
Dw = ir.D_from_log_coordinates(w)

np.testing.assert_allclose(Da, Dw, atol=1e-10, rtol=0.0008)
assert Dw.dtype == np.float64, "D_from_log_coordinates should return float64"
assert Da.dtype == np.float64, "D_from_angles should return float64"
np.testing.assert_allclose(Da, Dw, atol=1e-10, rtol=0.002)


@pytest.mark.parametrize("ir", ["0e", "1e", "2e", "3e", "4e", "12e"])
def test_dtype_D_from_angles(ir):
jax.config.update("jax_enable_x64", True)

ir = e3nn.Irrep(ir)
e3nn.utils.assert_output_dtype_matches_input_dtype(
ir.D_from_angles, jnp.array(1.0), jnp.array(1.0), jnp.array(1.0)
)


@pytest.mark.parametrize("ir", ["0e", "1e", "2e", "3e", "4e", "12e"])
def test_dtype_D_from_log_coordinates(ir):
jax.config.update("jax_enable_x64", True)

ir = e3nn.Irrep(ir)
e3nn.utils.assert_output_dtype_matches_input_dtype(
ir.D_from_log_coordinates, jnp.array([1.0, 1.0, 0.0])
)

0 comments on commit 3e12290

Please sign in to comment.