From 2b563f1f8ac5db8b255fd4a6ef6b187cdb948a35 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Wed, 7 Feb 2024 15:18:44 -0700 Subject: [PATCH] Add some tests --- scico/test/linop/test_binop.py | 48 ++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 scico/test/linop/test_binop.py diff --git a/scico/test/linop/test_binop.py b/scico/test/linop/test_binop.py new file mode 100644 index 000000000..7a4fb0d23 --- /dev/null +++ b/scico/test/linop/test_binop.py @@ -0,0 +1,48 @@ +import operator as op + +import pytest + +import scico.numpy as snp +from scico import linop + + +class TestBinaryOp: + def setup_method(self, method): + self.input_shape = (5,) + self.input_dtype = snp.float32 + + @pytest.mark.parametrize("operator", [op.add, op.sub]) + def test_case1(self, operator): + A = linop.Convolve( + snp.ones((2,)), input_shape=self.input_shape, input_dtype=self.input_dtype, mode="same" + ) + B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype) + + assert type(A + B) == linop.LinearOperator + assert type(B + A) == linop.LinearOperator + assert type(2.0 * A + 3.0 * B) == linop.LinearOperator + assert type(2.0 * B + 3.0 * A) == linop.LinearOperator + + @pytest.mark.parametrize("operator", [op.add, op.sub]) + def test_case2(self, operator): + A = linop.SingleAxisFiniteDifference( + input_shape=self.input_shape, input_dtype=self.input_dtype, circular=True + ) + B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype) + + assert type(A + B) == linop.LinearOperator + assert type(B + A) == linop.LinearOperator + assert type(2.0 * A + 3.0 * B) == linop.LinearOperator + assert type(2.0 * B + 3.0 * A) == linop.LinearOperator + + @pytest.mark.parametrize("operator", [op.add, op.sub]) + def test_case3(self, operator): + A = linop.ScaledIdentity( + scalar=0.5, input_shape=self.input_shape, input_dtype=self.input_dtype + ) + B = linop.Identity(input_shape=self.input_shape, input_dtype=self.input_dtype) + + assert type(A + B) == linop.ScaledIdentity + assert type(B + A) == linop.ScaledIdentity + assert type(2.0 * A + 3.0 * B) == linop.ScaledIdentity + assert type(2.0 * B + 3.0 * A) == linop.ScaledIdentity