Skip to content

Commit

Permalink
Make quaternion scalar type immutable
Browse files Browse the repository at this point in the history
  • Loading branch information
kohlerjl committed Dec 17, 2023
1 parent 776ae08 commit 4c3267b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
26 changes: 13 additions & 13 deletions src/numpy_quaternion.c
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,13 @@ QQ_QS_SQ_BINARY_QUATERNION_RETURNER(power)
Py_RETURN_NOTIMPLEMENTED; \
}
#define QQ_QS_SQ_BINARY_QUATERNION_INPLACE(name) QQ_QS_SQ_BINARY_QUATERNION_INPLACE_FULL(name, name)
QQ_QS_SQ_BINARY_QUATERNION_INPLACE(add)
QQ_QS_SQ_BINARY_QUATERNION_INPLACE(subtract)
QQ_QS_SQ_BINARY_QUATERNION_INPLACE(multiply)
QQ_QS_SQ_BINARY_QUATERNION_INPLACE(divide)
/* QQ_QS_SQ_BINARY_QUATERNION_INPLACE(add) */
/* QQ_QS_SQ_BINARY_QUATERNION_INPLACE(subtract) */
/* QQ_QS_SQ_BINARY_QUATERNION_INPLACE(multiply) */
/* QQ_QS_SQ_BINARY_QUATERNION_INPLACE(divide) */
/* QQ_QS_SQ_BINARY_QUATERNION_INPLACE_FULL(true_divide, divide) */
/* QQ_QS_SQ_BINARY_QUATERNION_INPLACE_FULL(floor_divide, divide) */
QQ_QS_SQ_BINARY_QUATERNION_INPLACE(power)
/* QQ_QS_SQ_BINARY_QUATERNION_INPLACE(power) */

static PyObject *
pyquaternion__reduce(PyQuaternion* self)
Expand Down Expand Up @@ -516,7 +516,7 @@ PyMethodDef pyquaternion_methods[] = {
};

static PyObject* pyquaternion_num_power(PyObject* a, PyObject* b, PyObject *c) { (void) c; return pyquaternion_power(a,b); }
static PyObject* pyquaternion_num_inplace_power(PyObject* a, PyObject* b, PyObject *c) { (void) c; return pyquaternion_inplace_power(a,b); }
/* static PyObject* pyquaternion_num_inplace_power(PyObject* a, PyObject* b, PyObject *c) { (void) c; return pyquaternion_inplace_power(a,b); } */
static PyObject* pyquaternion_num_negative(PyObject* a) { return pyquaternion_negative(a,NULL); }
static PyObject* pyquaternion_num_positive(PyObject* a) { return pyquaternion_positive(a,NULL); }
static PyObject* pyquaternion_num_absolute(PyObject* a) { return pyquaternion_absolute(a,NULL); }
Expand Down Expand Up @@ -572,23 +572,23 @@ static PyNumberMethods pyquaternion_as_number = {
pyquaternion_convert_oct, // nb_oct
pyquaternion_convert_hex, // nb_hex
#endif
pyquaternion_inplace_add, // nb_inplace_add
pyquaternion_inplace_subtract, // nb_inplace_subtract
pyquaternion_inplace_multiply, // nb_inplace_multiply
0, // nb_inplace_add
0, // nb_inplace_subtract
0, // nb_inplace_multiply
#if PY_MAJOR_VERSION < 3
pyquaternion_inplace_divide, // nb_inplace_divide
0, // nb_inplace_divide
#endif
0, // nb_inplace_remainder
pyquaternion_num_inplace_power, // nb_inplace_power
0, // nb_inplace_power
0, // nb_inplace_lshift
0, // nb_inplace_rshift
0, // nb_inplace_and
0, // nb_inplace_xor
0, // nb_inplace_or
pyquaternion_divide, // nb_floor_divide
pyquaternion_divide, // nb_true_divide
pyquaternion_inplace_divide, // nb_inplace_floor_divide
pyquaternion_inplace_divide, // nb_inplace_true_divide
0, // nb_inplace_floor_divide
0, // nb_inplace_true_divide
0, // nb_index
#if PY_MAJOR_VERSION >= 3
#if PY_MINOR_VERSION >= 5
Expand Down
38 changes: 38 additions & 0 deletions tests/test_quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,25 @@ def test_quaternion_add(Qs):
assert (s + q == quaternion.quaternion(q.w + s, q.x, q.y, q.z))


def test_quaternion_inplace_add(Qs):
for qo in Qs[Qs_finite]:
for po in Qs[Qs_finite]:
q = qo.copy() # copy to prevent mutating original
p = po.copy() # copy to prevent mutating original
q2 = q
q2 += p
assert q2 == quaternion.quaternion(qo.w + p.w, qo.x + p.x, qo.y + p.y, qo.z + p.z)
# Ensure value of object q has not changed
assert q == qo
for qo in Qs[Qs_nonnan]:
for s in [-3, -2.3, -1.2, -1.0, 0.0, 0, 1.0, 1, 1.2, 2.3, 3]:
q = qo.copy() # Make copy to prevent mutating original
q2 = q
q2 += s
assert (q2 == quaternion.quaternion(qo.w + s, qo.x, qo.y, qo.z))
assert (q == qo)


def test_quaternion_add_ufunc(Qs):
ufunc_binary_utility(Qs[Qs_finite], Qs[Qs_finite], operator.add)

Expand All @@ -790,6 +809,25 @@ def test_quaternion_subtract(Qs):
assert (s - q == quaternion.quaternion(s - q.w, -q.x, -q.y, -q.z))


def test_quaternion_inplace_subtract(Qs):
for qo in Qs[Qs_finite]:
for po in Qs[Qs_finite]:
q = qo.copy() # copy to prevent mutating original
p = po.copy() # copy to prevent mutating original
q2 = q
q2 -= p
assert q2 == quaternion.quaternion(qo.w - p.w, qo.x - p.x, qo.y - p.y, qo.z - p.z)
# Ensure value of object q has not changed
assert q == qo
for qo in Qs[Qs_nonnan]:
for s in [-3, -2.3, -1.2, -1.0, 0.0, 0, 1.0, 1, 1.2, 2.3, 3]:
q = qo.copy() # Make copy to prevent mutating original
q2 = q
q2 -= s
assert (q2 == quaternion.quaternion(qo.w - s, qo.x, qo.y, qo.z))
assert (q == qo)


def test_quaternion_subtract_ufunc(Qs):
ufunc_binary_utility(Qs[Qs_finite], Qs[Qs_finite], operator.sub)

Expand Down

0 comments on commit 4c3267b

Please sign in to comment.