Skip to content

Commit

Permalink
Non-commutative inner product support (pygae 426)
Browse files Browse the repository at this point in the history
Add separate `MultiVector.__ror__` (based on `__or__`), to calculate
inner product from swapped operands. This is NOT commutative operation.

Also add swapped operand operation test to improve pytest coverage.
  • Loading branch information
trundev committed Nov 26, 2024
1 parent db788e0 commit 6a0a06c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
16 changes: 15 additions & 1 deletion clifford/_multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,21 @@ def __or__(self, other) -> 'MultiVector':

return self._newMV(newValue)

__ror__ = __or__
def __ror__(self, other) -> 'MultiVector':
r"""Right-hand inner product, :math:`N \cdot M` """

other, mv = self._checkOther(other)

if mv:
newValue = self.layout.imt_func(other.value, self.value)
else:
if isinstance(other, np.ndarray):
obj = self.__array__()
return other|obj
# l * M = M * l = 0 for scalar l
return self._newMV(dtype=np.result_type(self.value.dtype, other))

return self._newMV(newValue)

def __add__(self, other) -> 'MultiVector':
""" ``self + other``, addition """
Expand Down
18 changes: 18 additions & 0 deletions clifford/test/test_clifford.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,24 @@ def test_right_multiplication_matrix(self, algebra, rng): # noqa: F811
res2 = layout.MultiVector(value=b_right@a.value)
np.testing.assert_almost_equal(res.value, res2.value)

@pytest.mark.parametrize('func', [
operator.add,
operator.sub,
operator.mul,
operator.xor, # outer product
operator.or_, # inner product
])
def test_swapped_operands(self, algebra, rng, func): # noqa: F811
layout = algebra
for _ in range(10):
mv = layout.randomMV(rng=rng)
mv2 = layout.randomMV(rng=rng)
# Convert first operand to MVArray. This provokes use of operation with
# swapped operands: MultiVector.__rmul__, __ror__, etc.
ma = clifford.MVArray(mv)
np.testing.assert_equal(func(ma, mv2), func(mv, mv2))
np.testing.assert_equal(func(mv2, ma), func(mv2, mv))


class TestPrettyRepr:
""" Test ipython pretty printing, with tidy line wrapping """
Expand Down

0 comments on commit 6a0a06c

Please sign in to comment.