Skip to content

Commit

Permalink
Merge pull request #16 from juglab/merger
Browse files Browse the repository at this point in the history
Merger
  • Loading branch information
tibuch authored Mar 9, 2021
2 parents 25ec06f + 8596e35 commit 4a0fff1
Show file tree
Hide file tree
Showing 13 changed files with 437 additions and 102 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Build Python package:
`python setup.py bdist_wheel`

Build singularity recipe:
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.19-py3-none-any.whl /fourier_image_transformers-0.1.19-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.19-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.19.Singularity`
`neurodocker generate singularity -b nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 -p apt --copy /home/tibuch/Gitrepos/FourierImageTransformer/dist/fourier_image_transformers-0.1.20-py3-none-any.whl /fourier_image_transformers-0.1.20-py3-none-any.whl --miniconda create_env=fit conda_install='python=3.7 astra-toolbox pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch -c astra-toolbox/label/dev' pip_install='/fourier_image_transformers-0.1.20-py3-none-any.whl' activate=true --entrypoint "/neurodocker/startup.sh python" > v0.1.20.Singularity`

Build singularity container:
`sudo singularity build fit_v0.1.19.simg v0.1.19.Singularity`
`sudo singularity build fit_v0.1.20.simg v0.1.20.Singularity`
11 changes: 6 additions & 5 deletions fit/datamodules/super_res/SRecDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


class MNISTSResFourierTargetDataModule(LightningDataModule):
IMG_SHAPE = 28
IMG_SHAPE = 27

def __init__(self, root_dir, batch_size):
"""
Expand All @@ -38,8 +38,9 @@ def setup(self, stage: Optional[str] = None):
mnist_train_val = MNIST(self.root_dir, train=True, download=True).data.type(torch.float32)
np.random.seed(1612)
perm = np.random.permutation(mnist_train_val.shape[0])
mnist_train = mnist_train_val[perm[:55000]]
mnist_val = mnist_train_val[perm[55000:]]
mnist_train = mnist_train_val[perm[:55000], 1:, 1:]
mnist_val = mnist_train_val[perm[55000:], 1:, 1:]
mnist_test = mnist_test[:, 1:, 1:]

assert mnist_train.shape[1] == MNISTSResFourierTargetDataModule.IMG_SHAPE
assert mnist_train.shape[2] == MNISTSResFourierTargetDataModule.IMG_SHAPE
Expand Down Expand Up @@ -77,7 +78,7 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
return DataLoader(
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
img_shape=MNISTSResFourierTargetDataModule.IMG_SHAPE),
batch_size=1)
batch_size=self.batch_size)


class CelebASResFourierTargetDataModule(LightningDataModule):
Expand Down Expand Up @@ -134,4 +135,4 @@ def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]
return DataLoader(
SResFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
img_shape=self.gt_shape),
batch_size=1)
batch_size=self.batch_size)
2 changes: 1 addition & 1 deletion fit/datamodules/super_res/SResFCDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __getitem__(self, item):

img_mag = 2 * (img_mag - self.mag_min) / (self.mag_max - self.mag_min) - 1

img_phi = 2 * img_phi / (2 * np.pi) - 1
img_phi = img_phi / np.pi

img_fft = torch.stack([img_mag.flatten(), img_phi.flatten()], dim=-1)
return img_fft, (self.mag_min.unsqueeze(-1), self.mag_max.unsqueeze(-1))
Expand Down
95 changes: 94 additions & 1 deletion fit/datamodules/tomo_rec/TRecDataModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,100 @@ def setup(self, stage: Optional[str] = None):

self.gt_ds = get_projection_dataset(
GroundTruthDataset(gt_train, gt_val, gt_test),
num_angles=self.num_angles, im_shape=450, impl='astra_cpu', inner_circle=self.inner_circle)
num_angles=self.num_angles, im_shape=self.gt_shape + (self.gt_shape // 2 - 7), impl='astra_cpu',
inner_circle=self.inner_circle)

tmp_fcds = TRecFourierCoefficientDataset(self.gt_ds, mag_min=None, mag_max=None, part='train',
img_shape=self.gt_shape)
self.mag_min = tmp_fcds.mag_min
self.mag_max = tmp_fcds.mag_max

def train_dataloader(self, *args, **kwargs) -> DataLoader:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='train',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=1)

def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='validation',
img_shape=self.gt_shape),
batch_size=self.batch_size, num_workers=1)

def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
return DataLoader(
TRecFourierCoefficientDataset(self.gt_ds, mag_min=self.mag_min, mag_max=self.mag_max, part='test',
img_shape=self.gt_shape),
batch_size=1)


class CropLoDoPaBFourierTargetDataModule(LightningDataModule):
IMG_SHAPE = 361

def __init__(self, batch_size, gt_shape=361, num_angles=15):
"""
:param root_dir:
:param batch_size:
:param num_angles:
"""
super().__init__()
self.batch_size = batch_size
self.gt_shape = gt_shape
self.num_angles = num_angles
self.inner_circle = True
self.gt_ds = None
self.mean = None
self.std = None

def setup(self, stage: Optional[str] = None):
lodopab = dival.get_standard_dataset('lodopab', impl='astra_cpu')
assert self.gt_shape <= self.IMG_SHAPE, 'GT is larger than original images.'
if self.gt_shape < self.IMG_SHAPE:
crop_off = (362 - self.gt_shape) // 2
gt_train = np.array([lodopab.get_sample(i, part='train', out=(False, True))[1][crop_off:-(crop_off + 1),
crop_off:-(crop_off + 1)] for i in
range(4000)])
gt_val = np.array([lodopab.get_sample(i, part='validation', out=(False, True))[1][crop_off:-(crop_off + 1),
crop_off:-(crop_off + 1)] for i in
range(400)])
gt_test = np.array([lodopab.get_sample(i, part='test', out=(False, True))[1][crop_off:-(crop_off + 1),
crop_off:-(crop_off + 1)] for i in
range(3553)])
else:
gt_train = np.array(
[lodopab.get_sample(i, part='train', out=(False, True))[1][1:, 1:] for i in range(4000)])
gt_val = np.array(
[lodopab.get_sample(i, part='validation', out=(False, True))[1][1:, 1:] for i in range(400)])
gt_test = np.array(
[lodopab.get_sample(i, part='test', out=(False, True))[1][1:, 1:] for i in range(3553)])

gt_train = torch.from_numpy(gt_train)
gt_val = torch.from_numpy(gt_val)
gt_test = torch.from_numpy(gt_test)

assert gt_train.shape[1] == self.gt_shape
assert gt_train.shape[2] == self.gt_shape
x, y = torch.meshgrid(torch.arange(-self.gt_shape // 2 + 1,
self.gt_shape // 2 + 1),
torch.arange(-self.gt_shape // 2 + 1,
self.gt_shape // 2 + 1))

self.mean = gt_train.mean()
self.std = gt_train.std()

gt_train = normalize(gt_train, self.mean, self.std)
gt_val = normalize(gt_val, self.mean, self.std)
gt_test = normalize(gt_test, self.mean, self.std)

circle = torch.sqrt(x ** 2. + y ** 2.) <= self.gt_shape // 2
gt_train *= circle
gt_val *= circle
gt_test *= circle

self.gt_ds = get_projection_dataset(
GroundTruthDataset(gt_train, gt_val, gt_test),
num_angles=self.num_angles, im_shape=self.gt_shape + (self.gt_shape // 2 - 7), impl='astra_cpu',
inner_circle=self.inner_circle)

tmp_fcds = TRecFourierCoefficientDataset(self.gt_ds, mag_min=None, mag_max=None, part='train',
img_shape=self.gt_shape)
Expand Down
4 changes: 2 additions & 2 deletions fit/datamodules/tomo_rec/TRecFCDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __getitem__(self, item):
sino_mag = 2 * (sino_mag - self.mag_min) / (self.mag_max - self.mag_min) - 1
img_mag = 2 * (img_mag - self.mag_min) / (self.mag_max - self.mag_min) - 1

sino_phi = 2 * sino_phi / (2 * np.pi) - 1
img_phi = 2 * img_phi / (2 * np.pi) - 1
sino_phi = sino_phi / np.pi
img_phi = img_phi / np.pi

sino_fft = torch.stack([sino_mag.flatten(), sino_phi.flatten()], dim=-1)
img_fft = torch.stack([img_mag.flatten(), img_phi.flatten()], dim=-1)
Expand Down
Loading

0 comments on commit 4a0fff1

Please sign in to comment.