From 1b5b36ab89b38f9518ae1c8f52f27d8314fdd312 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 26 Aug 2024 15:22:30 -0400 Subject: [PATCH 1/2] make penalize_rnage more efficient --- src/plenoptic/tools/optim.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/plenoptic/tools/optim.py b/src/plenoptic/tools/optim.py index 439cc8c3..ae98cd2a 100644 --- a/src/plenoptic/tools/optim.py +++ b/src/plenoptic/tools/optim.py @@ -126,9 +126,8 @@ def penalize_range( penalty Penalty for values outside range """ - # the indexing should flatten it - below_min = synth_img[synth_img < allowed_range[0]] - below_min = torch.pow(below_min - allowed_range[0], 2) - above_max = synth_img[synth_img > allowed_range[1]] - above_max = torch.pow(above_max - allowed_range[1], 2) - return torch.sum(torch.cat([below_min, above_max])) + # Using clip like this is equivalent to using boolean indexing (e.g., + # synth_img[synth_img < allowed_range[0]]) but much faster + below_min = torch.clip(synth_img - allowed_range[0], max=0).pow(2).sum() + above_max = torch.clip(synth_img - allowed_range[1], min=0).pow(2).sum() + return below_min + above_max From c9230230a5f0f076fd66b097359a375c05d44015 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Mon, 26 Aug 2024 17:21:13 -0400 Subject: [PATCH 2/2] adds some tests for penalize_range --- tests/test_tools.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_tools.py b/tests/test_tools.py index 541bbf66..8147e6a1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -458,3 +458,16 @@ def test_validate_metric_identical(self): @pytest.mark.parametrize('model', ['frontend.OnOff.nograd'], indirect=True) def test_remove_grad(self, model): po.tools.validate.validate_model(model, device=DEVICE) + + +class TestOptim(object): + + def test_penalize_range_above(self): + img = .5 * torch.ones((1, 1, 4, 4)) + img[..., 0, :] = 2 + assert po.tools.optim.penalize_range(img).item() == 4 + + def test_penalize_range_below(self): + img = .5 * torch.ones((1, 1, 4, 4)) + img[..., 0, :] = -1 + assert po.tools.optim.penalize_range(img).item() == 4