From 3e12290a259b0d80e54b45fad0e4b7b0f531f07c Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 6 Nov 2023 15:37:39 +0100 Subject: [PATCH] fix tests --- e3nn_jax/_src/irreps.py | 11 ++++++----- e3nn_jax/_src/irreps_test.py | 28 +++++++++++++++++++++++++--- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/e3nn_jax/_src/irreps.py b/e3nn_jax/_src/irreps.py index adfd76e..f444ae7 100644 --- a/e3nn_jax/_src/irreps.py +++ b/e3nn_jax/_src/irreps.py @@ -1058,11 +1058,11 @@ def rot_y(phi): if b is not None: if l < len(Jd): - J = Jd[l] + J = Jd[l].astype(b.dtype) R += [J @ rot_y(b) @ J] else: X = generators(l) - R += [jax.scipy.linalg.expm(b * X[0])] + R += [jax.scipy.linalg.expm(b.astype(X.dtype) * X[0]).astype(b.dtype)] if c is not None: R += [rot_y(c)] @@ -1094,11 +1094,12 @@ def _wigner_D_from_log_coordinates(l: int, log_coordinates: jnp.ndarray) -> jnp. """ X = generators(l) - def func(log_coordinates): - return jax.scipy.linalg.expm(jnp.einsum("a,aij->ij", log_coordinates, X)) + def func(log): + log = log.astype(X.dtype) + return jax.scipy.linalg.expm(jnp.einsum("a,aij->ij", log, X)) f = func for _ in range(log_coordinates.ndim - 1): f = jax.vmap(f) - return f(log_coordinates) + return f(log_coordinates).astype(log_coordinates.dtype) diff --git a/e3nn_jax/_src/irreps_test.py b/e3nn_jax/_src/irreps_test.py index d0c14ad..64b8d57 100644 --- a/e3nn_jax/_src/irreps_test.py +++ b/e3nn_jax/_src/irreps_test.py @@ -1,7 +1,7 @@ import jax import numpy as np import pytest - +import jax.numpy as jnp import e3nn_jax as e3nn @@ -141,9 +141,31 @@ def test_D(keys, ir): jax.config.update("jax_enable_x64", True) ir = e3nn.Irrep(ir) - angles = e3nn.rand_angles(keys[0]) + angles = e3nn.rand_angles(keys[0], dtype=np.float64) Da = ir.D_from_angles(*angles) w = e3nn.angles_to_log_coordinates(*angles) Dw = ir.D_from_log_coordinates(w) - np.testing.assert_allclose(Da, Dw, atol=1e-10, rtol=0.0008) + assert Dw.dtype == np.float64, "D_from_log_coordinates should return float64" + assert Da.dtype == np.float64, "D_from_angles should return float64" + np.testing.assert_allclose(Da, Dw, atol=1e-10, rtol=0.002) + + +@pytest.mark.parametrize("ir", ["0e", "1e", "2e", "3e", "4e", "12e"]) +def test_dtype_D_from_angles(ir): + jax.config.update("jax_enable_x64", True) + + ir = e3nn.Irrep(ir) + e3nn.utils.assert_output_dtype_matches_input_dtype( + ir.D_from_angles, jnp.array(1.0), jnp.array(1.0), jnp.array(1.0) + ) + + +@pytest.mark.parametrize("ir", ["0e", "1e", "2e", "3e", "4e", "12e"]) +def test_dtype_D_from_log_coordinates(ir): + jax.config.update("jax_enable_x64", True) + + ir = e3nn.Irrep(ir) + e3nn.utils.assert_output_dtype_matches_input_dtype( + ir.D_from_log_coordinates, jnp.array([1.0, 1.0, 0.0]) + )