From 3c70a07e4dcd41741e1502e21ff506748ac9fd9a Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Tue, 31 Mar 2020 14:31:00 +0100 Subject: [PATCH] Split up table computation to compute each table separately (#297) This means that subclasses can override the `gmt` attribute, and will automatically get correct `imt`, `omt`, and `lcmt` tables. Performance impact seems negligeable. --- clifford/_layout.py | 172 +++++++++++++++++++++++++------------------- docs/changelog.rst | 3 + 2 files changed, 101 insertions(+), 74 deletions(-) diff --git a/clifford/_layout.py b/clifford/_layout.py index e3dae328..bdd1d90b 100644 --- a/clifford/_layout.py +++ b/clifford/_layout.py @@ -29,6 +29,7 @@ class _cached_property: def __init__(self, getter): self.fget = getter self.__name__ = getter.__name__ + self.__doc__ = getter.__doc__ def __get__(self, obj, cls): if obj is None: @@ -73,6 +74,8 @@ def imt_check(grade_v, grade_i, grade_j): """ A check used in imt table generation """ + # A_r . B_s = _|r-s| + # if r, s != 0 return (grade_v == abs(grade_i - grade_j)) and (grade_i != 0) and (grade_j != 0) @@ -81,6 +84,7 @@ def omt_check(grade_v, grade_i, grade_j): """ A check used in omt table generation """ + # A_r ^ B_s = _|r+s| return grade_v == (grade_i + grade_j) @@ -89,75 +93,87 @@ def lcmt_check(grade_v, grade_i, grade_j): """ A check used in lcmt table generation """ + # A_r _| B_s = _(s-r) if s-r >= 0 return grade_v == (grade_j - grade_i) @_numba_utils.njit(parallel=NUMBA_PARALLEL, nogil=True) -def _numba_construct_tables( - index_to_grade, index_to_bitmap, bitmap_to_index, signature +def _numba_construct_gmt( + index_to_bitmap, bitmap_to_index, signature ): - array_length = int(len(index_to_grade) * len(index_to_grade)) - indices = np.zeros((3, array_length), dtype=np.uint64) - k_list = indices[0, :] - l_list = indices[1, :] - m_list = indices[2, :] - - imt_prod_mask = np.zeros(array_length, dtype=np.bool_) - - omt_prod_mask = np.zeros(array_length, dtype=np.bool_) - - lcmt_prod_mask = np.zeros(array_length, dtype=np.bool_) + n = len(index_to_bitmap) + array_length = int(n * n) + coords = np.zeros((3, array_length), dtype=np.uint64) + k_list = coords[0, :] + l_list = coords[1, :] + m_list = coords[2, :] # use as small a type as possible to minimize type promotion mult_table_vals = np.zeros(array_length, dtype=np.int8) - for i, grade_i in enumerate(index_to_grade): + for i in range(n): bitmap_i = index_to_bitmap[i] - for j, grade_j in enumerate(index_to_grade): + for j in range(n): bitmap_j = index_to_bitmap[j] bitmap_v, mul = gmt_element(bitmap_i, bitmap_j, signature) v = bitmap_to_index[bitmap_v] - list_ind = i * len(index_to_grade) + j + list_ind = i * n + j k_list[list_ind] = i l_list[list_ind] = v m_list[list_ind] = j mult_table_vals[list_ind] = mul - grade_v = index_to_grade[v] - - # A_r . B_s = _|r-s| - # if r, s != 0 - imt_prod_mask[list_ind] = imt_check(grade_v, grade_i, grade_j) - - # A_r ^ B_s = _|r+s| - omt_prod_mask[list_ind] = omt_check(grade_v, grade_i, grade_j) - - # A_r _| B_s = _(s-r) if s-r >= 0 - lcmt_prod_mask[list_ind] = lcmt_check(grade_v, grade_i, grade_j) - return indices, mult_table_vals, imt_prod_mask, omt_prod_mask, lcmt_prod_mask + return coords, mult_table_vals -def construct_tables( +def construct_gmt( blade_order: BasisBladeOrder, signature -) -> Tuple[sparse.COO, sparse.COO, sparse.COO, sparse.COO]: +) -> sparse.COO: # wrap the numba one - indices, *arrs = _numba_construct_tables( - blade_order.grades, + coords, mult_table_vals = _numba_construct_gmt( blade_order.index_to_bitmap, blade_order.bitmap_to_index, signature ) dims = len(blade_order.grades) - return tuple( - sparse.COO( - coords=indices, data=arr, shape=(dims, dims, dims), - prune=True - ) - for arr in arrs + return sparse.COO(coords=coords, data=mult_table_vals, shape=(dims, dims, dims)) + + +@_numba_utils.njit(parallel=NUMBA_PARALLEL, nogil=True) +def _numba_construct_graded_mt( + index_to_grade, coords, gmt_vals, check_func +): + n_elems = coords.shape[1] + + mask = np.zeros(n_elems, dtype=np.bool_) + + for ind in range(coords.shape[1]): + k, l, m = coords[:, ind] + + grade_k = index_to_grade[k] + grade_l = index_to_grade[l] + grade_m = index_to_grade[m] + + mask[ind] = check_func(grade_l, grade_k, grade_m) + + return coords[:, mask], gmt_vals[mask] + + +def construct_graded_mt( + blade_order: BasisBladeOrder, gmt: sparse.COO, check_func +) -> sparse.COO: + # wrap the numba one + coords, mult_table_vals = _numba_construct_graded_mt( + blade_order.grades, + gmt.coords, + gmt.data, + check_func ) + dims = len(blade_order.grades) + return sparse.COO(coords=coords, data=mult_table_vals, shape=(dims, dims, dims)) @_utils.set_module('clifford') @@ -261,20 +277,6 @@ class Layout(object): 2**dims names : pretty-printing symbols for the blades - gmt : - multiplication table for geometric product - imt : - multiplication table for inner product - omt : - multiplication table for outer product - lcmt : - multiplication table for the left-contraction - - Notes - ----- - The multiplication tables :math:`M` are tensors of rank 3 such that - :math:`a = b \operatorname{op} c` can be computed as - :math:`a_j = \sum_{i,k} b_i \mathit{M}_{ijk} c_k`. """ # old signature def __init__(self, sig, bladeTupList, firstIdx=1, names=None): @@ -361,8 +363,11 @@ def __init__(self, *args, **kw): "names list of length %i needs to be of length %i" % (len(names), self.gaDims)) - self._genTables() # preload these lazy properties. Not doing this would likely be faster. + self.gmt_func + self.imt_func + self.omt_func + self.lcmt_func self.adjoint_func self.left_complement_func self.right_complement_func @@ -370,6 +375,30 @@ def __init__(self, *args, **kw): self.vee_func self.inv_func + @_cached_property + def gmt(self): + r""" Multiplication table for the geometric product. + + This is a tensor of rank 3 such that + :math:`a = b c` can be computed as + :math:`a_j = \sum_{i,k} b_i \mathit{M}_{ijk} c_k`.""" + return construct_gmt(self._basis_blade_order, self.sig) + + @_cached_property + def omt(self): + """ Multiplication table for the inner product, stored in the same way as :attr:`gmt` """ + return construct_graded_mt(self._basis_blade_order, self.gmt, omt_check) + + @_cached_property + def imt(self): + """ Multiplication table for the outer product, stored in the same way as :attr:`gmt` """ + return construct_graded_mt(self._basis_blade_order, self.gmt, imt_check) + + @_cached_property + def lcmt(self): + """ Multiplication table for the left-contraction, stored in the same way as :attr:`gmt` """ + return construct_graded_mt(self._basis_blade_order, self.gmt, lcmt_check) + @_cached_property def bladeTupList(self): return self._basis_vector_ids.order_as_tuples(self._basis_blade_order) @@ -490,27 +519,6 @@ def parse_multivector(self, mv_string: str) -> MultiVector: from ._parser import parse_multivector return parse_multivector(self, mv_string) - def _genTables(self): - "Generate the multiplication tables." - self.gmt, imt_prod_mask, omt_prod_mask, lcmt_prod_mask = construct_tables( - self._basis_blade_order, - self.sig - ) - self.omt = sparse.where(omt_prod_mask, self.gmt, self.gmt.dtype.type(0)) - self.imt = sparse.where(imt_prod_mask, self.gmt, self.gmt.dtype.type(0)) - self.lcmt = sparse.where(lcmt_prod_mask, self.gmt, self.gmt.dtype.type(0)) - - # This generates the functions that will perform the various products - self.gmt_func = get_mult_function(self.gmt, self.gradeList) - self.imt_func = get_mult_function(self.imt, self.gradeList) - self.omt_func = get_mult_function(self.omt, self.gradeList) - self.lcmt_func = get_mult_function(self.lcmt, self.gradeList) - - # these are probably not useful, but someone might want them - self.imt_prod_mask = imt_prod_mask - self.omt_prod_mask = omt_prod_mask - self.lcmt_prod_mask = lcmt_prod_mask - def gmt_func_generator(self, grades_a=None, grades_b=None, filter_mask=None): return get_mult_function( self.gmt, self.gradeList, @@ -566,6 +574,22 @@ def comp_func(Xval): return Yval return comp_func + @_cached_property + def gmt_func(self): + return get_mult_function(self.gmt, self.gradeList) + + @_cached_property + def imt_func(self): + return get_mult_function(self.imt, self.gradeList) + + @_cached_property + def omt_func(self): + return get_mult_function(self.omt, self.gradeList) + + @_cached_property + def lcmt_func(self): + return get_mult_function(self.lcmt, self.gradeList) + @_cached_property def left_complement_func(self): return self._gen_complement_func(omt=self.omt) diff --git a/docs/changelog.rst b/docs/changelog.rst index 198bb3bd..033739bb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -42,6 +42,9 @@ Compatibility notes numpy array of ``bytes``. The result now matches the construction order, rather than being sorted alphabetically. The order of :meth:`Layout.metric` has been adjusted for consistency. + * The ``imt_prod_mask``, ``omt_prod_mask``, and ``lcmt_prod_mask`` attributes + of :class:`Layout` objects have been removed, as these were an unnecessary + intermediate computation that had no need to be public. Changes in 1.2.x