diff --git a/e3nn_jax/__init__.py b/e3nn_jax/__init__.py index f8326303..b0745704 100644 --- a/e3nn_jax/__init__.py +++ b/e3nn_jax/__init__.py @@ -55,7 +55,13 @@ from e3nn_jax._src.basic import concatenate, stack, mean, norm, normal, dot, cross from e3nn_jax._src.basic import sum_ as sum from e3nn_jax._src.spherical_harmonics import spherical_harmonics, sh, legendre -from e3nn_jax._src.radial import sus, soft_one_hot_linspace, bessel, poly_envelope, soft_envelope +from e3nn_jax._src.radial import ( + sus, + soft_one_hot_linspace, + bessel, + poly_envelope, + soft_envelope, +) from e3nn_jax._src.instruction import Instruction from e3nn_jax._src.linear import FunctionalLinear from e3nn_jax._src.core_tensor_product import FunctionalTensorProduct @@ -66,7 +72,12 @@ tensor_square, ) from e3nn_jax._src.grad import grad -from e3nn_jax._src.activation import soft_odd, scalar_activation, normalize_function, norm_activation +from e3nn_jax._src.activation import ( + soft_odd, + scalar_activation, + normalize_function, + norm_activation, +) from e3nn_jax._src.gate import gate from e3nn_jax._src.radius_graph import radius_graph from e3nn_jax._src.scatter import index_add, scatter_sum, scatter_max @@ -75,8 +86,17 @@ reduced_symmetric_tensor_product_basis, reduced_antisymmetric_tensor_product_basis, ) -from e3nn_jax._src.s2grid import s2_irreps, to_s2grid, to_s2point, from_s2grid, s2_dirac, SphericalSignal -from e3nn_jax._src.tensor_product_with_spherical_harmonics import tensor_product_with_spherical_harmonics +from e3nn_jax._src.s2grid import ( + s2_irreps, + to_s2grid, + to_s2point, + from_s2grid, + s2_dirac, + SphericalSignal, +) +from e3nn_jax._src.tensor_product_with_spherical_harmonics import ( + tensor_product_with_spherical_harmonics, +) # make submodules flax and haiku available from e3nn_jax import flax, haiku diff --git a/e3nn_jax/_src/J.py b/e3nn_jax/_src/J.py index 0dfbb833..a5291d96 100644 --- a/e3nn_jax/_src/J.py +++ b/e3nn_jax/_src/J.py @@ -71,52 +71,472 @@ J6 = np.array( [ - [0, 0, 0, 0, 0, 0, 0, -3 * sqrt(22) / 16, 0, sqrt(55) / 16, 0, -sqrt(3) / 16, 0], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -3 * sqrt(22) / 16, + 0, + sqrt(55) / 16, + 0, + -sqrt(3) / 16, + 0, + ], [0, 5 / 16, 0, -sqrt(165) / 16, 0, sqrt(66) / 16, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, -sqrt(3) / 4, 0, -sqrt(30) / 8, 0, sqrt(22) / 8, 0], [0, -sqrt(165) / 16, 0, 1 / 16, 0, 3 * sqrt(10) / 16, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, -sqrt(10) / 16, 0, -9 / 16, 0, -sqrt(165) / 16, 0], [0, sqrt(66) / 16, 0, 3 * sqrt(10) / 16, 0, 5 / 8, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, -5 / 16, 0, -sqrt(210) / 32, 0, -3 * sqrt(7) / 16, 0, -sqrt(462) / 32], - [-3 * sqrt(22) / 16, 0, -sqrt(3) / 4, 0, -sqrt(10) / 16, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, -sqrt(210) / 32, 0, -17 / 32, 0, -sqrt(30) / 32, 0, 3 * sqrt(55) / 32], + [ + 0, + 0, + 0, + 0, + 0, + 0, + -5 / 16, + 0, + -sqrt(210) / 32, + 0, + -3 * sqrt(7) / 16, + 0, + -sqrt(462) / 32, + ], + [ + -3 * sqrt(22) / 16, + 0, + -sqrt(3) / 4, + 0, + -sqrt(10) / 16, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + -sqrt(210) / 32, + 0, + -17 / 32, + 0, + -sqrt(30) / 32, + 0, + 3 * sqrt(55) / 32, + ], [sqrt(55) / 16, 0, -sqrt(30) / 8, 0, -9 / 16, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, -3 * sqrt(7) / 16, 0, -sqrt(30) / 32, 0, 13 / 16, 0, -sqrt(66) / 32], + [ + 0, + 0, + 0, + 0, + 0, + 0, + -3 * sqrt(7) / 16, + 0, + -sqrt(30) / 32, + 0, + 13 / 16, + 0, + -sqrt(66) / 32, + ], [-sqrt(3) / 16, 0, sqrt(22) / 8, 0, -sqrt(165) / 16, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, -sqrt(462) / 32, 0, 3 * sqrt(55) / 32, 0, -sqrt(66) / 32, 0, 1 / 32], + [ + 0, + 0, + 0, + 0, + 0, + 0, + -sqrt(462) / 32, + 0, + 3 * sqrt(55) / 32, + 0, + -sqrt(66) / 32, + 0, + 1 / 32, + ], ] ) J7 = np.array( [ - [0, 0, 0, 0, 0, 0, 0, sqrt(429) / 32, 0, -sqrt(2002) / 64, 0, sqrt(91) / 32, 0, -sqrt(14) / 64, 0], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + sqrt(429) / 32, + 0, + -sqrt(2002) / 64, + 0, + sqrt(91) / 32, + 0, + -sqrt(14) / 64, + 0, + ], [0, 3 / 16, 0, -sqrt(26) / 8, 0, sqrt(143) / 16, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, sqrt(231) / 32, 0, sqrt(22) / 64, 0, -25 / 32, 0, 5 * sqrt(26) / 64, 0], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + sqrt(231) / 32, + 0, + sqrt(22) / 64, + 0, + -25 / 32, + 0, + 5 * sqrt(26) / 64, + 0, + ], [0, -sqrt(26) / 8, 0, 1 / 2, 0, sqrt(22) / 8, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 3 * sqrt(21) / 32, 0, 19 * sqrt(2) / 64, 0, -sqrt(11) / 32, 0, -3 * sqrt(286) / 64, 0], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 3 * sqrt(21) / 32, + 0, + 19 * sqrt(2) / 64, + 0, + -sqrt(11) / 32, + 0, + -3 * sqrt(286) / 64, + 0, + ], [0, sqrt(143) / 16, 0, sqrt(22) / 8, 0, 5 / 16, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 5 * sqrt(7) / 32, 0, 15 * sqrt(6) / 64, 0, 3 * sqrt(33) / 32, 0, sqrt(858) / 64, 0], - [sqrt(429) / 32, 0, sqrt(231) / 32, 0, 3 * sqrt(21) / 32, 0, 5 * sqrt(7) / 32, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, -5 / 64, 0, -9 * sqrt(3) / 64, 0, -5 * sqrt(33) / 64, 0, -sqrt(3003) / 64], - [-sqrt(2002) / 64, 0, sqrt(22) / 64, 0, 19 * sqrt(2) / 64, 0, 15 * sqrt(6) / 64, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, -9 * sqrt(3) / 64, 0, -39 / 64, 0, -11 * sqrt(11) / 64, 0, sqrt(1001) / 64], - [sqrt(91) / 32, 0, -25 / 32, 0, -sqrt(11) / 32, 0, 3 * sqrt(33) / 32, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, -5 * sqrt(33) / 64, 0, -11 * sqrt(11) / 64, 0, 43 / 64, 0, -sqrt(91) / 64], - [-sqrt(14) / 64, 0, 5 * sqrt(26) / 64, 0, -3 * sqrt(286) / 64, 0, sqrt(858) / 64, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, -sqrt(3003) / 64, 0, sqrt(1001) / 64, 0, -sqrt(91) / 64, 0, 1 / 64], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 5 * sqrt(7) / 32, + 0, + 15 * sqrt(6) / 64, + 0, + 3 * sqrt(33) / 32, + 0, + sqrt(858) / 64, + 0, + ], + [ + sqrt(429) / 32, + 0, + sqrt(231) / 32, + 0, + 3 * sqrt(21) / 32, + 0, + 5 * sqrt(7) / 32, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -5 / 64, + 0, + -9 * sqrt(3) / 64, + 0, + -5 * sqrt(33) / 64, + 0, + -sqrt(3003) / 64, + ], + [ + -sqrt(2002) / 64, + 0, + sqrt(22) / 64, + 0, + 19 * sqrt(2) / 64, + 0, + 15 * sqrt(6) / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -9 * sqrt(3) / 64, + 0, + -39 / 64, + 0, + -11 * sqrt(11) / 64, + 0, + sqrt(1001) / 64, + ], + [ + sqrt(91) / 32, + 0, + -25 / 32, + 0, + -sqrt(11) / 32, + 0, + 3 * sqrt(33) / 32, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -5 * sqrt(33) / 64, + 0, + -11 * sqrt(11) / 64, + 0, + 43 / 64, + 0, + -sqrt(91) / 64, + ], + [ + -sqrt(14) / 64, + 0, + 5 * sqrt(26) / 64, + 0, + -3 * sqrt(286) / 64, + 0, + sqrt(858) / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + -sqrt(3003) / 64, + 0, + sqrt(1001) / 64, + 0, + -sqrt(91) / 64, + 0, + 1 / 64, + ], ] ) J8 = np.array( [ - [0, 0, 0, 0, 0, 0, 0, 0, 0, sqrt(715) / 32, 0, -sqrt(273) / 32, 0, sqrt(35) / 32, 0, -1 / 32, 0], - [0, 7 / 64, 0, -5 * sqrt(35) / 64, 0, 3 * sqrt(273) / 64, 0, -sqrt(715) / 64, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, sqrt(858) / 64, 0, sqrt(910) / 64, 0, -7 * sqrt(42) / 64, 0, 3 * sqrt(30) / 64, 0], - [0, -5 * sqrt(35) / 64, 0, 45 / 64, 0, sqrt(195) / 64, 0, -sqrt(1001) / 64, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, sqrt(77) / 32, 0, 5 * sqrt(15) / 32, 0, 3 * sqrt(13) / 32, 0, -sqrt(455) / 32, 0], - [0, 3 * sqrt(273) / 64, 0, sqrt(195) / 64, 0, -17 / 64, 0, -sqrt(1155) / 64, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 0, sqrt(70) / 64, 0, 3 * sqrt(66) / 64, 0, sqrt(1430) / 64, 0, sqrt(2002) / 64, 0], - [0, -sqrt(715) / 64, 0, -sqrt(1001) / 64, 0, -sqrt(1155) / 64, 0, -35 / 64, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + sqrt(715) / 32, + 0, + -sqrt(273) / 32, + 0, + sqrt(35) / 32, + 0, + -1 / 32, + 0, + ], + [ + 0, + 7 / 64, + 0, + -5 * sqrt(35) / 64, + 0, + 3 * sqrt(273) / 64, + 0, + -sqrt(715) / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + sqrt(858) / 64, + 0, + sqrt(910) / 64, + 0, + -7 * sqrt(42) / 64, + 0, + 3 * sqrt(30) / 64, + 0, + ], + [ + 0, + -5 * sqrt(35) / 64, + 0, + 45 / 64, + 0, + sqrt(195) / 64, + 0, + -sqrt(1001) / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + sqrt(77) / 32, + 0, + 5 * sqrt(15) / 32, + 0, + 3 * sqrt(13) / 32, + 0, + -sqrt(455) / 32, + 0, + ], + [ + 0, + 3 * sqrt(273) / 64, + 0, + sqrt(195) / 64, + 0, + -17 / 64, + 0, + -sqrt(1155) / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + sqrt(70) / 64, + 0, + 3 * sqrt(66) / 64, + 0, + sqrt(1430) / 64, + 0, + sqrt(2002) / 64, + 0, + ], + [ + 0, + -sqrt(715) / 64, + 0, + -sqrt(1001) / 64, + 0, + -sqrt(1155) / 64, + 0, + -35 / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], [ 0, 0, @@ -136,14 +556,158 @@ 0, 3 * sqrt(715) / 128, ], - [sqrt(715) / 32, 0, sqrt(858) / 64, 0, sqrt(77) / 32, 0, sqrt(70) / 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 3 * sqrt(70) / 64, 0, 1 / 2, 0, sqrt(110) / 32, 0, 0, 0, -sqrt(2002) / 64], - [-sqrt(273) / 32, 0, sqrt(910) / 64, 0, 5 * sqrt(15) / 32, 0, 3 * sqrt(66) / 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 3 * sqrt(77) / 64, 0, sqrt(110) / 32, 0, -9 / 32, 0, -sqrt(546) / 32, 0, sqrt(455) / 64], - [sqrt(35) / 32, 0, -7 * sqrt(42) / 64, 0, 3 * sqrt(13) / 32, 0, sqrt(1430) / 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, sqrt(858) / 64, 0, 0, 0, -sqrt(546) / 32, 0, 1 / 2, 0, -sqrt(30) / 64], - [-1 / 32, 0, 3 * sqrt(30) / 64, 0, -sqrt(455) / 32, 0, sqrt(2002) / 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 0, 3 * sqrt(715) / 128, 0, -sqrt(2002) / 64, 0, sqrt(455) / 64, 0, -sqrt(30) / 64, 0, 1 / 128], + [ + sqrt(715) / 32, + 0, + sqrt(858) / 64, + 0, + sqrt(77) / 32, + 0, + sqrt(70) / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 3 * sqrt(70) / 64, + 0, + 1 / 2, + 0, + sqrt(110) / 32, + 0, + 0, + 0, + -sqrt(2002) / 64, + ], + [ + -sqrt(273) / 32, + 0, + sqrt(910) / 64, + 0, + 5 * sqrt(15) / 32, + 0, + 3 * sqrt(66) / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 3 * sqrt(77) / 64, + 0, + sqrt(110) / 32, + 0, + -9 / 32, + 0, + -sqrt(546) / 32, + 0, + sqrt(455) / 64, + ], + [ + sqrt(35) / 32, + 0, + -7 * sqrt(42) / 64, + 0, + 3 * sqrt(13) / 32, + 0, + sqrt(1430) / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + sqrt(858) / 64, + 0, + 0, + 0, + -sqrt(546) / 32, + 0, + 1 / 2, + 0, + -sqrt(30) / 64, + ], + [ + -1 / 32, + 0, + 3 * sqrt(30) / 64, + 0, + -sqrt(455) / 32, + 0, + sqrt(2002) / 64, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 3 * sqrt(715) / 128, + 0, + -sqrt(2002) / 64, + 0, + sqrt(455) / 64, + 0, + -sqrt(30) / 64, + 0, + 1 / 128, + ], ] ) @@ -170,7 +734,27 @@ -3 * sqrt(2) / 256, 0, ], - [0, 1 / 16, 0, -sqrt(102) / 32, 0, sqrt(119) / 16, 0, -sqrt(442) / 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 0, + 1 / 16, + 0, + -sqrt(102) / 32, + 0, + sqrt(119) / 16, + 0, + -sqrt(442) / 32, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], [ 0, 0, @@ -192,7 +776,27 @@ 7 * sqrt(34) / 256, 0, ], - [0, -sqrt(102) / 32, 0, 23 / 32, 0, -sqrt(42) / 32, 0, -3 * sqrt(39) / 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 0, + -sqrt(102) / 32, + 0, + 23 / 32, + 0, + -sqrt(42) / 32, + 0, + -3 * sqrt(39) / 32, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], [ 0, 0, @@ -214,7 +818,27 @@ -5 * sqrt(170) / 128, 0, ], - [0, sqrt(119) / 16, 0, -sqrt(42) / 32, 0, -9 / 16, 0, -sqrt(182) / 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 0, + sqrt(119) / 16, + 0, + -sqrt(42) / 32, + 0, + -9 / 16, + 0, + -sqrt(182) / 32, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], [ 0, 0, @@ -236,7 +860,27 @@ sqrt(9282) / 128, 0, ], - [0, -sqrt(442) / 32, 0, -3 * sqrt(39) / 32, 0, -sqrt(182) / 32, 0, -7 / 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 0, + -sqrt(442) / 32, + 0, + -3 * sqrt(39) / 32, + 0, + -sqrt(182) / 32, + 0, + -7 / 32, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], [ 0, 0, @@ -1036,7 +1680,31 @@ 9 * sqrt(42) / 1024, 0, ], - [0, -sqrt(70) / 64, 0, 1 / 2, 0, -3 * sqrt(190) / 64, 0, 0, 0, sqrt(323) / 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 0, + -sqrt(70) / 64, + 0, + 1 / 2, + 0, + -3 * sqrt(190) / 64, + 0, + 0, + 0, + sqrt(323) / 32, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], [ 0, 0, @@ -1112,7 +1780,31 @@ 5 * sqrt(13566) / 1024, 0, ], - [0, -sqrt(1938) / 64, 0, 0, 0, sqrt(714) / 64, 0, 1 / 2, 0, sqrt(105) / 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ + 0, + -sqrt(1938) / 64, + 0, + 0, + 0, + sqrt(714) / 64, + 0, + 1 / 2, + 0, + sqrt(105) / 32, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], [ 0, 0, diff --git a/e3nn_jax/_src/activation.py b/e3nn_jax/_src/activation.py index 77a11572..9a4b8dab 100644 --- a/e3nn_jax/_src/activation.py +++ b/e3nn_jax/_src/activation.py @@ -125,7 +125,10 @@ def scalar_activation( assert isinstance(input, e3nn.IrrepsArray) if acts is None: - acts = [{1: even_act, -1: odd_act}[ir.p] if ir.l == 0 else None for _, ir in input.irreps] + acts = [ + {1: even_act, -1: odd_act}[ir.p] if ir.l == 0 else None + for _, ir in input.irreps + ] assert len(input.irreps) == len(acts), (input.irreps, acts) @@ -227,7 +230,9 @@ def key_value_activation(phi, key, value): assert value.ndim == 1 d = value.shape[0] - key = key / jnp.sqrt(1 / 16 + jnp.sum(key**2)) # 1/16 is arbitrary small... but not too small... + key = key / jnp.sqrt( + 1 / 16 + jnp.sum(key**2) + ) # 1/16 is arbitrary small... but not too small... scalar = jnp.sum(key * value) scalar = normalize_function(phi)(scalar) return d**0.5 * scalar * key # component normalized diff --git a/e3nn_jax/_src/activation_test.py b/e3nn_jax/_src/activation_test.py index 9bc30887..f4fc560c 100644 --- a/e3nn_jax/_src/activation_test.py +++ b/e3nn_jax/_src/activation_test.py @@ -14,7 +14,9 @@ def test_errors(keys): def test_zero_in_zero(): - x = IrrepsArray.from_list("0e + 0o + 0o + 0e", [jnp.ones((1, 1)), None, None, None], ()) + x = IrrepsArray.from_list( + "0e + 0o + 0o + 0e", [jnp.ones((1, 1)), None, None, None], () + ) y = scalar_activation(x, [jnp.tanh, jnp.tanh, lambda x: x**2, jnp.cos]) assert y.irreps == Irreps("0e + 0o + 0e + 0e") @@ -24,9 +26,9 @@ def test_zero_in_zero(): def test_irreps_argument(): - assert scalar_activation("0e + 0o + 0o + 0e", [jnp.tanh, jnp.tanh, lambda x: x**2, jnp.cos]) == Irreps( - "0e + 0o + 0e + 0e" - ) + assert scalar_activation( + "0e + 0o + 0o + 0e", [jnp.tanh, jnp.tanh, lambda x: x**2, jnp.cos] + ) == Irreps("0e + 0o + 0e + 0e") def test_norm_act(): diff --git a/e3nn_jax/_src/basic.py b/e3nn_jax/_src/basic.py index 6f1ad6c3..ea987579 100644 --- a/e3nn_jax/_src/basic.py +++ b/e3nn_jax/_src/basic.py @@ -9,7 +9,9 @@ from e3nn_jax._src.irreps_array import _infer_backend, _standardize_axis -def _align_two_irreps_arrays(input1: e3nn.IrrepsArray, input2: e3nn.IrrepsArray) -> Tuple[e3nn.IrrepsArray, e3nn.IrrepsArray]: +def _align_two_irreps_arrays( + input1: e3nn.IrrepsArray, input2: e3nn.IrrepsArray +) -> Tuple[e3nn.IrrepsArray, e3nn.IrrepsArray]: assert input1.irreps.num_irreps == input2.irreps.num_irreps irreps_in1 = list(input1.irreps) @@ -38,7 +40,10 @@ def _align_two_irreps_arrays(input1: e3nn.IrrepsArray, input2: e3nn.IrrepsArray) def _reduce( - op, array: e3nn.IrrepsArray, axis: Union[None, int, Tuple[int, ...]] = None, keepdims: bool = False + op, + array: e3nn.IrrepsArray, + axis: Union[None, int, Tuple[int, ...]] = None, + keepdims: bool = False, ) -> e3nn.IrrepsArray: axis = _standardize_axis(axis, array.ndim) @@ -50,7 +55,10 @@ def _reduce( return e3nn.IrrepsArray( array.irreps, op(array.array, axis=axis, keepdims=keepdims), - [None if x is None else op(x, axis=axis, keepdims=keepdims) for x in array.list], + [ + None if x is None else op(x, axis=axis, keepdims=keepdims) + for x in array.list + ], ) array = _reduce(op, array, axis=axis[:-1], keepdims=keepdims) @@ -62,7 +70,11 @@ def _reduce( ) -def mean(array: e3nn.IrrepsArray, axis: Union[None, int, Tuple[int, ...]] = None, keepdims: bool = False) -> e3nn.IrrepsArray: +def mean( + array: e3nn.IrrepsArray, + axis: Union[None, int, Tuple[int, ...]] = None, + keepdims: bool = False, +) -> e3nn.IrrepsArray: """Mean of IrrepsArray along the specified axis. Args: @@ -87,7 +99,11 @@ def mean(array: e3nn.IrrepsArray, axis: Union[None, int, Tuple[int, ...]] = None return _reduce(jnp.mean, array, axis, keepdims) -def sum_(array: e3nn.IrrepsArray, axis: Union[None, int, Tuple[int, ...]] = None, keepdims: bool = False) -> e3nn.IrrepsArray: +def sum_( + array: e3nn.IrrepsArray, + axis: Union[None, int, Tuple[int, ...]] = None, + keepdims: bool = False, +) -> e3nn.IrrepsArray: """Sum of IrrepsArray along the specified axis. Args: @@ -157,7 +173,9 @@ def concatenate(arrays: List[e3nn.IrrepsArray], axis: int = -1) -> e3nn.IrrepsAr if {x.irreps for x in arrays} != {arrays[0].irreps}: raise ValueError("Irreps must be the same for all arrays") - arrays = [x.replace_none_with_zeros() for x in arrays] # TODO this could be optimized + arrays = [ + x.replace_none_with_zeros() for x in arrays + ] # TODO this could be optimized return e3nn.IrrepsArray( irreps=arrays[0].irreps, array=jnp.concatenate([x.array for x in arrays], axis=axis), @@ -209,7 +227,9 @@ def stack(arrays: List[e3nn.IrrepsArray], axis=0) -> e3nn.IrrepsArray: if {x.irreps for x in arrays} != {arrays[0].irreps}: raise ValueError("Irreps must be the same for all arrays") - arrays = [x.replace_none_with_zeros() for x in arrays] # TODO this could be optimized + arrays = [ + x.replace_none_with_zeros() for x in arrays + ] # TODO this could be optimized return e3nn.IrrepsArray( irreps=arrays[0].irreps, array=jnp.stack([x.array for x in arrays], axis=axis), @@ -217,7 +237,9 @@ def stack(arrays: List[e3nn.IrrepsArray], axis=0) -> e3nn.IrrepsArray: ) -def norm(array: e3nn.IrrepsArray, *, squared: bool = False, per_irrep: bool = True) -> e3nn.IrrepsArray: +def norm( + array: e3nn.IrrepsArray, *, squared: bool = False, per_irrep: bool = True +) -> e3nn.IrrepsArray: """Norm of IrrepsArray. Args: @@ -263,7 +285,9 @@ def f(x): return e3nn.IrrepsArray("0e", f(array.array)) -def dot(a: e3nn.IrrepsArray, b: e3nn.IrrepsArray, per_irrep: bool = False) -> e3nn.IrrepsArray: +def dot( + a: e3nn.IrrepsArray, b: e3nn.IrrepsArray, per_irrep: bool = False +) -> e3nn.IrrepsArray: """Dot product of two IrrepsArray. Args: @@ -289,7 +313,9 @@ def dot(a: e3nn.IrrepsArray, b: e3nn.IrrepsArray, per_irrep: bool = False) -> e3 b = b.simplify() if a.irreps != b.irreps: - raise ValueError("Dot product is only defined for IrrepsArray with the same irreps.") + raise ValueError( + "Dot product is only defined for IrrepsArray with the same irreps." + ) if per_irrep: out = [] @@ -341,7 +367,9 @@ def cross(a: e3nn.IrrepsArray, b: e3nn.IrrepsArray) -> e3nn.IrrepsArray: if any(ir.l != 1 for _, ir in b.irreps): raise ValueError(f"Cross product is only defined for vectors. Got {b.irreps}.") if a.irreps.num_irreps != b.irreps.num_irreps: - raise ValueError("Cross product is only defined for inputs with the same number of vectors.") + raise ValueError( + "Cross product is only defined for inputs with the same number of vectors." + ) a, b = _align_two_irreps_arrays(a, b) shape = jnp.broadcast_shapes(a.shape[:-1], b.shape[:-1]) @@ -350,7 +378,9 @@ def cross(a: e3nn.IrrepsArray, b: e3nn.IrrepsArray) -> e3nn.IrrepsArray: out = [] dtype = a.dtype - for ((mul, irx), x), ((_, iry), y) in zip(zip(a.irreps, a.list), zip(b.irreps, b.list)): + for ((mul, irx), x), ((_, iry), y) in zip( + zip(a.irreps, a.list), zip(b.irreps, b.list) + ): irreps_out.append((mul, (1, irx.p * iry.p))) if x is None or y is None: out.append(None) @@ -419,7 +449,9 @@ def normal( normalization = e3nn.config("irrep_normalization") if key is None: - warnings.warn("e3nn.normal: the key (random seed) is not provided, use the hash of the irreps as key!") + warnings.warn( + "e3nn.normal: the key (random seed) is not provided, use the hash of the irreps as key!" + ) key = jax.random.PRNGKey(hash(irreps)) if normalize: @@ -432,7 +464,10 @@ def normal( return e3nn.IrrepsArray.from_list(irreps, list, leading_shape, dtype) else: if normalization == "component": - return e3nn.IrrepsArray(irreps, jax.random.normal(key, leading_shape + (irreps.dim,), dtype=dtype)) + return e3nn.IrrepsArray( + irreps, + jax.random.normal(key, leading_shape + (irreps.dim,), dtype=dtype), + ) elif normalization == "norm": list = [] for mul, ir in irreps: diff --git a/e3nn_jax/_src/batchnorm_haiku.py b/e3nn_jax/_src/batchnorm_haiku.py index c8f0dc3c..d2a847be 100644 --- a/e3nn_jax/_src/batchnorm_haiku.py +++ b/e3nn_jax/_src/batchnorm_haiku.py @@ -65,7 +65,9 @@ def _roll_avg(curr, update): field_mean = field.mean(1).reshape(batch, mul) # [batch, mul] else: field_mean = field.mean([0, 1]).reshape(mul) # [mul] - new_means.append(_roll_avg(running_mean[i_rmu : i_rmu + mul], field_mean)) + new_means.append( + _roll_avg(running_mean[i_rmu : i_rmu + mul], field_mean) + ) else: field_mean = running_mean[i_rmu : i_rmu + mul] i_rmu += mul @@ -79,7 +81,9 @@ def _roll_avg(curr, update): elif normalization == "component": field_norm = jnp.square(field).mean(3) # [batch, sample, mul] else: - raise ValueError("Invalid normalization option {}".format(normalization)) + raise ValueError( + "Invalid normalization option {}".format(normalization) + ) if reduce == "mean": field_norm = field_norm.mean(1) # [batch, mul] @@ -90,11 +94,15 @@ def _roll_avg(curr, update): if not is_instance: field_norm = field_norm.mean(0) # [mul] - new_vars.append(_roll_avg(running_var[i_wei : i_wei + mul], field_norm)) + new_vars.append( + _roll_avg(running_var[i_wei : i_wei + mul], field_norm) + ) else: field_norm = running_var[i_wei : i_wei + mul] - field_norm = jax.lax.rsqrt((1 - epsilon) * field_norm + epsilon) # [(batch,) mul] + field_norm = jax.lax.rsqrt( + (1 - epsilon) * field_norm + epsilon + ) # [(batch,) mul] if has_affine: sub_weight = weight[i_wei : i_wei + mul] # [mul] @@ -112,7 +120,9 @@ def _roll_avg(curr, update): fields.append(field) # [batch, sample, mul, repr] i_wei += mul - output = IrrepsArray.from_list(input.irreps, fields, (batch, prod(size)), input.dtype) + output = IrrepsArray.from_list( + input.irreps, fields, (batch, prod(size)), input.dtype + ) output = output.reshape((batch,) + tuple(size) + (-1,)) return output, new_means, new_vars @@ -163,7 +173,10 @@ def __init__( if normalization is None: normalization = config("irrep_normalization") - assert normalization in ["norm", "component"], "normalization needs to be 'norm' or 'component'" + assert normalization in [ + "norm", + "component", + ], "normalization needs to be 'norm' or 'component'" self.normalization = normalization def __repr__(self): @@ -186,8 +199,12 @@ def __call__(self, input: IrrepsArray, is_training: bool = True) -> IrrepsArray: num_features = input.irreps.num_irreps if not self.instance: - running_mean = hk.get_state("running_mean", shape=(num_scalar,), init=jnp.zeros) - running_var = hk.get_state("running_var", shape=(num_features,), init=jnp.ones) + running_mean = hk.get_state( + "running_mean", shape=(num_scalar,), init=jnp.zeros + ) + running_var = hk.get_state( + "running_var", shape=(num_features,), init=jnp.ones + ) else: running_mean = None running_var = None diff --git a/e3nn_jax/_src/batchnorm_haiku_test.py b/e3nn_jax/_src/batchnorm_haiku_test.py index 17f0389a..b78c219e 100644 --- a/e3nn_jax/_src/batchnorm_haiku_test.py +++ b/e3nn_jax/_src/batchnorm_haiku_test.py @@ -6,7 +6,9 @@ from e3nn_jax.util import assert_equivariant -@pytest.mark.parametrize("irreps", [e3nn.Irreps("3x0e + 3x0o + 4x1e"), e3nn.Irreps("3x0o + 3x0e + 4x1e")]) +@pytest.mark.parametrize( + "irreps", [e3nn.Irreps("3x0e + 3x0o + 4x1e"), e3nn.Irreps("3x0o + 3x0e + 4x1e")] +) def test_equivariant(keys, irreps): @hk.without_apply_rng @hk.transform_with_state @@ -19,9 +21,13 @@ def b(x, is_training=True): _, state = b.apply(params, state, e3nn.normal(irreps, next(keys), (16,))) m_train = lambda x: b.apply(params, state, x)[0] - assert_equivariant(m_train, next(keys), args_in=[e3nn.normal(irreps, next(keys), (16,))]) + assert_equivariant( + m_train, next(keys), args_in=[e3nn.normal(irreps, next(keys), (16,))] + ) m_eval = lambda x: b.apply(params, state, x, is_training=False)[0] - assert_equivariant(m_eval, next(keys), args_in=[e3nn.normal(irreps, next(keys), (16,))]) + assert_equivariant( + m_eval, next(keys), args_in=[e3nn.normal(irreps, next(keys), (16,))] + ) @pytest.mark.parametrize("affine", [True, False]) @@ -34,7 +40,13 @@ def test_modes(keys, affine, reduce, normalization, instance): @hk.without_apply_rng @hk.transform_with_state def b(x, is_training=True): - m = e3nn.haiku.BatchNorm(irreps=irreps, affine=affine, reduce=reduce, normalization=normalization, instance=instance) + m = e3nn.haiku.BatchNorm( + irreps=irreps, + affine=affine, + reduce=reduce, + normalization=normalization, + instance=instance, + ) return m(x, is_training) params, state = b.init(next(keys), e3nn.normal(irreps, next(keys), (20, 20))) @@ -71,12 +83,16 @@ def b(x, is_training=True): assert jnp.max(jnp.abs(jnp.square(a).mean([0, 1]) - 1)) < sqrt_float_tolerance a = x.list[1] # [batch, space, mul, repr] - assert jnp.max(jnp.abs(jnp.square(a).sum(3).mean([0, 1]) - 1)) < sqrt_float_tolerance + assert ( + jnp.max(jnp.abs(jnp.square(a).sum(3).mean([0, 1]) - 1)) < sqrt_float_tolerance + ) @hk.without_apply_rng @hk.transform_with_state def b(x, is_training=True): - m = e3nn.haiku.BatchNorm(irreps=irreps, normalization="component", instance=instance) + m = e3nn.haiku.BatchNorm( + irreps=irreps, normalization="component", instance=instance + ) return m(x, is_training) params, state = b.init(next(keys), e3nn.normal(irreps, next(keys), (16,))) @@ -89,4 +105,6 @@ def b(x, is_training=True): assert jnp.max(jnp.abs(jnp.square(a).mean([0, 1]) - 1)) < sqrt_float_tolerance a = x.list[1] # [batch, space, mul, repr] - assert jnp.max(jnp.abs(jnp.square(a).mean(3).mean([0, 1]) - 1)) < sqrt_float_tolerance + assert ( + jnp.max(jnp.abs(jnp.square(a).mean(3).mean([0, 1]) - 1)) < sqrt_float_tolerance + ) diff --git a/e3nn_jax/_src/core_tensor_product.py b/e3nn_jax/_src/core_tensor_product.py index 1795fa9d..e534decc 100644 --- a/e3nn_jax/_src/core_tensor_product.py +++ b/e3nn_jax/_src/core_tensor_product.py @@ -99,7 +99,9 @@ def __init__( if gradient_normalization is None: gradient_normalization = config("gradient_normalization") if isinstance(gradient_normalization, str): - gradient_normalization = {"element": 0.0, "path": 1.0}[gradient_normalization] + gradient_normalization = {"element": 0.0, "path": 1.0}[ + gradient_normalization + ] self.instructions = _normalize_instruction_path_weights( instructions, @@ -120,7 +122,9 @@ def __init__( [ jnp.ones(mul_ir.dim, dtype=bool) if any( - (ins.i_out == i_out) and (ins.path_weight != 0) and (0 not in ins.path_shape) + (ins.i_out == i_out) + and (ins.path_weight != 0) + and (0 not in ins.path_shape) for ins in self.instructions ) else jnp.zeros(mul_ir.dim, dtype=bool) @@ -247,15 +251,22 @@ def _normalize_instruction_path_weights( """Returns instructions with normalized path weights.""" def var(instruction): - return first_input_variance[instruction.i_in1] * second_input_variance[instruction.i_in2] * instruction.num_elements + return ( + first_input_variance[instruction.i_in1] + * second_input_variance[instruction.i_in2] + * instruction.num_elements + ) # Precompute normalization factors. path_normalization_sums = collections.defaultdict(lambda: 0.0) for instruction in instructions: - path_normalization_sums[instruction.i_out] += var(instruction) ** (1.0 - path_normalization_exponent) + path_normalization_sums[instruction.i_out] += var(instruction) ** ( + 1.0 - path_normalization_exponent + ) path_normalization_factors = { - instruction: var(instruction) ** path_normalization_exponent * path_normalization_sums[instruction.i_out] + instruction: var(instruction) ** path_normalization_exponent + * path_normalization_sums[instruction.i_out] for instruction in instructions } @@ -270,7 +281,11 @@ def update(instruction: Instruction) -> float: mul_ir_out = output_irreps[instruction.i_out] assert mul_ir_in1.ir.p * mul_ir_in2.ir.p == mul_ir_out.ir.p - assert abs(mul_ir_in1.ir.l - mul_ir_in2.ir.l) <= mul_ir_out.ir.l <= mul_ir_in1.ir.l + mul_ir_in2.ir.l + assert ( + abs(mul_ir_in1.ir.l - mul_ir_in2.ir.l) + <= mul_ir_out.ir.l + <= mul_ir_in1.ir.l + mul_ir_in2.ir.l + ) if irrep_normalization == "component": alpha = mul_ir_out.ir.dim @@ -297,7 +312,11 @@ def update(instruction: Instruction) -> float: return [update(instruction) for instruction in instructions] -@partial(jax.jit, static_argnums=(0,), static_argnames=("custom_einsum_jvp", "fused", "sparse")) +@partial( + jax.jit, + static_argnums=(0,), + static_argnames=("custom_einsum_jvp", "fused", "sparse"), +) @partial(jax.profiler.annotate_function, name="TensorProduct.left_right") def _left_right( self: FunctionalTensorProduct, @@ -317,7 +336,9 @@ def _left_right( return IrrepsArray.zeros(self.irreps_out, (), dtype) if sparse: - assert not custom_einsum_jvp, "custom_einsum_jvp does not support sparse tensors." + assert ( + not custom_einsum_jvp + ), "custom_einsum_jvp does not support sparse tensors." def einsum(op, *args): f = sparsify(lambda *args: jnp.einsum(op, *args)) @@ -327,7 +348,9 @@ def einsum(op, *args): einsum = opt_einsum if custom_einsum_jvp else jnp.einsum if isinstance(weights, list): - assert len(weights) == len([ins for ins in self.instructions if ins.has_weight]), ( + assert len(weights) == len( + [ins for ins in self.instructions if ins.has_weight] + ), ( len(weights), len([ins for ins in self.instructions if ins.has_weight]), ) @@ -345,15 +368,25 @@ def einsum(op, *args): assert i == weights.size del weights - assert input1.ndim == 1, f"input1 is shape {input1.shape}. Execting ndim to be 1. Use jax.vmap to map over input1" - assert input2.ndim == 1, f"input2 is shape {input2.shape}. Execting ndim to be 1. Use jax.vmap to map over input2" + assert ( + input1.ndim == 1 + ), f"input1 is shape {input1.shape}. Execting ndim to be 1. Use jax.vmap to map over input1" + assert ( + input2.ndim == 1 + ), f"input2 is shape {input2.shape}. Execting ndim to be 1. Use jax.vmap to map over input2" if fused: - output = _fused_left_right(self, weights_flat, input1, input2, einsum, sparse, dtype) + output = _fused_left_right( + self, weights_flat, input1, input2, einsum, sparse, dtype + ) else: - output = _block_left_right(self, weights_list, input1, input2, einsum, sparse, dtype) + output = _block_left_right( + self, weights_list, input1, input2, einsum, sparse, dtype + ) - assert output.dtype == dtype, f"output.dtype {output.dtype} != dtype {dtype}, Please report this bug." + assert ( + output.dtype == dtype + ), f"output.dtype {output.dtype} != dtype {dtype}, Please report this bug." return output @@ -466,7 +499,11 @@ def multiply(in1, in2, mode): out = [ _sum_tensors( - [out for ins, out in zip(self.instructions, out_list) if ins.i_out == i_out], + [ + out + for ins, out in zip(self.instructions, out_list) + if ins.i_out == i_out + ], shape=(mul_ir_out.mul, mul_ir_out.ir.dim), empty_return_none=True, dtype=dtype, @@ -569,13 +606,17 @@ def set_w3j(x, i, u, v, w): out = einsum("pijk,i,j->k", big_w3j, input1.array, input2.array) else: if has_path_with_no_weights: - weights_flat = jnp.concatenate([jnp.ones((1,), weights_flat.dtype), weights_flat]) + weights_flat = jnp.concatenate( + [jnp.ones((1,), weights_flat.dtype), weights_flat] + ) if sparse: f = sparsify(lambda w, w3j, x1, x2: einsum("p,pijk,i,j->k", w, w3j, x1, x2)) out = f(weights_flat, big_w3j, input1.array, input2.array) else: - out = einsum("p,pijk,i,j->k", weights_flat, big_w3j, input1.array, input2.array) + out = einsum( + "p,pijk,i,j->k", weights_flat, big_w3j, input1.array, input2.array + ) return IrrepsArray(self.irreps_out, out) @@ -650,7 +691,9 @@ def _right( out = einsum("uv,ijk,vj->uivk", w, w3j, x2) else: # not so useful operation because u is summed - out = einsum("ijk,vj,u->uivk", w3j, x2, jnp.ones((mul_ir_in1.mul,), dtype)) + out = einsum( + "ijk,vj,u->uivk", w3j, x2, jnp.ones((mul_ir_in1.mul,), dtype) + ) if ins.connection_mode == "uuw": assert mul_ir_in1.mul == mul_ir_in2.mul if ins.has_weight: @@ -681,7 +724,11 @@ def _right( jnp.concatenate( [ _sum_tensors( - [out for ins, out in zip(self.instructions, out_list) if (ins.i_in1, ins.i_out) == (i_in1, i_out)], + [ + out + for ins, out in zip(self.instructions, out_list) + if (ins.i_in1, ins.i_out) == (i_in1, i_out) + ], shape=(mul_ir_in1.dim, mul_ir_out.dim), dtype=dtype, ) diff --git a/e3nn_jax/_src/core_tensor_product_test.py b/e3nn_jax/_src/core_tensor_product_test.py index 60731c1b..84ba782a 100644 --- a/e3nn_jax/_src/core_tensor_product_test.py +++ b/e3nn_jax/_src/core_tensor_product_test.py @@ -37,7 +37,11 @@ def f(ws, x1, x2): g = tp.left_right - ws = [jax.random.normal(next(keys), ins.path_shape) for ins in tp.instructions if ins.has_weight] + ws = [ + jax.random.normal(next(keys), ins.path_shape) + for ins in tp.instructions + if ins.has_weight + ] x1 = e3nn.normal(tp.irreps_in1, next(keys), ()) x2 = e3nn.normal(tp.irreps_in2, next(keys), ()) @@ -138,7 +142,11 @@ def test_fused_mix_weight(keys): def test_fuse(keys): tp = e3nn.FunctionalFullyConnectedTensorProduct("2x0e+1e", "0e+1e", "1e+0e") - ws = [jax.random.normal(next(keys), ins.path_shape) for ins in tp.instructions if ins.has_weight] + ws = [ + jax.random.normal(next(keys), ins.path_shape) + for ins in tp.instructions + if ins.has_weight + ] wf = jnp.concatenate([w.flatten() for w in ws]) x1 = e3nn.normal(tp.irreps_in1, next(keys), ()) x2 = e3nn.normal(tp.irreps_in2, next(keys), ()) @@ -151,7 +159,9 @@ def test_fuse(keys): @pytest.mark.parametrize("gradient_normalization", ["element", "path", 0.5]) @pytest.mark.parametrize("path_normalization", ["element", "path", 0.5]) @pytest.mark.parametrize("irrep_normalization", ["component", "norm"]) -def test_normalization(keys, irrep_normalization, path_normalization, gradient_normalization): +def test_normalization( + keys, irrep_normalization, path_normalization, gradient_normalization +): tp = e3nn.FunctionalFullyConnectedTensorProduct( "5x0e+1x0e+10x1e", "2x0e+2x1e+10x1e", @@ -161,7 +171,11 @@ def test_normalization(keys, irrep_normalization, path_normalization, gradient_n gradient_normalization=gradient_normalization, ) - ws = [ins.weight_std * jax.random.normal(next(keys), ins.path_shape) for ins in tp.instructions if ins.has_weight] + ws = [ + ins.weight_std * jax.random.normal(next(keys), ins.path_shape) + for ins in tp.instructions + if ins.has_weight + ] x1 = e3nn.normal(tp.irreps_in1, next(keys), (), normalization=irrep_normalization) x2 = e3nn.normal(tp.irreps_in2, next(keys), (), normalization=irrep_normalization) diff --git a/e3nn_jax/_src/dropout_haiku.py b/e3nn_jax/_src/dropout_haiku.py index e58f8a1d..8ceffdd1 100644 --- a/e3nn_jax/_src/dropout_haiku.py +++ b/e3nn_jax/_src/dropout_haiku.py @@ -65,7 +65,9 @@ def __call__(self, rng, x: IrrepsArray, is_training=True) -> IrrepsArray: out_list.append(a) noises.append(jnp.ones((mul * ir.dim,), x.dtype)) else: - noise = jax.random.bernoulli(rng, p=1 - self.p, shape=(mul, 1)) / (1 - self.p) + noise = jax.random.bernoulli(rng, p=1 - self.p, shape=(mul, 1)) / ( + 1 - self.p + ) out_list.append(noise * a) noises.append(jnp.repeat(noise, ir.dim, axis=1).flatten()) diff --git a/e3nn_jax/_src/einsum.py b/e3nn_jax/_src/einsum.py index 1c7742cc..45c354c9 100644 --- a/e3nn_jax/_src/einsum.py +++ b/e3nn_jax/_src/einsum.py @@ -11,5 +11,9 @@ def einsum(eq, *xs): @einsum.defjvp -def einsum_jvp(eq: str, xs: Tuple[jnp.ndarray], x_dots: Tuple[jnp.ndarray]) -> jnp.ndarray: - return einsum(eq, *xs), sum(einsum(eq, *(xs[:i] + (x_dot,) + xs[i + 1 :])) for i, x_dot in enumerate(x_dots)) +def einsum_jvp( + eq: str, xs: Tuple[jnp.ndarray], x_dots: Tuple[jnp.ndarray] +) -> jnp.ndarray: + return einsum(eq, *xs), sum( + einsum(eq, *(xs[:i] + (x_dot,) + xs[i + 1 :])) for i, x_dot in enumerate(x_dots) + ) diff --git a/e3nn_jax/_src/fc_tp_haiku.py b/e3nn_jax/_src/fc_tp_haiku.py index 62916cfc..97e6ba4d 100644 --- a/e3nn_jax/_src/fc_tp_haiku.py +++ b/e3nn_jax/_src/fc_tp_haiku.py @@ -12,7 +12,9 @@ def __init__(self, irreps_out, *, irreps_in1=None, irreps_in2=None): self.irreps_in1 = e3nn.Irreps(irreps_in1) if irreps_in1 is not None else None self.irreps_in2 = e3nn.Irreps(irreps_in2) if irreps_in2 is not None else None - def __call__(self, x1: e3nn.IrrepsArray, x2: e3nn.IrrepsArray, **kwargs) -> e3nn.IrrepsArray: + def __call__( + self, x1: e3nn.IrrepsArray, x2: e3nn.IrrepsArray, **kwargs + ) -> e3nn.IrrepsArray: if self.irreps_in1 is not None: x1 = x1._convert(self.irreps_in1) if self.irreps_in2 is not None: @@ -21,7 +23,9 @@ def __call__(self, x1: e3nn.IrrepsArray, x2: e3nn.IrrepsArray, **kwargs) -> e3nn x1 = x1.remove_nones().simplify() x2 = x2.remove_nones().simplify() - tp = e3nn.FunctionalFullyConnectedTensorProduct(x1.irreps, x2.irreps, self.irreps_out.simplify()) + tp = e3nn.FunctionalFullyConnectedTensorProduct( + x1.irreps, x2.irreps, self.irreps_out.simplify() + ) ws = [ hk.get_parameter( ( @@ -33,6 +37,8 @@ def __call__(self, x1: e3nn.IrrepsArray, x2: e3nn.IrrepsArray, **kwargs) -> e3nn ) for ins in tp.instructions ] - f = naive_broadcast_decorator(lambda x1, x2: tp.left_right(ws, x1, x2, **kwargs)) + f = naive_broadcast_decorator( + lambda x1, x2: tp.left_right(ws, x1, x2, **kwargs) + ) output = f(x1, x2) return output._convert(self.irreps_out) diff --git a/e3nn_jax/_src/gate.py b/e3nn_jax/_src/gate.py index 3c9159b8..d57b5ff9 100644 --- a/e3nn_jax/_src/gate.py +++ b/e3nn_jax/_src/gate.py @@ -9,23 +9,40 @@ @partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) -def _gate(input: IrrepsArray, even_act, odd_act, even_gate_act, odd_gate_act, normalize_act) -> IrrepsArray: +def _gate( + input: IrrepsArray, even_act, odd_act, even_gate_act, odd_gate_act, normalize_act +) -> IrrepsArray: scalars = input.filtered(keep=["0e", "0o"]) vectors = input.filtered(drop=["0e", "0o"]) del input if vectors.shape[-1] == 0: - return scalar_activation(scalars, even_act=even_act, odd_act=odd_act, normalize_act=normalize_act) + return scalar_activation( + scalars, even_act=even_act, odd_act=odd_act, normalize_act=normalize_act + ) if scalars.irreps.dim < vectors.irreps.num_irreps: - raise ValueError("The input must have at least as many scalars as the number of non-scalar irreps") - - scalars_extra: e3nn.IrrepsArray = scalars.slice_by_mul[: scalars.irreps.dim - vectors.irreps.num_irreps] - scalars_gates: e3nn.IrrepsArray = scalars.slice_by_mul[scalars.irreps.dim - vectors.irreps.num_irreps :] + raise ValueError( + "The input must have at least as many scalars as the number of non-scalar irreps" + ) + + scalars_extra: e3nn.IrrepsArray = scalars.slice_by_mul[ + : scalars.irreps.dim - vectors.irreps.num_irreps + ] + scalars_gates: e3nn.IrrepsArray = scalars.slice_by_mul[ + scalars.irreps.dim - vectors.irreps.num_irreps : + ] del scalars - scalars_extra = scalar_activation(scalars_extra, even_act=even_act, odd_act=odd_act, normalize_act=normalize_act) - scalars_gates = scalar_activation(scalars_gates, even_act=even_gate_act, odd_act=odd_gate_act, normalize_act=normalize_act) + scalars_extra = scalar_activation( + scalars_extra, even_act=even_act, odd_act=odd_act, normalize_act=normalize_act + ) + scalars_gates = scalar_activation( + scalars_gates, + even_act=even_gate_act, + odd_act=odd_gate_act, + normalize_act=normalize_act, + ) return e3nn.concatenate([scalars_extra, scalars_gates * vectors], axis=-1) diff --git a/e3nn_jax/_src/grad.py b/e3nn_jax/_src/grad.py index 290b9556..866d35c4 100644 --- a/e3nn_jax/_src/grad.py +++ b/e3nn_jax/_src/grad.py @@ -46,16 +46,22 @@ def _grad(*args, **kwargs) -> e3nn.IrrepsArray: def naked_fun(*args, **kwargs) -> List[jnp.ndarray]: args = list(args) - args[argnums] = e3nn.IrrepsArray.from_list(irreps_in, args[argnums], leading_shape_in, x.dtype) + args[argnums] = e3nn.IrrepsArray.from_list( + irreps_in, args[argnums], leading_shape_in, x.dtype + ) if has_aux: y, aux = fun(*args, **kwargs) if not isinstance(y, e3nn.IrrepsArray): - raise TypeError(f"Expected equivariant function to return an e3nn.IrrepsArray, got {type(y)}.") + raise TypeError( + f"Expected equivariant function to return an e3nn.IrrepsArray, got {type(y)}." + ) return y.list, (y.irreps, y.shape[:-1], aux) else: y = fun(*args, **kwargs) if not isinstance(y, e3nn.IrrepsArray): - raise TypeError(f"Expected equivariant function to return an e3nn.IrrepsArray, got {type(y)}.") + raise TypeError( + f"Expected equivariant function to return an e3nn.IrrepsArray, got {type(y)}." + ) return y.list, (y.irreps, y.shape[:-1]) output = jax.jacobian( @@ -74,20 +80,39 @@ def naked_fun(*args, **kwargs) -> List[jnp.ndarray]: for mir_out, y_list in zip(irreps_out, jac): for mir_in, z in zip(irreps_in, y_list): assert z.shape == ( - leading_shape_out + (mir_out.mul, mir_out.ir.dim) + leading_shape_in + (mir_in.mul, mir_in.ir.dim) + leading_shape_out + + (mir_out.mul, mir_out.ir.dim) + + leading_shape_in + + (mir_in.mul, mir_in.ir.dim) ) z = jnp.reshape( z, - (prod(leading_shape_out), mir_out.mul, mir_out.ir.dim, prod(leading_shape_in), mir_in.mul, mir_in.ir.dim), + ( + prod(leading_shape_out), + mir_out.mul, + mir_out.ir.dim, + prod(leading_shape_in), + mir_in.mul, + mir_in.ir.dim, + ), ) for ir in mir_out.ir * mir_in.ir: irreps.append((mir_out.mul * mir_in.mul, ir)) lst.append( jnp.einsum( - "auibvj,ijk->abuvk", z, jnp.sqrt(ir.dim) * e3nn.clebsch_gordan(mir_out.ir.l, mir_in.ir.l, ir.l) - ).reshape(leading_shape_out + leading_shape_in + (mir_out.mul * mir_in.mul, ir.dim)) + "auibvj,ijk->abuvk", + z, + jnp.sqrt(ir.dim) + * e3nn.clebsch_gordan(mir_out.ir.l, mir_in.ir.l, ir.l), + ).reshape( + leading_shape_out + + leading_shape_in + + (mir_out.mul * mir_in.mul, ir.dim) + ) ) - output = e3nn.IrrepsArray.from_list(irreps, lst, leading_shape_out + leading_shape_in, x.dtype) + output = e3nn.IrrepsArray.from_list( + irreps, lst, leading_shape_out + leading_shape_in, x.dtype + ) if regroup_output: output = output.regroup() if has_aux: diff --git a/e3nn_jax/_src/grad_test.py b/e3nn_jax/_src/grad_test.py index 76e8efd1..f9eb03bf 100644 --- a/e3nn_jax/_src/grad_test.py +++ b/e3nn_jax/_src/grad_test.py @@ -5,9 +5,17 @@ def test_equivariance(): - assert_equivariant(e3nn.grad(lambda x: e3nn.tensor_product(x, x)), random.PRNGKey(0), irreps_in=("2x0e + 1e",)) - assert_equivariant(e3nn.grad(lambda x: e3nn.norm(x)), random.PRNGKey(1), irreps_in=("2x0e + 1e",)) - assert_equivariant(e3nn.grad(lambda x: e3nn.sum(x)), random.PRNGKey(2), irreps_in=("2x0e + 1e",)) + assert_equivariant( + e3nn.grad(lambda x: e3nn.tensor_product(x, x)), + random.PRNGKey(0), + irreps_in=("2x0e + 1e",), + ) + assert_equivariant( + e3nn.grad(lambda x: e3nn.norm(x)), random.PRNGKey(1), irreps_in=("2x0e + 1e",) + ) + assert_equivariant( + e3nn.grad(lambda x: e3nn.sum(x)), random.PRNGKey(2), irreps_in=("2x0e + 1e",) + ) def test_simple_grad(): @@ -15,7 +23,9 @@ def fn(x): return e3nn.sum(0.5 * e3nn.norm(x, squared=True).simplify()) x = e3nn.normal("2e + 0e + 2x1o", random.PRNGKey(0), ()) - np.testing.assert_allclose(e3nn.grad(fn, regroup_output=False)(x).array, x.array, atol=1e-6, rtol=1e-6) + np.testing.assert_allclose( + e3nn.grad(fn, regroup_output=False)(x).array, x.array, atol=1e-6, rtol=1e-6 + ) def test_aux(): diff --git a/e3nn_jax/_src/instruction.py b/e3nn_jax/_src/instruction.py index 5369f51c..84fc1494 100644 --- a/e3nn_jax/_src/instruction.py +++ b/e3nn_jax/_src/instruction.py @@ -32,7 +32,9 @@ def __post_init__(self): "uvu bool: """Compare two irreps.""" @@ -318,7 +337,9 @@ def __lt__(self, other): return (self.ir, self.mul) < (other.ir, other.mul) -jax.tree_util.register_pytree_node(MulIrrep, lambda mulir: ((), mulir), lambda mulir, _: mulir) +jax.tree_util.register_pytree_node( + MulIrrep, lambda mulir: ((), mulir), lambda mulir, _: mulir +) IntoIrreps = Union[ None, @@ -526,7 +547,9 @@ def __mul__(self, other): 2x0e+2x1e """ if isinstance(other, Irreps): - raise NotImplementedError("Use e3nn.tensor_product for this, see the documentation") + raise NotImplementedError( + "Use e3nn.tensor_product for this, see the documentation" + ) return Irreps([(mul * other, ir) for mul, ir in self]) def __rmul__(self, other): @@ -623,7 +646,9 @@ def simplify(self) -> "Irreps": """ return self.remove_zero_multiplicities().unify() - def sort(self) -> NamedTuple("Sort", irreps="Irreps", p=Tuple[int, ...], inv=Tuple[int, ...]): + def sort( + self, + ) -> NamedTuple("Sort", irreps="Irreps", p=Tuple[int, ...], inv=Tuple[int, ...]): r"""Sort the representations. Returns: @@ -835,7 +860,11 @@ def D_from_log_coordinates(self, log_coordinates, k=0): `jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ return jax.scipy.linalg.block_diag( - *[ir.D_from_log_coordinates(log_coordinates, k) for mul, ir in self for _ in range(mul)] + *[ + ir.D_from_log_coordinates(log_coordinates, k) + for mul, ir in self + for _ in range(mul) + ] ) def D_from_angles(self, alpha, beta, gamma, k=0): @@ -850,7 +879,13 @@ def D_from_angles(self, alpha, beta, gamma, k=0): Returns: `jax.numpy.ndarray`: array of shape :math:`(..., \mathrm{dim}, \mathrm{dim})` """ - return jax.scipy.linalg.block_diag(*[ir.D_from_angles(alpha, beta, gamma, k) for mul, ir in self for _ in range(mul)]) + return jax.scipy.linalg.block_diag( + *[ + ir.D_from_angles(alpha, beta, gamma, k) + for mul, ir in self + for _ in range(mul) + ] + ) def D_from_quaternion(self, q, k=0): r"""Matrix of the representation. @@ -879,7 +914,9 @@ def D_from_matrix(self, R): return self.D_from_angles(*matrix_to_angles(R), k) def D_from_axis_angle(self, axis, angle, k=0): - return self.D_from_log_coordinates(axis_angle_to_log_coordinates(axis, angle), k) + return self.D_from_log_coordinates( + axis_angle_to_log_coordinates(axis, angle), k + ) def generators(self) -> jnp.ndarray: r"""Generators of the representation. @@ -887,10 +924,14 @@ def generators(self) -> jnp.ndarray: Returns: `jax.numpy.ndarray`: array of shape :math:`(3, \mathrm{dim}, \mathrm{dim})` """ - return jax.vmap(jax.scipy.linalg.block_diag)(*[ir.generators() for mul, ir in self for _ in range(mul)]) + return jax.vmap(jax.scipy.linalg.block_diag)( + *[ir.generators() for mul, ir in self for _ in range(mul)] + ) -jax.tree_util.register_pytree_node(Irreps, lambda irreps: ((), irreps), lambda irreps, _: irreps) +jax.tree_util.register_pytree_node( + Irreps, lambda irreps: ((), irreps), lambda irreps, _: irreps +) class _MulIndexSliceHelper: @@ -963,7 +1004,10 @@ def __getitem__(self, index: slice) -> Irreps: def _wigner_D_from_angles( - l: int, alpha: Optional[jnp.ndarray], beta: Optional[jnp.ndarray], gamma: Optional[jnp.ndarray] + l: int, + alpha: Optional[jnp.ndarray], + beta: Optional[jnp.ndarray], + gamma: Optional[jnp.ndarray], ) -> jnp.ndarray: r"""The Wigner-D matrix of the real irreducible representations of :math:`SO(3)`. diff --git a/e3nn_jax/_src/irreps_array.py b/e3nn_jax/_src/irreps_array.py index 5b51a4e2..6d3e81eb 100644 --- a/e3nn_jax/_src/irreps_array.py +++ b/e3nn_jax/_src/irreps_array.py @@ -14,7 +14,9 @@ def _infer_backend(pytree): - any_numpy = any(isinstance(x, np.ndarray) for x in jax.tree_util.tree_leaves(pytree)) + any_numpy = any( + isinstance(x, np.ndarray) for x in jax.tree_util.tree_leaves(pytree) + ) any_jax = any(isinstance(x, jnp.ndarray) for x in jax.tree_util.tree_leaves(pytree)) if any_numpy and any_jax: raise ValueError("Cannot mix numpy and jax arrays") @@ -69,10 +71,17 @@ class IrrepsArray: irreps: Irreps array: jnp.ndarray # this field is mendatory because it contains the shape - _list: List[Optional[jnp.ndarray]] # this field is lazy, it is computed only when needed + _list: List[ + Optional[jnp.ndarray] + ] # this field is lazy, it is computed only when needed def __init__( - self, irreps: IntoIrreps, array: jnp.ndarray, list: List[Optional[jnp.ndarray]] = None, *, _perform_checks: bool = True + self, + irreps: IntoIrreps, + array: jnp.ndarray, + list: List[Optional[jnp.ndarray]] = None, + *, + _perform_checks: bool = True, ): """Create an IrrepsArray.""" self.irreps = Irreps(irreps) @@ -81,7 +90,9 @@ def __init__( if _perform_checks: if not isinstance(self.array, (np.ndarray, jnp.ndarray)): - raise ValueError(f"IrrepsArray: Array must be a jax.numpy.ndarray, got {type(self.array)}") + raise ValueError( + f"IrrepsArray: Array must be a jax.numpy.ndarray, got {type(self.array)}" + ) if self.array.shape[-1] != self.irreps.dim: raise ValueError( f"IrrepsArray: Array shape {self.array.shape} incompatible with irreps {self.irreps}. " @@ -89,7 +100,9 @@ def __init__( ) if self._list is not None: if len(self._list) != len(self.irreps): - raise ValueError(f"IrrepsArray: List length {len(self._list)} incompatible with irreps {self.irreps}.") + raise ValueError( + f"IrrepsArray: List length {len(self._list)} incompatible with irreps {self.irreps}." + ) for x, (mul, ir) in zip(self._list, self.irreps): if x is not None: if x.shape != self.array.shape[:-1] + (mul, ir.dim): @@ -98,14 +111,21 @@ def __init__( f"incompatible with array shape {self.array.shape} and irreps {self.irreps}. " f"Expecting {[self.array.shape[:-1] + (mul, ir.dim) for (mul, ir) in self.irreps]}." ) - assert all(x.dtype == self.array.dtype for x in self._list if x is not None), ( + assert all( + x.dtype == self.array.dtype for x in self._list if x is not None + ), ( f"IrrepsArray: List dtypes {[None if x is None else x.dtype for x in self._list]} " f"incompatible with array dtype {self.array.dtype}." ) @staticmethod def from_list( - irreps: IntoIrreps, list: List[Optional[jnp.ndarray]], leading_shape: Tuple[int], dtype=None, *, backend=None + irreps: IntoIrreps, + list: List[Optional[jnp.ndarray]], + leading_shape: Tuple[int], + dtype=None, + *, + backend=None, ) -> "IrrepsArray": r"""Create an IrrepsArray from a list of arrays. @@ -121,12 +141,19 @@ def from_list( irreps = Irreps(irreps) if len(irreps) != len(list): - raise ValueError(f"IrrepsArray.from_list: len(irreps) != len(list), {len(irreps)} != {len(list)}") + raise ValueError( + f"IrrepsArray.from_list: len(irreps) != len(list), {len(irreps)} != {len(list)}" + ) if not all(x is None or isinstance(x, jnp.ndarray) for x in list): - raise ValueError(f"IrrepsArray.from_list: list contains non-array elements type={[type(x) for x in list]}") + raise ValueError( + f"IrrepsArray.from_list: list contains non-array elements type={[type(x) for x in list]}" + ) - if not all(x is None or x.shape == leading_shape + (mul, ir.dim) for x, (mul, ir) in zip(list, irreps)): + if not all( + x is None or x.shape == leading_shape + (mul, ir.dim) + for x, (mul, ir) in zip(list, irreps) + ): raise ValueError( f"IrrepsArray.from_list: list shapes {[None if x is None else x.shape for x in list]} " f"incompatible with leading shape {leading_shape} and irreps {irreps}. " @@ -139,12 +166,16 @@ def from_list( break if dtype is None: - raise ValueError("IrrepsArray.from_list: Need to specify dtype if list is empty or contains only None.") + raise ValueError( + "IrrepsArray.from_list: Need to specify dtype if list is empty or contains only None." + ) if irreps.dim > 0: array = jnp.concatenate( [ - jnp.zeros(leading_shape + (mul_ir.dim,), dtype) if x is None else x.reshape(leading_shape + (mul_ir.dim,)) + jnp.zeros(leading_shape + (mul_ir.dim,), dtype) + if x is None + else x.reshape(leading_shape + (mul_ir.dim,)) for mul_ir, x in zip(irreps, list) ], axis=-1, @@ -170,7 +201,9 @@ def as_irreps_array(array: Union[jnp.ndarray, "IrrepsArray"], *, backend=None): array = jnp.asarray(array) if array.ndim == 0: - raise ValueError("IrrepsArray.as_irreps_array: Cannot convert an array of rank 0 to an IrrepsArray.") + raise ValueError( + "IrrepsArray.as_irreps_array: Cannot convert an array of rank 0 to an IrrepsArray." + ) return IrrepsArray(f"{array.shape[-1]}x0e", array) @@ -184,13 +217,17 @@ def zeros(irreps: IntoIrreps, leading_shape, dtype=None) -> "IrrepsArray": r"""Create an IrrepsArray of zeros.""" irreps = Irreps(irreps) return IrrepsArray( - irreps=irreps, array=jnp.zeros(leading_shape + (irreps.dim,), dtype=dtype), list=[None] * len(irreps) + irreps=irreps, + array=jnp.zeros(leading_shape + (irreps.dim,), dtype=dtype), + list=[None] * len(irreps), ) @staticmethod def zeros_like(irreps_array: "IrrepsArray") -> "IrrepsArray": r"""Create an IrrepsArray of zeros with the same shape as another IrrepsArray.""" - return IrrepsArray.zeros(irreps_array.irreps, irreps_array.shape[:-1], irreps_array.dtype) + return IrrepsArray.zeros( + irreps_array.irreps, irreps_array.shape[:-1], irreps_array.dtype + ) @staticmethod def ones(irreps: IntoIrreps, leading_shape, dtype=None) -> "IrrepsArray": @@ -269,12 +306,16 @@ def __repr__(self): # noqa: D105 def __len__(self): # noqa: D105 return len(self.array) - def __eq__(self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray]) -> "IrrepsArray": # noqa: D105 + def __eq__( + self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray] + ) -> "IrrepsArray": # noqa: D105 jnp = _infer_backend(self.array) if isinstance(other, IrrepsArray): if self.irreps != other.irreps: - raise ValueError("IrrepsArray({self.irreps}) == IrrepsArray({other.irreps}) is not equivariant.") + raise ValueError( + "IrrepsArray({self.irreps}) == IrrepsArray({other.irreps}) is not equivariant." + ) leading_shape = jnp.broadcast_shapes(self.shape[:-1], other.shape[:-1]) @@ -288,18 +329,31 @@ def eq(mul: int, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: return jnp.all(x == y, axis=-1) - list = [eq(mul, x, y)[..., None] for (mul, ir), x, y in zip(self.irreps, self.list, other.list)] - return IrrepsArray.from_list([(mul, "0e") for mul, _ in self.irreps], list, leading_shape, bool) + list = [ + eq(mul, x, y)[..., None] + for (mul, ir), x, y in zip(self.irreps, self.list, other.list) + ] + return IrrepsArray.from_list( + [(mul, "0e") for mul, _ in self.irreps], list, leading_shape, bool + ) other = jnp.asarray(other) if self.irreps.lmax > 0 or (other.ndim > 0 and other.shape[-1] != 1): - raise ValueError(f"IrrepsArray({self.irreps}) == scalar(shape={other.shape}) is not equivariant.") + raise ValueError( + f"IrrepsArray({self.irreps}) == scalar(shape={other.shape}) is not equivariant." + ) return IrrepsArray(irreps=self.irreps, array=self.array == other) def __neg__(self: "IrrepsArray") -> "IrrepsArray": - return IrrepsArray(irreps=self.irreps, array=-self.array, list=[-x if x is not None else None for x in self.list]) + return IrrepsArray( + irreps=self.irreps, + array=-self.array, + list=[-x if x is not None else None for x in self.list], + ) - def __add__(self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray]) -> "IrrepsArray": # noqa: D105 + def __add__( + self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray] + ) -> "IrrepsArray": # noqa: D105 jnp = _infer_backend(self.array) if not isinstance(other, IrrepsArray): @@ -309,12 +363,21 @@ def __add__(self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray]) -> "I raise ValueError(f"IrrepsArray({self.irreps}) + scalar is not equivariant.") if self.irreps != other.irreps: - raise ValueError(f"IrrepsArray({self.irreps}) + IrrepsArray({other.irreps}) is not equivariant.") + raise ValueError( + f"IrrepsArray({self.irreps}) + IrrepsArray({other.irreps}) is not equivariant." + ) - list = [x if y is None else (y if x is None else x + y) for x, y in zip(self.list, other.list)] - return IrrepsArray(irreps=self.irreps, array=self.array + other.array, list=list) + list = [ + x if y is None else (y if x is None else x + y) + for x, y in zip(self.list, other.list) + ] + return IrrepsArray( + irreps=self.irreps, array=self.array + other.array, list=list + ) - def __sub__(self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray]) -> "IrrepsArray": # noqa: D105 + def __sub__( + self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray] + ) -> "IrrepsArray": # noqa: D105 jnp = _infer_backend(self.array) if not isinstance(other, IrrepsArray): @@ -324,11 +387,20 @@ def __sub__(self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray]) -> "I raise ValueError(f"IrrepsArray({self.irreps}) - scalar is not equivariant.") if self.irreps != other.irreps: - raise ValueError(f"IrrepsArray({self.irreps}) - IrrepsArray({other.irreps}) is not equivariant.") - list = [x if y is None else (-y if x is None else x - y) for x, y in zip(self.list, other.list)] - return IrrepsArray(irreps=self.irreps, array=self.array - other.array, list=list) + raise ValueError( + f"IrrepsArray({self.irreps}) - IrrepsArray({other.irreps}) is not equivariant." + ) + list = [ + x if y is None else (-y if x is None else x - y) + for x, y in zip(self.list, other.list) + ] + return IrrepsArray( + irreps=self.irreps, array=self.array - other.array, list=list + ) - def __mul__(self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray]) -> "IrrepsArray": # noqa: D105 + def __mul__( + self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray] + ) -> "IrrepsArray": # noqa: D105 jnp = _infer_backend(self.array) if isinstance(other, IrrepsArray): @@ -351,51 +423,85 @@ def __mul__(self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray]) -> "I return e3nn.elementwise_tensor_product(self, other) if self.irreps.lmax > 0 and other.ndim > 0 and other.shape[-1] != 1: - raise ValueError(f"IrrepsArray({self.irreps}) * scalar(shape={other.shape}) is not equivariant.") + raise ValueError( + f"IrrepsArray({self.irreps}) * scalar(shape={other.shape}) is not equivariant." + ) list = [None if x is None else x * other[..., None] for x in self.list] return IrrepsArray(irreps=self.irreps, array=self.array * other, list=list) - def __rmul__(self: "IrrepsArray", other: jnp.ndarray) -> "IrrepsArray": # noqa: D105 + def __rmul__( + self: "IrrepsArray", other: jnp.ndarray + ) -> "IrrepsArray": # noqa: D105 return self * other - def __truediv__(self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray]) -> "IrrepsArray": # noqa: D105 + def __truediv__( + self: "IrrepsArray", other: Union["IrrepsArray", jnp.ndarray] + ) -> "IrrepsArray": # noqa: D105 jnp = _infer_backend(self.array) if isinstance(other, IrrepsArray): - if len(other.irreps) == 0 or other.irreps.lmax > 0 or self.irreps.num_irreps != other.irreps.num_irreps: - raise ValueError(f"IrrepsArray({self.irreps}) / IrrepsArray({other.irreps}) is not equivariant.") + if ( + len(other.irreps) == 0 + or other.irreps.lmax > 0 + or self.irreps.num_irreps != other.irreps.num_irreps + ): + raise ValueError( + f"IrrepsArray({self.irreps}) / IrrepsArray({other.irreps}) is not equivariant." + ) if any(x is None for x in other.list): - raise ValueError("There are deterministic Zeros in the array of the lhs. Cannot divide by Zero.") + raise ValueError( + "There are deterministic Zeros in the array of the lhs. Cannot divide by Zero." + ) other = 1.0 / other return e3nn.elementwise_tensor_product(self, other) other = jnp.asarray(other) if self.irreps.lmax > 0 and other.ndim > 0 and other.shape[-1] != 1: - raise ValueError(f"IrrepsArray({self.irreps}) / scalar(shape={other.shape}) is not equivariant.") + raise ValueError( + f"IrrepsArray({self.irreps}) / scalar(shape={other.shape}) is not equivariant." + ) list = [None if x is None else x / other[..., None] for x in self.list] return IrrepsArray(irreps=self.irreps, array=self.array / other, list=list) - def __rtruediv__(self: "IrrepsArray", other: jnp.ndarray) -> "IrrepsArray": # noqa: D105 + def __rtruediv__( + self: "IrrepsArray", other: jnp.ndarray + ) -> "IrrepsArray": # noqa: D105 jnp = _infer_backend((self.array, other)) other = jnp.asarray(other) if self.irreps.lmax > 0: - raise ValueError(f"scalar(shape={other.shape}) / IrrepsArray({self.irreps}) is not equivariant.") + raise ValueError( + f"scalar(shape={other.shape}) / IrrepsArray({self.irreps}) is not equivariant." + ) if any(x is None for x in self.list): - raise ValueError("There are deterministic Zeros in the array of the lhs. Cannot divide by Zero.") + raise ValueError( + "There are deterministic Zeros in the array of the lhs. Cannot divide by Zero." + ) - return IrrepsArray(irreps=self.irreps, array=other / self.array, list=[other[..., None] / x for x in self.list]) + return IrrepsArray( + irreps=self.irreps, + array=other / self.array, + list=[other[..., None] / x for x in self.list], + ) def __pow__(self, exponent) -> "IrrepsArray": # noqa: D105 if all(ir == "0e" for _, ir in self.irreps): - return IrrepsArray(irreps=self.irreps, array=self.array**exponent, list=[x**exponent for x in self.list]) + return IrrepsArray( + irreps=self.irreps, + array=self.array**exponent, + list=[x**exponent for x in self.list], + ) if exponent % 1.0 == 0.0 and self.irreps.lmax == 0: irreps = self.irreps if exponent % 2.0 == 0.0: irreps = [(mul, "0e") for mul, ir in self.irreps] - return IrrepsArray(irreps, array=self.array**exponent, list=[x**exponent for x in self.list]) + return IrrepsArray( + irreps, + array=self.array**exponent, + list=[x**exponent for x in self.list], + ) raise ValueError(f"IrrepsArray({self.irreps}) ** scalar is not equivariant.") @@ -418,7 +524,11 @@ def __getitem__(self, index) -> "IrrepsArray": # noqa: D105 irreps = Irreps(index[-1]) - ii = [i for i in range(len(self.irreps)) if self.irreps[i : i + len(irreps)] == irreps] + ii = [ + i + for i in range(len(self.irreps)) + if self.irreps[i : i + len(irreps)] == irreps + ] if len(ii) != 1: raise IndexError( f"Error in IrrepsArray.__getitem__, Can't slice with {irreps} " @@ -428,7 +538,9 @@ def __getitem__(self, index) -> "IrrepsArray": # noqa: D105 return IrrepsArray( irreps, - self.array[..., self.irreps[:i].dim : self.irreps[: i + len(irreps)].dim], + self.array[ + ..., self.irreps[:i].dim : self.irreps[: i + len(irreps)].dim + ], self.list[i : i + len(irreps)], )[index[:-1] + (slice(None),)] @@ -456,7 +568,9 @@ def __getitem__(self, index) -> "IrrepsArray": # noqa: D105 if (start - self.irreps[: i - 1].dim) % ir.dim == 0: mul1 = (start - self.irreps[: i - 1].dim) // ir.dim return self._convert( - self.irreps[: i - 1] + e3nn.Irreps([(mul1, ir), (mul - mul1, ir)]) + self.irreps[i:] + self.irreps[: i - 1] + + e3nn.Irreps([(mul1, ir), (mul - mul1, ir)]) + + self.irreps[i:] )[index] if self.irreps[:i].dim == stop: @@ -469,14 +583,20 @@ def __getitem__(self, index) -> "IrrepsArray": # noqa: D105 if (stop - self.irreps[: i - 1].dim) % ir.dim == 0: mul1 = (stop - self.irreps[: i - 1].dim) // ir.dim return self._convert( - self.irreps[: i - 1] + e3nn.Irreps([(mul1, ir), (mul - mul1, ir)]) + self.irreps[i:] + self.irreps[: i - 1] + + e3nn.Irreps([(mul1, ir), (mul - mul1, ir)]) + + self.irreps[i:] )[index] if irreps_start is None or irreps_stop is None: - raise IndexError(f"Error in IrrepsArray.__getitem__, unable to slice {self.irreps} with {start}:{stop}.") + raise IndexError( + f"Error in IrrepsArray.__getitem__, unable to slice {self.irreps} with {start}:{stop}." + ) return IrrepsArray( - self.irreps[irreps_start:irreps_stop], self.array[..., start:stop], self.list[irreps_start:irreps_stop] + self.irreps[irreps_start:irreps_stop], + self.array[..., start:stop], + self.list[irreps_start:irreps_stop], )[index[:-1] + (slice(None),)] if len(index) == self.ndim or any(map(_is_ellipse, index)): @@ -524,8 +644,15 @@ def reshape(self, shape) -> "IrrepsArray": """ assert shape[-1] == self.irreps.dim or shape[-1] == -1 shape = shape[:-1] - list = [None if x is None else x.reshape(shape + (mul, ir.dim)) for (mul, ir), x in zip(self.irreps, self.list)] - return IrrepsArray(irreps=self.irreps, array=self.array.reshape(shape + (self.irreps.dim,)), list=list) + list = [ + None if x is None else x.reshape(shape + (mul, ir.dim)) + for (mul, ir), x in zip(self.irreps, self.list) + ] + return IrrepsArray( + irreps=self.irreps, + array=self.array.reshape(shape + (self.irreps.dim,)), + list=list, + ) def astype(self, dtype) -> "IrrepsArray": r"""Change the dtype of the array. @@ -537,7 +664,12 @@ def astype(self, dtype) -> "IrrepsArray": IrrepsArray: new IrrepsArray """ list = [None if x is None else x.astype(dtype) for x in self.list] - return IrrepsArray(irreps=self.irreps, array=self.array.astype(dtype), list=list, _perform_checks=False) + return IrrepsArray( + irreps=self.irreps, + array=self.array.astype(dtype), + list=list, + _perform_checks=False, + ) def replace_none_with_zeros(self) -> "IrrepsArray": r"""Replace all None in ``.list`` with zeros.""" @@ -552,7 +684,9 @@ def replace_none_with_zeros(self) -> "IrrepsArray": def remove_nones(self) -> "IrrepsArray": r"""Remove all None in ``.list`` and ``.irreps``.""" if any(x is None for x in self.list): - irreps = [mul_ir for mul_ir, x in zip(self.irreps, self.list) if x is not None] + irreps = [ + mul_ir for mul_ir, x in zip(self.irreps, self.list) if x is not None + ] list = [x for x in self.list if x is not None] return IrrepsArray.from_list(irreps, list, self.shape[:-1], self.dtype) return self @@ -587,7 +721,11 @@ def sort(self) -> "IrrepsArray": """ irreps, p, inv = self.irreps.sort() return IrrepsArray.from_list( - irreps, [self.list[i] for i in inv], self.shape[:-1], self.dtype, backend=_infer_backend(self.array) + irreps, + [self.list[i] for i in inv], + self.shape[:-1], + self.dtype, + backend=_infer_backend(self.array), ) sorted = sort @@ -605,9 +743,13 @@ def regroup(self) -> "IrrepsArray": def filter( self, - keep: Union[e3nn.Irreps, List[e3nn.Irrep], Callable[[e3nn.MulIrrep], bool]] = None, + keep: Union[ + e3nn.Irreps, List[e3nn.Irrep], Callable[[e3nn.MulIrrep], bool] + ] = None, *, - drop: Union[e3nn.Irreps, List[e3nn.Irrep], Callable[[e3nn.MulIrrep], bool]] = None, + drop: Union[ + e3nn.Irreps, List[e3nn.Irrep], Callable[[e3nn.MulIrrep], bool] + ] = None, lmax: int = None, ) -> "IrrepsArray": r"""Filter the irreps. @@ -687,7 +829,9 @@ def irreps_to_axis(self) -> "IrrepsArray": # noqa: D102 # Move multiplicity to the previous last axis and back - def mul_to_axis(self, factor: Optional[int] = None, axis: int = -2) -> "IrrepsArray": + def mul_to_axis( + self, factor: Optional[int] = None, axis: int = -2 + ) -> "IrrepsArray": r"""Create a new axis in the previous last position by factoring the multiplicities. Increase the dimension of the array by 1. @@ -706,13 +850,17 @@ def mul_to_axis(self, factor: Optional[int] = None, axis: int = -2) -> "IrrepsAr """ axis = _standardize_axis(axis, self.ndim + 1) if axis == self.ndim: - raise ValueError("axis cannot be the last axis. The last axis is reserved for the irreps dimension.") + raise ValueError( + "axis cannot be the last axis. The last axis is reserved for the irreps dimension." + ) if factor is None: factor = functools.reduce(math.gcd, (mul for mul, _ in self.irreps)) if not all(mul % factor == 0 for mul, _ in self.irreps): - raise ValueError(f"factor {factor} does not divide all multiplicities: {self.irreps}") + raise ValueError( + f"factor {factor} does not divide all multiplicities: {self.irreps}" + ) irreps = Irreps([(mul // factor, ir) for mul, ir in self.irreps]) new_list = [ @@ -720,7 +868,9 @@ def mul_to_axis(self, factor: Optional[int] = None, axis: int = -2) -> "IrrepsAr for (mul, ir), x in zip(irreps, self.list) ] new_list = [None if x is None else jnp.moveaxis(x, -3, axis) for x in new_list] - return IrrepsArray.from_list(irreps, new_list, self.shape[:-1] + (factor,), self.dtype) + return IrrepsArray.from_list( + irreps, new_list, self.shape[:-1] + (factor,), self.dtype + ) def axis_to_mul(self, axis: int = -2) -> "IrrepsArray": r"""Repeat the multiplicity by the previous last axis of the array. @@ -739,7 +889,9 @@ def axis_to_mul(self, axis: int = -2) -> "IrrepsArray": axis = _standardize_axis(axis, self.ndim)[0] if axis == self.ndim - 1: - raise ValueError("The last axis is the irreps dimension and therefore cannot be converted to multiplicity.") + raise ValueError( + "The last axis is the irreps dimension and therefore cannot be converted to multiplicity." + ) new_list = [None if x is None else jnp.moveaxis(x, axis, -3) for x in self.list] new_irreps = Irreps([(self.shape[-2] * mul, ir) for mul, ir in self.irreps]) @@ -752,7 +904,9 @@ def axis_to_mul(self, axis: int = -2) -> "IrrepsArray": repeat_mul_by_last_axis = axis_to_mul factor_mul_to_last_axis = mul_to_axis - def transform_by_log_coordinates(self, log_coordinates: jnp.ndarray, k: int = 0) -> "IrrepsArray": + def transform_by_log_coordinates( + self, log_coordinates: jnp.ndarray, k: int = 0 + ) -> "IrrepsArray": r"""Rotate data by a rotation given by log coordinates. Args: @@ -763,14 +917,23 @@ def transform_by_log_coordinates(self, log_coordinates: jnp.ndarray, k: int = 0) `IrrepsArray`: rotated data """ log_coordinates = log_coordinates.astype(self.dtype) - D = {ir: ir.D_from_log_coordinates(log_coordinates, k) for ir in {ir for _, ir in self.irreps}} + D = { + ir: ir.D_from_log_coordinates(log_coordinates, k) + for ir in {ir for _, ir in self.irreps} + } new_list = [ - jnp.reshape(jnp.einsum("ij,...uj->...ui", D[ir], x), self.shape[:-1] + (mul, ir.dim)) if x is not None else None + jnp.reshape( + jnp.einsum("ij,...uj->...ui", D[ir], x), self.shape[:-1] + (mul, ir.dim) + ) + if x is not None + else None for (mul, ir), x in zip(self.irreps, self.list) ] return IrrepsArray.from_list(self.irreps, new_list, self.shape[:-1], self.dtype) - def transform_by_angles(self, alpha: float, beta: float, gamma: float, k: int = 0, inverse: bool = False) -> "IrrepsArray": + def transform_by_angles( + self, alpha: float, beta: float, gamma: float, k: int = 0, inverse: bool = False + ) -> "IrrepsArray": r"""Rotate the data by angles according to the irreps. Args: @@ -789,14 +952,33 @@ def transform_by_angles(self, alpha: float, beta: float, gamma: float, k: int = >>> x.transform_by_angles(jnp.pi, 0, 0) 1x2e [ 0.1 -2. 1. -1. 1. ] """ - alpha = alpha if isinstance(alpha, (int, float)) else jnp.asarray(alpha, dtype=self.dtype) - beta = beta if isinstance(beta, (int, float)) else jnp.asarray(beta, dtype=self.dtype) - gamma = gamma if isinstance(gamma, (int, float)) else jnp.asarray(gamma, dtype=self.dtype) - D = {ir: ir.D_from_angles(alpha, beta, gamma, k) for ir in {ir for _, ir in self.irreps}} + alpha = ( + alpha + if isinstance(alpha, (int, float)) + else jnp.asarray(alpha, dtype=self.dtype) + ) + beta = ( + beta + if isinstance(beta, (int, float)) + else jnp.asarray(beta, dtype=self.dtype) + ) + gamma = ( + gamma + if isinstance(gamma, (int, float)) + else jnp.asarray(gamma, dtype=self.dtype) + ) + D = { + ir: ir.D_from_angles(alpha, beta, gamma, k) + for ir in {ir for _, ir in self.irreps} + } if inverse: D = {ir: jnp.swapaxes(D[ir], -2, -1) for ir in D} new_list = [ - jnp.reshape(jnp.einsum("ij,...uj->...ui", D[ir], x), self.shape[:-1] + (mul, ir.dim)) if x is not None else None + jnp.reshape( + jnp.einsum("ij,...uj->...ui", D[ir], x), self.shape[:-1] + (mul, ir.dim) + ) + if x is not None + else None for (mul, ir), x in zip(self.irreps, self.list) ] return IrrepsArray.from_list(self.irreps, new_list, self.shape[:-1], self.dtype) @@ -811,9 +993,13 @@ def transform_by_quaternion(self, q: jnp.ndarray, k: int = 0) -> "IrrepsArray": Returns: `IrrepsArray`: rotated data """ - return self.transform_by_log_coordinates(e3nn.quaternion_to_log_coordinates(q), k) + return self.transform_by_log_coordinates( + e3nn.quaternion_to_log_coordinates(q), k + ) - def transform_by_axis_angle(self, axis: jnp.ndarray, angle: float, k: int = 0) -> "IrrepsArray": + def transform_by_axis_angle( + self, axis: jnp.ndarray, angle: float, k: int = 0 + ) -> "IrrepsArray": r"""Rotate data by a rotation given by an axis and an angle. Args: @@ -824,7 +1010,9 @@ def transform_by_axis_angle(self, axis: jnp.ndarray, angle: float, k: int = 0) - Returns: `IrrepsArray`: rotated data """ - return self.transform_by_log_coordinates(e3nn.axis_angle_to_log_coordinates(axis, angle), k) + return self.transform_by_log_coordinates( + e3nn.axis_angle_to_log_coordinates(axis, angle), k + ) def transform_by_matrix(self, R: jnp.ndarray) -> "IrrepsArray": r"""Rotate data by a rotation given by a matrix. @@ -904,7 +1092,13 @@ def _convert(self, irreps: IntoIrreps) -> "IrrepsArray": current_array += m else: current_array = jnp.concatenate( - [current_array, jnp.zeros(leading_shape + (m, mul_ir.ir.dim), self.dtype)], axis=-2 + [ + current_array, + jnp.zeros( + leading_shape + (m, mul_ir.ir.dim), self.dtype + ), + ], + axis=-2, ) else: if isinstance(current_array, int): @@ -912,7 +1106,14 @@ def _convert(self, irreps: IntoIrreps) -> "IrrepsArray": current_array = x else: current_array = jnp.concatenate( - [jnp.zeros(leading_shape + (current_array, mul_ir.ir.dim), self.dtype), x], axis=-2 + [ + jnp.zeros( + leading_shape + (current_array, mul_ir.ir.dim), + self.dtype, + ), + x, + ], + axis=-2, ) else: current_array = jnp.concatenate([current_array, x], axis=-2) @@ -932,8 +1133,13 @@ def _convert(self, irreps: IntoIrreps) -> "IrrepsArray": assert current_array == 0 assert len(new_list) == len(irreps) - assert all(x is None or isinstance(x, jnp.ndarray) for x in new_list), [type(x) for x in new_list] - assert all(x is None or x.shape[-2:] == (mul, ir.dim) for x, (mul, ir) in zip(new_list, irreps)) + assert all(x is None or isinstance(x, jnp.ndarray) for x in new_list), [ + type(x) for x in new_list + ] + assert all( + x is None or x.shape[-2:] == (mul, ir.dim) + for x, (mul, ir) in zip(new_list, irreps) + ) return IrrepsArray(irreps=irreps, array=self.array, list=new_list) @@ -955,11 +1161,15 @@ def broadcast_to(self, shape) -> "IrrepsArray": jax.tree_util.register_pytree_node( IrrepsArray, lambda x: ((x.array, x._list), x.irreps), - lambda x, data: IrrepsArray(irreps=x, array=data[0], list=data[1], _perform_checks=False), + lambda x, data: IrrepsArray( + irreps=x, array=data[0], list=data[1], _perform_checks=False + ), ) -def _standardize_axis(axis: Union[None, int, Tuple[int, ...]], result_ndim: int) -> Tuple[int, ...]: +def _standardize_axis( + axis: Union[None, int, Tuple[int, ...]], result_ndim: int +) -> Tuple[int, ...]: if axis is None: return tuple(range(result_ndim)) try: @@ -1011,20 +1221,27 @@ def set(self, values: Any) -> IrrepsArray: if len(index) == self.ndim or any(map(_is_ellipse, index)): if not (_is_ellipse(index[-1]) or _is_none_slice(index[-1])): - raise IndexError(f"Indexing with {index[-1]} in the irreps dimension is not supported.") + raise IndexError( + f"Indexing with {index[-1]} in the irreps dimension is not supported." + ) # Support of x.at[index, :].set(0) if isinstance(values, (int, float)) and values == 0: return IrrepsArray( self.irreps, array=self.array.at[index].set(0), - list=[None if x is None else x.at[index + (slice(None),)].set(0) for x in self.list], + list=[ + None if x is None else x.at[index + (slice(None),)].set(0) + for x in self.list + ], ) # Support of x.at[index, :].set(IrrArray(...)) if isinstance(values, IrrepsArray): if self.irreps.simplify() != values.irreps.simplify(): - raise ValueError("The irreps of the array and the values to set must be the same.") + raise ValueError( + "The irreps of the array and the values to set must be the same." + ) values = values._convert(self.irreps) @@ -1034,17 +1251,26 @@ def fn(x, y, mul, ir): if x is not None and y is None: return x.at[index + (slice(None),)].set(0) if x is None and y is not None: - return jnp.zeros(self.shape[:-1] + (mul, ir.dim), self.dtype).at[index + (slice(None),)].set(y) + return ( + jnp.zeros(self.shape[:-1] + (mul, ir.dim), self.dtype) + .at[index + (slice(None),)] + .set(y) + ) if x is None and y is None: return None return IrrepsArray( self.irreps, array=self.array.at[index].set(values.array), - list=[fn(x, y, mul, ir) for (mul, ir), x, y in zip(self.irreps, self.list, values.list)], + list=[ + fn(x, y, mul, ir) + for (mul, ir), x, y in zip(self.irreps, self.list, values.list) + ], ) - raise NotImplementedError(f"x.at[i].set(v) with v={type(values)} is not implemented.") + raise NotImplementedError( + f"x.at[i].set(v) with v={type(values)} is not implemented." + ) def add(self, values: Any) -> IrrepsArray: index = self.index @@ -1070,12 +1296,16 @@ def add(self, values: Any) -> IrrepsArray: if len(index) == self.ndim or any(map(_is_ellipse, index)): if not (_is_ellipse(index[-1]) or _is_none_slice(index[-1])): - raise IndexError(f"Indexing with {index[-1]} in the irreps dimension is not supported.") + raise IndexError( + f"Indexing with {index[-1]} in the irreps dimension is not supported." + ) # Support of x.at[index, :].add(IrrArray(...)) if isinstance(values, IrrepsArray): if self.irreps.simplify() != values.irreps.simplify(): - raise ValueError("The irreps of the array and the values to add must be the same.") + raise ValueError( + "The irreps of the array and the values to add must be the same." + ) values = values._convert(self.irreps) @@ -1085,17 +1315,26 @@ def fn(x, y, mul, ir): if x is not None and y is None: return x if x is None and y is not None: - return jnp.zeros(self.shape[:-1] + (mul, ir.dim), self.dtype).at[index + (slice(None),)].add(y) + return ( + jnp.zeros(self.shape[:-1] + (mul, ir.dim), self.dtype) + .at[index + (slice(None),)] + .add(y) + ) if x is None and y is None: return None return IrrepsArray( self.irreps, array=self.array.at[index].add(values.array), - list=[fn(x, y, mul, ir) for (mul, ir), x, y in zip(self.irreps, self.list, values.list)], + list=[ + fn(x, y, mul, ir) + for (mul, ir), x, y in zip(self.irreps, self.list, values.list) + ], ) - raise NotImplementedError(f"x.at[i].add(v) with v={type(values)} is not implemented.") + raise NotImplementedError( + f"x.at[i].add(v) with v={type(values)} is not implemented." + ) class _MulIndexSliceHelper: @@ -1106,10 +1345,14 @@ def __init__(self, irreps_array) -> None: def __getitem__(self, index: slice) -> Irreps: if not isinstance(index, slice): - raise IndexError("IrrepsArray.slice_by_mul only supports one slices (like IrrepsArray.slice_by_mul[2:4]).") + raise IndexError( + "IrrepsArray.slice_by_mul only supports one slices (like IrrepsArray.slice_by_mul[2:4])." + ) start, stop, stride = index.indices(self.irreps_array.irreps.num_irreps) if stride != 1: - raise NotImplementedError("IrrepsArray.slice_by_mul does not support strides.") + raise NotImplementedError( + "IrrepsArray.slice_by_mul does not support strides." + ) irreps = [] list = [] @@ -1140,7 +1383,9 @@ def __init__(self, irreps_array) -> None: def __getitem__(self, index: slice) -> Irreps: if not isinstance(index, slice): - raise IndexError("IrrepsArray.slice_by_dim only supports slices (like IrrepsArray.slice_by_dim[2:4]).") + raise IndexError( + "IrrepsArray.slice_by_dim only supports slices (like IrrepsArray.slice_by_dim[2:4])." + ) return self.irreps_array[..., index] @@ -1152,7 +1397,9 @@ def __init__(self, irreps_array) -> None: def __getitem__(self, index: slice) -> Irreps: if not isinstance(index, slice): - raise IndexError("IrrepsArray.slice_by_chunk only supports slices (like IrrepsArray.slice_by_chunk[2:4]).") + raise IndexError( + "IrrepsArray.slice_by_chunk only supports slices (like IrrepsArray.slice_by_chunk[2:4])." + ) start, stop, stride = index.indices(len(self.irreps_array.irreps)) return IrrepsArray.from_list( diff --git a/e3nn_jax/_src/irreps_array_test.py b/e3nn_jax/_src/irreps_array_test.py index 47be86c3..74721a5c 100644 --- a/e3nn_jax/_src/irreps_array_test.py +++ b/e3nn_jax/_src/irreps_array_test.py @@ -16,11 +16,17 @@ def test_empty(): def test_convert(): id = e3nn.IrrepsArray.from_list("10x0e + 10x0e", [None, jnp.ones((1, 10, 1))], (1,)) - assert jax.tree_util.tree_map(lambda x: x.shape, id._convert("0x0e + 20x0e + 0x0e")).list == [None, (1, 20, 1), None] - assert jax.tree_util.tree_map(lambda x: x.shape, id._convert("7x0e + 4x0e + 9x0e")).list == [None, (1, 4, 1), (1, 9, 1)] + assert jax.tree_util.tree_map( + lambda x: x.shape, id._convert("0x0e + 20x0e + 0x0e") + ).list == [None, (1, 20, 1), None] + assert jax.tree_util.tree_map( + lambda x: x.shape, id._convert("7x0e + 4x0e + 9x0e") + ).list == [None, (1, 4, 1), (1, 9, 1)] id = e3nn.IrrepsArray.from_list("10x0e + 10x1e", [None, jnp.ones((1, 10, 3))], (1,)) - assert jax.tree_util.tree_map(lambda x: x.shape, id._convert("5x0e + 5x0e + 5x1e + 5x1e")).list == [ + assert jax.tree_util.tree_map( + lambda x: x.shape, id._convert("5x0e + 5x0e + 5x1e + 5x1e") + ).list == [ None, None, (1, 5, 3), @@ -32,7 +38,14 @@ def test_convert(): a = e3nn.IrrepsArray.from_list( " 10x0e + 0x0e +1x1e + 0x0e + 9x1e + 0x0e", - [jnp.ones((2, 10, 1)), None, None, jnp.ones((2, 0, 1)), jnp.ones((2, 9, 3)), None], + [ + jnp.ones((2, 10, 1)), + None, + None, + jnp.ones((2, 0, 1)), + jnp.ones((2, 9, 3)), + None, + ], (2,), ) b = a._convert("5x0e + 0x2e + 5x0e + 0x2e + 5x1e + 5x1e") @@ -66,24 +79,39 @@ def test_indexing(): with pytest.raises(IndexError): x[..., :2] - x = e3nn.IrrepsArray("2x1e + 2x1e", jnp.array([0.1, 0.2, 0.3, 1.1, 1.2, 1.3, 2.1, 2.2, 2.3, 3.1, 3.2, 3.3])) + x = e3nn.IrrepsArray( + "2x1e + 2x1e", + jnp.array([0.1, 0.2, 0.3, 1.1, 1.2, 1.3, 2.1, 2.2, 2.3, 3.1, 3.2, 3.3]), + ) assert x[3:-3].irreps == "1e + 1e" np.testing.assert_allclose(x[3:-3].array, jnp.array([1.1, 1.2, 1.3, 2.1, 2.2, 2.3])) def test_reductions(): - x = e3nn.IrrepsArray("2x0e + 1x1e", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]])) + x = e3nn.IrrepsArray( + "2x0e + 1x1e", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]]) + ) assert e3nn.sum(x).irreps == "0e + 1e" np.testing.assert_allclose(e3nn.sum(x).array, jnp.array([12.0, 9, 10, 11])) - np.testing.assert_allclose(e3nn.sum(x, axis=0).array, jnp.array([5.0, 7, 9, 10, 11])) - np.testing.assert_allclose(e3nn.sum(x, axis=1).array, jnp.array([[3.0, 3, 4, 5], [9.0, 6, 6, 6]])) + np.testing.assert_allclose( + e3nn.sum(x, axis=0).array, jnp.array([5.0, 7, 9, 10, 11]) + ) + np.testing.assert_allclose( + e3nn.sum(x, axis=1).array, jnp.array([[3.0, 3, 4, 5], [9.0, 6, 6, 6]]) + ) - np.testing.assert_allclose(e3nn.mean(x, axis=1).array, jnp.array([[1.5, 3, 4, 5], [4.5, 6, 6, 6]])) + np.testing.assert_allclose( + e3nn.mean(x, axis=1).array, jnp.array([[1.5, 3, 4, 5], [4.5, 6, 6, 6]]) + ) def test_operators(): - x = e3nn.IrrepsArray("2x0e + 1x1e", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]])) - y = e3nn.IrrepsArray("2x0e + 1x1o", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]])) + x = e3nn.IrrepsArray( + "2x0e + 1x1e", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]]) + ) + y = e3nn.IrrepsArray( + "2x0e + 1x1o", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]]) + ) with pytest.raises(ValueError): x + 1 @@ -151,7 +179,9 @@ def f(*shape): [None, None, f(2, 1, 1), f(2, 1, 1)], (2,), ) - v = e3nn.IrrepsArray.from_list("1e + 0e + 0e + 0e", [None, f(1, 1), None, f(1, 1)], ()) + v = e3nn.IrrepsArray.from_list( + "1e + 0e + 0e + 0e", [None, f(1, 1), None, f(1, 1)], () + ) y1 = x.at[0].add(v) y2 = e3nn.IrrepsArray(x.irreps, x.array.at[0].add(v.array)) np.testing.assert_array_equal(y1.array, y2.array) @@ -178,25 +208,35 @@ def test_slice_by_mul(): def test_norm(): - x = e3nn.IrrepsArray("2x0e + 1x1e", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]])) + x = e3nn.IrrepsArray( + "2x0e + 1x1e", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]]) + ) assert e3nn.norm(x).shape == (2, 3) assert e3nn.norm(x, per_irrep=True).shape == (2, 3) assert e3nn.norm(x, per_irrep=False).shape == (2, 1) - x = e3nn.IrrepsArray.from_list("2x0e + 1x1e", [None, None], (2,), dtype=jnp.complex64) + x = e3nn.IrrepsArray.from_list( + "2x0e + 1x1e", [None, None], (2,), dtype=jnp.complex64 + ) assert e3nn.norm(x).shape == (2, 3) def test_dot(): - x = e3nn.IrrepsArray("2x0e + 1x1e", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]])) - y = e3nn.IrrepsArray("2x0e + 1x1e", jnp.array([[1.0j, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]])) + x = e3nn.IrrepsArray( + "2x0e + 1x1e", jnp.array([[1.0, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]]) + ) + y = e3nn.IrrepsArray( + "2x0e + 1x1e", jnp.array([[1.0j, 2, 3, 4, 5], [4.0, 5, 6, 6, 6]]) + ) assert e3nn.dot(x, y).shape == (2, 1) assert e3nn.dot(x, y, per_irrep=True).shape == (2, 3) assert e3nn.dot(x, y, per_irrep=False).shape == (2, 1) - y = e3nn.IrrepsArray.from_list("2x0e + 1x1e", [None, None], (2,), dtype=jnp.complex64) + y = e3nn.IrrepsArray.from_list( + "2x0e + 1x1e", [None, None], (2,), dtype=jnp.complex64 + ) assert e3nn.dot(x, y).shape == (2, 1) diff --git a/e3nn_jax/_src/irreps_test.py b/e3nn_jax/_src/irreps_test.py index 943fd715..7771dd07 100644 --- a/e3nn_jax/_src/irreps_test.py +++ b/e3nn_jax/_src/irreps_test.py @@ -58,7 +58,9 @@ def test_arithmetic(): assert 2 * e3nn.Irreps("2x2e + 4x1o") == e3nn.Irreps("4x2e + 8x1o") assert e3nn.Irreps("2x2e + 4x1o") * 2 == e3nn.Irreps("4x2e + 8x1o") - assert e3nn.Irreps("1o + 4o") + e3nn.Irreps("1o + 7e") == e3nn.Irreps("1o + 4o + 1o + 7e") + assert e3nn.Irreps("1o + 4o") + e3nn.Irreps("1o + 7e") == e3nn.Irreps( + "1o + 4o + 1o + 7e" + ) def test_empty_irreps(): @@ -115,13 +117,17 @@ def test_ordering(): def test_slice_by_mul(): assert e3nn.Irreps("10x0e").slice_by_mul[1:4] == e3nn.Irreps("3x0e") assert e3nn.Irreps("10x0e + 10x1e").slice_by_mul[5:15] == e3nn.Irreps("5x0e + 5x1e") - assert e3nn.Irreps("10x0e + 2e + 10x1e").slice_by_mul[5:15] == e3nn.Irreps("5x0e + 2e + 4x1e") + assert e3nn.Irreps("10x0e + 2e + 10x1e").slice_by_mul[5:15] == e3nn.Irreps( + "5x0e + 2e + 4x1e" + ) def test_slice_by_dim(): assert e3nn.Irreps("10x0e").slice_by_dim[1:4] == e3nn.Irreps("3x0e") assert e3nn.Irreps("10x0e + 10x1e").slice_by_dim[5:13] == e3nn.Irreps("5x0e + 1e") - assert e3nn.Irreps("10x0e + 2e + 10x1e").slice_by_dim[5:18] == e3nn.Irreps("5x0e + 2e + 1e") + assert e3nn.Irreps("10x0e + 2e + 10x1e").slice_by_dim[5:18] == e3nn.Irreps( + "5x0e + 2e + 1e" + ) def test_slice_by_chunk(): diff --git a/e3nn_jax/_src/linear.py b/e3nn_jax/_src/linear.py index 91af64a0..d5b904ff 100644 --- a/e3nn_jax/_src/linear.py +++ b/e3nn_jax/_src/linear.py @@ -41,7 +41,9 @@ def __init__( if gradient_normalization is None: gradient_normalization = config("gradient_normalization") if isinstance(gradient_normalization, str): - gradient_normalization = {"element": 0.0, "path": 1.0}[gradient_normalization] + gradient_normalization = {"element": 0.0, "path": 1.0}[ + gradient_normalization + ] irreps_in = Irreps(irreps_in) irreps_out = Irreps(irreps_out) @@ -68,7 +70,9 @@ def __init__( def alpha(this): x = irreps_in[this.i_in].mul ** path_normalization * sum( - irreps_in[other.i_in].mul ** (1.0 - path_normalization) for other in instructions if other.i_out == this.i_out + irreps_in[other.i_in].mul ** (1.0 - path_normalization) + for other in instructions + if other.i_out == this.i_out ) return 1 / x if x > 0 else 1.0 @@ -92,7 +96,13 @@ def alpha(this): assert all(ir.is_scalar() or (not b) for b, (_, ir) in zip(biases, irreps_out)) instructions += [ - Instruction(i_in=-1, i_out=i_out, path_shape=(mul_ir.dim,), path_weight=1.0, weight_std=0.0) + Instruction( + i_in=-1, + i_out=i_out, + path_shape=(mul_ir.dim,), + path_weight=1.0, + weight_std=0.0, + ) for i_out, (bias, mul_ir) in enumerate(zip(biases, irreps_out)) if bias ] @@ -102,7 +112,10 @@ def alpha(this): output_mask = jnp.concatenate( [ jnp.ones(mul_ir.dim, bool) - if any((ins.i_out == i_out) and (0 not in ins.path_shape) for ins in instructions) + if any( + (ins.i_out == i_out) and (0 not in ins.path_shape) + for ins in instructions + ) else jnp.zeros(mul_ir.dim, bool) for i_out, mul_ir in enumerate(irreps_out) ] @@ -122,7 +135,11 @@ def num_weights(self) -> int: def aggregate_paths(self, paths, output_shape, output_dtype) -> IrrepsArray: output = [ _sum_tensors( - [out for ins, out in zip(self.instructions, paths) if ins.i_out == i_out], + [ + out + for ins, out in zip(self.instructions, paths) + if ins.i_out == i_out + ], shape=output_shape + ( mul_ir_out.mul, @@ -132,20 +149,28 @@ def aggregate_paths(self, paths, output_shape, output_dtype) -> IrrepsArray: ) for i_out, mul_ir_out in enumerate(self.irreps_out) ] - return IrrepsArray.from_list(self.irreps_out, output, output_shape, output_dtype) + return IrrepsArray.from_list( + self.irreps_out, output, output_shape, output_dtype + ) def split_weights(self, weights: jnp.ndarray) -> List[jnp.ndarray]: ws = [] cursor = 0 for i in self.instructions: - ws += [weights[cursor : cursor + np.prod(i.path_shape)].reshape(i.path_shape)] + ws += [ + weights[cursor : cursor + np.prod(i.path_shape)].reshape(i.path_shape) + ] cursor += np.prod(i.path_shape) return ws - def __call__(self, ws: Union[List[jnp.ndarray], jnp.ndarray], input: IrrepsArray) -> IrrepsArray: + def __call__( + self, ws: Union[List[jnp.ndarray], jnp.ndarray], input: IrrepsArray + ) -> IrrepsArray: input = input._convert(self.irreps_in) if input.ndim != 1: - raise ValueError(f"FunctionalLinear does not support broadcasting, input shape is {input.shape}") + raise ValueError( + f"FunctionalLinear does not support broadcasting, input shape is {input.shape}" + ) if not isinstance(ws, list): ws = self.split_weights(ws) @@ -153,7 +178,11 @@ def __call__(self, ws: Union[List[jnp.ndarray], jnp.ndarray], input: IrrepsArray paths = [ ins.path_weight * w if ins.i_in == -1 - else (None if input.list[ins.i_in] is None else ins.path_weight * jnp.einsum("uw,ui->wi", w, input.list[ins.i_in])) + else ( + None + if input.list[ins.i_in] is None + else ins.path_weight * jnp.einsum("uw,ui->wi", w, input.list[ins.i_in]) + ) for ins, w in zip(self.instructions, ws) ] return self.aggregate_paths(paths, input.shape[:-1], input.dtype) @@ -173,7 +202,9 @@ def matrix(self, ws: List[jnp.ndarray]) -> jnp.ndarray: assert ins.i_in != -1 mul_in, ir_in = self.irreps_in[ins.i_in] mul_out, ir_out = self.irreps_out[ins.i_out] - output = output.at[self.irreps_in.slices()[ins.i_in], self.irreps_out.slices()[ins.i_out]].add( + output = output.at[ + self.irreps_in.slices()[ins.i_in], self.irreps_out.slices()[ins.i_out] + ].add( ins.path_weight * jnp.einsum("uw,ij->uiwj", w, jnp.eye(ir_in.dim, dtype=dtype)).reshape( (mul_in * ir_in.dim, mul_out * ir_out.dim) @@ -183,7 +214,9 @@ def matrix(self, ws: List[jnp.ndarray]) -> jnp.ndarray: def linear_vanilla( - input: IrrepsArray, linear: FunctionalLinear, get_parameter: Callable[[str, Tuple[int, ...], float, Any], jnp.ndarray] + input: IrrepsArray, + linear: FunctionalLinear, + get_parameter: Callable[[str, Tuple[int, ...], float, Any], jnp.ndarray], ) -> IrrepsArray: w = [ get_parameter( @@ -261,7 +294,8 @@ def linear_mixed( ] # List of shape (d, *path_shape) weights = weights.astype(input.array.dtype) w = [ - jnp.sqrt(alpha) ** gradient_normalization * jax.lax.dot_general(weights, wi, (((weights.ndim - 1,), (0,)), ((), ()))) + jnp.sqrt(alpha) ** gradient_normalization + * jax.lax.dot_general(weights, wi, (((weights.ndim - 1,), (0,)), ((), ()))) for wi in w ] # List of shape (..., *path_shape) @@ -301,7 +335,8 @@ def linear_mixed_per_channel( ] # List of shape (d, num_channels, *path_shape) weights = weights.astype(input.array.dtype) w = [ - jnp.sqrt(alpha) ** gradient_normalization * jax.lax.dot_general(weights, wi, (((weights.ndim - 1,), (0,)), ((), ()))) + jnp.sqrt(alpha) ** gradient_normalization + * jax.lax.dot_general(weights, wi, (((weights.ndim - 1,), (0,)), ((), ()))) for wi in w ] # List of shape (..., num_channels, *path_shape) diff --git a/e3nn_jax/_src/linear_flax.py b/e3nn_jax/_src/linear_flax.py index e471a5d9..931d2c7c 100644 --- a/e3nn_jax/_src/linear_flax.py +++ b/e3nn_jax/_src/linear_flax.py @@ -6,7 +6,13 @@ import e3nn_jax as e3nn from e3nn_jax._src.util.dtype import get_pytree_dtype -from .linear import FunctionalLinear, linear_indexed, linear_mixed, linear_mixed_per_channel, linear_vanilla +from .linear import ( + FunctionalLinear, + linear_indexed, + linear_mixed, + linear_mixed_per_channel, + linear_vanilla, +) class Linear(flax.linen.Module): @@ -115,7 +121,9 @@ def __call__(self, weights_or_input, input_or_none=None) -> e3nn.IrrepsArray: ) def param(name, shape, std, dtype): - return self.param(name, flax.linen.initializers.normal(stddev=std), shape, dtype) + return self.param( + name, flax.linen.initializers.normal(stddev=std), shape, dtype + ) if weights is None: assert not self.weights_per_channel # Not implemented yet @@ -128,19 +136,27 @@ def param(name, shape, std, dtype): if weights.dtype.kind == "i" and self.num_indexed_weights is not None: assert not self.weights_per_channel # Not implemented yet - output = linear_indexed(input, lin, param, weights, self.num_indexed_weights) + output = linear_indexed( + input, lin, param, weights, self.num_indexed_weights + ) elif weights.dtype.kind in "fc" and self.num_indexed_weights is None: gradient_normalization = self.gradient_normalization if gradient_normalization is None: gradient_normalization = e3nn.config("gradient_normalization") if isinstance(gradient_normalization, str): - gradient_normalization = {"element": 0.0, "path": 1.0}[gradient_normalization] + gradient_normalization = {"element": 0.0, "path": 1.0}[ + gradient_normalization + ] if self.weights_per_channel: - output = linear_mixed_per_channel(input, lin, param, weights, gradient_normalization) + output = linear_mixed_per_channel( + input, lin, param, weights, gradient_normalization + ) else: - output = linear_mixed(input, lin, param, weights, gradient_normalization) + output = linear_mixed( + input, lin, param, weights, gradient_normalization + ) else: raise ValueError( diff --git a/e3nn_jax/_src/linear_flax_test.py b/e3nn_jax/_src/linear_flax_test.py index 939a4be0..553b8ed9 100644 --- a/e3nn_jax/_src/linear_flax_test.py +++ b/e3nn_jax/_src/linear_flax_test.py @@ -6,8 +6,12 @@ from e3nn_jax.util import assert_output_dtype_matches_input_dtype -@pytest.mark.parametrize("irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"]) -@pytest.mark.parametrize("irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"]) +@pytest.mark.parametrize( + "irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"] +) +@pytest.mark.parametrize( + "irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"] +) def test_linear_vanilla(keys, irreps_in, irreps_out): linear = e3nn.flax.Linear(irreps_in=irreps_in, irreps_out=irreps_out) x = e3nn.normal(irreps_in, next(keys), (128,)) @@ -16,10 +20,16 @@ def test_linear_vanilla(keys, irreps_in, irreps_out): assert y.shape == (128, e3nn.Irreps(irreps_out).dim) -@pytest.mark.parametrize("irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"]) -@pytest.mark.parametrize("irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"]) +@pytest.mark.parametrize( + "irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"] +) +@pytest.mark.parametrize( + "irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"] +) def test_linear_indexed(keys, irreps_in, irreps_out): - linear = e3nn.flax.Linear(irreps_in=irreps_in, irreps_out=irreps_out, num_indexed_weights=10) + linear = e3nn.flax.Linear( + irreps_in=irreps_in, irreps_out=irreps_out, num_indexed_weights=10 + ) x = e3nn.normal(irreps_in, next(keys), (128,)) i = jnp.arange(128) % 10 w = linear.init(next(keys), i, x) @@ -27,8 +37,12 @@ def test_linear_indexed(keys, irreps_in, irreps_out): assert y.shape == (128, e3nn.Irreps(irreps_out).dim) -@pytest.mark.parametrize("irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"]) -@pytest.mark.parametrize("irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"]) +@pytest.mark.parametrize( + "irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"] +) +@pytest.mark.parametrize( + "irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"] +) def test_linear_mixed(keys, irreps_in, irreps_out): linear = e3nn.flax.Linear(irreps_in=irreps_in, irreps_out=irreps_out) x = e3nn.normal(irreps_in, next(keys), (128,)) @@ -38,10 +52,16 @@ def test_linear_mixed(keys, irreps_in, irreps_out): assert y.shape == (128, e3nn.Irreps(irreps_out).dim) -@pytest.mark.parametrize("irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"]) -@pytest.mark.parametrize("irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"]) +@pytest.mark.parametrize( + "irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"] +) +@pytest.mark.parametrize( + "irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"] +) def test_linear_mixed_per_channel(keys, irreps_in, irreps_out): - linear = e3nn.flax.Linear(irreps_in=irreps_in, irreps_out=irreps_out, weights_per_channel=True) + linear = e3nn.flax.Linear( + irreps_in=irreps_in, irreps_out=irreps_out, weights_per_channel=True + ) x = e3nn.normal(irreps_in, next(keys), (128,)) e = jax.random.normal(next(keys), (10,)) w = jax.jit(linear.init)(next(keys), e, x) @@ -49,8 +69,12 @@ def test_linear_mixed_per_channel(keys, irreps_in, irreps_out): assert y.shape == (128, e3nn.Irreps(irreps_out).dim) -@pytest.mark.parametrize("irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"]) -@pytest.mark.parametrize("irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"]) +@pytest.mark.parametrize( + "irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"] +) +@pytest.mark.parametrize( + "irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"] +) def test_linear_dtype(keys, irreps_in, irreps_out): jax.config.update("jax_enable_x64", True) diff --git a/e3nn_jax/_src/linear_haiku.py b/e3nn_jax/_src/linear_haiku.py index 62b16d51..8cb5a88a 100644 --- a/e3nn_jax/_src/linear_haiku.py +++ b/e3nn_jax/_src/linear_haiku.py @@ -6,7 +6,13 @@ import e3nn_jax as e3nn from e3nn_jax._src.util.dtype import get_pytree_dtype -from .linear import FunctionalLinear, linear_indexed, linear_mixed, linear_mixed_per_channel, linear_vanilla +from .linear import ( + FunctionalLinear, + linear_indexed, + linear_mixed, + linear_mixed_per_channel, + linear_vanilla, +) class Linear(hk.Module): @@ -80,7 +86,9 @@ def __init__( biases: bool = False, path_normalization: Union[str, float] = None, gradient_normalization: Union[str, float] = None, - get_parameter: Optional[Callable[[str, Tuple[int, ...], float, Any], jnp.ndarray]] = None, + get_parameter: Optional[ + Callable[[str, Tuple[int, ...], float, Any], jnp.ndarray] + ] = None, num_indexed_weights: Optional[int] = None, weights_per_channel: bool = False, name: Optional[str] = None, @@ -98,12 +106,19 @@ def __init__( if gradient_normalization is None: gradient_normalization = e3nn.config("gradient_normalization") if isinstance(gradient_normalization, str): - gradient_normalization = {"element": 0.0, "path": 1.0}[gradient_normalization] + gradient_normalization = {"element": 0.0, "path": 1.0}[ + gradient_normalization + ] self.gradient_normalization = gradient_normalization if get_parameter is None: - def get_parameter(name: str, path_shape: Tuple[int, ...], weight_std: float, dtype: jnp.dtype = jnp.float32): + def get_parameter( + name: str, + path_shape: Tuple[int, ...], + weight_std: float, + dtype: jnp.dtype = jnp.float32, + ): return hk.get_parameter( name, shape=path_shape, @@ -174,13 +189,27 @@ def __call__(self, weights_or_input, input_or_none=None) -> e3nn.IrrepsArray: if weights.dtype.kind == "i" and self.num_indexed_weights is not None: assert not self.weights_per_channel # Not implemented yet - output = linear_indexed(input, lin, self.get_parameter, weights, self.num_indexed_weights) + output = linear_indexed( + input, lin, self.get_parameter, weights, self.num_indexed_weights + ) elif weights.dtype.kind in "fc" and self.num_indexed_weights is None: if self.weights_per_channel: - output = linear_mixed_per_channel(input, lin, self.get_parameter, weights, self.gradient_normalization) + output = linear_mixed_per_channel( + input, + lin, + self.get_parameter, + weights, + self.gradient_normalization, + ) else: - output = linear_mixed(input, lin, self.get_parameter, weights, self.gradient_normalization) + output = linear_mixed( + input, + lin, + self.get_parameter, + weights, + self.gradient_normalization, + ) else: raise ValueError( diff --git a/e3nn_jax/_src/linear_haiku_test.py b/e3nn_jax/_src/linear_haiku_test.py index fd5aeb29..62a6cfab 100644 --- a/e3nn_jax/_src/linear_haiku_test.py +++ b/e3nn_jax/_src/linear_haiku_test.py @@ -42,8 +42,12 @@ def __call__(self, ws, x): return self.tp.left_right(ws, x, ones) -@pytest.mark.parametrize("irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"]) -@pytest.mark.parametrize("irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"]) +@pytest.mark.parametrize( + "irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"] +) +@pytest.mark.parametrize( + "irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"] +) def test_linear_like_tp(keys, irreps_in, irreps_out): """Test that Linear gives the same results as the corresponding TensorProduct.""" m = e3nn.FunctionalLinear(irreps_in, irreps_out) @@ -55,8 +59,12 @@ def test_linear_like_tp(keys, irreps_in, irreps_out): assert jnp.allclose(m(ws, x).array, m_tp(ws_tp, x).array) -@pytest.mark.parametrize("irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"]) -@pytest.mark.parametrize("irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"]) +@pytest.mark.parametrize( + "irreps_in", ["5x0e", "1e + 2e + 4x1e + 3x3o", "2x1o + 0x3e", "0x0e"] +) +@pytest.mark.parametrize( + "irreps_out", ["5x0e", "1e + 2e + 3x3o + 3x1e", "2x1o + 0x3e", "0x0e"] +) def test_linear_matrix(keys, irreps_in, irreps_out): m = e3nn.FunctionalLinear(irreps_in, irreps_out) diff --git a/e3nn_jax/_src/mlp_flax.py b/e3nn_jax/_src/mlp_flax.py index 441c69e4..a4436295 100644 --- a/e3nn_jax/_src/mlp_flax.py +++ b/e3nn_jax/_src/mlp_flax.py @@ -25,7 +25,9 @@ class MultiLayerPerceptron(flax.linen.Module): output_activation: Union[Callable, bool] = True @flax.linen.compact - def __call__(self, x: Union[jnp.ndarray, e3nn.IrrepsArray]) -> Union[jnp.ndarray, e3nn.IrrepsArray]: + def __call__( + self, x: Union[jnp.ndarray, e3nn.IrrepsArray] + ) -> Union[jnp.ndarray, e3nn.IrrepsArray]: """Evaluate the MLP Input and output are either `jax.numpy.ndarray` or `IrrepsArray`. @@ -50,7 +52,9 @@ def __call__(self, x: Union[jnp.ndarray, e3nn.IrrepsArray]) -> Union[jnp.ndarray if gradient_normalization is None: gradient_normalization = e3nn.config("gradient_normalization") if isinstance(gradient_normalization, str): - gradient_normalization = {"element": 0.0, "path": 1.0}[gradient_normalization] + gradient_normalization = {"element": 0.0, "path": 1.0}[ + gradient_normalization + ] if isinstance(x, e3nn.IrrepsArray): if not x.irreps.is_scalar(): @@ -61,14 +65,20 @@ def __call__(self, x: Union[jnp.ndarray, e3nn.IrrepsArray]) -> Union[jnp.ndarray output_irrepsarray = False act = None if self.act is None else e3nn.normalize_function(self.act) - last_act = None if output_activation is None else e3nn.normalize_function(output_activation) + last_act = ( + None + if output_activation is None + else e3nn.normalize_function(output_activation) + ) for i, h in enumerate(self.list_neurons): alpha = 1 / x.shape[-1] d = flax.linen.Dense( features=h, use_bias=False, - kernel_init=flax.linen.initializers.normal(stddev=jnp.sqrt(alpha) ** (1.0 - gradient_normalization)), + kernel_init=flax.linen.initializers.normal( + stddev=jnp.sqrt(alpha) ** (1.0 - gradient_normalization) + ), param_dtype=x.dtype, ) x = jnp.sqrt(alpha) ** gradient_normalization * d(x) diff --git a/e3nn_jax/_src/mlp_haiku.py b/e3nn_jax/_src/mlp_haiku.py index 2316d8ae..6323b2dd 100644 --- a/e3nn_jax/_src/mlp_haiku.py +++ b/e3nn_jax/_src/mlp_haiku.py @@ -45,10 +45,14 @@ def __init__( if gradient_normalization is None: gradient_normalization = e3nn.config("gradient_normalization") if isinstance(gradient_normalization, str): - gradient_normalization = {"element": 0.0, "path": 1.0}[gradient_normalization] + gradient_normalization = {"element": 0.0, "path": 1.0}[ + gradient_normalization + ] self.gradient_normalization = gradient_normalization - def __call__(self, x: Union[jnp.ndarray, e3nn.IrrepsArray]) -> Union[jnp.ndarray, e3nn.IrrepsArray]: + def __call__( + self, x: Union[jnp.ndarray, e3nn.IrrepsArray] + ) -> Union[jnp.ndarray, e3nn.IrrepsArray]: """Evaluate the MLP Input and output are either `jax.numpy.ndarray` or `IrrepsArray`. @@ -69,14 +73,20 @@ def __call__(self, x: Union[jnp.ndarray, e3nn.IrrepsArray]) -> Union[jnp.ndarray output_irrepsarray = False act = None if self.act is None else e3nn.normalize_function(self.act) - last_act = None if self.output_activation is None else e3nn.normalize_function(self.output_activation) + last_act = ( + None + if self.output_activation is None + else e3nn.normalize_function(self.output_activation) + ) for i, h in enumerate(self.list_neurons): alpha = 1 / x.shape[-1] d = hk.Linear( h, with_bias=False, - w_init=hk.initializers.RandomNormal(stddev=jnp.sqrt(alpha) ** (1.0 - self.gradient_normalization)), + w_init=hk.initializers.RandomNormal( + stddev=jnp.sqrt(alpha) ** (1.0 - self.gradient_normalization) + ), name=f"linear_{i}", ) x = jnp.sqrt(alpha) ** self.gradient_normalization * d(x) diff --git a/e3nn_jax/_src/perm_test.py b/e3nn_jax/_src/perm_test.py index 9a628de8..a57eafea 100644 --- a/e3nn_jax/_src/perm_test.py +++ b/e3nn_jax/_src/perm_test.py @@ -39,9 +39,13 @@ def test_rand(n): def test_not_group(): assert not perm.is_group(set()) # empty - assert not perm.is_group({(1, 0, 2), (0, 2, 1), (1, 2, 0), (2, 0, 1), (2, 1, 0)}) # missing neutral + assert not perm.is_group( + {(1, 0, 2), (0, 2, 1), (1, 2, 0), (2, 0, 1), (2, 1, 0)} + ) # missing neutral assert not perm.is_group({(0, 1, 2), (1, 2, 0)}) # missing inverse - assert not perm.is_group({(0, 1, 2, 3), (3, 0, 1, 2), (1, 2, 3, 0)}) # g1 . g2 not in G + assert not perm.is_group( + {(0, 1, 2, 3), (3, 0, 1, 2), (1, 2, 3, 0)} + ) # g1 . g2 not in G def test_to_cycles(): diff --git a/e3nn_jax/_src/radial.py b/e3nn_jax/_src/radial.py index d07aca6e..48a43288 100644 --- a/e3nn_jax/_src/radial.py +++ b/e3nn_jax/_src/radial.py @@ -192,12 +192,18 @@ def soft_one_hot_linspace( x = (input[..., None] - start) / (end - start) if start_zero and end_zero: i = jnp.arange(1, number + 1) - return jnp.where((0.0 < x) & (x < 1.0), jnp.sin(jnp.pi * i * x) / jnp.sqrt(0.25 + number / 2), 0.0) + return jnp.where( + (0.0 < x) & (x < 1.0), + jnp.sin(jnp.pi * i * x) / jnp.sqrt(0.25 + number / 2), + 0.0, + ) elif not start_zero and not end_zero: i = jnp.arange(0, number) return jnp.cos(jnp.pi * i * x) / jnp.sqrt(0.25 + number / 2) else: - raise ValueError("when using fourier basis, start_zero and end_zero must be the same") + raise ValueError( + "when using fourier basis, start_zero and end_zero must be the same" + ) raise ValueError(f'basis="{basis}" is not a valid entry') @@ -237,18 +243,33 @@ def bessel(x: jnp.ndarray, n: int, x_max: float = 1.0) -> jnp.ndarray: def u(p: int, x: jnp.ndarray) -> jnp.ndarray: r"""Equivalent to :func:`poly_envelope` with ``n0 = p-1`` and ``n1 = 2``.""" - return 1 - (p + 1) * (p + 2) / 2 * x**p + p * (p + 2) * x ** (p + 1) - p * (p + 1) / 2 * x ** (p + 2) + return ( + 1 + - (p + 1) * (p + 2) / 2 * x**p + + p * (p + 2) * x ** (p + 1) + - p * (p + 1) / 2 * x ** (p + 2) + ) def _constraint(x: float, derivative: int, degree: int): - return [0 if derivative > N else factorial(N) // factorial(N - derivative) * x ** (N - derivative) for N in range(degree)] + return [ + 0 + if derivative > N + else factorial(N) // factorial(N - derivative) * x ** (N - derivative) + for N in range(degree) + ] @lru_cache(maxsize=None) def solve_polynomial(constraints) -> jnp.ndarray: with jax.ensure_compile_time_eval(): degree = len(constraints) - A = np.array([_constraint(x, derivative, degree) for x, derivative, _ in sorted(constraints)]) + A = np.array( + [ + _constraint(x, derivative, degree) + for x, derivative, _ in sorted(constraints) + ] + ) B = np.array([y for _, _, y in sorted(constraints)]) c = np.linalg.solve(A, B)[::-1] diff --git a/e3nn_jax/_src/radial_test.py b/e3nn_jax/_src/radial_test.py index a4f5a0e9..679cf7ee 100644 --- a/e3nn_jax/_src/radial_test.py +++ b/e3nn_jax/_src/radial_test.py @@ -25,7 +25,15 @@ def test_soft_one_hot_linspace(basis: str, start_zero: bool, end_zero: bool): pytest.skip() x = jnp.linspace(0.2, 0.8, 100) - y = e3nn.soft_one_hot_linspace(x, start=0.0, end=1.0, number=5, basis=basis, start_zero=start_zero, end_zero=end_zero) + y = e3nn.soft_one_hot_linspace( + x, + start=0.0, + end=1.0, + number=5, + basis=basis, + start_zero=start_zero, + end_zero=end_zero, + ) assert y.shape == (100, 5) np.testing.assert_allclose(jnp.sum(y**2, axis=1), 1.0, atol=0.4) @@ -33,7 +41,13 @@ def test_soft_one_hot_linspace(basis: str, start_zero: bool, end_zero: bool): jax.config.update("jax_enable_x64", True) assert_output_dtype_matches_input_dtype( lambda x: e3nn.soft_one_hot_linspace( - x, start=0.0, end=1.0, number=5, basis=basis, start_zero=start_zero, end_zero=end_zero + x, + start=0.0, + end=1.0, + number=5, + basis=basis, + start_zero=start_zero, + end_zero=end_zero, ), x, ) diff --git a/e3nn_jax/_src/radius_graph.py b/e3nn_jax/_src/radius_graph.py index 42b485a9..59789f58 100644 --- a/e3nn_jax/_src/radius_graph.py +++ b/e3nn_jax/_src/radius_graph.py @@ -42,7 +42,9 @@ def radius_graph( if isinstance(pos, e3nn.IrrepsArray): pos = pos.array - r = jax.vmap(jax.vmap(lambda x, y: jnp.linalg.norm(x - y), (None, 0), 0), (0, None), 0)(pos, pos) + r = jax.vmap( + jax.vmap(lambda x, y: jnp.linalg.norm(x - y), (None, 0), 0), (0, None), 0 + )(pos, pos) if loop: mask = r < r_max else: diff --git a/e3nn_jax/_src/radius_graph_test.py b/e3nn_jax/_src/radius_graph_test.py index 79899d9e..4e12bded 100644 --- a/e3nn_jax/_src/radius_graph_test.py +++ b/e3nn_jax/_src/radius_graph_test.py @@ -3,7 +3,9 @@ def test_radius_graph(): - pos = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + pos = jnp.array( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + ) r_max = 1.01 batch = jnp.array([0, 0, 0, 0]) @@ -15,6 +17,8 @@ def test_radius_graph(): assert src.shape == (6 + 4,) assert dst.shape == (6 + 4,) - src, dst = e3nn.radius_graph(pos, r_max, batch=batch, size=12, fill_src=-1, fill_dst=-1) + src, dst = e3nn.radius_graph( + pos, r_max, batch=batch, size=12, fill_src=-1, fill_dst=-1 + ) assert src.shape == (12,) assert dst.shape == (12,) diff --git a/e3nn_jax/_src/reduced_tensor_product.py b/e3nn_jax/_src/reduced_tensor_product.py index 18e5ffc8..740b13de 100644 --- a/e3nn_jax/_src/reduced_tensor_product.py +++ b/e3nn_jax/_src/reduced_tensor_product.py @@ -67,8 +67,12 @@ def reduced_tensor_product_basis( if isinstance(formula_or_irreps_list, (tuple, list)): irreps_list = formula_or_irreps_list irreps_tuple = tuple(e3nn.Irreps(irreps) for irreps in irreps_list) - perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]] = frozenset({(1, tuple(range(len(irreps_tuple))))}) - return _reduced_tensor_product_basis(irreps_tuple, perm_repr, keep_ir, epsilon, _use_optimized_implementation) + perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]] = frozenset( + {(1, tuple(range(len(irreps_tuple))))} + ) + return _reduced_tensor_product_basis( + irreps_tuple, perm_repr, keep_ir, epsilon, _use_optimized_implementation + ) formula = formula_or_irreps_list f0, perm_repr = germinate_perm_repr(formula) @@ -82,7 +86,11 @@ def reduced_tensor_product_basis( for _sign, p in perm_repr: f = "".join(f0[i] for i in p) for i, j in zip(f0, f): - if i in irreps_dict and j in irreps_dict and irreps_dict[i] != irreps_dict[j]: + if ( + i in irreps_dict + and j in irreps_dict + and irreps_dict[i] != irreps_dict[j] + ): raise RuntimeError(f"irreps of {i} and {j} should be the same") if i in irreps_dict: irreps_dict[j] = irreps_dict[i] @@ -95,11 +103,15 @@ def reduced_tensor_product_basis( for i in irreps_dict: if i not in f0: - raise RuntimeError(f"index {i} has an irreps but does not appear in the fomula") + raise RuntimeError( + f"index {i} has an irreps but does not appear in the fomula" + ) irreps_tuple = tuple(irreps_dict[i] for i in f0) - return _reduced_tensor_product_basis(irreps_tuple, perm_repr, keep_ir, epsilon, _use_optimized_implementation) + return _reduced_tensor_product_basis( + irreps_tuple, perm_repr, keep_ir, epsilon, _use_optimized_implementation + ) def _symmetric_perm_repr(n: int): @@ -136,7 +148,13 @@ def reduced_symmetric_tensor_product_basis( irreps = e3nn.Irreps(irreps) perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]] = _symmetric_perm_repr(degree) - return _reduced_tensor_product_basis(tuple([irreps] * degree), perm_repr, keep_ir, epsilon, _use_optimized_implementation) + return _reduced_tensor_product_basis( + tuple([irreps] * degree), + perm_repr, + keep_ir, + epsilon, + _use_optimized_implementation, + ) def _antisymmetric_perm_repr(n: int): @@ -173,7 +191,13 @@ def reduced_antisymmetric_tensor_product_basis( irreps = e3nn.Irreps(irreps) perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]] = _antisymmetric_perm_repr(degree) - return _reduced_tensor_product_basis(tuple([irreps] * degree), perm_repr, keep_ir, epsilon, _use_optimized_implementation) + return _reduced_tensor_product_basis( + tuple([irreps] * degree), + perm_repr, + keep_ir, + epsilon, + _use_optimized_implementation, + ) @functools.lru_cache(maxsize=None) @@ -205,7 +229,13 @@ def _reduced_tensor_product_basis( frozenset({i}), e3nn.IrrepsArray( irreps, - np.reshape(np.eye(irreps.dim), (1,) * i + (irreps.dim,) + (1,) * (len(irreps_tuple) - i - 1) + (irreps.dim,)), + np.reshape( + np.eye(irreps.dim), + (1,) * i + + (irreps.dim,) + + (1,) * (len(irreps_tuple) - i - 1) + + (irreps.dim,), + ), ), ) for i, irreps in enumerate(irreps_tuple) @@ -257,8 +287,12 @@ def _reduced_tensor_product_basis( for i, irreps in enumerate(irreps_tuple): if i not in f: keep = _tp_ir_seq(keep, irreps) - ab = _reduced_tensor_product_basis(sub_irreps, sub_perm_repr, keep, epsilon, _use_optimized_implementation) - ab = ab.reshape(tuple(dims[i] if i in f else 1 for i in range(len(dims))) + (-1,)) + ab = _reduced_tensor_product_basis( + sub_irreps, sub_perm_repr, keep, epsilon, _use_optimized_implementation + ) + ab = ab.reshape( + tuple(dims[i] if i in f else 1 for i in range(len(dims))) + (-1,) + ) bases = [(f, ab)] + bases @@ -337,7 +371,10 @@ def compute_padding_for_irrep_index(irrep_index: int): dims_after = inp_irreps_dims_cumsum_after[irrep_index] return (dims_before, dims_after) - return [compute_padding_for_irrep_index(irrep_index) for irrep_index in irrep_indices] + [(0, 0)] + return [ + compute_padding_for_irrep_index(irrep_index) + for irrep_index in irrep_indices + ] + [(0, 0)] def repeat_indices(indices: Sequence[int], powers: Sequence[int]) -> List[int]: """Given [i1, i2, ...] and [p1, p2, ...], returns [i1, i1, ... (p1 times), i2, i2, ... (p2 times), ...]""" @@ -346,7 +383,9 @@ def repeat_indices(indices: Sequence[int], powers: Sequence[int]) -> List[int]: repeated_indices.extend([index] * power) return repeated_indices - def generate_permutations(seq: Sequence[float]) -> Iterator[Tuple[Sequence[float], Sequence[int]]]: + def generate_permutations( + seq: Sequence[float], + ) -> Iterator[Tuple[Sequence[float], Sequence[int]]]: """Generates permutations of a sequence along with the indices used to create the permutation.""" indices = range(len(seq)) for permuted_indices in itertools.permutations(indices): @@ -367,7 +406,9 @@ def cumsum_after(seq: Sequence[float]) -> np.ndarray: """ return cumsum_before(seq[::-1])[::-1] - def reshape_for_basis_product(terms: Sequence[e3nn.IrrepsArray], non_zero_powers: Sequence[float]): + def reshape_for_basis_product( + terms: Sequence[e3nn.IrrepsArray], non_zero_powers: Sequence[float] + ): """Adds extra axes to each term to be compatible for reduce_basis_product().""" term_powers_cumsum_before = cumsum_before(non_zero_powers) term_powers_cumsum_after = cumsum_after(non_zero_powers) @@ -381,7 +422,10 @@ def reshape_term_for_basis_product(index, term): ) return term.reshape(new_shape) - return [reshape_term_for_basis_product(index, term) for index, term in enumerate(terms)] + return [ + reshape_term_for_basis_product(index, term) + for index, term in enumerate(terms) + ] irreps = e3nn.Irreps(irreps) irreps = e3nn.Irreps([(1, ir) for mul, ir in irreps for _ in range(mul)]) @@ -392,7 +436,9 @@ def reshape_term_for_basis_product(index, term): irreps_powers[i] = [e3nn.IrrepsArray("0e", np.asarray([1.0]))] for n in range(1, degree + 1): keep = _tp_ir_seq(keep_ir, _tp_ir_seq_pow(irreps, degree - n)) - power = reduced_symmetric_tensor_product_basis(mul_ir, n, epsilon=epsilon, keep_ir=keep) + power = reduced_symmetric_tensor_product_basis( + mul_ir, n, epsilon=epsilon, keep_ir=keep + ) irreps_powers[i].append(power) # Take all products of irreps whose powers sum up to degree. @@ -416,11 +462,17 @@ def reshape_term_for_basis_product(index, term): non_zero_indices = [i for i, n in enumerate(term_powers) if n != 0] non_zero_powers = [n for n in term_powers if n != 0] - non_zero_indices_repeated = tuple(repeat_indices(non_zero_indices, non_zero_powers)) + non_zero_indices_repeated = tuple( + repeat_indices(non_zero_indices, non_zero_powers) + ) # Add axes to all terms, so that they have the same number of input axes. - non_zero_terms = [irreps_powers[i][n] for i, n in zip(non_zero_indices, non_zero_powers)] - non_zero_terms_reshaped = reshape_for_basis_product(non_zero_terms, non_zero_powers) + non_zero_terms = [ + irreps_powers[i][n] for i, n in zip(non_zero_indices, non_zero_powers) + ] + non_zero_terms_reshaped = reshape_for_basis_product( + non_zero_terms, non_zero_powers + ) # Compute basis product, two terms at a time. if len(non_zero_terms_reshaped) == 1: @@ -430,7 +482,9 @@ def reshape_term_for_basis_product(index, term): for next_term in non_zero_terms_reshaped[1:-1]: current_term = reduce_basis_product(current_term, next_term) last_term = non_zero_terms_reshaped[-1] - product_basis = reduce_basis_product(current_term, last_term, filter_ir_out=keep_ir) + product_basis = reduce_basis_product( + current_term, last_term, filter_ir_out=keep_ir + ) if product_basis.irreps.dim == 0: continue @@ -440,7 +494,9 @@ def reshape_term_for_basis_product(index, term): seen_permutations = set() # Now, average over the different permutations. - for permuted_indices_repeated, permuted_axes in generate_permutations(non_zero_indices_repeated): + for permuted_indices_repeated, permuted_axes in generate_permutations( + non_zero_indices_repeated + ): # Keep track of which permutations we have seen. # Don't repeat permutations! if permuted_indices_repeated in seen_permutations: @@ -448,17 +504,26 @@ def reshape_term_for_basis_product(index, term): seen_permutations.add(permuted_indices_repeated) # Permute axes according to this term. - permuted_product_basis_array = np.transpose(product_basis.array, permuted_axes + (len(permuted_axes),)) + permuted_product_basis_array = np.transpose( + product_basis.array, permuted_axes + (len(permuted_axes),) + ) # Add padding. padding = compute_padding_for_term(permuted_indices_repeated) - slices = tuple(slice(start, total - stop) for (start, stop), total in zip(padding, shape)) + slices = tuple( + slice(start, total - stop) + for (start, stop), total in zip(padding, shape) + ) sum_of_permuted_bases[slices] += permuted_product_basis_array # Normalize the sum of bases. - symmetrized_sum_of_permuted_bases = sum_of_permuted_bases / np.sqrt(len(seen_permutations)) - product_basis = e3nn.IrrepsArray(product_basis.irreps, symmetrized_sum_of_permuted_bases) + symmetrized_sum_of_permuted_bases = sum_of_permuted_bases / np.sqrt( + len(seen_permutations) + ) + product_basis = e3nn.IrrepsArray( + product_basis.irreps, symmetrized_sum_of_permuted_bases + ) symmetric_product.append(product_basis) # Filter out irreps, if needed. @@ -469,9 +534,13 @@ def reshape_term_for_basis_product(index, term): @functools.lru_cache(maxsize=None) -def germinate_perm_repr(formula: str) -> Tuple[str, FrozenSet[Tuple[int, Tuple[int, ...]]]]: +def germinate_perm_repr( + formula: str, +) -> Tuple[str, FrozenSet[Tuple[int, Tuple[int, ...]]]]: """Convert the formula (generators) into a group.""" - formulas = [(-1 if f.startswith("-") else 1, f.replace("-", "")) for f in formula.split("=")] + formulas = [ + (-1 if f.startswith("-") else 1, f.replace("-", "")) for f in formula.split("=") + ] s0, f0 = formulas[0] assert s0 == 1 @@ -483,7 +552,9 @@ def germinate_perm_repr(formula: str) -> Tuple[str, FrozenSet[Tuple[int, Tuple[i # `perm_repr` is a list of (sign, permutation of indices) # each formula can be viewed as a permutation of the original formula - perm_repr = {(s, tuple(f.index(i) for i in f0)) for s, f in formulas} # set of generators (permutations) + perm_repr = { + (s, tuple(f.index(i) for i in f0)) for s, f in formulas + } # set of generators (permutations) # they can be composed, for instance if you have ijk=jik=ikj # you also have ijk=jki @@ -491,7 +562,13 @@ def germinate_perm_repr(formula: str) -> Tuple[str, FrozenSet[Tuple[int, Tuple[i while True: n = len(perm_repr) perm_repr = perm_repr.union([(s, perm.inverse(p)) for s, p in perm_repr]) - perm_repr = perm_repr.union([(s1 * s2, perm.compose(p1, p2)) for s1, p1 in perm_repr for s2, p2 in perm_repr]) + perm_repr = perm_repr.union( + [ + (s1 * s2, perm.compose(p1, p2)) + for s1, p1 in perm_repr + for s2, p2 in perm_repr + ] + ) if len(perm_repr) == n: break # we break when the set is stable => it is now a group \o/ @@ -525,7 +602,11 @@ def reduce_basis_product( new_list.append(x) new = e3nn.IrrepsArray.from_list( - new_irreps, new_list, np.broadcast_shapes(basis1.shape[:-1], basis2.shape[:-1]), np.float64, backend=np + new_irreps, + new_list, + np.broadcast_shapes(basis1.shape[:-1], basis2.shape[:-1]), + np.float64, + backend=np, ) return new.regroup() @@ -548,7 +629,10 @@ def constrain_rotation_basis_by_permutation_basis( """ assert rotation_basis.shape[:-1] == permutation_basis.shape[1:] - perm = np.reshape(permutation_basis, (permutation_basis.shape[0], prod(permutation_basis.shape[1:]))) # (free, dim) + perm = np.reshape( + permutation_basis, + (permutation_basis.shape[0], prod(permutation_basis.shape[1:])), + ) # (free, dim) new_irreps: List[Tuple[int, e3nn.Irrep]] = [] new_list: List[np.ndarray] = [] @@ -569,7 +653,9 @@ def constrain_rotation_basis_by_permutation_basis( new_irreps.append((len(P), ir)) new_list.append(round_fn(np.einsum("vu,...ui->...vi", P, rot_basis))) - return e3nn.IrrepsArray.from_list(new_irreps, new_list, rotation_basis.shape[:-1], np.float64, backend=np) + return e3nn.IrrepsArray.from_list( + new_irreps, new_list, rotation_basis.shape[:-1], np.float64, backend=np + ) def subrepr_permutation( @@ -586,7 +672,10 @@ def subrepr_permutation( def reduce_subgroup_permutation( - sub_f0: FrozenSet[int], perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], dims: Tuple[int, ...], return_dim: bool = False + sub_f0: FrozenSet[int], + perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], + dims: Tuple[int, ...], + return_dim: bool = False, ) -> np.ndarray: sub_perm_repr = subrepr_permutation(sub_f0, perm_repr) sub_dims = tuple(dims[i] for i in sub_f0) @@ -598,7 +687,10 @@ def reduce_subgroup_permutation( if return_dim: return len(base) permutation_basis = reduce_permutation_matrix(base, sub_dims) - return np.reshape(permutation_basis, (-1,) + tuple(dims[i] if i in sub_f0 else 1 for i in range(len(dims)))) + return np.reshape( + permutation_basis, + (-1,) + tuple(dims[i] if i in sub_f0 else 1 for i in range(len(dims))), + ) @functools.lru_cache(maxsize=None) @@ -633,7 +725,8 @@ def reduce_permutation_base( @functools.lru_cache(maxsize=None) def reduce_permutation_matrix( - base: FrozenSet[FrozenSet[FrozenSet[Tuple[int, Tuple[int, ...]]]]], dims: Tuple[int, ...] + base: FrozenSet[FrozenSet[FrozenSet[Tuple[int, Tuple[int, ...]]]]], + dims: Tuple[int, ...], ) -> np.ndarray: base = sorted( [sorted([sorted(xs) for xs in x]) for x in base] diff --git a/e3nn_jax/_src/reduced_tensor_product_test.py b/e3nn_jax/_src/reduced_tensor_product_test.py index e7da1e18..609c5538 100644 --- a/e3nn_jax/_src/reduced_tensor_product_test.py +++ b/e3nn_jax/_src/reduced_tensor_product_test.py @@ -11,8 +11,12 @@ def test_reduce_tensor_Levi_Civita_symbol(): Q = e3nn.reduced_tensor_product_basis("ijk=-ikj=-jik", i="1e") assert Q.irreps == "0e" - np.testing.assert_allclose(Q.array, -np.einsum("ijkx->ikjx", Q.array), atol=1e-6, rtol=1e-6) - np.testing.assert_allclose(Q.array, -np.einsum("ijkx->jikx", Q.array), atol=1e-6, rtol=1e-6) + np.testing.assert_allclose( + Q.array, -np.einsum("ijkx->ikjx", Q.array), atol=1e-6, rtol=1e-6 + ) + np.testing.assert_allclose( + Q.array, -np.einsum("ijkx->jikx", Q.array), atol=1e-6, rtol=1e-6 + ) def test_reduce_tensor_elasticity_tensor(): @@ -25,9 +29,15 @@ def test_reduce_tensor_elasticity_tensor_parity(): assert Q.irreps.dim == 21 assert all(ir.p == 1 for _, ir in Q.irreps) - np.testing.assert_allclose(Q.array, np.einsum("ijklx->jiklx", Q.array), atol=1e-6, rtol=1e-6) - np.testing.assert_allclose(Q.array, np.einsum("ijklx->ijlkx", Q.array), atol=1e-6, rtol=1e-6) - np.testing.assert_allclose(Q.array, np.einsum("ijklx->klijx", Q.array), atol=1e-6, rtol=1e-6) + np.testing.assert_allclose( + Q.array, np.einsum("ijklx->jiklx", Q.array), atol=1e-6, rtol=1e-6 + ) + np.testing.assert_allclose( + Q.array, np.einsum("ijklx->ijlkx", Q.array), atol=1e-6, rtol=1e-6 + ) + np.testing.assert_allclose( + Q.array, np.einsum("ijklx->klijx", Q.array), atol=1e-6, rtol=1e-6 + ) def test_reduced_symmetric_tensor_product_basis(): @@ -61,8 +71,12 @@ def test_tensor_product_basis_equivariance(keys): def test_optimized_reduced_symmetric_tensor_product_basis_order_2(): - Q = e3nn.reduced_symmetric_tensor_product_basis("1e + 2o", 2, _use_optimized_implementation=True) - P = e3nn.reduced_symmetric_tensor_product_basis("1e + 2o", 2, _use_optimized_implementation=False) + Q = e3nn.reduced_symmetric_tensor_product_basis( + "1e + 2o", 2, _use_optimized_implementation=True + ) + P = e3nn.reduced_symmetric_tensor_product_basis( + "1e + 2o", 2, _use_optimized_implementation=False + ) assert Q.irreps == P.irreps np.testing.assert_almost_equal(Q.array, P.array) @@ -72,8 +86,12 @@ def test_optimized_reduced_symmetric_tensor_product_basis_order_2(): def test_optimized_reduced_symmetric_tensor_product_basis_order_3a(): - Q = e3nn.reduced_symmetric_tensor_product_basis("0e + 1e", 3, _use_optimized_implementation=True) - P = e3nn.reduced_symmetric_tensor_product_basis("0e + 1e", 3, _use_optimized_implementation=False) + Q = e3nn.reduced_symmetric_tensor_product_basis( + "0e + 1e", 3, _use_optimized_implementation=True + ) + P = e3nn.reduced_symmetric_tensor_product_basis( + "0e + 1e", 3, _use_optimized_implementation=False + ) assert Q.irreps == P.irreps Q = Q.array.reshape(-1, Q.irreps.dim) @@ -82,8 +100,12 @@ def test_optimized_reduced_symmetric_tensor_product_basis_order_3a(): def test_optimized_reduced_symmetric_tensor_product_basis_order_3b(): - Q = e3nn.reduced_symmetric_tensor_product_basis("3x0e + 1e", 3, _use_optimized_implementation=True) - P = e3nn.reduced_symmetric_tensor_product_basis("3x0e + 1e", 3, _use_optimized_implementation=False) + Q = e3nn.reduced_symmetric_tensor_product_basis( + "3x0e + 1e", 3, _use_optimized_implementation=True + ) + P = e3nn.reduced_symmetric_tensor_product_basis( + "3x0e + 1e", 3, _use_optimized_implementation=False + ) assert Q.irreps == P.irreps Q = Q.array.reshape(-1, Q.irreps.dim) @@ -93,8 +115,12 @@ def test_optimized_reduced_symmetric_tensor_product_basis_order_3b(): def test_optimized_reduced_symmetric_tensor_product_basis_order_3c(): irreps = "1o + 2e + 4e" - Q = e3nn.reduced_symmetric_tensor_product_basis(irreps, 3, keep_ir="0e + 1o", _use_optimized_implementation=True) - P = e3nn.reduced_symmetric_tensor_product_basis(irreps, 3, keep_ir="0e + 1o", _use_optimized_implementation=False) + Q = e3nn.reduced_symmetric_tensor_product_basis( + irreps, 3, keep_ir="0e + 1o", _use_optimized_implementation=True + ) + P = e3nn.reduced_symmetric_tensor_product_basis( + irreps, 3, keep_ir="0e + 1o", _use_optimized_implementation=False + ) assert Q.irreps == P.irreps Q = Q.array.reshape(-1, Q.irreps.dim) @@ -103,8 +129,12 @@ def test_optimized_reduced_symmetric_tensor_product_basis_order_3c(): def test_optimized_reduced_symmetric_tensor_product_basis_order_4(): - Q = e3nn.reduced_symmetric_tensor_product_basis("1e + 2o", 4, _use_optimized_implementation=True) - P = e3nn.reduced_symmetric_tensor_product_basis("1e + 2o", 4, _use_optimized_implementation=False) + Q = e3nn.reduced_symmetric_tensor_product_basis( + "1e + 2o", 4, _use_optimized_implementation=True + ) + P = e3nn.reduced_symmetric_tensor_product_basis( + "1e + 2o", 4, _use_optimized_implementation=False + ) assert Q.irreps == P.irreps Q = Q.array.reshape(-1, Q.irreps.dim) diff --git a/e3nn_jax/_src/rotation.py b/e3nn_jax/_src/rotation.py index ae4f5f07..64113c2d 100644 --- a/e3nn_jax/_src/rotation.py +++ b/e3nn_jax/_src/rotation.py @@ -172,10 +172,22 @@ def compose_quaternion(q1, q2): q1, q2 = jnp.broadcast_arrays(q1, q2) return jnp.stack( [ - q1[..., 0] * q2[..., 0] - q1[..., 1] * q2[..., 1] - q1[..., 2] * q2[..., 2] - q1[..., 3] * q2[..., 3], - q1[..., 1] * q2[..., 0] + q1[..., 0] * q2[..., 1] + q1[..., 2] * q2[..., 3] - q1[..., 3] * q2[..., 2], - q1[..., 0] * q2[..., 2] - q1[..., 1] * q2[..., 3] + q1[..., 2] * q2[..., 0] + q1[..., 3] * q2[..., 1], - q1[..., 0] * q2[..., 3] + q1[..., 1] * q2[..., 2] - q1[..., 2] * q2[..., 1] + q1[..., 3] * q2[..., 0], + q1[..., 0] * q2[..., 0] + - q1[..., 1] * q2[..., 1] + - q1[..., 2] * q2[..., 2] + - q1[..., 3] * q2[..., 3], + q1[..., 1] * q2[..., 0] + + q1[..., 0] * q2[..., 1] + + q1[..., 2] * q2[..., 3] + - q1[..., 3] * q2[..., 2], + q1[..., 0] * q2[..., 2] + - q1[..., 1] * q2[..., 3] + + q1[..., 2] * q2[..., 0] + + q1[..., 3] * q2[..., 1], + q1[..., 0] * q2[..., 3] + + q1[..., 1] * q2[..., 2] + - q1[..., 2] * q2[..., 1] + + q1[..., 3] * q2[..., 0], ], axis=-1, ) @@ -246,7 +258,10 @@ def compose_axis_angle(axis1, angle1, axis2, angle2): angle (`jax.numpy.ndarray`): array of shape :math:`(...)` """ return quaternion_to_axis_angle( - compose_quaternion(axis_angle_to_quaternion(axis1, angle1), axis_angle_to_quaternion(axis2, angle2)) + compose_quaternion( + axis_angle_to_quaternion(axis1, angle1), + axis_angle_to_quaternion(axis2, angle2), + ) ) @@ -307,7 +322,9 @@ def compose_log_coordinates(log1, log2): log coordinates (`jax.numpy.ndarray`): array of shape :math:`(..., 3)` """ return quaternion_to_log_coordinates( - compose_quaternion(log_coordinates_to_quaternion(log1), log_coordinates_to_quaternion(log2)) + compose_quaternion( + log_coordinates_to_quaternion(log1), log_coordinates_to_quaternion(log2) + ) ) @@ -401,7 +418,14 @@ def matrix_z(angle): s = jnp.sin(angle) o = jnp.ones_like(angle) z = jnp.zeros_like(angle) - return jnp.stack([jnp.stack([c, -s, z], axis=-1), jnp.stack([s, c, z], axis=-1), jnp.stack([z, z, o], axis=-1)], axis=-2) + return jnp.stack( + [ + jnp.stack([c, -s, z], axis=-1), + jnp.stack([s, c, z], axis=-1), + jnp.stack([z, z, o], axis=-1), + ], + axis=-2, + ) def angles_to_matrix(alpha, beta, gamma): diff --git a/e3nn_jax/_src/rotation_test.py b/e3nn_jax/_src/rotation_test.py index a7beb4ea..b274bf54 100644 --- a/e3nn_jax/_src/rotation_test.py +++ b/e3nn_jax/_src/rotation_test.py @@ -32,7 +32,9 @@ def test_xyz(keys): for r in rs: a, b = e3nn.xyz_to_angles(r) R = e3nn.angles_to_matrix(a, -b, -a) - np.testing.assert_allclose(R @ r, np.array([0.0, 1.0, 0.0]), atol=float_tolerance) + np.testing.assert_allclose( + R @ r, np.array([0.0, 1.0, 0.0]), atol=float_tolerance + ) Ja, Jb = jax.jacobian(e3nn.xyz_to_angles)(jnp.array([0.0, 1.0, 0.0])) np.testing.assert_allclose(Ja, 0.0, atol=float_tolerance) @@ -178,7 +180,9 @@ def f(key): return e3nn.axis_angle_to_matrix(axis, angle) @ jnp.array([0.2, 0.5, 0.3]) x = f(next(keys)) - np.testing.assert_allclose(jnp.mean(x, axis=0), jnp.array([0.0, 0.0, 0.0]), atol=0.005) + np.testing.assert_allclose( + jnp.mean(x, axis=0), jnp.array([0.0, 0.0, 0.0]), atol=0.005 + ) jax.config.update("jax_enable_x64", False) diff --git a/e3nn_jax/_src/s2grid.py b/e3nn_jax/_src/s2grid.py index 0e1fb231..95eeac85 100644 --- a/e3nn_jax/_src/s2grid.py +++ b/e3nn_jax/_src/s2grid.py @@ -110,16 +110,24 @@ def __init__( ) -> None: if _perform_checks: if len(grid_values.shape) < 2: - raise ValueError(f"Grid values should have atleast 2 axes. Got grid_values of shape {grid_values.shape}.") + raise ValueError( + f"Grid values should have atleast 2 axes. Got grid_values of shape {grid_values.shape}." + ) if quadrature not in ["soft", "gausslegendre"]: - raise ValueError(f"Invalid quadrature for SphericalSignal: {quadrature}") + raise ValueError( + f"Invalid quadrature for SphericalSignal: {quadrature}" + ) if p_val not in (-1, 1): - raise ValueError(f"Parity p_val must be either +1 or -1. Received: {p_val}") + raise ValueError( + f"Parity p_val must be either +1 or -1. Received: {p_val}" + ) if p_arg not in (-1, 1): - raise ValueError(f"Parity p_arg must be either +1 or -1. Received: {p_arg}") + raise ValueError( + f"Parity p_arg must be either +1 or -1. Received: {p_arg}" + ) self.grid_values = grid_values self.quadrature = quadrature @@ -161,11 +169,17 @@ def __mul__(self, scalar: Union[float, "SphericalSignal"]) -> "SphericalSignal": if isinstance(scalar, SphericalSignal): other = scalar if self.quadrature != other.quadrature: - raise ValueError("Multiplication of SphericalSignals with different quadrature is not supported.") + raise ValueError( + "Multiplication of SphericalSignals with different quadrature is not supported." + ) if self.grid_resolution != other.grid_resolution: - raise ValueError("Multiplication of SphericalSignals with different grid resolution is not supported.") + raise ValueError( + "Multiplication of SphericalSignals with different grid resolution is not supported." + ) if self.p_arg != other.p_arg: - raise ValueError("Multiplication of SphericalSignals with different p_arg is not equivariant.") + raise ValueError( + "Multiplication of SphericalSignals with different p_arg is not equivariant." + ) return SphericalSignal( self.grid_values * other.grid_values, @@ -220,7 +234,9 @@ def __sub__(self, other: "SphericalSignal") -> "SphericalSignal": def __neg__(self) -> "SphericalSignal": """Negate SphericalSignal.""" - return SphericalSignal(-self.grid_values, self.quadrature, p_val=self.p_val, p_arg=self.p_arg) + return SphericalSignal( + -self.grid_values, self.quadrature, p_val=self.p_val, p_arg=self.p_arg + ) @property def shape(self) -> Tuple[int, ...]: @@ -276,7 +292,9 @@ def grid_resolution(self) -> Tuple[int, int]: """Grid resolution for (beta, alpha).""" return (self.res_beta, self.res_alpha) - def resample(self, res_beta: int, res_alpha: int, lmax: int, quadrature: Optional[str] = None) -> "SphericalSignal": + def resample( + self, res_beta: int, res_alpha: int, lmax: int, quadrature: Optional[str] = None + ) -> "SphericalSignal": """Resamples a signal via the spherical harmonic coefficients. Args: @@ -323,7 +341,9 @@ def _transform_by( p_arg=self.p_arg, ) - def transform_by_angles(self, alpha: float, beta: float, gamma: float, lmax: int) -> "SphericalSignal": + def transform_by_angles( + self, alpha: float, beta: float, gamma: float, lmax: int + ) -> "SphericalSignal": """Rotate the signal by the given Euler angles.""" return self._transform_by( "angles", @@ -335,9 +355,13 @@ def transform_by_matrix(self, R: jnp.ndarray, lmax: int) -> "SphericalSignal": """Rotate the signal by the given rotation matrix.""" return self._transform_by("matrix", transform_kwargs=dict(R=R), lmax=lmax) - def transform_by_axis_angle(self, axis: jnp.ndarray, angle: float, lmax: int) -> "SphericalSignal": + def transform_by_axis_angle( + self, axis: jnp.ndarray, angle: float, lmax: int + ) -> "SphericalSignal": """Rotate the signal by the given angle around an axis.""" - return self._transform_by("axis_angle", transform_kwargs=dict(axis=axis, angle=angle), lmax=lmax) + return self._transform_by( + "axis_angle", transform_kwargs=dict(axis=axis, angle=angle), lmax=lmax + ) def transform_by_quaternion(self, q: jnp.ndarray, lmax: int) -> "SphericalSignal": """Rotate the signal by the given quaternion.""" @@ -350,7 +374,9 @@ def apply(self, func: Callable[[jnp.ndarray], jnp.ndarray]): raise ValueError( "Activation: the parity is violated! The input scalar is odd but the activation is neither even nor odd." ) - return SphericalSignal(func(self.grid_values), self.quadrature, p_val=new_p_val, p_arg=self.p_arg) + return SphericalSignal( + func(self.grid_values), self.quadrature, p_val=new_p_val, p_arg=self.p_arg + ) @staticmethod def _find_peaks_2d(x: np.ndarray) -> List[Tuple[int, int]]: @@ -393,7 +419,9 @@ def find_peaks(self, lmax: int) -> Tuple[np.ndarray, np.ndarray]: f2p = np.stack([f2[i, j] for i, j in ij]) # Union of the results - mask = scipy.spatial.distance.cdist(x1p, x2p) < 2 * np.pi / max(*grid_resolution) + mask = scipy.spatial.distance.cdist(x1p, x2p) < 2 * np.pi / max( + *grid_resolution + ) x = np.concatenate([x1p[mask.sum(axis=1) == 0], x2p]) f = np.concatenate([f1p[mask.sum(axis=1) == 0], f2p]) @@ -430,7 +458,9 @@ def pad_to_plot( one = jnp.ones_like(y, shape=(1,)) ones = jnp.ones_like(f, shape=(1, len(alpha))) y = jnp.concatenate([-one, y, one]) # [res_beta + 2] - f = jnp.concatenate([jnp.mean(f[0]) * ones, f, jnp.mean(f[-1]) * ones], axis=0) # [res_beta + 2, res_alpha] + f = jnp.concatenate( + [jnp.mean(f[0]) * ones, f, jnp.mean(f[-1]) * ones], axis=0 + ) # [res_beta + 2, res_alpha] # alpha: [0, 2pi] alpha = jnp.concatenate([alpha, alpha[:1]]) # [res_alpha + 1] @@ -561,14 +591,18 @@ def f(k, p_ya): # single signal only k1, k2 = jax.random.split(k) p_y = self.quadrature_weights * jnp.sum(p_ya, axis=1) # [y] y_index = jax.random.choice(k1, jnp.arange(self.res_beta), p=p_y) # [] - alpha_index = jax.random.choice(k2, jnp.arange(self.res_alpha), p=p_ya[y_index]) # [] + alpha_index = jax.random.choice( + k2, jnp.arange(self.res_alpha), p=p_ya[y_index] + ) # [] return y_index, alpha_index vf = f for _ in range(self.ndim - 2): vf = jax.vmap(vf) - keys = jax.random.split(key, math.prod(self.shape[:-2])).reshape(self.shape[:-2] + key.shape) + keys = jax.random.split(key, math.prod(self.shape[:-2])).reshape( + self.shape[:-2] + key.shape + ) return vf(keys, self.grid_values) def __getitem__(self, index) -> "SphericalSignal": @@ -602,7 +636,9 @@ def __getitem__(self, index) -> "SphericalSignal": ) -def s2_dirac(position: Union[jnp.ndarray, e3nn.IrrepsArray], lmax: int, *, p_val: int, p_arg: int) -> e3nn.IrrepsArray: +def s2_dirac( + position: Union[jnp.ndarray, e3nn.IrrepsArray], lmax: int, *, p_val: int, p_arg: int +) -> e3nn.IrrepsArray: r"""Spherical harmonics expansion of a Dirac delta on the sphere. The integral of the Dirac delta is 1. @@ -682,7 +718,9 @@ def s2_dirac(position: Union[jnp.ndarray, e3nn.IrrepsArray], lmax: int, *, p_val e3nn.sum(e3nn.s2_dirac(positions, 4, p_val=1, p_arg=-1) * weights[:, None], axis=0) """ irreps = s2_irreps(lmax, p_val, p_arg) - coeffs = e3nn.spherical_harmonics(irreps, position, normalize=True, normalization="integral") # [dim] + coeffs = e3nn.spherical_harmonics( + irreps, position, normalize=True, normalization="integral" + ) # [dim] return coeffs / jnp.sqrt(4 * jnp.pi) @@ -745,7 +783,9 @@ def from_s2grid( if lmax_in is None: lmax_in = lmax - _, _, sh_y, sha, qw = _spherical_harmonics_s2grid(lmax, res_beta, res_alpha, quadrature=x.quadrature, dtype=x.dtype) + _, _, sh_y, sha, qw = _spherical_harmonics_s2grid( + lmax, res_beta, res_alpha, quadrature=x.quadrature, dtype=x.dtype + ) # sh_y: (res_beta, (l+1)(l+2)/2) n = _normalization(lmax, normalization, x.dtype, "from_s2", lmax_in) @@ -760,7 +800,9 @@ def from_s2grid( if fft: int_a = _rfft(x.grid_values, lmax) / res_alpha # [..., res_beta, 2*l+1] else: - int_a = jnp.einsum("...ba,am->...bm", x.grid_values, sha) / res_alpha # [..., res_beta, 2*l+1] + int_a = ( + jnp.einsum("...ba,am->...bm", x.grid_values, sha) / res_alpha + ) # [..., res_beta, 2*l+1] # integrate over beta int_b = jnp.einsum("mbi,...bm->...i", sh_y, int_a) # [..., irreps] @@ -860,9 +902,13 @@ def to_s2grid( p_val, p_arg = _check_parities(coeffs.irreps, p_val, p_arg) if p_val is None or p_arg is None: - raise ValueError(f"p_val and p_arg cannot be determined from the irreps {coeffs.irreps}, please specify them.") + raise ValueError( + f"p_val and p_arg cannot be determined from the irreps {coeffs.irreps}, please specify them." + ) - _, _, sh_y, sha, _ = _spherical_harmonics_s2grid(lmax, res_beta, res_alpha, quadrature=quadrature, dtype=coeffs.dtype) + _, _, sh_y, sha, _ = _spherical_harmonics_s2grid( + lmax, res_beta, res_alpha, quadrature=quadrature, dtype=coeffs.dtype + ) n = _normalization(lmax, normalization, coeffs.dtype, "to_s2") @@ -881,7 +927,9 @@ def to_s2grid( signal = _irfft(signal_b, res_alpha) * res_alpha # [..., res_beta, res_alpha] else: - signal = jnp.einsum("...bm,am->...ba", signal_b, sha) # [..., res_beta, res_alpha] + signal = jnp.einsum( + "...bm,am->...ba", signal_b, sha + ) # [..., res_beta, res_alpha] return SphericalSignal(signal, quadrature=quadrature, p_val=p_val, p_arg=p_arg) @@ -918,8 +966,12 @@ def to_s2point( p_arg = point.irreps[0].ir.p p_val, _ = _check_parities(coeffs.irreps, None, p_arg) - sh = e3nn.spherical_harmonics(coeffs.irreps.ls, point, True, "integral") # [*shape2, irreps] - n = _normalization(sh.irreps.lmax, normalization, coeffs.dtype, "to_s2")[jnp.array(sh.irreps.ls)] # [num_irreps] + sh = e3nn.spherical_harmonics( + coeffs.irreps.ls, point, True, "integral" + ) # [*shape2, irreps] + n = _normalization(sh.irreps.lmax, normalization, coeffs.dtype, "to_s2")[ + jnp.array(sh.irreps.ls) + ] # [num_irreps] sh = sh * n shape1 = coeffs.shape[:-1] @@ -959,21 +1011,28 @@ def _quadrature_weights_soft(b: int) -> np.ndarray: r"""function copied from ``lie_learn.spaces.S3`` Compute quadrature weights for the grid used by Kostelec & Rockmore [1, 2]. """ - assert b % 2 == 0, "res_beta needs to be even for soft quadrature weights to be computed properly" + assert ( + b % 2 == 0 + ), "res_beta needs to be even for soft quadrature weights to be computed properly" k = np.arange(b // 2) return np.array( [ ( (4.0 / b) * np.sin(np.pi * (2.0 * j + 1.0) / (2.0 * b)) - * ((1.0 / (2 * k + 1)) * np.sin((2 * j + 1) * (2 * k + 1) * np.pi / (2.0 * b))).sum() + * ( + (1.0 / (2 * k + 1)) + * np.sin((2 * j + 1) * (2 * k + 1) * np.pi / (2.0 * b)) + ).sum() ) for j in np.arange(b) ], ) -def _s2grid(res_beta: int, res_alpha: int, quadrature: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: +def _s2grid( + res_beta: int, res_alpha: int, quadrature: str +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: r"""Returns arrays describing the grid on the sphere. Args: @@ -1032,17 +1091,23 @@ def _spherical_harmonics_s2grid( qw (`jax.numpy.ndarray`): array of shape ``(res_beta)`` """ y, alphas, qw = _s2grid(res_beta, res_alpha, quadrature) - y, alphas, qw = jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype), (y, alphas, qw)) + y, alphas, qw = jax.tree_util.tree_map( + lambda x: jnp.asarray(x, dtype), (y, alphas, qw) + ) sh_alpha = _sh_alpha(lmax, alphas) # [..., 2 * l + 1] sh_y = _sh_beta(lmax, y) # [..., (lmax + 1) * (lmax + 2) // 2] return y, alphas, sh_y, sh_alpha, qw -def _check_parities(irreps: e3nn.Irreps, p_val: Optional[int] = None, p_arg: Optional[int] = None) -> Tuple[int, int]: +def _check_parities( + irreps: e3nn.Irreps, p_val: Optional[int] = None, p_arg: Optional[int] = None +) -> Tuple[int, int]: p_even = {ir.p for mul, ir in irreps if ir.l % 2 == 0} p_odd = {ir.p for mul, ir in irreps if ir.l % 2 == 1} if not (p_even in [{1}, {-1}, set()] and p_odd in [{1}, {-1}, set()]): - raise ValueError("irrep parities should be of the form (p_val * p_arg**l) for all l, where p_val and p_arg are ±1") + raise ValueError( + "irrep parities should be of the form (p_val * p_arg**l) for all l, where p_val and p_arg are ±1" + ) p_even = p_even.pop() if p_even else None p_odd = p_odd.pop() if p_odd else None @@ -1058,7 +1123,9 @@ def _check_parities(irreps: e3nn.Irreps, p_val: Optional[int] = None, p_arg: Opt if p_even is None: p_even = p_val if p_even != p_val: - raise ValueError(f"irrep ({irreps}) parities are not compatible with the given p_val ({p_val}).") + raise ValueError( + f"irrep ({irreps}) parities are not compatible with the given p_val ({p_val})." + ) if p_arg is not None: if p_odd is None and p_even is not None: @@ -1067,7 +1134,9 @@ def _check_parities(irreps: e3nn.Irreps, p_val: Optional[int] = None, p_arg: Opt p_even = p_odd * p_arg elif p_odd is not None and p_even is not None: if p_odd != p_even * p_arg: - raise ValueError(f"irrep ({irreps}) parities are not compatible with the given p_arg ({p_arg}).") + raise ValueError( + f"irrep ({irreps}) parities are not compatible with the given p_arg ({p_arg})." + ) if p_even is not None and p_odd is not None: return p_even, p_even * p_odd @@ -1075,7 +1144,9 @@ def _check_parities(irreps: e3nn.Irreps, p_val: Optional[int] = None, p_arg: Opt return p_even, None -def _normalization(lmax: int, normalization: str, dtype, direction: str, lmax_in: Optional[int] = None) -> jnp.ndarray: +def _normalization( + lmax: int, normalization: str, dtype, direction: str, lmax_in: Optional[int] = None +) -> jnp.ndarray: """Handles normalization of different components of IrrepsArrays.""" assert direction in ["to_s2", "from_s2"] @@ -1100,7 +1171,9 @@ def _normalization(lmax: int, normalization: str, dtype, direction: str, lmax_in if direction == "to_s2": return jnp.sqrt(4 * jnp.pi) * jnp.ones(lmax + 1, dtype) / jnp.sqrt(lmax + 1) else: - return jnp.sqrt(4 * jnp.pi) * jnp.ones(lmax + 1, dtype) * jnp.sqrt(lmax_in + 1) + return ( + jnp.sqrt(4 * jnp.pi) * jnp.ones(lmax + 1, dtype) * jnp.sqrt(lmax_in + 1) + ) if normalization == "integral": # normalize such that the coefficient L=0 is equal to 4 pi the integral of the function # for "integral" normalization, the direction does not matter. @@ -1171,7 +1244,9 @@ def _expand_matrix(ls: List[int]) -> np.ndarray: m = np.zeros((lmax + 1, 2 * lmax + 1, sum(2 * l + 1 for l in ls)), np.float64) i = 0 for l in ls: - m[l, lmax - l : lmax + l + 1, i : i + 2 * l + 1] = np.eye(2 * l + 1, dtype=np.float64) + m[l, lmax - l : lmax + l + 1, i : i + 2 * l + 1] = np.eye( + 2 * l + 1, dtype=np.float64 + ) i += 2 * l + 1 return m diff --git a/e3nn_jax/_src/s2grid_test.py b/e3nn_jax/_src/s2grid_test.py index 16fc4e9f..e3b32900 100644 --- a/e3nn_jax/_src/s2grid_test.py +++ b/e3nn_jax/_src/s2grid_test.py @@ -8,14 +8,18 @@ from e3nn_jax.util import assert_output_dtype_matches_input_dtype -@pytest.mark.parametrize("irreps", ["0e", "0e + 1o", "1o + 2e", "2e + 0e", e3nn.s2_irreps(4)]) +@pytest.mark.parametrize( + "irreps", ["0e", "0e + 1o", "1o + 2e", "2e + 0e", e3nn.s2_irreps(4)] +) @pytest.mark.parametrize("quadrature", ["soft", "gausslegendre"]) @pytest.mark.parametrize("fft_to", [False, True]) @pytest.mark.parametrize("fft_from", [False, True]) def test_s2grid_transforms(keys, irreps, quadrature, fft_to, fft_from): @jax.jit def f(c): - res = e3nn.to_s2grid(c, 30, 51, quadrature=quadrature, fft=fft_to, p_val=1, p_arg=-1) + res = e3nn.to_s2grid( + c, 30, 51, quadrature=quadrature, fft=fft_to, p_val=1, p_arg=-1 + ) return e3nn.from_s2grid(res, c.irreps, fft=fft_from) a = e3nn.normal(irreps, keys[0]) @@ -35,7 +39,9 @@ def test_fft(keys): @pytest.mark.parametrize("quadrature", ["soft", "gausslegendre"]) def test_grid_vectors(quadrature): - _, _, sh_y, sh_alpha, _ = _spherical_harmonics_s2grid(lmax=1, res_beta=4, res_alpha=5, quadrature=quadrature) + _, _, sh_y, sh_alpha, _ = _spherical_harmonics_s2grid( + lmax=1, res_beta=4, res_alpha=5, quadrature=quadrature + ) r = e3nn.SphericalSignal(jnp.empty((4, 5)), quadrature).grid_vectors sh_y = np.stack([sh_y[:, 2], sh_y[:, 1], sh_y[:, 2]], axis=1) # for l=1 @@ -62,7 +68,16 @@ def test_to_s2grid_dtype(normalization, quadrature, fft): jax.config.update("jax_enable_x64", True) assert_output_dtype_matches_input_dtype( - lambda x: e3nn.to_s2grid(x, 4, 5, normalization=normalization, quadrature=quadrature, fft=fft, p_val=1, p_arg=1), + lambda x: e3nn.to_s2grid( + x, + 4, + 5, + normalization=normalization, + quadrature=quadrature, + fft=fft, + p_val=1, + p_arg=1, + ), e3nn.IrrepsArray("0e", jnp.array([1.0])), ) @@ -87,10 +102,18 @@ def test_inverse(keys, normalization, quadrature, fft, irreps): jax.config.update("jax_enable_x64", True) coeffs_orig = e3nn.normal(irreps, keys[0], (12,)) - sigs = jax.vmap(lambda x: e3nn.to_s2grid(x, 100, 99, normalization=normalization, quadrature=quadrature))(coeffs_orig) - coeffs_new = jax.vmap(lambda y: e3nn.from_s2grid(y, irreps, normalization=normalization, fft=fft))(sigs) - - np.testing.assert_allclose(coeffs_orig.array, coeffs_new.array, atol=1e-7, rtol=1e-7) + sigs = jax.vmap( + lambda x: e3nn.to_s2grid( + x, 100, 99, normalization=normalization, quadrature=quadrature + ) + )(coeffs_orig) + coeffs_new = jax.vmap( + lambda y: e3nn.from_s2grid(y, irreps, normalization=normalization, fft=fft) + )(sigs) + + np.testing.assert_allclose( + coeffs_orig.array, coeffs_new.array, atol=1e-7, rtol=1e-7 + ) def test_fft_dtype(): @@ -107,11 +130,15 @@ def test_to_s2point(keys, irreps, normalization, quadrature): jax.config.update("jax_enable_x64", True) coeffs = e3nn.normal(irreps, keys[0], ()) - s = e3nn.to_s2grid(coeffs, 20, 19, normalization=normalization, quadrature=quadrature) + s = e3nn.to_s2grid( + coeffs, 20, 19, normalization=normalization, quadrature=quadrature + ) vec = e3nn.IrrepsArray({1: "1e", -1: "1o"}[s.p_arg], s.grid_vectors) values = e3nn.to_s2point(coeffs, vec, normalization=normalization) - np.testing.assert_allclose(values.array[..., 0], s.grid_values, atol=1e-7, rtol=1e-7) + np.testing.assert_allclose( + values.array[..., 0], s.grid_values, atol=1e-7, rtol=1e-7 + ) jax.config.update("jax_enable_x64", False) @@ -129,7 +156,9 @@ def test_transform_by_angles(keys, irreps, alpha, beta, gamma): rotated_coeffs = e3nn.from_s2grid(rotated_sig, irreps) expected_rotated_coeffs = coeffs.transform_by_angles(alpha, beta, gamma) - np.testing.assert_allclose(rotated_coeffs.array, expected_rotated_coeffs.array, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose( + rotated_coeffs.array, expected_rotated_coeffs.array, atol=1e-5, rtol=1e-5 + ) @pytest.mark.parametrize("alpha", [0.1, 0.2]) @@ -146,7 +175,9 @@ def test_transform_by_matrix(keys, irreps, alpha, beta, gamma): rotated_coeffs = e3nn.from_s2grid(rotated_sig, irreps) expected_rotated_coeffs = coeffs.transform_by_angles(alpha, beta, gamma) - np.testing.assert_allclose(rotated_coeffs.array, expected_rotated_coeffs.array, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose( + rotated_coeffs.array, expected_rotated_coeffs.array, atol=1e-5, rtol=1e-5 + ) @pytest.mark.parametrize("alpha", [0.1, 0.2]) @@ -163,7 +194,9 @@ def test_transform_by_axis_angle(keys, irreps, alpha, beta, gamma): rotated_coeffs = e3nn.from_s2grid(rotated_sig, irreps) expected_rotated_coeffs = coeffs.transform_by_angles(alpha, beta, gamma) - np.testing.assert_allclose(rotated_coeffs.array, expected_rotated_coeffs.array, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose( + rotated_coeffs.array, expected_rotated_coeffs.array, atol=1e-5, rtol=1e-5 + ) @pytest.mark.parametrize("alpha", [0.1, 0.2]) @@ -180,7 +213,9 @@ def test_transform_by_quaternion(keys, irreps, alpha, beta, gamma): rotated_coeffs = e3nn.from_s2grid(rotated_sig, irreps) expected_rotated_coeffs = coeffs.transform_by_angles(alpha, beta, gamma) - np.testing.assert_allclose(rotated_coeffs.array, expected_rotated_coeffs.array, atol=1e-5, rtol=1e-5) + np.testing.assert_allclose( + rotated_coeffs.array, expected_rotated_coeffs.array, atol=1e-5, rtol=1e-5 + ) def test_s2_dirac(): @@ -201,7 +236,15 @@ def test_s2_dirac(): @pytest.mark.parametrize("quadrature", ["soft", "gausslegendre"]) def test_integrate_scalar(lmax, quadrature): coeffs = e3nn.normal(e3nn.s2_irreps(lmax, p_val=1, p_arg=-1), jax.random.PRNGKey(0)) - sig = e3nn.to_s2grid(coeffs, 100, 99, normalization="integral", quadrature=quadrature, p_val=1, p_arg=-1) + sig = e3nn.to_s2grid( + coeffs, + 100, + 99, + normalization="integral", + quadrature=quadrature, + p_val=1, + p_arg=-1, + ) integral = sig.integrate().array.squeeze() scalar_term = coeffs["0e"].array[0] @@ -220,7 +263,12 @@ def test_integrate_polynomials(degree): def test_sample(keys): - p = e3nn.to_s2grid(0.5 * e3nn.normal(e3nn.s2_irreps(4), keys[0]), 30, 51, quadrature="gausslegendre").apply(jnp.exp) + p = e3nn.to_s2grid( + 0.5 * e3nn.normal(e3nn.s2_irreps(4), keys[0]), + 30, + 51, + quadrature="gausslegendre", + ).apply(jnp.exp) p: e3nn.SphericalSignal = p / p.integrate() keys = jax.random.split(keys[1], 100_000) beta_index, alpha_index = jax.vmap(lambda k: p.sample(k))(keys) @@ -237,7 +285,9 @@ def test_sample(keys): @pytest.mark.parametrize("lmax", [2, 4, 10]) def test_find_peaks(lmax): - pytest.skip("Still has the bug `ValueError: buffer source array is read-only`") # TODO + pytest.skip( + "Still has the bug `ValueError: buffer source array is read-only`" + ) # TODO pos = jnp.asarray( [ @@ -251,7 +301,9 @@ def test_find_peaks(lmax): -1.0, ] ) - coeffs = e3nn.s2_sum_of_diracs(positions=pos, weights=val, lmax=lmax, p_val=1, p_arg=-1) + coeffs = e3nn.s2_sum_of_diracs( + positions=pos, weights=val, lmax=lmax, p_val=1, p_arg=-1 + ) sig = e3nn.to_s2grid(coeffs, 50, 49, quadrature="gausslegendre") x, f = sig.find_peaks(lmax) diff --git a/e3nn_jax/_src/scatter.py b/e3nn_jax/_src/scatter.py index 5d3f81e8..e669d536 100644 --- a/e3nn_jax/_src/scatter.py +++ b/e3nn_jax/_src/scatter.py @@ -19,7 +19,9 @@ def _distinct_but_small(x: jnp.ndarray) -> jnp.ndarray: shape = x.shape x = jnp.ravel(x) unique = jnp.unique(x, size=x.shape[0]) # Pigeonhole principle - x = jax.lax.scan(lambda _, i: (None, jnp.where(i == unique, size=1)[0][0]), None, x)[1] + x = jax.lax.scan( + lambda _, i: (None, jnp.where(i == unique, size=1)[0][0]), None, x + )[1] return jnp.reshape(x, shape) @@ -48,7 +50,16 @@ def scatter_sum( Returns: `jax.numpy.ndarray` or `IrrepsArray`: output array of shape ``(output_size, ...)`` """ - return _scatter_op("sum", 0.0, data, dst=dst, nel=nel, output_size=output_size, map_back=map_back, mode=mode) + return _scatter_op( + "sum", + 0.0, + data, + dst=dst, + nel=nel, + output_size=output_size, + map_back=map_back, + mode=mode, + ) def scatter_max( @@ -87,7 +98,16 @@ def scatter_max( if not data.irreps.is_scalar(): raise ValueError("scatter_max only works with scalar IrrepsArray") - return _scatter_op("max", initial, data, dst=dst, nel=nel, output_size=output_size, map_back=map_back, mode=mode) + return _scatter_op( + "max", + initial, + data, + dst=dst, + nel=nel, + output_size=output_size, + map_back=map_back, + mode=mode, + ) def _scatter_op( @@ -186,5 +206,9 @@ def index_add( >>> index_add(i, x, out_dim=4) Array([-9., 0., 5., 0.], dtype=float32) """ - warnings.warn("e3nn.index_add is deprecated, use e3nn.scatter_sum instead", DeprecationWarning) - return scatter_sum(input, dst=indices, nel=n_elements, output_size=out_dim, map_back=map_back) + warnings.warn( + "e3nn.index_add is deprecated, use e3nn.scatter_sum instead", DeprecationWarning + ) + return scatter_sum( + input, dst=indices, nel=n_elements, output_size=out_dim, map_back=map_back + ) diff --git a/e3nn_jax/_src/scatter_test.py b/e3nn_jax/_src/scatter_test.py index 5e4804c5..12b8574c 100644 --- a/e3nn_jax/_src/scatter_test.py +++ b/e3nn_jax/_src/scatter_test.py @@ -25,7 +25,10 @@ def test_scatter_sum(): ) np.testing.assert_allclose( # nel - e3nn.scatter_sum(jnp.array([1.0, 2.0, 1.0, 0.5, 0.5, 0.7, 0.2, 0.1]), nel=jnp.array([3, 2, 3])), + e3nn.scatter_sum( + jnp.array([1.0, 2.0, 1.0, 0.5, 0.5, 0.7, 0.2, 0.1]), + nel=jnp.array([3, 2, 3]), + ), jnp.array([4.0, 1.0, 1.0]), ) @@ -46,6 +49,9 @@ def test_scatter_max(): ) np.testing.assert_allclose( # nel - e3nn.scatter_max(jnp.array([-1.0, -2.0, -1.0, 0.5, 0.5, 0.7, 0.2, 0.1]), nel=jnp.array([3, 2, 3])), + e3nn.scatter_max( + jnp.array([-1.0, -2.0, -1.0, 0.5, 0.5, 0.7, 0.2, 0.1]), + nel=jnp.array([3, 2, 3]), + ), jnp.array([-1.0, 0.5, 0.7]), ) diff --git a/e3nn_jax/_src/so3.py b/e3nn_jax/_src/so3.py index 595f44cf..aa6b6f43 100644 --- a/e3nn_jax/_src/so3.py +++ b/e3nn_jax/_src/so3.py @@ -13,7 +13,9 @@ def change_basis_real_to_complex(l: int) -> np.ndarray: for m in range(1, l + 1): q[l + m, l + abs(m)] = (-1) ** m / np.sqrt(2) q[l + m, l - abs(m)] = 1j * (-1) ** m / np.sqrt(2) - return (-1j) ** l * q # Added factor of 1j**l to make the Clebsch-Gordan coefficients real + return ( + -1j + ) ** l * q # Added factor of 1j**l to make the Clebsch-Gordan coefficients real def clebsch_gordan(l1: int, l2: int, l3: int) -> np.ndarray: diff --git a/e3nn_jax/_src/so3_test.py b/e3nn_jax/_src/so3_test.py index 6c833946..72a9f63e 100644 --- a/e3nn_jax/_src/so3_test.py +++ b/e3nn_jax/_src/so3_test.py @@ -7,11 +7,23 @@ def test_clebsch_gordan_symmetry(): - assert jnp.allclose(clebsch_gordan(1, 2, 3), jnp.swapaxes(clebsch_gordan(1, 3, 2), 1, 2)) - assert jnp.allclose(clebsch_gordan(1, 2, 3), jnp.swapaxes(clebsch_gordan(2, 1, 3), 0, 1)) - assert jnp.allclose(clebsch_gordan(1, 2, 3), jnp.swapaxes(clebsch_gordan(3, 2, 1), 0, 2)) - assert jnp.allclose(clebsch_gordan(1, 2, 3), jnp.swapaxes(jnp.swapaxes(clebsch_gordan(3, 1, 2), 0, 1), 1, 2)) - assert jnp.allclose(clebsch_gordan(1, 2, 3), jnp.swapaxes(jnp.swapaxes(clebsch_gordan(2, 3, 1), 0, 2), 1, 2)) + assert jnp.allclose( + clebsch_gordan(1, 2, 3), jnp.swapaxes(clebsch_gordan(1, 3, 2), 1, 2) + ) + assert jnp.allclose( + clebsch_gordan(1, 2, 3), jnp.swapaxes(clebsch_gordan(2, 1, 3), 0, 1) + ) + assert jnp.allclose( + clebsch_gordan(1, 2, 3), jnp.swapaxes(clebsch_gordan(3, 2, 1), 0, 2) + ) + assert jnp.allclose( + clebsch_gordan(1, 2, 3), + jnp.swapaxes(jnp.swapaxes(clebsch_gordan(3, 1, 2), 0, 1), 1, 2), + ) + assert jnp.allclose( + clebsch_gordan(1, 2, 3), + jnp.swapaxes(jnp.swapaxes(clebsch_gordan(2, 3, 1), 0, 2), 1, 2), + ) def unique_triplets(lmax): diff --git a/e3nn_jax/_src/spherical_harmonics.py b/e3nn_jax/_src/spherical_harmonics.py index bfc7a613..394ef387 100644 --- a/e3nn_jax/_src/spherical_harmonics.py +++ b/e3nn_jax/_src/spherical_harmonics.py @@ -34,7 +34,9 @@ def sh( `jax.numpy.ndarray`: polynomials of the spherical harmonics """ input = IrrepsArray("1e", input) - return spherical_harmonics(irreps_out, input, normalize, normalization, algorithm=algorithm).array + return spherical_harmonics( + irreps_out, input, normalize, normalization, algorithm=algorithm + ).array def _check_is_vector(irreps: Irreps): @@ -110,13 +112,17 @@ def spherical_harmonics( if isinstance(irreps_out, int): l = irreps_out if not isinstance(input, IrrepsArray): - raise ValueError("If irreps_out is an int, input must be an IrrepsArray.") + raise ValueError( + "If irreps_out is an int, input must be an IrrepsArray." + ) vec_p = _check_is_vector(input.irreps) irreps_out = Irreps([(1, (l, vec_p**l))]) if all(isinstance(l, int) for l in irreps_out): if not isinstance(input, IrrepsArray): - raise ValueError("If irreps_out is a list of int, input must be an IrrepsArray.") + raise ValueError( + "If irreps_out is a list of int, input must be an IrrepsArray." + ) vec_p = _check_is_vector(input.irreps) irreps_out = Irreps([(1, (l, vec_p**l)) for l in irreps_out]) @@ -128,7 +134,9 @@ def spherical_harmonics( if isinstance(input, IrrepsArray): vec_p = _check_is_vector(input.irreps) if not all([vec_p == p for _, (l, p) in irreps_out if l % 2 == 1]): - raise ValueError(f"Input ({input.irreps}) and output ({irreps_out}) must have a compatible parity.") + raise ValueError( + f"Input ({input.irreps}) and output ({irreps_out}) must have a compatible parity." + ) x = input.array else: @@ -147,7 +155,10 @@ def spherical_harmonics( else: algorithm = config("spherical_harmonics_algorithm") - assert all(keyword in ["legendre", "recursive", "dense", "sparse", "custom_jvp"] for keyword in algorithm) + assert all( + keyword in ["legendre", "recursive", "dense", "sparse", "custom_jvp"] + for keyword in algorithm + ) assert x.shape[-1] == 3 if normalize: @@ -155,8 +166,13 @@ def spherical_harmonics( r2 = jnp.where(r2 == 0.0, 1.0, r2) x = x / jnp.sqrt(r2) - sh = _jited_spherical_harmonics(tuple(ir.l for _, ir in irreps_out), x, normalization, algorithm) - sh = [jnp.repeat(y[..., None, :], mul, -2) if mul != 1 else y[..., None, :] for (mul, ir), y in zip(irreps_out, sh)] + sh = _jited_spherical_harmonics( + tuple(ir.l for _, ir in irreps_out), x, normalization, algorithm + ) + sh = [ + jnp.repeat(y[..., None, :], mul, -2) if mul != 1 else y[..., None, :] + for (mul, ir), y in zip(irreps_out, sh) + ] return IrrepsArray.from_list(irreps_out, sh, x.shape[:-1], x.dtype) @@ -170,7 +186,9 @@ def _jited_spherical_harmonics( return _spherical_harmonics(ls, x, normalization, algorithm) -def _spherical_harmonics(ls: Tuple[int, ...], x: jnp.ndarray, normalization: str, algorithm: Tuple[str]) -> List[jnp.ndarray]: +def _spherical_harmonics( + ls: Tuple[int, ...], x: jnp.ndarray, normalization: str, algorithm: Tuple[str] +) -> List[jnp.ndarray]: if "legendre" in algorithm: out = _legendre_spherical_harmonics(max(ls), x, False, normalization) return [out[..., l**2 : (l + 1) ** 2] for l in ls] @@ -191,7 +209,11 @@ def _custom_jvp_spherical_harmonics( @_custom_jvp_spherical_harmonics.defjvp def _jvp( - ls: Tuple[int, ...], normalization: str, algorithm: Tuple[str], primals: Tuple[jnp.ndarray], tangents: Tuple[jnp.ndarray] + ls: Tuple[int, ...], + normalization: str, + algorithm: Tuple[str], + primals: Tuple[jnp.ndarray], + tangents: Tuple[jnp.ndarray], ) -> List[jnp.ndarray]: (x,) = primals (x_dot,) = tangents @@ -214,7 +236,12 @@ def h(l: int, r: jnp.ndarray) -> jnp.ndarray: return jnp.stack( [ sum( - [w[i, j, k] * r[..., i] * x_dot[..., k] for i in range(2 * l - 1) for k in range(3) if w[i, j, k] != 0] + [ + w[i, j, k] * r[..., i] * x_dot[..., k] + for i in range(2 * l - 1) + for k in range(3) + if w[i, j, k] != 0 + ] ) for j in range(2 * l + 1) ], @@ -227,14 +254,20 @@ def h(l: int, r: jnp.ndarray) -> jnp.ndarray: def _recursive_spherical_harmonics( - l: int, context: Dict[int, jnp.ndarray], input: jnp.ndarray, normalization: str, algorithm: Tuple[str] + l: int, + context: Dict[int, jnp.ndarray], + input: jnp.ndarray, + normalization: str, + algorithm: Tuple[str], ) -> sympy.Array: context.update(dict(jnp=jnp, clebsch_gordan=clebsch_gordan)) if l == 0: if 0 not in context: if normalization == "integral": - context[0] = math.sqrt(1 / (4 * math.pi)) * jnp.ones_like(input[..., :1]) + context[0] = math.sqrt(1 / (4 * math.pi)) * jnp.ones_like( + input[..., :1] + ) elif normalization == "component": context[0] = jnp.ones_like(input[..., :1]) else: @@ -262,13 +295,21 @@ def sh_var(l): w = sqrtQarray_to_sympy(clebsch_gordan(l1, l2, l)) yx = sympy.Array( [ - sum(sh_var(l1)[i] * sh_var(l2)[j] * w[i, j, k] for i in range(2 * l1 + 1) for j in range(2 * l2 + 1)) + sum( + sh_var(l1)[i] * sh_var(l2)[j] * w[i, j, k] + for i in range(2 * l1 + 1) + for j in range(2 * l2 + 1) + ) for k in range(2 * l + 1) ] ) - sph_1_l1 = _recursive_spherical_harmonics(l1, context, input, normalization, algorithm) - sph_1_l2 = _recursive_spherical_harmonics(l2, context, input, normalization, algorithm) + sph_1_l1 = _recursive_spherical_harmonics( + l1, context, input, normalization, algorithm + ) + sph_1_l2 = _recursive_spherical_harmonics( + l2, context, input, normalization, algorithm + ) y1 = yx.subs(zip(sh_var(l1), sph_1_l1)).subs(zip(sh_var(l2), sph_1_l2)) norm = sympy.sqrt(sum(y1.applyfunc(lambda x: x**2))) @@ -277,7 +318,8 @@ def sh_var(l): if l not in context: if normalization == "integral": x = math.sqrt((2 * l + 1) / (4 * math.pi)) / ( - math.sqrt((2 * l1 + 1) / (4 * math.pi)) * math.sqrt((2 * l2 + 1) / (4 * math.pi)) + math.sqrt((2 * l1 + 1) / (4 * math.pi)) + * math.sqrt((2 * l2 + 1) / (4 * math.pi)) ) elif normalization == "component": x = math.sqrt((2 * l + 1) / ((2 * l1 + 1) * (2 * l2 + 1))) @@ -362,7 +404,9 @@ def f2(l, m): p = jax.lax.fori_loop( 2, lmax + 1, - lambda l, p: p.at[k(l, 0)].set(f1(l, 0) * x * p[k(l - 1, 0)] - f2(l, 0) * p[k(l - 2, 0)]), + lambda l, p: p.at[k(l, 0)].set( + f1(l, 0) * x * p[k(l - 1, 0)] - f2(l, 0) * p[k(l - 2, 0)] + ), p, ) @@ -380,7 +424,9 @@ def g(m, vals): # Calculate P(l,m) def f(l, p): - p = p.at[k(l, m)].set(f1(l, m) * x * p[k(l - 1, m)] - f2(l, m) * p[k(l - 2, m)]) + p = p.at[k(l, m)].set( + f1(l, m) * x * p[k(l - 1, m)] - f2(l, m) * p[k(l - 2, m)] + ) p = p.at[k(l - 2, m)].multiply(rescalem) return p @@ -444,7 +490,12 @@ def _sh_beta(lmax: int, cos_betas: jnp.ndarray) -> jnp.ndarray: sh_y = sh_y * np.array( [ - math.sqrt(fractions.Fraction((2 * l + 1) * math.factorial(l - m), 4 * math.factorial(l + m)) / math.pi) + math.sqrt( + fractions.Fraction( + (2 * l + 1) * math.factorial(l - m), 4 * math.factorial(l + m) + ) + / math.pi + ) for l in range(lmax + 1) for m in range(l + 1) ], @@ -453,7 +504,9 @@ def _sh_beta(lmax: int, cos_betas: jnp.ndarray) -> jnp.ndarray: return sh_y -def _legendre_spherical_harmonics(lmax: int, x: jnp.ndarray, normalize: bool, normalization: str) -> jnp.ndarray: +def _legendre_spherical_harmonics( + lmax: int, x: jnp.ndarray, normalize: bool, normalization: str +) -> jnp.ndarray: alpha = jnp.arctan2(x[..., 0], x[..., 2]) sh_alpha = _sh_alpha(lmax, alpha) # [..., 2 * l + 1] diff --git a/e3nn_jax/_src/spherical_harmonics_test.py b/e3nn_jax/_src/spherical_harmonics_test.py index 4a083e7e..2aba1cc6 100644 --- a/e3nn_jax/_src/spherical_harmonics_test.py +++ b/e3nn_jax/_src/spherical_harmonics_test.py @@ -22,8 +22,12 @@ def test_equivariance(keys, algorithm, l): input = e3nn.normal("1o", keys[0], (10,)) abc = e3nn.rand_angles(keys[1], ()) - output1 = e3nn.spherical_harmonics(l, input.transform_by_angles(*abc), False, algorithm=algorithm) - output2 = e3nn.spherical_harmonics(l, input, False, algorithm=algorithm).transform_by_angles(*abc) + output1 = e3nn.spherical_harmonics( + l, input.transform_by_angles(*abc), False, algorithm=algorithm + ) + output2 = e3nn.spherical_harmonics( + l, input, False, algorithm=algorithm + ).transform_by_angles(*abc) np.testing.assert_allclose(output1.array, output2.array, atol=1e-2, rtol=1e-2) @@ -52,7 +56,11 @@ def test_normalization_integral(keys, algorithm, l): n = jnp.mean( e3nn.spherical_harmonics( - irreps, jax.random.normal(keys[l + 0], (3,)), normalize=True, normalization="integral", algorithm=algorithm + irreps, + jax.random.normal(keys[l + 0], (3,)), + normalize=True, + normalization="integral", + algorithm=algorithm, ).array ** 2 ) @@ -65,7 +73,11 @@ def test_normalization_norm(keys, algorithm, l): n = jnp.sum( e3nn.spherical_harmonics( - irreps, jax.random.normal(keys[l + 1], (3,)), normalize=True, normalization="norm", algorithm=algorithm + irreps, + jax.random.normal(keys[l + 1], (3,)), + normalize=True, + normalization="norm", + algorithm=algorithm, ).array ** 2 ) @@ -78,7 +90,11 @@ def test_normalization_component(keys, algorithm, l): n = jnp.mean( e3nn.spherical_harmonics( - irreps, jax.random.normal(keys[l + 2], (3,)), normalize=True, normalization="component", algorithm=algorithm + irreps, + jax.random.normal(keys[l + 2], (3,)), + normalize=True, + normalization="component", + algorithm=algorithm, ).array ** 2 ) @@ -90,8 +106,12 @@ def test_parity(keys, algorithm, l): irreps = e3nn.Irreps([l]) x = jax.random.normal(next(keys), (3,)) - y1 = (-1) ** l * e3nn.spherical_harmonics(irreps, x, normalize=True, normalization="integral", algorithm=algorithm) - y2 = e3nn.spherical_harmonics(irreps, -x, normalize=True, normalization="integral", algorithm=algorithm) + y1 = (-1) ** l * e3nn.spherical_harmonics( + irreps, x, normalize=True, normalization="integral", algorithm=algorithm + ) + y2 = e3nn.spherical_harmonics( + irreps, -x, normalize=True, normalization="integral", algorithm=algorithm + ) np.testing.assert_allclose(y1.array, y2.array, atol=1e-6, rtol=1e-6) @@ -99,12 +119,24 @@ def test_parity(keys, algorithm, l): def test_recurrence_relation(keys, algorithm, l): x = jax.random.normal(next(keys), (3,)) - y1 = e3nn.spherical_harmonics(e3nn.Irreps([l + 1]), x, normalize=True, normalization="integral", algorithm=algorithm).array + y1 = e3nn.spherical_harmonics( + e3nn.Irreps([l + 1]), + x, + normalize=True, + normalization="integral", + algorithm=algorithm, + ).array y2 = jnp.einsum( "ijk,i,j->k", e3nn.clebsch_gordan(1, l, l + 1), x, - e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True, normalization="integral", algorithm=algorithm).array, + e3nn.spherical_harmonics( + e3nn.Irreps([l]), + x, + normalize=True, + normalization="integral", + algorithm=algorithm, + ).array, ) y1 = y1 / jnp.linalg.norm(y1) @@ -116,7 +148,9 @@ def test_recurrence_relation(keys, algorithm, l): @pytest.mark.parametrize("irreps", ["3x1o+2e+2x4e", "2x0e", "10e"]) def test_check_grads(keys, algorithm, irreps, normalization): check_grads( - lambda x: e3nn.spherical_harmonics(irreps, x, normalize=False, normalization=normalization, algorithm=algorithm).array, + lambda x: e3nn.spherical_harmonics( + irreps, x, normalize=False, normalization=normalization, algorithm=algorithm + ).array, (jax.random.normal(keys[0], (10, 3)),), 1, modes=["fwd", "rev"], @@ -129,10 +163,14 @@ def test_check_grads(keys, algorithm, irreps, normalization): def test_normalize(keys, algorithm, l): x = jax.random.normal(keys[0], (10, 3)) y1 = ( - e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=True, algorithm=algorithm).array + e3nn.spherical_harmonics( + e3nn.Irreps([l]), x, normalize=True, algorithm=algorithm + ).array * jnp.linalg.norm(x, axis=1, keepdims=True) ** l ) - y2 = e3nn.spherical_harmonics(e3nn.Irreps([l]), x, normalize=False, algorithm=algorithm).array + y2 = e3nn.spherical_harmonics( + e3nn.Irreps([l]), x, normalize=False, algorithm=algorithm + ).array np.testing.assert_allclose(y1, y2, atol=1e-6, rtol=1e-5) diff --git a/e3nn_jax/_src/su2.py b/e3nn_jax/_src/su2.py index 55c05a85..d12faf00 100644 --- a/e3nn_jax/_src/su2.py +++ b/e3nn_jax/_src/su2.py @@ -32,7 +32,9 @@ def su2_clebsch_gordan(j1: float, j2: float, j3: float) -> np.ndarray: for m1 in (x / 2 for x in range(-int(2 * j1), int(2 * j1) + 1, 2)): for m2 in (x / 2 for x in range(-int(2 * j2), int(2 * j2) + 1, 2)): if abs(m1 + m2) <= j3: - mat[int(j1 + m1), int(j2 + m2), int(j3 + m1 + m2)] = _su2_cg((j1, m1), (j2, m2), (j3, m1 + m2)) + mat[int(j1 + m1), int(j2 + m2), int(j3 + m1 + m2)] = _su2_cg( + (j1, m1), (j2, m2), (j3, m1 + m2) + ) return mat / math.sqrt(2 * j3 + 1) @@ -59,7 +61,11 @@ def f(n): C = ( (2.0 * j3 + 1.0) * Fraction( - f(j3 + j1 - j2) * f(j3 - j1 + j2) * f(j1 + j2 - j3) * f(j3 + m3) * f(j3 - m3), + f(j3 + j1 - j2) + * f(j3 - j1 + j2) + * f(j1 + j2 - j3) + * f(j3 + m3) + * f(j3 - m3), f(j1 + j2 + j3 + 1) * f(j1 - m1) * f(j1 + m1) * f(j2 - m2) * f(j2 + m2), ) ) ** 0.5 @@ -67,7 +73,8 @@ def f(n): S = 0 for v in range(vmin, vmax + 1): S += (-1.0) ** (v + j2 + m2) * Fraction( - f(j2 + j3 + m1 - v) * f(j1 - m1 + v), f(v) * f(j3 - j1 + j2 - v) * f(j3 + m3 - v) * f(v + j1 - j2 - m3) + f(j2 + j3 + m1 - v) * f(j1 - m1 + v), + f(v) * f(j3 - j1 + j2 - v) * f(j3 + m3 - v) * f(v + j1 - j2 - m3), ) C = C * S return C diff --git a/e3nn_jax/_src/symmetric_tensor_product_haiku.py b/e3nn_jax/_src/symmetric_tensor_product_haiku.py index 295ad13e..40e24820 100644 --- a/e3nn_jax/_src/symmetric_tensor_product_haiku.py +++ b/e3nn_jax/_src/symmetric_tensor_product_haiku.py @@ -38,7 +38,9 @@ def __init__( self, orders: Tuple[int, ...], keep_irrep_out: Optional[Set[e3nn.Irrep]] = None, - get_parameter: Optional[Callable[[str, Tuple[int, ...], Any], jnp.ndarray]] = None, + get_parameter: Optional[ + Callable[[str, Tuple[int, ...], Any], jnp.ndarray] + ] = None, ): super().__init__() @@ -57,7 +59,9 @@ def __init__( self.keep_irrep_out = keep_irrep_out if get_parameter is None: - get_parameter = lambda name, shape, dtype: hk.get_parameter(name, shape, dtype, hk.initializers.RandomNormal()) + get_parameter = lambda name, shape, dtype: hk.get_parameter( + name, shape, dtype, hk.initializers.RandomNormal() + ) self.get_parameter = get_parameter @@ -80,7 +84,9 @@ def fn(x: e3nn.IrrepsArray): out = dict() for order in range(max(self.orders), 0, -1): # max(orders), ..., 1 - U = e3nn.reduced_symmetric_tensor_product_basis(x.irreps, order, keep_ir=self.keep_irrep_out) + U = e3nn.reduced_symmetric_tensor_product_basis( + x.irreps, order, keep_ir=self.keep_irrep_out + ) # ((w3 x + w2) x + w1) x # \-----------/ @@ -89,7 +95,9 @@ def fn(x: e3nn.IrrepsArray): if order in self.orders: for (mul, ir_out), u in zip(U.irreps, U.list): # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] - u = u / u.shape[-2] # normalize both U and the contraction with w + u = ( + u / u.shape[-2] + ) # normalize both U and the contraction with w w = self.get_parameter( # parameters initialized with a normal distribution (variance 1) f"w{order}_{ir_out}", @@ -127,7 +135,10 @@ def fn(x: e3nn.IrrepsArray): # out[irrep_out] : [num_channel, ir_out.dim] irreps_out = e3nn.Irreps(sorted(out.keys())) return e3nn.IrrepsArray.from_list( - irreps_out, [out[ir][:, None, :] for (_, ir) in irreps_out], (x.shape[0],), x.dtype + irreps_out, + [out[ir][:, None, :] for (_, ir) in irreps_out], + (x.shape[0],), + x.dtype, ) # Treat batch indices using vmap diff --git a/e3nn_jax/_src/tensor_product_with_spherical_harmonics.py b/e3nn_jax/_src/tensor_product_with_spherical_harmonics.py index fd04d527..cb1d23b6 100644 --- a/e3nn_jax/_src/tensor_product_with_spherical_harmonics.py +++ b/e3nn_jax/_src/tensor_product_with_spherical_harmonics.py @@ -37,7 +37,9 @@ def tensor_product_with_spherical_harmonics( input = e3nn.IrrepsArray.as_irreps_array(input) if not (vector.irreps == "1o" or vector.irreps == "1e"): - raise ValueError("tensor_product_with_spherical_harmonics: vector must be a vector.") + raise ValueError( + "tensor_product_with_spherical_harmonics: vector must be a vector." + ) leading_shape = jnp.broadcast_shapes(input.shape[:-1], vector.shape[:-1]) input = input.broadcast_to(leading_shape + (-1,)) @@ -50,7 +52,9 @@ def tensor_product_with_spherical_harmonics( return f(input, vector, degree) -def impl(input: e3nn.IrrepsArray, vector: e3nn.IrrepsArray, degree: int) -> e3nn.IrrepsArray: +def impl( + input: e3nn.IrrepsArray, vector: e3nn.IrrepsArray, degree: int +) -> e3nn.IrrepsArray: """ This implementation looks like a lot of operations, but actually only few lines are traced by JAX. They are indicated by the comment `# <-- ops`. @@ -64,7 +68,9 @@ def impl(input: e3nn.IrrepsArray, vector: e3nn.IrrepsArray, degree: int) -> e3nn def fix_gimbal_lock(array, inverse): array_rot = array.transform_by_angles(0.0, jnp.pi / 2.0, 0.0, inverse=inverse) - return jax.tree_util.tree_map(lambda x_rot, x: jnp.where(gimbal_lock, x_rot, x), array_rot, array) + return jax.tree_util.tree_map( + lambda x_rot, x: jnp.where(gimbal_lock, x_rot, x), array_rot, array + ) input = fix_gimbal_lock(input, inverse=True) # <-- ops vector = fix_gimbal_lock(vector, inverse=True) # <-- ops diff --git a/e3nn_jax/_src/tensor_products.py b/e3nn_jax/_src/tensor_products.py index e9dea65f..b341e42b 100644 --- a/e3nn_jax/_src/tensor_products.py +++ b/e3nn_jax/_src/tensor_products.py @@ -98,14 +98,26 @@ def tensor_product( irreps_out = Irreps(irreps_out) irreps_out, p, _ = irreps_out.sort() - instructions = [(i_1, i_2, p[i_out], mode, train) for i_1, i_2, i_out, mode, train in instructions] + instructions = [ + (i_1, i_2, p[i_out], mode, train) + for i_1, i_2, i_out, mode, train in instructions + ] tp = FunctionalTensorProduct( - input1.irreps, input2.irreps, irreps_out, instructions, irrep_normalization=irrep_normalization + input1.irreps, + input2.irreps, + irreps_out, + instructions, + irrep_normalization=irrep_normalization, ) output = naive_broadcast_decorator( - partial(tp.left_right, fused=fused, sparse=sparse, custom_einsum_jvp=custom_einsum_jvp) + partial( + tp.left_right, + fused=fused, + sparse=sparse, + custom_einsum_jvp=custom_einsum_jvp, + ) )(input1, input2) if regroup_output: output = output.regroup() @@ -148,7 +160,9 @@ def elementwise_tensor_product( filter_ir_out = [Irrep(ir) for ir in filter_ir_out] if input1.irreps.num_irreps != input2.irreps.num_irreps: - raise ValueError(f"Number of irreps must be the same, got {input1.irreps.num_irreps} and {input2.irreps.num_irreps}") + raise ValueError( + f"Number of irreps must be the same, got {input1.irreps.num_irreps} and {input2.irreps.num_irreps}" + ) input1, input2 = _align_two_irreps_arrays(input1, input2) @@ -272,7 +286,9 @@ def tensor_square( elif irrep_normalization == "none": alpha = 1 else: - raise ValueError(f"irrep_normalization={irrep_normalization}") + raise ValueError( + f"irrep_normalization={irrep_normalization}" + ) else: if irrep_normalization == "component": if ir_out.l == 0: @@ -287,7 +303,9 @@ def tensor_square( elif irrep_normalization == "none": alpha = 1 else: - raise ValueError(f"irrep_normalization={irrep_normalization}") + raise ValueError( + f"irrep_normalization={irrep_normalization}" + ) i_out = len(irreps_out) irreps_out.append((mul, ir_out)) @@ -296,7 +314,10 @@ def tensor_square( irreps_out = Irreps(irreps_out) irreps_out, p, _ = irreps_out.sort() - instructions = [(i_1, i_2, p[i_out], mode, train, alpha) for i_1, i_2, i_out, mode, train, alpha in instructions] + instructions = [ + (i_1, i_2, p[i_out], mode, train, alpha) + for i_1, i_2, i_out, mode, train, alpha in instructions + ] tp = FunctionalTensorProduct( input.irreps, @@ -306,7 +327,9 @@ def tensor_square( irrep_normalization="none", ) - output = naive_broadcast_decorator(partial(tp.left_right, fused=fused, custom_einsum_jvp=custom_einsum_jvp))(input, input) + output = naive_broadcast_decorator( + partial(tp.left_right, fused=fused, custom_einsum_jvp=custom_einsum_jvp) + )(input, input) if regroup_output: output = output.regroup() return output diff --git a/e3nn_jax/_src/tensor_products_test.py b/e3nn_jax/_src/tensor_products_test.py index b146c941..0a6b8b4e 100644 --- a/e3nn_jax/_src/tensor_products_test.py +++ b/e3nn_jax/_src/tensor_products_test.py @@ -60,7 +60,9 @@ def test_tensor_square_normalization(keys): x = e3nn.normal("2x0e + 2x0o + 1o + 1e", keys[1], (10_000,), normalize=True) y = e3nn.tensor_square(x, normalized_input=True, irrep_normalization="norm") - np.testing.assert_allclose(e3nn.mean(e3nn.norm(y, squared=True), axis=0).array, 1.0, rtol=0.1) + np.testing.assert_allclose( + e3nn.mean(e3nn.norm(y, squared=True), axis=0).array, 1.0, rtol=0.1 + ) x = e3nn.normal("2x0e + 2x0o + 1o + 1e", keys[1], (10_000,), normalize=True) y = e3nn.tensor_square(x, normalized_input=True, irrep_normalization="component") @@ -72,7 +74,9 @@ def test_tensor_square_normalization(keys): x = e3nn.normal("2x0e + 2x0o + 1o + 1e", keys[1], (10_000,), normalization="norm") y = e3nn.tensor_square(x, irrep_normalization="norm") - np.testing.assert_allclose(e3nn.mean(e3nn.norm(y, squared=True), axis=0).array, 1.0, rtol=0.1) + np.testing.assert_allclose( + e3nn.mean(e3nn.norm(y, squared=True), axis=0).array, 1.0, rtol=0.1 + ) def test_tensor_square_and_spherical_harmonics(keys): @@ -82,11 +86,15 @@ def test_tensor_square_and_spherical_harmonics(keys): y2 = e3nn.spherical_harmonics("2e", x, normalize=False, normalization="norm") np.testing.assert_allclose(y1.array, y2.array) - y1 = e3nn.tensor_square(x, normalized_input=True, irrep_normalization="component")["2e"] + y1 = e3nn.tensor_square(x, normalized_input=True, irrep_normalization="component")[ + "2e" + ] y2 = e3nn.spherical_harmonics("2e", x, normalize=False, normalization="component") np.testing.assert_allclose(y1.array, y2.array, atol=1e-5) # normalize the input - y1 = e3nn.tensor_square(x / e3nn.norm(x), normalized_input=True, irrep_normalization="component")["2e"] + y1 = e3nn.tensor_square( + x / e3nn.norm(x), normalized_input=True, irrep_normalization="component" + )["2e"] y2 = e3nn.spherical_harmonics("2e", x, normalize=True, normalization="component") np.testing.assert_allclose(y1.array, y2.array, atol=1e-5) diff --git a/e3nn_jax/_src/util/decorators.py b/e3nn_jax/_src/util/decorators.py index 1e8f957e..a12a7fa3 100644 --- a/e3nn_jax/_src/util/decorators.py +++ b/e3nn_jax/_src/util/decorators.py @@ -5,25 +5,43 @@ from e3nn_jax import Irreps, IrrepsArray -def overload_for_irreps_without_array(irrepsarray_argnums=None, irrepsarray_argnames=None, shape=()): +def overload_for_irreps_without_array( + irrepsarray_argnums=None, irrepsarray_argnames=None, shape=() +): def decorator(func): # TODO: this is very bad to use a function from the internal API try: - from jax._src.api import _infer_argnums_and_argnames as infer_argnums_and_argnames + from jax._src.api import ( + _infer_argnums_and_argnames as infer_argnums_and_argnames, + ) except ImportError: from jax._src.api_util import infer_argnums_and_argnames - argnums, argnames = infer_argnums_and_argnames(inspect.signature(func), irrepsarray_argnums, irrepsarray_argnames) + argnums, argnames = infer_argnums_and_argnames( + inspect.signature(func), irrepsarray_argnums, irrepsarray_argnames + ) @wraps(func) def wrapper(*args, **kwargs): - concerned_args = [args[i] for i in argnums if i < len(args)] + [kwargs[k] for k in argnames if k in kwargs] + concerned_args = [args[i] for i in argnums if i < len(args)] + [ + kwargs[k] for k in argnames if k in kwargs + ] if any(isinstance(arg, (Irreps, str)) for arg in concerned_args): # assume arguments are Irreps (not IrrepsArray) - converted_args = {i: IrrepsArray.ones(a, shape) for i, a in enumerate(args) if i in argnums} - converted_args.update({k: IrrepsArray.ones(v, shape) for k, v in kwargs.items() if k in argnames}) + converted_args = { + i: IrrepsArray.ones(a, shape) + for i, a in enumerate(args) + if i in argnums + } + converted_args.update( + { + k: IrrepsArray.ones(v, shape) + for k, v in kwargs.items() + if k in argnames + } + ) def fn(converted_args): args_ = [converted_args.get(i, a) for i, a in enumerate(args)] @@ -35,8 +53,12 @@ def fn(converted_args): if isinstance(output, IrrepsArray): return output.irreps if isinstance(output, tuple): - return tuple(o.irreps if isinstance(o, IrrepsArray) else o for o in output) - raise TypeError(f"{func.__name__} returned {type(output)} which is not supported by `overload_irrep_no_data`.") + return tuple( + o.irreps if isinstance(o, IrrepsArray) else o for o in output + ) + raise TypeError( + f"{func.__name__} returned {type(output)} which is not supported by `overload_irrep_no_data`." + ) # otherwise, assume arguments are IrrepsArray return func(*args, **kwargs) diff --git a/e3nn_jax/_src/util/dtype.py b/e3nn_jax/_src/util/dtype.py index 4a114712..09bb8f99 100644 --- a/e3nn_jax/_src/util/dtype.py +++ b/e3nn_jax/_src/util/dtype.py @@ -8,5 +8,7 @@ def get_pytree_dtype(*args, default_dtype=jnp.float32, real_part=False): return default_dtype if real_part: - return jax.eval_shape(lambda xs: sum(jnp.sum(jnp.real(x)) for x in xs), leaves).dtype + return jax.eval_shape( + lambda xs: sum(jnp.sum(jnp.real(x)) for x in xs), leaves + ).dtype return jax.eval_shape(lambda xs: sum(jnp.sum(x) for x in xs), leaves).dtype diff --git a/e3nn_jax/_src/util/math_numpy.py b/e3nn_jax/_src/util/math_numpy.py index db14beaa..419c35a4 100644 --- a/e3nn_jax/_src/util/math_numpy.py +++ b/e3nn_jax/_src/util/math_numpy.py @@ -38,7 +38,12 @@ def as_approx_integer_ratio(x): def limit_denominator(n, d, max_denominator=1_000_000): # (n, d) = must be normalized n0, d0 = n, d - p0, q0, p1, q1 = np.zeros_like(n), np.ones_like(n), np.ones_like(n), np.zeros_like(n) + p0, q0, p1, q1 = ( + np.zeros_like(n), + np.ones_like(n), + np.ones_like(n), + np.zeros_like(n), + ) while True: a = n // d q2 = q0 + a * q1 @@ -72,7 +77,9 @@ def round_to_sqrt_rational(x: np.ndarray, max_denominator=4096) -> np.ndarray: """Round a number to the closest number of the form ``sqrt(p)/q`` for ``q <= max_denominator``""" x = np.array(x) if np.iscomplex(x).any(): - return _round_to_sqrt_rational(np.real(x), max_denominator) + 1j * _round_to_sqrt_rational(np.imag(x), max_denominator) + return _round_to_sqrt_rational( + np.real(x), max_denominator + ) + 1j * _round_to_sqrt_rational(np.imag(x), max_denominator) return _round_to_sqrt_rational(np.real(x), max_denominator) @@ -81,7 +88,10 @@ def gram_schmidt(A: np.ndarray, *, epsilon=1e-5, round_fn=lambda x: x) -> np.nda Orthogonalize a matrix using the Gram-Schmidt process. """ assert A.ndim == 2, "Gram-Schmidt process only works for matrices." - assert A.dtype in [np.float64, np.complex128], "Gram-Schmidt process only works for float64 matrices." + assert A.dtype in [ + np.float64, + np.complex128, + ], "Gram-Schmidt process only works for float64 matrices." Q = [] for i in range(A.shape[0]): v = A[i] diff --git a/e3nn_jax/_src/util/optimize_jaxpr.py b/e3nn_jax/_src/util/optimize_jaxpr.py index 677dfac9..6343d92e 100644 --- a/e3nn_jax/_src/util/optimize_jaxpr.py +++ b/e3nn_jax/_src/util/optimize_jaxpr.py @@ -14,7 +14,9 @@ def curry(f): @curry @curry -def closed_jaxpr_transform_to_fn_transform(closed_jaxpr_transform, fn, *args): # pragma: no cover +def closed_jaxpr_transform_to_fn_transform( + closed_jaxpr_transform, fn, *args +): # pragma: no cover f = lu.wrap_init(fn) in_flat, in_tree = jax.tree_util.tree_flatten(args) @@ -44,7 +46,9 @@ def re(vars): ) -def remove_deadcode(jaxpr: Jaxpr, output_indices=None) -> ClosedJaxpr: # pragma: no cover +def remove_deadcode( + jaxpr: Jaxpr, output_indices=None +) -> ClosedJaxpr: # pragma: no cover if output_indices is None: output_indices = range(len(jaxpr.outvars)) @@ -58,13 +62,16 @@ def remove_deadcode(jaxpr: Jaxpr, output_indices=None) -> ClosedJaxpr: # pragma if eqn.primitive in [xla_call_p]: xla_call_jaxpr = eqn.params["call_jaxpr"] xla_call_jaxpr, input_indices = remove_deadcode( - xla_call_jaxpr, [i for i, out in enumerate(eqn.outvars) if out in needed] + xla_call_jaxpr, + [i for i, out in enumerate(eqn.outvars) if out in needed], ) eqn = eqn.replace( params={ **eqn.params, "call_jaxpr": xla_call_jaxpr, - "donated_invars": tuple(eqn.params["donated_invars"][i] for i in input_indices), + "donated_invars": tuple( + eqn.params["donated_invars"][i] for i in input_indices + ), }, outvars=[out for out in eqn.outvars if out in needed], invars=[inv for i, inv in enumerate(eqn.invars) if i in input_indices], @@ -84,7 +91,9 @@ def remove_deadcode(jaxpr: Jaxpr, output_indices=None) -> ClosedJaxpr: # pragma return (jaxpr, input_indices) -def remove_duplicate_constants(closed_jaxpr: ClosedJaxpr) -> ClosedJaxpr: # pragma: no cover +def remove_duplicate_constants( + closed_jaxpr: ClosedJaxpr, +) -> ClosedJaxpr: # pragma: no cover for i, cst1 in enumerate(closed_jaxpr.consts): for j, cst2 in enumerate(closed_jaxpr.consts[:i]): if type(cst1) is np.ndarray and type(cst2) is np.ndarray: @@ -99,7 +108,10 @@ def remove_duplicate_constants(closed_jaxpr: ClosedJaxpr) -> ClosedJaxpr: # pra closed_jaxpr.jaxpr = closed_jaxpr.jaxpr.replace( eqns=[ eqn.replace( - params={k: remove_duplicate_constants(v) if type(v) is ClosedJaxpr else v for k, v in eqn.params.items()} + params={ + k: remove_duplicate_constants(v) if type(v) is ClosedJaxpr else v + for k, v in eqn.params.items() + } ) for eqn in closed_jaxpr.jaxpr.eqns ] @@ -108,7 +120,9 @@ def remove_duplicate_constants(closed_jaxpr: ClosedJaxpr) -> ClosedJaxpr: # pra return closed_jaxpr -def remove_duplicate_equations(jaxpr: Jaxpr, skip_first=0) -> ClosedJaxpr: # pragma: no cover +def remove_duplicate_equations( + jaxpr: Jaxpr, skip_first=0 +) -> ClosedJaxpr: # pragma: no cover def atom_key(a: Atom): if type(a) is Literal: return a.val @@ -146,7 +160,12 @@ def param_key(p: Any): # apply reccursively jaxpr = jaxpr.replace( eqns=[ - eqn.replace(params={k: remove_duplicate_equations(v) if type(v) is Jaxpr else v for k, v in eqn.params.items()}) + eqn.replace( + params={ + k: remove_duplicate_equations(v) if type(v) is Jaxpr else v + for k, v in eqn.params.items() + } + ) for eqn in jaxpr.eqns ] ) @@ -154,7 +173,9 @@ def param_key(p: Any): eqns=[ eqn.replace( params={ - k: v.replace(jaxpr=remove_duplicate_equations(v.jaxpr)) if type(v) is ClosedJaxpr else v + k: v.replace(jaxpr=remove_duplicate_equations(v.jaxpr)) + if type(v) is ClosedJaxpr + else v for k, v in eqn.params.items() } ) diff --git a/e3nn_jax/_src/util/test.py b/e3nn_jax/_src/util/test.py index 6606941a..6aaf41f6 100644 --- a/e3nn_jax/_src/util/test.py +++ b/e3nn_jax/_src/util/test.py @@ -55,7 +55,12 @@ def assert_output_dtype_matches_input_dtype(fun: Callable, *args, **kwargs): raise ValueError("This test requires jax_enable_x64=True") dtype = get_pytree_dtype(args, kwargs, real_part=True) - assert get_pytree_dtype(jax.eval_shape(fun, *args, **kwargs), default_dtype=dtype, real_part=True) == dtype + assert ( + get_pytree_dtype( + jax.eval_shape(fun, *args, **kwargs), default_dtype=dtype, real_part=True + ) + == dtype + ) def astype(x, dtype): if x.dtype.kind == "f": @@ -73,4 +78,6 @@ def astype(x, dtype): in_dtype = jax.tree_util.tree_map(lambda x: x.dtype, args) out_dtype = jax.tree_util.tree_map(lambda x: x.dtype, out) - raise AssertionError(f"Expected {dtype} -> {dtype}. Got {in_dtype} -> {out_dtype}") + raise AssertionError( + f"Expected {dtype} -> {dtype}. Got {in_dtype} -> {out_dtype}" + ) diff --git a/e3nn_jax/experimental/linear_shtp.py b/e3nn_jax/experimental/linear_shtp.py index 807f3287..a76af86c 100644 --- a/e3nn_jax/experimental/linear_shtp.py +++ b/e3nn_jax/experimental/linear_shtp.py @@ -35,7 +35,9 @@ class LinearSHTP(flax.linen.Module): mix: bool = True @flax.linen.compact - def __call__(self, input: e3nn.IrrepsArray, direction: e3nn.IrrepsArray) -> e3nn.IrrepsArray: + def __call__( + self, input: e3nn.IrrepsArray, direction: e3nn.IrrepsArray + ) -> e3nn.IrrepsArray: assert input.shape == (input.irreps.dim,) assert direction.shape == (3,) @@ -49,8 +51,12 @@ def __call__(self, input: e3nn.IrrepsArray, direction: e3nn.IrrepsArray) -> e3nn gimbal_lock = jnp.abs(direction.array[1]) > 0.99 def fix_gimbal_lock(array, inverse): - array_rot = array.transform_by_angles(0.0, jnp.pi / 2.0, 0.0, inverse=inverse) - return jax.tree_util.tree_map(lambda x_rot, x: jnp.where(gimbal_lock, x_rot, x), array_rot, array) + array_rot = array.transform_by_angles( + 0.0, jnp.pi / 2.0, 0.0, inverse=inverse + ) + return jax.tree_util.tree_map( + lambda x_rot, x: jnp.where(gimbal_lock, x_rot, x), array_rot, array + ) input = fix_gimbal_lock(input, inverse=True) direction = fix_gimbal_lock(direction, inverse=True) @@ -138,7 +144,9 @@ def fix_gimbal_lock(array, inverse): outputs.append(z) if self.mix: - z = sum_tensors(zs, (mulz, irz.dim), empty_return_none=True, dtype=x.dtype) + z = sum_tensors( + zs, (mulz, irz.dim), empty_return_none=True, dtype=x.dtype + ) if z is not None: z = z / jnp.sqrt(dim) irreps_out_.append((z.shape[0], irz)) @@ -156,7 +164,11 @@ def fix_gimbal_lock(array, inverse): return out -def shtp(input: e3nn.IrrepsArray, direction: e3nn.IrrepsArray, filter_irreps_out: Sequence[e3nn.Irrep]) -> e3nn.IrrepsArray: +def shtp( + input: e3nn.IrrepsArray, + direction: e3nn.IrrepsArray, + filter_irreps_out: Sequence[e3nn.Irrep], +) -> e3nn.IrrepsArray: assert input.shape == (input.irreps.dim,) assert direction.shape == (3,) @@ -171,7 +183,9 @@ def shtp(input: e3nn.IrrepsArray, direction: e3nn.IrrepsArray, filter_irreps_out def fix_gimbal_lock(array, inverse): array_rot = array.transform_by_angles(0.0, jnp.pi / 2.0, 0.0, inverse=inverse) - return jax.tree_util.tree_map(lambda x_rot, x: jnp.where(gimbal_lock, x_rot, x), array_rot, array) + return jax.tree_util.tree_map( + lambda x_rot, x: jnp.where(gimbal_lock, x_rot, x), array_rot, array + ) input = fix_gimbal_lock(input, inverse=True) direction = fix_gimbal_lock(direction, inverse=True) @@ -204,8 +218,12 @@ def fix_gimbal_lock(array, inverse): ly = (irx.l + irz.l) % 2 if ird.p**ly == py: zeros = jnp.zeros_like(x, shape=(mulx, l + 1, irz.dim)) - z = zeros.at[:, jnp.arange(l + 1), jnp.arange(irz.l, irz.l + l + 1)].set(x[:, l:]) - z = z.at[:, jnp.arange(l, 0, -1), jnp.arange(irz.l - l, irz.l)].set(x[:, :l]) + z = zeros.at[ + :, jnp.arange(l + 1), jnp.arange(irz.l, irz.l + l + 1) + ].set(x[:, l:]) + z = z.at[:, jnp.arange(l, 0, -1), jnp.arange(irz.l - l, irz.l)].set( + x[:, :l] + ) z = jnp.reshape(z, (mulx * (l + 1), irz.dim)) irreps_out.append((z.shape[0], irz)) @@ -215,8 +233,12 @@ def fix_gimbal_lock(array, inverse): ly = (irx.l + irz.l + 1) % 2 if ird.p**ly == py and l > 0: zeros = jnp.zeros_like(x, shape=(mulx, l, irz.dim)) - z = zeros.at[:, jnp.arange(l), jnp.arange(irz.l - 1, irz.l - l - 1, -1)].set(x[:, l + 1 :]) - z = z.at[:, jnp.arange(l), jnp.arange(irz.l + 1, irz.l + l + 1)].set(-x[:, :l][:, ::-1]) + z = zeros.at[ + :, jnp.arange(l), jnp.arange(irz.l - 1, irz.l - l - 1, -1) + ].set(x[:, l + 1 :]) + z = z.at[:, jnp.arange(l), jnp.arange(irz.l + 1, irz.l + l + 1)].set( + -x[:, :l][:, ::-1] + ) z = jnp.reshape(z, (mulx * l, irz.dim)) irreps_out.append((z.shape[0], irz)) diff --git a/e3nn_jax/experimental/linear_shtp_test.py b/e3nn_jax/experimental/linear_shtp_test.py index 019c3de7..48ead01b 100644 --- a/e3nn_jax/experimental/linear_shtp_test.py +++ b/e3nn_jax/experimental/linear_shtp_test.py @@ -19,11 +19,18 @@ def test_equivariance_linear(keys, mix: bool): def f(x, d): return conv.apply(w, x, d) - z1, z2 = equivariance_test(f, next(keys), x, e3nn.IrrepsArray("1o", jnp.array([0.1, -0.2, -0.4]))) + z1, z2 = equivariance_test( + f, next(keys), x, e3nn.IrrepsArray("1o", jnp.array([0.1, -0.2, -0.4])) + ) np.testing.assert_allclose(z1.array, z2.array, atol=1e-5) # Test Gimbal Lock - z1, z2 = equivariance_test(e3nn.grad(f, 1), next(keys), x, e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0]))) + z1, z2 = equivariance_test( + e3nn.grad(f, 1), + next(keys), + x, + e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0])), + ) np.testing.assert_allclose(z1.array, z2.array, atol=1e-5) if mix: @@ -36,9 +43,16 @@ def test_equivariance(keys): def f(x, d): return shtp(x, d, "0e + 0o + 4x1o + 1e + 2e + 2o") - z1, z2 = equivariance_test(f, next(keys), x, e3nn.IrrepsArray("1o", jnp.array([0.1, -0.2, -0.4]))) + z1, z2 = equivariance_test( + f, next(keys), x, e3nn.IrrepsArray("1o", jnp.array([0.1, -0.2, -0.4])) + ) np.testing.assert_allclose(z1.array, z2.array, atol=1e-5) # Test Gimbal Lock - z1, z2 = equivariance_test(e3nn.grad(f, 1), next(keys), x, e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0]))) + z1, z2 = equivariance_test( + e3nn.grad(f, 1), + next(keys), + x, + e3nn.IrrepsArray("1o", jnp.array([0.0, 1.0, 0.0])), + ) np.testing.assert_allclose(z1.array, z2.array, atol=1e-5) diff --git a/e3nn_jax/experimental/point_convolution.py b/e3nn_jax/experimental/point_convolution.py index 64cfd943..e6ac95fc 100644 --- a/e3nn_jax/experimental/point_convolution.py +++ b/e3nn_jax/experimental/point_convolution.py @@ -52,13 +52,17 @@ def radial_basis(r, cutoff, num_radial_basis): return e3nn.bessel(r, num_radial_basis) * e3nn.soft_envelope(r)[:, None] -def _call(self, positions, node_feats, senders, receivers, Linear, MultiLayerPerceptron): +def _call( + self, positions, node_feats, senders, receivers, Linear, MultiLayerPerceptron +): if not isinstance(positions, e3nn.IrrepsArray): raise TypeError( f"positions must be an e3nn.IrrepsArray with shape (n_nodes, 3) and irreps '1o' or '1e'. Got {type(positions)}" ) if not isinstance(node_feats, e3nn.IrrepsArray): - raise TypeError(f"node_feats must be an e3nn.IrrepsArray with shape (n_nodes, irreps). Got {type(node_feats)}") + raise TypeError( + f"node_feats must be an e3nn.IrrepsArray with shape (n_nodes, irreps). Got {type(node_feats)}" + ) assert positions.ndim == 2 assert node_feats.ndim == 2 @@ -97,7 +101,9 @@ def _call(self, positions, node_feats, senders, receivers, Linear, MultiLayerPer messages = messages * mix # [n_edges, irreps] - zeros = e3nn.IrrepsArray.zeros(messages.irreps, node_feats.shape[:1], messages.dtype) + zeros = e3nn.IrrepsArray.zeros( + messages.irreps, node_feats.shape[:1], messages.dtype + ) node_feats = zeros.at[receivers].add(messages) # [n_nodes, irreps] node_feats = node_feats / jnp.sqrt(self.avg_num_neighbors) diff --git a/e3nn_jax/experimental/point_convolution_test.py b/e3nn_jax/experimental/point_convolution_test.py index 7602ae1a..c092cba1 100644 --- a/e3nn_jax/experimental/point_convolution_test.py +++ b/e3nn_jax/experimental/point_convolution_test.py @@ -3,7 +3,10 @@ import jax.numpy as jnp import e3nn_jax as e3nn -from e3nn_jax.experimental.point_convolution import MessagePassingConvolutionHaiku, radial_basis +from e3nn_jax.experimental.point_convolution import ( + MessagePassingConvolutionHaiku, + radial_basis, +) from e3nn_jax.util import assert_equivariant, assert_output_dtype_matches_input_dtype diff --git a/e3nn_jax/experimental/transformer.py b/e3nn_jax/experimental/transformer.py index 786acd05..d649e371 100644 --- a/e3nn_jax/experimental/transformer.py +++ b/e3nn_jax/experimental/transformer.py @@ -48,9 +48,16 @@ def __call__( """ def f(x, y, filter_ir_out=None, name=None): - out1 = e3nn.concatenate([x, e3nn.tensor_product(x, y.filter(drop="0e"))]).regroup().filter(keep=filter_ir_out) + out1 = ( + e3nn.concatenate([x, e3nn.tensor_product(x, y.filter(drop="0e"))]) + .regroup() + .filter(keep=filter_ir_out) + ) out2 = e3nn.haiku.MultiLayerPerceptron( - self.list_neurons + [out1.irreps.num_irreps], self.act, output_activation=False, name=name + self.list_neurons + [out1.irreps.num_irreps], + self.act, + output_activation=False, + name=name, )(y.filter(keep="0e")) return out1 * out2 @@ -59,15 +66,25 @@ def f(x, y, filter_ir_out=None, name=None): e3nn.tensor_product(node_feat[edge_dst], edge_key, filter_ir_out="0e") ).array # [E, H] node_logit_max = _index_max(edge_dst, edge_logit, node_feat.shape[0]) # [N, H] - exp = edge_weight_cutoff[:, None] * jnp.exp(edge_logit - node_logit_max[edge_dst]) # [E, H] - z = e3nn.scatter_sum(exp, dst=edge_dst, output_size=node_feat.shape[0]) # [N, H] + exp = edge_weight_cutoff[:, None] * jnp.exp( + edge_logit - node_logit_max[edge_dst] + ) # [E, H] + z = e3nn.scatter_sum( + exp, dst=edge_dst, output_size=node_feat.shape[0] + ) # [N, H] z = jnp.where(z == 0.0, 1.0, z) alpha = exp / z[edge_dst] # [E, H] - edge_v = f(node_feat[edge_src], edge_attr, self.irreps_node_output, "mlp_val") # [E, D] + edge_v = f( + node_feat[edge_src], edge_attr, self.irreps_node_output, "mlp_val" + ) # [E, D] edge_v = edge_v.mul_to_axis(self.num_heads) # [E, H, D] edge_v = edge_v * jnp.sqrt(jax.nn.relu(alpha))[:, :, None] # [E, H, D] edge_v = edge_v.axis_to_mul() # [E, D] - node_out = e3nn.scatter_sum(edge_v, dst=edge_dst, output_size=node_feat.shape[0]) # [N, D] - return e3nn.haiku.Linear(self.irreps_node_output, name="linear_out")(node_out) # [N, D] + node_out = e3nn.scatter_sum( + edge_v, dst=edge_dst, output_size=node_feat.shape[0] + ) # [N, D] + return e3nn.haiku.Linear(self.irreps_node_output, name="linear_out")( + node_out + ) # [N, D] diff --git a/e3nn_jax/experimental/transformer_test.py b/e3nn_jax/experimental/transformer_test.py index ed3ed145..e78cd709 100644 --- a/e3nn_jax/experimental/transformer_test.py +++ b/e3nn_jax/experimental/transformer_test.py @@ -14,7 +14,14 @@ def model(pos, src, dst, node_feat): edge_weight_cutoff = e3nn.sus(3.0 * (2.0 - edge_distance)) edge_attr = e3nn.concatenate( [ - e3nn.soft_one_hot_linspace(edge_distance, start=0.0, end=2.0, number=5, basis="smooth_finite", cutoff=True), + e3nn.soft_one_hot_linspace( + edge_distance, + start=0.0, + end=2.0, + number=5, + basis="smooth_finite", + cutoff=True, + ), e3nn.spherical_harmonics("1e + 2e", pos[dst] - pos[src], True), ] ) @@ -43,5 +50,8 @@ def model(pos, src, dst, node_feat): apply(w, pos, src, dst, node_feat) assert_equivariant( - lambda pos, node_feat: apply(w, pos, src, dst, node_feat), jax.random.PRNGKey(0), args_in=[pos, node_feat], atol=1e-4 + lambda pos, node_feat: apply(w, pos, src, dst, node_feat), + jax.random.PRNGKey(0), + args_in=[pos, node_feat], + atol=1e-4, ) diff --git a/e3nn_jax/experimental/voxel_convolution.py b/e3nn_jax/experimental/voxel_convolution.py index e787186b..72dfacf2 100644 --- a/e3nn_jax/experimental/voxel_convolution.py +++ b/e3nn_jax/experimental/voxel_convolution.py @@ -61,11 +61,18 @@ def kernel( """ def _get_params(name: str, shape: Tuple[int, ...], weight_std: float): - return hk.get_parameter(name, shape=shape, init=hk.initializers.RandomNormal(weight_std), dtype=dtype) + return hk.get_parameter( + name, + shape=shape, + init=hk.initializers.RandomNormal(weight_std), + dtype=dtype, + ) return _kernel(self, irreps_in, irreps_out, steps, _get_params, dtype) - def __call__(self, input: e3nn.IrrepsArray, steps: Optional[jnp.ndarray] = None) -> e3nn.IrrepsArray: + def __call__( + self, input: e3nn.IrrepsArray, steps: Optional[jnp.ndarray] = None + ) -> e3nn.IrrepsArray: r"""Evaluate the convolution. Args: @@ -95,12 +102,16 @@ def kernel( dtype: jnp.dtype = jnp.float32, ) -> jnp.ndarray: def _get_params(name: str, shape: Tuple[int, ...], weight_std: float): - return self.param(name, flax.linen.initializers.normal(stddev=weight_std), shape, dtype) + return self.param( + name, flax.linen.initializers.normal(stddev=weight_std), shape, dtype + ) return _kernel(self, irreps_in, irreps_out, steps, _get_params, dtype) @flax.linen.compact - def __call__(self, input: e3nn.IrrepsArray, steps: Optional[jnp.ndarray] = None) -> e3nn.IrrepsArray: + def __call__( + self, input: e3nn.IrrepsArray, steps: Optional[jnp.ndarray] = None + ) -> e3nn.IrrepsArray: return _call(self, input, steps) @@ -122,8 +133,16 @@ def _tp_weight( weight_std: float, get_parameter, ) -> jnp.ndarray: - number = self.num_radial_basis if isinstance(self.num_radial_basis, int) else self.num_radial_basis[ir_sh.l] - start = self.relative_starts if isinstance(self.relative_starts, (float, int)) else self.relative_starts[ir_sh.l] + number = ( + self.num_radial_basis + if isinstance(self.num_radial_basis, int) + else self.num_radial_basis[ir_sh.l] + ) + start = ( + self.relative_starts + if isinstance(self.relative_starts, (float, int)) + else self.relative_starts[ir_sh.l] + ) embedding = e3nn.soft_one_hot_linspace( jnp.linalg.norm(lattice, ord=2, axis=-1), @@ -135,7 +154,11 @@ def _tp_weight( end_zero=True, ) # [x, y, z, number] - w = get_parameter(f"w[{i_in},{i_sh},{i_out}] {mul_ir_in},{ir_sh},{mul_ir_out}", (number,) + path_shape, weight_std) + w = get_parameter( + f"w[{i_in},{i_sh},{i_out}] {mul_ir_in},{ir_sh},{mul_ir_out}", + (number,) + path_shape, + weight_std, + ) return jnp.einsum("xyzk,k...->xyz...", embedding, w) / ( lattice.shape[0] * lattice.shape[1] * lattice.shape[2] @@ -168,7 +191,9 @@ def _kernel( lattice = lattice.astype(dtype) # convolution kernel - tp = e3nn.FunctionalFullyConnectedTensorProduct(irreps_in, self.irreps_sh, irreps_out) + tp = e3nn.FunctionalFullyConnectedTensorProduct( + irreps_in, self.irreps_sh, irreps_out + ) ws = [ _tp_weight( @@ -187,7 +212,9 @@ def _kernel( for i in tp.instructions ] - sh = e3nn.spherical_harmonics(irreps_out=self.irreps_sh, input=lattice, normalize=True) # [x, y, z, irreps_sh.dim] + sh = e3nn.spherical_harmonics( + irreps_out=self.irreps_sh, input=lattice, normalize=True + ) # [x, y, z, irreps_sh.dim] tp_right = tp.right for _ in range(3): @@ -210,7 +237,9 @@ def _kernel( def _call( - self: Union[ConvolutionHaiku, ConvolutionFlax], input: e3nn.IrrepsArray, steps: Optional[jnp.ndarray] = None + self: Union[ConvolutionHaiku, ConvolutionFlax], + input: e3nn.IrrepsArray, + steps: Optional[jnp.ndarray] = None, ) -> e3nn.IrrepsArray: if not isinstance(input, e3nn.IrrepsArray): raise ValueError("Convolution: input should be of type IrrepsArray") @@ -221,7 +250,11 @@ def _call( [ (mul, ir) for (mul, ir) in e3nn.Irreps(self.irreps_out) - if any(ir in ir_in * ir_sh for _, ir_in in input.irreps for _, ir_sh in e3nn.Irreps(self.irreps_sh)) + if any( + ir in ir_in * ir_sh + for _, ir_in in input.irreps + for _, ir_sh in e3nn.Irreps(self.irreps_sh) + ) ] ) @@ -245,6 +278,8 @@ def _call( i += 1 else: list.append(None) - output = e3nn.IrrepsArray.from_list(self.irreps_out, list, output.shape[:-1], output.dtype) + output = e3nn.IrrepsArray.from_list( + self.irreps_out, list, output.shape[:-1], output.dtype + ) return output diff --git a/e3nn_jax/experimental/voxel_convolution_test.py b/e3nn_jax/experimental/voxel_convolution_test.py index e46803bc..9c73165e 100644 --- a/e3nn_jax/experimental/voxel_convolution_test.py +++ b/e3nn_jax/experimental/voxel_convolution_test.py @@ -25,7 +25,12 @@ def test_convolution(keys): f = jax.jit(c.apply) x0 = e3nn.normal(irreps_in, next(keys), (3, 8, 8, 8)) - x0 = jax.tree_util.tree_map(lambda x: jnp.pad(x, ((0, 0), (4, 4), (4, 4), (4, 4)) + ((0, 0),) * (x.ndim - 4)), x0) + x0 = jax.tree_util.tree_map( + lambda x: jnp.pad( + x, ((0, 0), (4, 4), (4, 4), (4, 4)) + ((0, 0),) * (x.ndim - 4) + ), + x0, + ) w = c.init(next(keys), x0, jnp.array([1.0, 1.0, 1.0])) @@ -60,7 +65,12 @@ def test_convolution_defaults(keys): f = jax.jit(c.apply) x0 = e3nn.normal(irreps_in, next(keys), (3, 8, 8, 8)) - x0 = jax.tree_util.tree_map(lambda x: jnp.pad(x, ((0, 0), (4, 4), (4, 4), (4, 4)) + ((0, 0),) * (x.ndim - 4)), x0) + x0 = jax.tree_util.tree_map( + lambda x: jnp.pad( + x, ((0, 0), (4, 4), (4, 4), (4, 4)) + ((0, 0),) * (x.ndim - 4) + ), + x0, + ) w = c.init(next(keys), x0) y0 = f(w, x0) diff --git a/e3nn_jax/experimental/voxel_pooling.py b/e3nn_jax/experimental/voxel_pooling.py index 5d858395..d3173887 100644 --- a/e3nn_jax/experimental/voxel_pooling.py +++ b/e3nn_jax/experimental/voxel_pooling.py @@ -5,7 +5,9 @@ import jax.numpy as jnp -def interpolate_trilinear(input: jnp.ndarray, x: float, y: float, z: float) -> jnp.ndarray: +def interpolate_trilinear( + input: jnp.ndarray, x: float, y: float, z: float +) -> jnp.ndarray: r"""Interpolate voxels in coordinate (x, y, z). Args: @@ -54,7 +56,9 @@ def zclip(z): return wa * Ia + wb * Ib + wc * Ic + wd * Id + we * Ie + wf * If + wg * Ig + wh * Ih -def interpolate_nearest(input: jnp.ndarray, x: float, y: float, z: float) -> jnp.ndarray: +def interpolate_nearest( + input: jnp.ndarray, x: float, y: float, z: float +) -> jnp.ndarray: r"""Interpolate voxels in coordinate (x, y, z). Args: @@ -106,7 +110,9 @@ def f(n_src, n_dst): if interpolation == "nearest": interp = interpolate_nearest - output = jax.vmap(interp, (None, 0, 0, 0), -1)(input, xg.flatten(), yg.flatten(), zg.flatten()) + output = jax.vmap(interp, (None, 0, 0, 0), -1)( + input, xg.flatten(), yg.flatten(), zg.flatten() + ) output = output.reshape(*input.shape[:-3], *output_size) return output @@ -136,7 +142,11 @@ def zoom( if isinstance(resize_rate, (float, int)): resize_rate = (resize_rate,) * 3 - output_size = (round(nx * resize_rate[0]), round(ny * resize_rate[1]), round(nz * resize_rate[2])) + output_size = ( + round(nx * resize_rate[0]), + round(ny * resize_rate[1]), + round(nz * resize_rate[2]), + ) assert isinstance(output_size, tuple) diff --git a/e3nn_jax/util.py b/e3nn_jax/util.py index edb12886..c8e77b1a 100644 --- a/e3nn_jax/util.py +++ b/e3nn_jax/util.py @@ -1,3 +1,11 @@ -from e3nn_jax._src.util.test import assert_equivariant, equivariance_test, assert_output_dtype_matches_input_dtype +from e3nn_jax._src.util.test import ( + assert_equivariant, + equivariance_test, + assert_output_dtype_matches_input_dtype, +) -__all__ = ["assert_equivariant", "equivariance_test", "assert_output_dtype_matches_input_dtype"] +__all__ = [ + "assert_equivariant", + "equivariance_test", + "assert_output_dtype_matches_input_dtype", +] diff --git a/pyproject.toml b/pyproject.toml index 6f115045..37f0a801 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,5 @@ [tool.black] -line-length = 127 -target-version = ['py37'] +target-version = ['py311'] include = '\.pyi?$' exclude = ''' /(