Skip to content

2023-06-19

Compare
Choose a tag to compare
@mariogeiger mariogeiger released this 19 Jun 15:51
· 195 commits to main since this release

Highlight

Add set_mul to Irreps, note that it's not an in-place operation.

irreps = e3nn.Irreps("0e + 1o")
irreps = irreps.set_mul(2)
# 2x0e+2x1o

Add the option lmax to Irreps.filter and IrrepsArray.filter.

irreps = irreps.filter(lmax=0)
irreps
# 2x0e

e3nn.utils is now directly accessible as a submodule and has a documentation.

x1 = e3nn.IrrepsArray("1o", jnp.array([1.0, 3.0, 4.0]))
x2 = e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 4.0]))
y1, y2 = e3nn.utils.equivariance_test(
    e3nn.tensor_product, jax.random.PRNGKey(0), x1, x2
)
# y1 = R x1 otimes R x2
# y2 = R (x1 otimes x2)

Changelog

Changed

  • [BREAKING] Renamed e3nn.util in e3nn.utils

Added

  • Irreps.set_mul(int) to set the multiplicity of all irreps
  • Irreps.filter(lmax=int) to filter out irreps with l > lmax
  • IrrepsArray.filter(lmax=int) to filter out irreps with l > lmax
  • IrrepsArray.__radd__ and IrrepsArray.__rsub__ to support scalar + IrrepsArray and scalar - IrrepsArray
  • 0 + IrrepsArray and 0 - IrrepsArray are now always accepted as special cases.
  • Support for IrrepsArray / array
  • Add utils as a submodule

Fixed

  • e3nn.scatter operation handle indices with ndim > 1