Please always check the ChangeLog for breaking changes.
To install the latest released version:
pip install --upgrade e3nn-jax
To install the latest GitHub version:
pip install git+https://github.com/e3nn/e3nn-jax.git
To install from a local copy for development, we recommend creating a virtual enviroment:
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
To check that the tests are running:
pip install pytest
pytest tests/tensor_products_test.py
- No more
shared_weights
andinternal_weights
inTensorProduct
. Extensive use ofjax.vmap
instead (see example below) - Support of python structure
IrrepsArray
that contains a contiguous version of the data and a list ofjnp.ndarray
for the data. This allows to avoid unnecessaryjnp.concatenante
followed by indexing to reverse the concatenation (even thatjax.jit
is probably able to unroll the concatenations) - Support of
None
in the list ofjnp.ndarray
to avoid unnecessary computation with zeros (basically imposing0 * x = 0
, which is not simplified by default by jax because0 * nan = nan
)
The examples are moved in the documentation.