Skip to content

Commit

Permalink
Fix bugs when processing 2D datasets with main scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
fpicetti committed Aug 4, 2021
1 parent 48618d3 commit 2542890
Show file tree
Hide file tree
Showing 5 changed files with 532 additions and 11 deletions.
2 changes: 1 addition & 1 deletion data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def extract_patches(args) -> List[dict]:

pe = _get_patch_extractor(original.shape, args.patch_shape, args.patch_stride, args.datadim, args.imgchannel)

if args.datadim == "2.5d":
if args.datadim == "2.5d" or (args.datadim == "2d" and pe.ndim == 3):
final_shape = (-1,) + pe.dim
else:
final_shape = (-1,) + pe.dim + (1,)
Expand Down
6 changes: 3 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,17 +172,17 @@ def optimization_loop(self):
if self.iiter == 0:
self.loss_min = self.history.loss[-1]
self.out_best = u.torch_to_np(out_, True) if out_.ndim > 4 else \
u.torch_to_np(out_, False).squeeze().transpose((1, 2, 0))
u.torch_to_np(out_, False)[0].transpose((1, 2, 0))
elif self.history.loss[-1] <= self.loss_min:
self.loss_min = self.history.loss[-1]
self.out_best = u.torch_to_np(out_, True) if out_.ndim > 4 else \
u.torch_to_np(out_, False).squeeze().transpose((1, 2, 0))
u.torch_to_np(out_, False)[0].transpose((1, 2, 0))
else:
pass

# saving intermediate outputs
if self.iiter in self.iter_to_be_saved and self.iiter != 0:
out_img = u.torch_to_np(out_, True) if out_.ndim > 4 else u.torch_to_np(out_, False).squeeze().transpose((1, 2, 0))
out_img = u.torch_to_np(out_, True) if out_.ndim > 4 else u.torch_to_np(out_, False)[0].transpose((1, 2, 0))
np.save(os.path.join(self.outpath,
self.image_name.split('.')[0] + '_output%s.npy' % str(self.iiter).zfill(self.zfill)),
out_img)
Expand Down
10 changes: 5 additions & 5 deletions main_pocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,18 @@ def optimization_loop(self):
# save the output if the loss is decreasing
if self.iiter == 0:
self.loss_min = self.history.loss[-1]
self.out_best = u.torch_to_np(out_, True) if out_.ndim > 4 else u.torch_to_np(out_, False).squeeze().transpose(
(1, 2, 0))
self.out_best = u.torch_to_np(out_, True) if out_.ndim > 4 else \
u.torch_to_np(out_, False)[0].transpose((1, 2, 0))
elif self.history.loss[-1] <= self.loss_min:
self.loss_min = self.history.loss[-1]
self.out_best = u.torch_to_np(out_, True) if out_.ndim > 4 else u.torch_to_np(out_, False).squeeze().transpose(
(1, 2, 0))
self.out_best = u.torch_to_np(out_, True) if out_.ndim > 4 else \
u.torch_to_np(out_, False)[0].transpose((1, 2, 0))
else:
pass

# saving intermediate outputs
if self.iiter in self.iter_to_be_saved and self.iiter != 0:
out_img = u.torch_to_np(out_, True) if out_.ndim > 4 else u.torch_to_np(out_, False).squeeze().transpose((1, 2, 0))
out_img = u.torch_to_np(out_, True) if out_.ndim > 4 else u.torch_to_np(out_, False)[0].transpose((1, 2, 0))
np.save(os.path.join(self.outpath,
self.image_name.split('.')[0] + '_output%s.npy' % str(self.iiter).zfill(self.zfill)),
out_img)
Expand Down
521 changes: 521 additions & 0 deletions proof_of_concept_2D.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions proof_of_concept.ipynb → proof_of_concept_3D.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -611,7 +611,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
"version": "3.8.10"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 2542890

Please sign in to comment.