Skip to content

Commit

Permalink
use ModuleList in steerpyr
Browse files Browse the repository at this point in the history
  • Loading branch information
billbrod committed Jan 9, 2024
1 parent 63c7ecb commit 7f724b0
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/plenoptic/simulate/models/portilla_simoncelli.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def __init__(
height=0, order=1,
tight_frame=False
)
for i in range(n_scales):
pyr = SteerablePyramidFreq(
self.unoriented_band_pyrs = torch.nn.ModuleList([
SteerablePyramidFreq(
getattr(self.pyr, f'_himasks_scale_{i}').shape[-2:],
height=1,
order=self.n_orientations - 1,
is_complex=False,
tight_frame=False,
)
setattr(self, f'unoriented_band_pyrs_scale_{i}', pyr)
) for i in range(n_scales)
])

self.use_true_correlations = use_true_correlations
self.scales = (
Expand Down Expand Up @@ -666,7 +666,7 @@ def _calculate_autocorrelation_skew_kurtosis(self):
reconstructed_image = reconstructed_image.unsqueeze(0).unsqueeze(0)

# reconstruct the unoriented band for this scale
unoriented_band_pyr = getattr(self, f'unoriented_band_pyrs_scale_{this_scale}')
unoriented_band_pyr = self.unoriented_band_pyrs[this_scale]
unoriented_pyr_coeffs = unoriented_band_pyr.forward(reconstructed_image)
for ii in range(self.n_orientations):
unoriented_pyr_coeffs[(0, ii)] = (
Expand Down

0 comments on commit 7f724b0

Please sign in to comment.