diff --git a/ChangeLog.md b/ChangeLog.md index 40cdc678..ad47e28d 100644 --- a/ChangeLog.md +++ b/ChangeLog.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- `e3nn.reduced_symmetric_tensor_product_basis(irreps: Irreps, order: int)` ## [0.9.0] - 2022-09-04 ### Added diff --git a/docs/api/extra.rst b/docs/api/extra.rst index d141c8e9..f2868df3 100644 --- a/docs/api/extra.rst +++ b/docs/api/extra.rst @@ -24,3 +24,6 @@ Extra Stuff .. autofunction:: e3nn_jax.reduced_tensor_product_basis + + +.. autofunction:: e3nn_jax.reduced_symmetric_tensor_product_basis diff --git a/e3nn_jax/__init__.py b/e3nn_jax/__init__.py index cdd5fb5d..9711345d 100644 --- a/e3nn_jax/__init__.py +++ b/e3nn_jax/__init__.py @@ -58,7 +58,7 @@ from e3nn_jax._src.mlp import MultiLayerPerceptron from e3nn_jax._src.graph_util import index_add, radius_graph from e3nn_jax._src.poly_envelope import poly_envelope -from e3nn_jax._src.reduced_tensor_product import reduced_tensor_product_basis +from e3nn_jax._src.reduced_tensor_product import reduced_tensor_product_basis, reduced_symmetric_tensor_product_basis __all__ = [ "config", # not in docs @@ -132,4 +132,5 @@ "radius_graph", "poly_envelope", "reduced_tensor_product_basis", + "reduced_symmetric_tensor_product_basis", ] diff --git a/e3nn_jax/_src/reduced_tensor_product.py b/e3nn_jax/_src/reduced_tensor_product.py index 6d469c66..dc95072e 100644 --- a/e3nn_jax/_src/reduced_tensor_product.py +++ b/e3nn_jax/_src/reduced_tensor_product.py @@ -26,7 +26,7 @@ def reduced_tensor_product_basis( Returns: IrrepsArray: The change of basis - The shape is ``(d1, ..., dn, irreps.dim)`` + The shape is ``(d1, ..., dn, irreps_out.dim)`` where ``di`` is the dimension of the index ``i`` and ``n`` is the number of indices in the formula. Example: @@ -76,6 +76,28 @@ def reduced_tensor_product_basis( return _reduced_tensor_product_basis(irreps, formulas, epsilon) +def reduced_symmetric_tensor_product_basis( + irreps: e3nn.Irreps, + order: int, + *, + epsilon: float = 1e-5, +): + r"""Reduce a symmetric tensor product. + + Args: + irreps (Irreps): the irreps of each index. + order (int): the order of the tensor product. i.e. the number of indices. + + Returns: + IrrepsArray: The change of basis + The shape is ``(d, ..., d, irreps_out.dim)`` + where ``d`` is the dimension of ``irreps``. + """ + irreps = e3nn.Irreps(irreps) + formulas: FrozenSet[Tuple[int, Tuple[int, ...]]] = frozenset((1, p) for p in itertools.permutations(range(order))) + return _reduced_tensor_product_basis(tuple([irreps] * order), formulas, epsilon) + + @functools.lru_cache(maxsize=None) def _reduced_tensor_product_basis( irreps: Tuple[e3nn.Irreps], formulas: FrozenSet[Tuple[int, Tuple[int, ...]]], epsilon: float diff --git a/e3nn_jax/_src/reduced_tensor_product_test.py b/e3nn_jax/_src/reduced_tensor_product_test.py index f45e39fc..de8b28b7 100644 --- a/e3nn_jax/_src/reduced_tensor_product_test.py +++ b/e3nn_jax/_src/reduced_tensor_product_test.py @@ -27,3 +27,9 @@ def test_reduce_tensor_elasticity_tensor_parity(): np.testing.assert_allclose(Q.array, np.einsum("ijklx->jiklx", Q.array), atol=1e-6, rtol=1e-6) np.testing.assert_allclose(Q.array, np.einsum("ijklx->ijlkx", Q.array), atol=1e-6, rtol=1e-6) np.testing.assert_allclose(Q.array, np.einsum("ijklx->klijx", Q.array), atol=1e-6, rtol=1e-6) + + +def test_reduced_symmetric_tensor_product_basis(): + Q = e3nn.reduced_symmetric_tensor_product_basis("1e", 5) + P = e3nn.reduced_tensor_product_basis("ijklm=jiklm=jklmi", i="1e") + np.testing.assert_equal(Q.array, P.array)