Skip to content

Commit

Permalink
added wigner_d
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Mar 27, 2024
1 parent f103d88 commit 0ee1f01
Show file tree
Hide file tree
Showing 7 changed files with 1,994 additions and 0 deletions.
1 change: 1 addition & 0 deletions BasisLib/so3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .common import _cartesian_permutation_for_degree
from .normalization import normalization_constant
from .rotations import random_rotation
from .rotations import wigner_d
1,608 changes: 1,608 additions & 0 deletions BasisLib/so3/_wigner_d_lut.npz

Large diffs are not rendered by default.

40 changes: 40 additions & 0 deletions BasisLib/so3/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,22 @@ def _cartesian_permutation(
return p


def _cartesian_permutation_wigner_d_entries(
max_degree: int,
) -> np.array:
"""Generates a permutation to Cartesian order for Wigner-D matrix entries."""
_check_degree_is_positive_or_zero(max_degree)
permutations = []
offset = 0
for l in range(max_degree + 1):
pvec = _cartesian_permutation_for_degree(l)
num = pvec.size
pmat = np.arange(num * num, dtype=pvec.dtype).reshape(num, num)
pmat = np.reshape(pmat[pvec, :][:, pvec], -1)
permutations.append(pmat + offset)
offset += pmat.size
return np.concatenate(permutations)

def _total_number_of_spherical_harmonics(max_degree: int) -> int:
"""Calculates total number of spherical harmonics."""
max_degree_plus_one = max_degree + 1
Expand All @@ -46,6 +62,23 @@ def _total_number_of_cartesian_monomials(max_degree: int) -> int:
"""Calculates total number of Cartesian monomials."""
return ((max_degree + 1) * (max_degree + 2) * (max_degree + 3)) // 6

def _total_number_of_rotation_matrix_monomials(max_degree: int) -> int:
"""Calculates total number of monomials of 9 variables up to max_degree."""
return ((max_degree + 1) * math.comb(max_degree + 9, max_degree + 1)) // 9

def _number_of_rotation_matrix_monomials_of_degree(degree: int) -> int:
"""Calculates number of monomials of 9 variables of a given degree."""
return math.comb(degree + 8, degree)


def _number_of_wigner_d_entries_of_degree(degree: int) -> int:
"""Calculates number of Wigner-D matrix entries of a given degree."""
num = 2 * degree + 1
return num * num

def _total_number_of_wigner_d_entries(max_degree: int) -> int:
"""Calculates total number of Wigner-D matrix entries."""
return ((max_degree + 1) * (2 * max_degree + 1) * (2 * max_degree + 3)) // 3

def _partitions(n: int, k: int, l: int = 0) -> Iterator[tuple[int, ...]]:
"""Yields all k-tuples of integers >= l that sum to n."""
Expand Down Expand Up @@ -73,6 +106,13 @@ def _monomial_powers_of_degree(degree: int) -> Iterator[tuple[int, int, int]]:
yield multicombination


def _rotation_matrix_powers_of_degree(
degree: int,
) -> Iterator[tuple[int, int, int, int, int, int, int, int, int]]:
"""Yields all power combinations of 9 variable monomials with given degree."""
for multicombination in _multicombinations(n=degree, k=9):
yield multicombination

def _integer_powers(x: np.array, max_degree: int) -> np.array:
"""Calculates all integer powers up to max_degree of x along axis -2."""
return np.cumprod(
Expand Down
90 changes: 90 additions & 0 deletions BasisLib/so3/rotations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
import numpy as np
from BasisLib.config import Config
from .common import _integer_powers
from .common import _cartesian_permutation_wigner_d_entries
from .wigner_d_lookup_lut import _generate_wigner_d_lookup_table

def _check_rotation_matrix_shape(rot: np.array) -> None:
"""Helper function to check the shape of a rotation matrix.
Args:
rot: Array that should be checked for the correct shape.
Raises:
ValueError: If the shape is invalid for a rotation matrix.
"""
if rot.shape[-2:] != (3, 3):
raise (
ValueError(
'rotation matrices must have shape (..., 3, 3), received '
f'shape {rot.shape}'
)
)

def random_rotation(
perturbation: float = 1.0,
Expand Down Expand Up @@ -64,3 +85,72 @@ def random_rotation(
row3 = np.stack((2 * (ik - jr), 2 * (jk + ir), 1 - 2 * (i2 + j2)), axis=-1)
rot = np.squeeze(np.stack((row1, row2, row3), axis=-1))
return rot

def wigner_d(
rot: np.array,
max_degree: int,
cartesian_order: bool = Config.cartesian_order,
) -> np.array:
r"""Wigner-D matrix corresponding to a given :math:`3\times3` rotation matrix.
Transform :math:`3\times3` rotation matrices to
:math:`(\mathrm{max\_degree}+1)^2 \times (\mathrm{max\_degree}+1)^2` Wigner-D
matrices that can be used to rotate irreducible representations of
:math:`\mathrm{SO}(3)`.
Args:
rot: An Array of shape :math:`(\dots, 3, 3)` representing :math:`3\times3`
rotation matrices.
max_degree: Maximum degree of the irreducible representations.
cartesian_order: If True, Cartesian order is assumed.
Returns:
An Array of shape
:math:`(\dots, (\mathrm{max\_degree}+1)^2,(\mathrm{max\_degree}+1)^2)`
representing Wigner-D matrices corresponding to the input rotations.
Raises:
ValueError: If ``rot`` does not have shape `(..., 3, 3)`.
"""
_check_rotation_matrix_shape(rot) # Raise if shape is not (..., 3, 3).

# Load/Generate lookup table and convert to jax arrays.
lookup_table = _generate_wigner_d_lookup_table(max_degree)
cm = lookup_table['cm']
ls = lookup_table['ls']
# Optionally reorder to Cartesian order.
if cartesian_order:
cm = cm[:, _cartesian_permutation_wigner_d_entries(max_degree)]

# Calculate all relevant monomials of the rotation matrix entries.
# Note: This is done via integer powers and indexing on purpose! Using
# jnp.power or the "**"-operator for this operation leads to NaNs in the
# gradients for some inputs (jnp.power is not NaN-safe).
rot_powers = _integer_powers(rot.reshape(*rot.shape[:-2], 1, -1), max_degree)
monomials = (
rot_powers[..., 0][..., ls[:, 0]] # R_00**l_00.
* rot_powers[..., 1][..., ls[:, 1]] # R_01**l_01.
* rot_powers[..., 2][..., ls[:, 2]] # R_02**l_02.
* rot_powers[..., 3][..., ls[:, 3]] # R_10**l_10.
* rot_powers[..., 4][..., ls[:, 4]] # R_11**l_11.
* rot_powers[..., 5][..., ls[:, 5]] # R_12**l_12.
* rot_powers[..., 6][..., ls[:, 6]] # R_20**l_20.
* rot_powers[..., 7][..., ls[:, 7]] # R_21**l_21.
* rot_powers[..., 8][..., ls[:, 8]] # R_22**l_22.
)

# Entries of the Wigner-D matrix are linear combinations of the monomials.
dmat_entries = np.matmul(monomials, cm)

# Assemble Wigner-D matrix.
dmat = np.zeros_like( # Initialize Wigner-D matrix to zeros.
rot, shape=(*rot.shape[:-2], (max_degree + 1) ** 2, (max_degree + 1) ** 2)
)
for l in range(max_degree + 1): # Set entries of non-zero blocks on diagonal.
i = l**2 # Start index Wigner-D slice.
j = (l + 1) ** 2 # Stop index Wigner-D slice.
b = ((l + 1) * (2 * l + 1) * (2 * l + 3)) // 3 # Start index entries.
a = b - (2 * l + 1) ** 2 # Stop index entries.
num = 2 * l + 1 # Matrix block has shape (..., 2*l+1, 2*l+1).
dmat[..., i:j, i:j] = dmat_entries[..., a:b].reshape((*rot.shape[:-2], num, num))
return dmat
231 changes: 231 additions & 0 deletions BasisLib/so3/wigner_d_lookup_lut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""Code for generating Wigner-D matrix lookup tables."""

import argparse
import multiprocessing as mp
from typing import Any, Dict, IO, Tuple, TypedDict, cast
from absl import logging
from etils import epath
import numpy as np
import sympy as sp

from BasisLib.config import Config
from .common import _check_degree_is_positive_or_zero
from .common import _number_of_rotation_matrix_monomials_of_degree
from .common import _number_of_wigner_d_entries_of_degree
from .common import _rotation_matrix_powers_of_degree
from .common import _total_number_of_rotation_matrix_monomials
from .common import _total_number_of_wigner_d_entries
from .lookup_table_generation_utility import _load_lookup_table_from_disk
from .lookup_table_generation_utility import _print_cache_usage_information
from .lookup_table_generation_utility import _save_lookup_table_to_disk
from .symbolic import _polynomial_dot_product
from .symbolic import _rotate_xyz_polynomial
from .symbolic import _spherical_harmonics
# pylint: enable=g-importing-member



_wigner_d_lut_name = 'Wigner-D matrix'
_wigner_d_lut_path = '_wigner_d_lut.npz'


class WignerDLookupTable(TypedDict):
"""A lookup table with coefficients for computing Wigner-D matrices.
Attributes:
max_degree: Maximum degree of Wigner-D matrices for which coefficients are
stored in the table.
ls: Vector containing the powers for the rotation matrix entry monomials.
cm: Coefficient matrix for computing the Wigner-D matrix entries by matrix
multiplication with a vector containing rotation matrix entry monomials.
"""

max_degree: int
ls: np.array
cm: np.array


def _generate_wigner_d_lookup_table(
max_degree: int, num_processes: int = 1
) -> WignerDLookupTable:
"""Generates a table with Wigner-D matrix coefficients."""

_check_degree_is_positive_or_zero(max_degree)

def _init_empty_lookup_table(
max_degree: int,
) -> WignerDLookupTable:
"""Initializes a lookup table of the correct size containing only zeros."""
num_rot = _total_number_of_rotation_matrix_monomials(max_degree)
num_wig = _total_number_of_wigner_d_entries(max_degree)
return WignerDLookupTable(
max_degree=max_degree,
cm=np.zeros((num_rot, num_wig), dtype=np.float64),
ls=np.zeros((num_rot, 9), dtype=np.int64),
)

def _load_from_cache(
f: IO[bytes],
) -> Tuple[int, WignerDLookupTable]:
"""Loads a (compressed) lookup table from the cache and uncompresses it."""
lookup_table = _init_empty_lookup_table(max_degree)
with np.load(f) as cache:
cached_max_degree = cache['max_degree']
if cached_max_degree < 0: # Lookup table contains nothing.
return -1, lookup_table
irot = 0
iwig = 0
for l in range(min(cached_max_degree, max_degree) + 1):
nrot = _number_of_rotation_matrix_monomials_of_degree(l)
nwig = _number_of_wigner_d_entries_of_degree(l)
cm_for_l = np.zeros((nrot, nwig), dtype=np.float64)
cm_for_l[cache[f'i0{l}'], cache[f'i1{l}']] = cache[f'cm{l}']
lookup_table['cm'][irot : irot + nrot, iwig : iwig + nwig] = cm_for_l
lookup_table['ls'][irot : irot + nrot] = cache[f'ls{l}']
irot += nrot
iwig += nwig
return cached_max_degree, lookup_table

def _compress(lookup_table: WignerDLookupTable) -> Dict[str, Any]:
"""Compress a lookup table to store only non-zero entries in lists."""
cache = {'max_degree': lookup_table['max_degree']}
irot = 0
iwig = 0
for l in range(cache['max_degree'] + 1):
nrot = _number_of_rotation_matrix_monomials_of_degree(l)
nwig = _number_of_wigner_d_entries_of_degree(l)
cm = lookup_table['cm'][irot : irot + nrot, iwig : iwig + nwig]
i0, i1 = np.nonzero(cm)
cache[f'cm{l}'] = cm[i0, i1]
cache[f'i0{l}'] = i0
cache[f'i1{l}'] = i1
cache[f'ls{l}'] = lookup_table['ls'][irot : irot + nrot]
irot += nrot
iwig += nwig
return cache

# Load cache stored on disk.
cached_max_degree, lookup_table = _load_lookup_table_from_disk(
max_degree=max_degree,
lookup_table_name=_wigner_d_lut_name,
config_cache_path=Config.wigner_d_cache,
package_cache_path=_wigner_d_lut_path,
load_from_cache=_load_from_cache,
init_empty_lookup_table=_init_empty_lookup_table,
)
lookup_table = cast(WignerDLookupTable, lookup_table)

# Return immediately if all values are contained.
if max_degree <= cached_max_degree:
return lookup_table

lstart = cached_max_degree + 1 # Start generation from degree=lstart.

# Inform user that it might be preferable to cache the results.
_print_cache_usage_information(
lstart=lstart,
max_degree=max_degree,
config_cache_path=Config.wigner_d_cache,
set_cache_method_name='set_wigner_d_cache',
lookup_table_name=_wigner_d_lut_name,
pregeneration_name=__name__,
)

# Calculate all combinations of degrees and orders.
degrees_and_orders = []
for l in range(lstart, max_degree + 1):
for m in range(-l, l + 1):
degrees_and_orders.append((l, m))

def _construct_polynomial_pairs(sph_polynomials, rot_polynomials):
"""Helper function to create pairs of (un)rotated polynomials."""
poly_pairs = []
for l in range(lstart, max_degree + 1):
offset = l**2 + l - lstart**2
for mrot in range(-l, l + 1):
irot = offset + mrot
for msph in range(-l, l + 1):
isph = offset + msph
poly_pairs.append((sph_polynomials[isph], rot_polynomials[irot]))
return poly_pairs

# Calculate Wigner-D entries.
if num_processes > 1: # Use multiple processes in parallel.
with mp.Pool(num_processes) as pool:
sph_polynomials = pool.starmap(_spherical_harmonics, degrees_and_orders)
rot_polynomials = pool.map(_rotate_xyz_polynomial, sph_polynomials)
poly_pairs = _construct_polynomial_pairs(sph_polynomials, rot_polynomials)
wigner_d_entries = pool.starmap(_polynomial_dot_product, poly_pairs)
else: # Sequential computation.
sph_polynomials = [
_spherical_harmonics(*args) for args in degrees_and_orders
]
rot_polynomials = [_rotate_xyz_polynomial(poly) for poly in sph_polynomials]
poly_pairs = _construct_polynomial_pairs(sph_polynomials, rot_polynomials)
wigner_d_entries = [_polynomial_dot_product(*args) for args in poly_pairs]

# Create index mapping for the rotation matrix monomials and store
# corresponding powers.
if lstart > 0:
idx = _total_number_of_rotation_matrix_monomials(lstart - 1)
else:
idx = 0
monomial_map = {}
for l in range(lstart, max_degree + 1):
for powers in _rotation_matrix_powers_of_degree(l):
monomial_map[powers] = idx
lookup_table['ls'][idx] = np.asarray(powers, dtype=np.int64)
idx += 1

# Store results in lookup table.
if lstart > 0:
offset = _total_number_of_wigner_d_entries(lstart - 1)
else:
offset = 0
for i, polynomial in enumerate(wigner_d_entries):
iwig = i + offset
for monomial, coefficient in polynomial.terms():
irot = monomial_map[monomial]
lookup_table['cm'][irot, iwig] = sp.simplify(coefficient)

# Save lookup table to disk cache.
_save_lookup_table_to_disk(
lookup_table=_compress(lookup_table),
lookup_table_name=_wigner_d_lut_name,
config_cache_path=Config.wigner_d_cache,
)

return lookup_table


if __name__ == '__main__':
mp.freeze_support() # Might be necessary for Windows support.
parser = argparse.ArgumentParser(
description='Generates lookup tables for computing Wigner-D matrices.'
)
parser.add_argument(
'--max_degree',
required=True,
type=int,
help='Maximum degree of the Wigner-D matrices.',
)
parser.add_argument(
'--path',
required=False,
type=str,
default=epath.Path(__file__).parent / _wigner_d_lut_path,
help='Path to .npz file for storing the lookup table.',
)
parser.add_argument(
'--num_processes',
required=False,
type=int,
default=mp.cpu_count(),
help='Number of processes for parallel computation.',
)
args = parser.parse_args()
logging.set_verbosity(logging.INFO)
Config.set_wigner_d_cache(args.path)
_generate_wigner_d_lookup_table(
max_degree=args.max_degree, num_processes=args.num_processes
)
File renamed without changes.
Loading

0 comments on commit 0ee1f01

Please sign in to comment.