Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Portilla-Simoncelli model #225

Merged
merged 119 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
cd7ac47
run black on portilla_simoncelli.py
billbrod Aug 25, 2023
a88e266
updates docstring
billbrod Aug 25, 2023
9ab9ef7
cleans up _get_rep_scales
billbrod Aug 25, 2023
e037240
ran black on steerpyr
billbrod Aug 25, 2023
6bbfdc4
adds type annotations to steerpyr
billbrod Aug 25, 2023
3b50b39
fix typo
billbrod Aug 25, 2023
1a576e9
adds type annotation for PS init
billbrod Aug 25, 2023
6e085ed
Merge branch 'psMinStats' of github.com:LabForComputationalVision/ple…
billbrod Aug 25, 2023
0731e9d
more type annotations
billbrod Aug 25, 2023
f58e529
fixes see also
billbrod Aug 28, 2023
7727b0b
updates type annotations
billbrod Aug 28, 2023
e936099
fix typo
billbrod Aug 28, 2023
b65e87e
updates docstrings in stats.py
billbrod Sep 5, 2023
1f3dbee
PS can use stats.skew/kurtosis rather than having own version
billbrod Sep 5, 2023
5e02309
adds support for multi-batch, multi-channel
billbrod Sep 14, 2023
eecb9a3
steerpyr.recon_pyr should return same dtype as it got
billbrod Sep 18, 2023
d9f3290
adds modulate_phase function
billbrod Sep 18, 2023
8645c0a
adds build/ dir to gitignore
billbrod Sep 18, 2023
f6c24a7
big refactor of PS model
billbrod Sep 18, 2023
911d438
functions in signal.py
billbrod Sep 18, 2023
ecdbb91
shrink uses torch.where now
billbrod Sep 19, 2023
67929fa
fix shrink for gpus
billbrod Sep 19, 2023
1df0617
Merge branch 'main' of github.com:LabForComputationalVision/plenoptic…
billbrod Sep 19, 2023
5632008
Merge branch 'ps_refactor' of github.com:LabForComputationalVision/pl…
billbrod Sep 19, 2023
276ac25
corrects type annotation for autocorrelation
billbrod Sep 20, 2023
11ddd7b
use non-downsampled pyramid
billbrod Sep 20, 2023
19cf521
makes new ps version gpu-compliant
billbrod Sep 20, 2023
b2813f4
go back to indexing method for shrink
billbrod Sep 20, 2023
9c7f2db
overhaul again! now lists of tensors
billbrod Sep 20, 2023
3d74ed9
fixes for plot_representation
billbrod Sep 20, 2023
8483277
removes attributes that are no longer here
billbrod Sep 20, 2023
fb7a924
modulate_phase: only call atan2 once
billbrod Sep 21, 2023
ad1b532
remove ues_true_correlation
billbrod Sep 21, 2023
b399ffd
remove always unnecessary stats
billbrod Sep 21, 2023
d13c2e0
cleans up plotting code
billbrod Sep 21, 2023
a3bf090
fix for multi-channel
billbrod Oct 30, 2023
3c50cc2
Adds scales_shape_dict
billbrod Oct 31, 2023
3e69429
Adds necessary stats mask
billbrod Nov 2, 2023
6675871
modulate_phase: only compute x.abs() once
billbrod Nov 2, 2023
c943577
updates PS tests to (mostly) work with refactor
billbrod Nov 3, 2023
723fe0b
fixes scale test
billbrod Nov 6, 2023
68ae172
remove torchvision: use center_crop
billbrod Nov 6, 2023
d55a5a2
remove negative of phase doubled real
billbrod Nov 6, 2023
6bfbc9a
adds torchvision back as nb dependency
billbrod Nov 6, 2023
efd69e7
use tools.center_crop in notebooks
billbrod Nov 6, 2023
cb52c38
adds PS tests for differently-shaped images
billbrod Nov 6, 2023
efdf4fc
tests, changes for plotting
billbrod Nov 6, 2023
cedec11
adds tests for shape and redundancies
billbrod Nov 7, 2023
f61d67d
adds tests for expand and shrink
billbrod Nov 7, 2023
fe3c1d8
updates tools tests
billbrod Nov 8, 2023
062a30a
bugfix for PS update_plot
billbrod Nov 8, 2023
9e0290b
updates docstrings and comments
billbrod Nov 8, 2023
6359d6c
Merge branch 'main' of github.com:LabForComputationalVision/plenoptic…
billbrod Nov 8, 2023
916bea5
updates metamer and display notebooks for new PS
billbrod Nov 8, 2023
9745dff
corrects indent
billbrod Nov 8, 2023
4a97218
reruns PS notebook through final section
billbrod Nov 8, 2023
c1817f0
Merge branch 'ps_refactor' of github.com:LabForComputationalVision/pl…
billbrod Nov 8, 2023
b602367
starts changing to true cross-corr
billbrod Nov 13, 2023
a3adb92
adds new section to tips
billbrod Nov 14, 2023
f5cc7fe
adds mags_std to representation
billbrod Nov 14, 2023
315e0f4
updates tests for real cross-corrs
billbrod Nov 14, 2023
f24caae
adds test for cross correlations
billbrod Nov 15, 2023
42a33c9
reruns most of PS notebook
billbrod Nov 15, 2023
71c23f2
Literal comes from typing_extensions in python 3.7
billbrod Nov 15, 2023
fe9eac9
Merge branch 'ps_refactor' of github.com:LabForComputationalVision/pl…
billbrod Nov 15, 2023
38c392e
Literal comes from typing_extensions in python 3.7
billbrod Nov 15, 2023
4afc98b
fixes some failing tests
billbrod Nov 16, 2023
18f1996
adds url for ps_synth_gpu refactor
billbrod Nov 16, 2023
25d3510
fixes to make tests run on GPU
billbrod Nov 16, 2023
e2e663d
Merge branch 'ps_refactor' of github.com:LabForComputationalVision/pl…
billbrod Nov 16, 2023
3b1b391
Merge branch 'ps_refactor' of github.com:LabForComputationalVision/pl…
billbrod Nov 16, 2023
a7a8061
fix: center_crop accepts only single int
billbrod Nov 16, 2023
fdc99ae
testing.array_equal for arrays, not testing.equal
billbrod Nov 16, 2023
eb1bd17
make shrink and expand error messages the same
billbrod Nov 16, 2023
e13b3b3
updates some tolerances
billbrod Nov 16, 2023
f60198a
fix failing test
billbrod Nov 16, 2023
3f78113
fixes failing tests
billbrod Nov 17, 2023
aef0696
fixes failing tests
billbrod Nov 17, 2023
d2ca2a8
fix for gpu tests
billbrod Nov 17, 2023
a9bdc65
Merge branch 'main' of github.com:LabForComputationalVision/plenoptic…
billbrod Nov 17, 2023
e709187
Merge branch 'main' of github.com:LabForComputationalVision/plenoptic…
billbrod Dec 11, 2023
fb479d3
Merge branch 'main' of github.com:LabForComputationalVision/plenoptic…
billbrod Dec 11, 2023
a2ab279
Make PS statistics tutorial work for refactored code
dherrera1911 Dec 14, 2023
b8cb360
Fix variable name typo
dherrera1911 Dec 14, 2023
f9f230a
Apply Edoardo's suggestions from code review
billbrod Jan 3, 2024
be7909d
add update_plot test
billbrod Jan 3, 2024
02a72ea
adds test for convert to dict errors
billbrod Jan 3, 2024
9327bcd
test even spatial_corr_width
billbrod Jan 3, 2024
493eb40
updates test_ps_torch_output
billbrod Jan 3, 2024
2dde6a9
adds test_vectors_refactor.tar.gz to osf_download
billbrod Jan 3, 2024
ce13d8b
updates test_ps_scales to include even spatial corr width
billbrod Jan 3, 2024
aebfcca
Merge branch 'ps_refactor' of github.com:LabForComputationalVision/pl…
billbrod Jan 3, 2024
919a0eb
runs isort
billbrod Jan 3, 2024
95ca562
makes pyramid attribute private
billbrod Jan 3, 2024
514acf3
replace B,C,S / B,C,H,W style notation
billbrod Jan 3, 2024
856df9d
adds Raises to PS.forward()
billbrod Jan 3, 2024
be22a16
breaks up line
billbrod Jan 3, 2024
f7a575d
change type annotations for Figure/Axes
billbrod Jan 3, 2024
c4ee4bd
adds explanation for einops.rearrange
billbrod Jan 4, 2024
15fbf5b
tries to straighten out type annotations
billbrod Jan 4, 2024
15dddf7
fix for last commit
billbrod Jan 4, 2024
78144d2
change threshold to be based on dtype.resolution
billbrod Jan 4, 2024
b23b5c0
switch from single letter var
billbrod Jan 4, 2024
28b3874
adds module level docstrings
billbrod Jan 4, 2024
ea17809
fix failing test on gpu
billbrod Jan 5, 2024
28ff190
fix failing notebook
billbrod Jan 5, 2024
bbbd2ac
Merge branch 'ps_refactor' of github.com:LabForComputationalVision/pl…
billbrod Jan 5, 2024
9f323f3
rename vector -> tensor, update notebook
billbrod Jan 5, 2024
9b09d27
Updates PS notebook with magnitude means
billbrod Jan 9, 2024
76e59fa
rearrange cells
billbrod Jan 11, 2024
29d46fd
adds more detailed explanation of scales shape dict
billbrod Jan 12, 2024
191bbe3
updates language in notebook
billbrod Jan 12, 2024
1bf81ef
updates description of expand/shrink
billbrod Feb 26, 2024
de5c970
removes unnecessary comments
billbrod Feb 26, 2024
9817859
updates PS notebook
billbrod Feb 28, 2024
f692242
Merge branch 'main' of github.com:LabForComputationalVision/plenoptic…
billbrod Feb 28, 2024
ea86a6a
fixes some rendering issues
billbrod Feb 28, 2024
99cdcae
fix typo
billbrod Feb 28, 2024
cad4606
fixes failing things from merge
billbrod Feb 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions docs/tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ methods.
- For metamers, this means double-checking that the difference between the model
representation of the metamer and the target image is small enough. If your
model's representation is multi-scale, trying coarse-to-fine optimization may
help (see `notebook <tutorials/06_Metamer.html#Coarse-to-fine-optimization>`_
help (see `notebook <tutorials/intro/06_Metamer.html#Coarse-to-fine-optimization>`_
for details).
- For MAD competition, this means double-checking that the reference metric is
constant and that the optimized metric has converged at a lower or higher
Expand All @@ -59,6 +59,40 @@ Additionally, it may be helpful to visualize the progression of synthesis, using
each synthesis method's ``animate`` or ``plot_synthesis_status`` helper
functions (e.g., :func:`plenoptic.synthesize.metamer.plot_synthesis_status`).

Tweaking the model
------------------

You can also improve your changes of finding a good synthesis by tweaking the
model. For example, the loss function used for metamer synthesis by default is
mean-squared error. This implicitly weights all aspects of the model's
representation equally. Thus, if there are portions of the representation whose
magnitudes are significantly smaller than the others, they might not be matched
at the same rate as the others. You can address this using coarse-to-fine
synthesis or picking a more suitable loss function, but it's generally a good
idea for all of a model's representation to have roughly the same magnitude. You
can do this in a principled or empirical manner:

- Principled: compose your representation of statistics that you know lie within
the same range. For example, use correlations instead of covariances (see the
Portilla-Simoncelli model, and in particular `how plenoptic's implementation
differs from matlab
<tutorials/models/Metamer-Portilla-Simoncelli#7.-Notable-differences-between-Matlab-and-Python-Implementations>`_
for an example of this).
- Empirical: measure your model's representation on a dataset of relevant
natural images and then use this output to z-score your model's representation
on each pass (see [Ziemba2021]_ for an example; this is what the Van Hateren
database is used for).
- In the middle: normalize statistics based on their value in the original image
(note: not the image the model is taking as input! this will likely make
optimization very difficult).

If you are computing a multi-channel representation, you may have a similar
problem where one channel is larger or smaller than the others. Here, tweaking
the loss function might be more useful. Using something like `logsumexp` (the
log of the sum of exponentials, a smooth approximation of the maximum function)
to combine across channels after using something like L2-norm to compute the
loss within each channel might help.

None of the existing synthesis methods meet my needs
====================================================

Expand All @@ -79,4 +113,4 @@ methods.

If you extend a method successfully or would like help making it work, please
let us know by posting a `discussion!
<https://github.com/Flatiron-CCN/plenoptic/discussions>`_
<https://github.com/LabForComputationalVision/plenoptic/discussions>`_
15 changes: 2 additions & 13 deletions examples/02_Eigendistortions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -572,21 +572,10 @@
}
],
"source": [
"# a couple helper functions\n",
"\n",
"def center_crop(im, n):\n",
" \"\"\"Crop an nxn image from the center of im\"\"\"\n",
" im_height, im_width = im.shape[-2:]\n",
" assert n<im_height and n<im_width\n",
"\n",
" im_crop = im[..., im_height//2-n//2:im_height//2+n//2,\n",
" im_width//2-n//2:im_width//2+n//2]\n",
" return im_crop\n",
"\n",
"n = 128 # this will be the img_height and width of the input, you can change this to accommodate your machine\n",
"img = po.data.color_wheel()\n",
"# center crop the image to nxn\n",
"img = center_crop(img, n)\n",
"img = po.tools.center_crop(img, n)\n",
"po.imshow(img, as_rgb=True, zoom=3);"
]
},
Expand Down Expand Up @@ -975,7 +964,7 @@
"img = po.data.curie()\n",
"\n",
"# center crop the image to nxn\n",
"img = center_crop(img, n)\n",
"img = po.tools.center_crop(img, n)\n",
"# because this is a grayscale image but ResNet expects a color image, \n",
"# need to duplicate along the color dimension\n",
"img3 = torch.repeat_interleave(img, 3, dim=1)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/03_Steerable_Pyramid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
"for k in pyr_coeffs.keys():\n",
" # we ignore the residual_highpass and residual_lowpass, since we're focusing on the filters here\n",
" if isinstance(k, tuple):\n",
" reconList.append(pyr.recon_pyr(pyr_coeffs, k[0], k[1]))\n",
" reconList.append(pyr.recon_pyr(pyr_coeffs, [k[0]], [k[1]]))\n",
" \n",
"po.imshow(reconList, col_wrap=order+1, vrange='indep1', zoom=2);"
]
Expand Down
9 changes: 4 additions & 5 deletions examples/05_Geodesics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
" \" please install it in your plenoptic environment \"\n",
" \"and restart the notebook kernel\")\n",
"import torchvision.transforms as transforms\n",
"from torchvision.transforms.functional import center_crop\n",
"from torchvision import models\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
Expand Down Expand Up @@ -111,7 +110,7 @@
"einstein = po.data.einstein()\n",
"einstein = po.tools.conv.blur_downsample(einstein, n_scales=2)\n",
"vid = po.tools.translation_sequence(einstein, n_steps=20)\n",
"vid = center_crop(vid, image_size // 2)\n",
"vid = po.tools.center_crop(vid, image_size // 2)\n",
"vid = po.tools.rescale(vid, 0, 1)\n",
"\n",
"imgA = vid[0:1]\n",
Expand Down Expand Up @@ -1066,9 +1065,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:plen_3.10]",
"display_name": "plenoptic",
"language": "python",
"name": "conda-env-plen_3.10-py"
"name": "plenoptic"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1080,7 +1079,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
"version": "3.10.13"
},
"toc-autonumbering": true,
"toc-showtags": true
Expand Down
27,533 changes: 12,942 additions & 14,591 deletions examples/06_Metamer.ipynb

