Skip to content

Commit

Permalink
change tests to avoid duplicating code
Browse files Browse the repository at this point in the history
  • Loading branch information
billbrod committed Jan 9, 2024
1 parent e8ce40c commit 63c7ecb
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 80 deletions.
112 changes: 46 additions & 66 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,35 @@ def portilla_simoncelli_scales():
return osf_download('portilla_simoncelli_scales.npz')


def template_test_cuda(model, img):
model.cuda()
model(img)
# make sure it ends on same device it started, since it might be a fixture
model.to(DEVICE)

def template_test_cpu_and_back(model, img):
model.cpu()
model.cuda()
model(img)
# make sure it ends on same device it started, since it might be a fixture
model.to(DEVICE)

def template_test_cuda_and_back(model, img):
model.cuda()
model.cpu()
model(img.cpu())
# make sure it ends on same device it started, since it might be a fixture
img.to(DEVICE)
model.to(DEVICE)

def template_test_cpu(model, img):
model.cpu()
model(img.cpu())
# make sure it ends on same device it started, since it might be a fixture
img.to(DEVICE)
model.to(DEVICE)


class TestNonLinearities(object):
def test_rectangular_to_polar_dict(self, basic_stim):
spc = po.simul.SteerablePyramidFreq(basic_stim.shape[-2:], height=5,
Expand Down Expand Up @@ -133,31 +162,21 @@ def test_match_pyrtools(self, curie_img, n_scales):
@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda(self, einstein_img):
lpyr = po.simul.LaplacianPyramid()
lpyr.cuda()
lpyr(einstein_img)
template_test_cuda(lpyr, einstein_img)

@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cpu_and_back(self, einstein_img):
lpyr = po.simul.LaplacianPyramid()
lpyr.cpu()
lpyr.cuda()
lpyr(einstein_img)
template_test_cpu_and_back(lpyr, einstein_img)

@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda_and_back(self, einstein_img):
lpyr = po.simul.LaplacianPyramid()
lpyr.cuda()
lpyr.cpu()
lpyr(einstein_img.cpu())
# make sure it ends on same device it started, since it's a fixture
einstein_img.to(DEVICE)
template_test_cuda_and_back(lpyr, einstein_img)

def test_cpu(self, einstein_img):
lpyr = po.simul.LaplacianPyramid()
lpyr.cpu()
lpyr(einstein_img.cpu())
# make sure it ends on same device it started, since it's a fixture
einstein_img.to(DEVICE)
template_test_cpu(lpyr, einstein_img)

class TestFrontEnd:

Expand Down Expand Up @@ -199,37 +218,21 @@ def test_frontend_display_filters(self, model):
@pytest.mark.parametrize("model", all_models, indirect=True)
@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda(self, model, einstein_img):
model.cuda()
model(einstein_img)
# make sure it ends on same device it started, since it's a fixture
model.to(DEVICE)
template_test_cuda(model, einstein_img)

@pytest.mark.parametrize("model", all_models, indirect=True)
@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cpu_and_back(self, model, einstein_img):
model.cpu()
model.cuda()
model(einstein_img)
# make sure it ends on same device it started, since it's a fixture
model.to(DEVICE)
template_test_cpu_and_back(model, einstein_img)

@pytest.mark.parametrize("model", all_models, indirect=True)
@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda_and_back(self, model, einstein_img):
model.cuda()
model.cpu()
model(einstein_img.cpu())
# make sure it ends on same device it started, since it's a fixture
model.to(DEVICE)
einstein_img.to(DEVICE)
template_test_cuda_and_back(model, einstein_img)

@pytest.mark.parametrize("model", all_models, indirect=True)
def test_cpu(self, model, einstein_img):
model.cpu()
model(einstein_img.cpu())
# make sure it ends on same device it started, since it's a fixture
model.to(DEVICE)
einstein_img.to(DEVICE)
template_test_cpu(model, einstein_img)

class TestNaive(object):

Expand Down Expand Up @@ -279,37 +282,22 @@ def test_linear(self, basic_stim):
@pytest.mark.parametrize("model", all_models, indirect=True)
@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda(self, model, einstein_img):
model.cuda()
model(einstein_img)
# make sure it ends on same device it started, since it's a fixture
model.to(DEVICE)
template_test_cuda(model, einstein_img)

@pytest.mark.parametrize("model", all_models, indirect=True)
@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cpu_and_back(self, model, einstein_img):
model.cpu()
model.cuda()
model(einstein_img)
# make sure it ends on same device it started, since it's a fixture
model.to(DEVICE)
template_test_cpu_and_back(model, einstein_img)

@pytest.mark.parametrize("model", all_models, indirect=True)
@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda_and_back(self, model, einstein_img):
model.cuda()
model.cpu()
model(einstein_img.cpu())
# make sure it ends on same device it started, since it's a fixture
model.to(DEVICE)
einstein_img.to(DEVICE)
template_test_cuda_and_back(model, einstein_img)

@pytest.mark.parametrize("model", all_models, indirect=True)
def test_cpu(self, model, einstein_img):
model.cpu()
model(einstein_img.cpu())
# make sure it ends on same device it started, since it's a fixture
model.to(DEVICE)
einstein_img.to(DEVICE)
template_test_cpu(model, einstein_img)


class TestPortillaSimoncelli(object):
@pytest.mark.parametrize("n_scales", [1, 2, 3, 4])
Expand Down Expand Up @@ -522,29 +510,21 @@ def test_ps_expand(self, im_shape):
@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda(self, einstein_img):
ps = po.simul.PortillaSimoncelli(einstein_img.shape[-2:])
ps.cuda()
ps(einstein_img)
template_test_cuda(ps, einstein_img)

@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cpu_and_back(self, einstein_img):
ps = po.simul.PortillaSimoncelli(einstein_img.shape[-2:])
ps.cpu()
ps.cuda()
ps(einstein_img)
template_test_cpu_and_back(ps, einstein_img)

@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda_and_back(self, einstein_img):
ps = po.simul.PortillaSimoncelli(einstein_img.shape[-2:])
ps.cuda()
ps.cpu()
ps(einstein_img.cpu())
einstein_img.to(DEVICE)
template_test_cuda_and_back(ps, einstein_img)

def test_cpu(self, einstein_img):
ps = po.simul.PortillaSimoncelli(einstein_img.shape[-2:])
ps.cpu()
ps(einstein_img.cpu())
einstein_img.to(DEVICE)
template_test_cpu(ps, einstein_img)


class TestFilters:
Expand Down
20 changes: 6 additions & 14 deletions tests/test_steerable_pyr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from itertools import product
from plenoptic.tools.data import to_numpy
from conftest import DEVICE, DATA_DIR
from test_models import (template_test_cuda, template_test_cpu_and_back,
template_test_cuda_and_back, template_test_cpu)


def check_pyr_coeffs(coeff_1, coeff_2, rtol=1e-3, atol=1e-3):
Expand Down Expand Up @@ -294,31 +296,21 @@ def test_order_values(self, img, order):
@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda(self, img):
pyr = po.simul.SteerablePyramidFreq(img.shape[-2:])
pyr.cuda()
pyr(img)
template_test_cuda(pyr, img)

@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cpu_and_back(self, img):
pyr = po.simul.SteerablePyramidFreq(img.shape[-2:])
pyr.cpu()
pyr.cuda()
pyr(img)
template_test_cpu_and_back(pyr, img)

@pytest.mark.skipif(DEVICE.type == 'cpu', reason="Can only test on cuda")
def test_cuda_and_back(self, img):
pyr = po.simul.SteerablePyramidFreq(img.shape[-2:])
pyr.cuda()
pyr.cpu()
pyr(img.cpu())
# make sure it ends on same device it started, since it's a fixture
img.to(DEVICE)
template_test_cuda_and_back(pyr, img)

def test_cpu(self, img):
pyr = po.simul.SteerablePyramidFreq(img.shape[-2:])
pyr.cpu()
pyr(img.cpu())
# make sure it ends on same device it started, since it's a fixture
img.to(DEVICE)
template_test_cpu(pyr, img)

@pytest.mark.parametrize('order', range(1, 16))
def test_buffers(self, order):
Expand Down

0 comments on commit 63c7ecb

Please sign in to comment.