Skip to content

e3nn/e3nn-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

a1330a2 Â· Sep 27, 2022
Sep 2, 2022
Aug 3, 2022
Sep 12, 2022
Sep 27, 2022
Sep 7, 2022
Dec 27, 2021
Mar 19, 2022
Aug 29, 2022
Aug 29, 2022
Sep 27, 2022
Nov 13, 2021
Aug 3, 2022
Aug 3, 2022
Sep 7, 2022
Aug 29, 2022
Apr 18, 2022
Apr 18, 2022
Sep 6, 2022
Sep 7, 2022

Repository files navigation

e3nn-jax Coverage Status

💥 Warning 💥

Please always check the ChangeLog for breaking changes.

Installation

To install the latest released version:

pip install --upgrade e3nn-jax

To install the latest GitHub version:

pip install git+https://github.com/e3nn/e3nn-jax.git

To install from a local copy for development, we recommend creating a virtual enviroment:

python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt

To check that the tests are running:

pip install pytest
pytest tests/tensor_products_test.py

What is different from the PyTorch version?

  • No more shared_weights and internal_weights in TensorProduct. Extensive use of jax.vmap instead (see example below)
  • Support of python structure IrrepsArray that contains a contiguous version of the data and a list of jnp.ndarray for the data. This allows to avoid unnecessary jnp.concatenante followed by indexing to reverse the concatenation (even that jax.jit is probably able to unroll the concatenations)
  • Support of None in the list of jnp.ndarray to avoid unnecessary computation with zeros (basically imposing 0 * x = 0, which is not simplified by default by jax because 0 * nan = nan)

Examples

The examples are moved in the documentation.