Large diffs are not rendered by default.

19,129 changes: 9,125 additions & 10,004 deletions examples/Display.ipynb

Large diffs are not rendered by default.

1,246 changes: 409 additions & 837 deletions examples/Metamer-Portilla-Simoncelli.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,15 @@ dev = [
"pytest>=5.1.2",
'pytest-cov',
'pytest-xdist',
"torchvision>=0.3",
"requests>=2.21",
"pooch>=1.2.0",
]

nb = [
'jupyter',
'ipywidgets',
"torchvision>=0.3",
'nbclient>=0.5.5',
"torchvision>=0.3",
"pooch>=1.2.0",
]

Expand Down
10 changes: 9 additions & 1 deletion src/plenoptic/data/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
'sample_images.tar.gz': '0ba6fe668a61e9f3cb52032da740fbcf32399ffcc142ddb14380a8e404409bf5',
'test_images.tar.gz': 'eaf35f5f6136e2d51e513f00202a11188a85cae8c6f44141fb9666de25ae9554',
'tid2013.tar.gz': 'bc486ac749b6cfca8dc5f5340b04b9bb01ab24149a5f3a712f13e9d0489dcde0',
'portilla_simoncelli_test_vectors_refactor.tar.gz': '2ca60f1a60b192668567eb3d94c0cdc8679b23bf94a98331890c41eb9406503a',
'portilla_simoncelli_synthesize_torch_v1.12.0_ps-refactor.npz': '9525844b71cf81509b86ed9677172745353588c6bb54e4de8000d695598afa47',
'portilla_simoncelli_synthesize_gpu_ps-refactor.npz': '9fbb490f1548133f6aa49c54832130cf70f8dc6546af59688ead17f62ab94e61',
'portilla_simoncelli_scales_ps-refactor.npz': '1053790a37707627810482beb0dd059cbc193efd688b60c441b3548a75626fdf',
}

