Skip to content

Commit

Permalink
add reduced_symmetric_tensor_product_basis
Browse files Browse the repository at this point in the history
  • Loading branch information
mariogeiger committed Sep 12, 2022
1 parent 5195089 commit 894b083
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 2 deletions.
2 changes: 2 additions & 0 deletions ChangeLog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `e3nn.reduced_symmetric_tensor_product_basis(irreps: Irreps, order: int)`

## [0.9.0] - 2022-09-04
### Added
Expand Down
3 changes: 3 additions & 0 deletions docs/api/extra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ Extra Stuff


.. autofunction:: e3nn_jax.reduced_tensor_product_basis


.. autofunction:: e3nn_jax.reduced_symmetric_tensor_product_basis
3 changes: 2 additions & 1 deletion e3nn_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from e3nn_jax._src.mlp import MultiLayerPerceptron
from e3nn_jax._src.graph_util import index_add, radius_graph
from e3nn_jax._src.poly_envelope import poly_envelope
from e3nn_jax._src.reduced_tensor_product import reduced_tensor_product_basis
from e3nn_jax._src.reduced_tensor_product import reduced_tensor_product_basis, reduced_symmetric_tensor_product_basis

__all__ = [
"config", # not in docs
Expand Down Expand Up @@ -132,4 +132,5 @@
"radius_graph",
"poly_envelope",
"reduced_tensor_product_basis",
"reduced_symmetric_tensor_product_basis",
]
24 changes: 23 additions & 1 deletion e3nn_jax/_src/reduced_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def reduced_tensor_product_basis(
Returns:
IrrepsArray: The change of basis
The shape is ``(d1, ..., dn, irreps.dim)``
The shape is ``(d1, ..., dn, irreps_out.dim)``
where ``di`` is the dimension of the index ``i`` and ``n`` is the number of indices in the formula.
Example:
Expand Down Expand Up @@ -76,6 +76,28 @@ def reduced_tensor_product_basis(
return _reduced_tensor_product_basis(irreps, formulas, epsilon)


def reduced_symmetric_tensor_product_basis(
irreps: e3nn.Irreps,
order: int,
*,
epsilon: float = 1e-5,
):
r"""Reduce a symmetric tensor product.
Args:
irreps (Irreps): the irreps of each index.
order (int): the order of the tensor product. i.e. the number of indices.
Returns:
IrrepsArray: The change of basis
The shape is ``(d, ..., d, irreps_out.dim)``
where ``d`` is the dimension of ``irreps``.
"""
irreps = e3nn.Irreps(irreps)
formulas: FrozenSet[Tuple[int, Tuple[int, ...]]] = frozenset((1, p) for p in itertools.permutations(range(order)))
return _reduced_tensor_product_basis(tuple([irreps] * order), formulas, epsilon)


@functools.lru_cache(maxsize=None)
def _reduced_tensor_product_basis(
irreps: Tuple[e3nn.Irreps], formulas: FrozenSet[Tuple[int, Tuple[int, ...]]], epsilon: float
Expand Down
6 changes: 6 additions & 0 deletions e3nn_jax/_src/reduced_tensor_product_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ def test_reduce_tensor_elasticity_tensor_parity():
np.testing.assert_allclose(Q.array, np.einsum("ijklx->jiklx", Q.array), atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(Q.array, np.einsum("ijklx->ijlkx", Q.array), atol=1e-6, rtol=1e-6)
np.testing.assert_allclose(Q.array, np.einsum("ijklx->klijx", Q.array), atol=1e-6, rtol=1e-6)


def test_reduced_symmetric_tensor_product_basis():
Q = e3nn.reduced_symmetric_tensor_product_basis("1e", 5)
P = e3nn.reduced_tensor_product_basis("ijklm=jiklm=jklmi", i="1e")
np.testing.assert_equal(Q.array, P.array)

0 comments on commit 894b083

Please sign in to comment.