OSF_TEMPLATE = "https://osf.io/{}/download"
Expand All @@ -40,8 +44,12 @@
'sample_images.tar.gz': OSF_TEMPLATE.format('6drmy'),
'test_images.tar.gz': OSF_TEMPLATE.format('au3b8'),
'tid2013.tar.gz': OSF_TEMPLATE.format('uscgv'),
'portilla_simoncelli_test_vectors_refactor.tar.gz': OSF_TEMPLATE.format('ca7qt'),
'portilla_simoncelli_synthesize_torch_v1.12.0_ps-refactor.npz': OSF_TEMPLATE.format('vmwzd'),
'portilla_simoncelli_synthesize_gpu_ps-refactor.npz': OSF_TEMPLATE.format('mqs6y'),
'portilla_simoncelli_scales_ps-refactor.npz': OSF_TEMPLATE.format('nvpr4'),
}
DOWNLOADABLE_FILES = list(REGISTRY.keys())
DOWNLOADABLE_FILES = list(REGISTRY_URLS.keys())

import pathlib
from typing import List
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, x):

Parameters
----------
x: torch.Tensor of shape (B, C, H, W)
x: torch.Tensor of shape (batch, channel, height, width)
Image, or batch of images. If there are multiple channels,
the Laplacian is computed separately for each of them

Expand Down Expand Up @@ -71,7 +71,7 @@ def recon_pyr(self, y):

Returns
-------
x: torch.Tensor of shape (B, C, H, W)
x: torch.Tensor of shape (batch, channel, height, width)
Image, or batch of images
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def local_gain_control(x, epsilon=1e-8):
Parameters
----------
x : torch.Tensor
Tensor of shape (B,C,H,W)
Tensor of shape (batch, channel, height, width)
epsilon: float, optional
Small constant to avoid division by zero.

Expand Down Expand Up @@ -134,7 +134,7 @@ def local_gain_release(norm, direction, epsilon=1e-8):
Returns
-------
x : torch.Tensor
Tensor of shape (B,C,H,W)
Tensor of shape (batch, channel, height, width)

Notes
-----
Expand Down Expand Up @@ -163,7 +163,7 @@ def local_gain_control_dict(coeff_dict, residuals=True):
Parameters
----------
coeff_dict : dict
A dictionary containing tensors of shape (B,C,H,W)
A dictionary containing tensors of shape (batch, channel, height, width)
residuals: bool, optional
An option to carry around residuals in the energy dict.
Note that the transformation is not applied to the residuals,
Expand Down Expand Up @@ -219,7 +219,7 @@ def local_gain_release_dict(energy, state, residuals=True):
Returns
-------
coeff_dict : dict
A dictionary containing tensors of shape (B,C,H,W)
A dictionary containing tensors of shape (batch, channel, height, width)

Notes
-----
Expand Down
Loading