Skip to content

Commit

Permalink
Add tailing detail resnet module
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed Jun 1, 2024
1 parent 346086c commit 0770ff2
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 9 deletions.
9 changes: 9 additions & 0 deletions descreen/networks/model/abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from abc import ABCMeta, abstractmethod
from numpy import ndarray
from torch import Tensor
from .. import AbsModule
from ...utilities import range_chunks
from ...utilities.filesys import resolve_path
Expand Down Expand Up @@ -57,6 +58,14 @@ def __deepcopy__(self, memo):
DescreenModel.params_json[id(cp)] = DescreenModel.params_json[id(self)]
return cp

@abstractmethod
def forward_t(self, x: Tensor) -> tuple[Tensor, Tensor]:
raise NotImplementedError()

def forward(self, x: Tensor) -> Tensor:
a, b = self.forward_t(x)
return b

@classmethod
@abstractmethod
def alias(cls) -> str:
Expand Down
11 changes: 8 additions & 3 deletions descreen/networks/model/basic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import torch.nn as nn
from .. import DescreenModel
from ...resnet import RepeatedResidualBlock
from ...modules import ResidualBlock
from ....utilities.array import fit_to_smaller_add
from ....utilities.array import fit_to_smaller, fit_to_smaller_add


class TopLevelModel(DescreenModel):
Expand All @@ -12,14 +13,18 @@ def __init__(self, internal_channels=128, N=8):
self.conv1 = nn.Conv2d(in_channels, internal_channels, kernel_size=3, padding=0)
self.blocks = nn.ModuleList([ResidualBlock(internal_channels, 25, nn.ReLU()) for _ in range(N)])
self.conv2 = nn.Conv2d(internal_channels, out_channels, kernel_size=3, padding=0)
self.resnet = RepeatedResidualBlock(3, 3, internal_channels)

def forward(self, x):
def forward_t(self, x):
residual = x
out = self.conv1(x)
for block in self.blocks:
out = block(out)
out = self.conv2(out)
return fit_to_smaller_add(residual, out)
r = self.resnet(out)
h = fit_to_smaller_add(out, r)
m, ff = fit_to_smaller(out, h)
return m, ff

@classmethod
def alias(cls) -> str:
Expand Down
14 changes: 10 additions & 4 deletions descreen/networks/model/unet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .. import DescreenModel
from ... import AbsModule
from ...modules import ResidualBlock, Lanczos2xUpsampler
from ...resnet import RepeatedResidualBlock
from ...utils import input_size, output_size
from ....utilities.array import fit_to_smaller, fit_to_smaller_add

Expand Down Expand Up @@ -63,12 +64,17 @@ def __init__(self, channels=128):
# self.up2 = Lanczos2xUpsampler(n=2, pad=False)
# self.up3 = Lanczos2xUpsampler(n=2, pad=False)
self.out = nn.Conv2d(channels, 3, kernel_size=3, stride=1, padding=0)
self.resnet = RepeatedResidualBlock(3, 3, channels)

def forward(self, x):
def forward_t(self, x):
z = self.down(x)
h1 = self.lower_block(z)
h2 = self.upper_block(x, h1)
return fit_to_smaller_add(x, self.out(h2))
h = fit_to_smaller_add(x, self.out(h2))
r = self.resnet(h)
f = fit_to_smaller_add(h, r)
m, ff = fit_to_smaller(h, f)
return m, ff

@classmethod
def alias(cls) -> str:
Expand All @@ -79,7 +85,7 @@ def multiple_of(self) -> int:
return 2

def input_size_unchecked(self, s):
return self.lower_block.input_size(self.upper_block.input_size(input_size(s, 3))) * 2
return self.lower_block.input_size(self.upper_block.input_size(input_size(self.resnet.input_size(s), 3))) * 2

def output_size_unchecked(self, s):
return output_size(self.upper_block.output_size(self.lower_block.output_size(s // 2)), 3)
return self.resnet.output_size(output_size(self.upper_block.output_size(self.lower_block.output_size(s // 2)), 3))
22 changes: 22 additions & 0 deletions descreen/networks/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,25 @@ def input_size_unchecked(self, output_size: int) -> int:

def output_size_unchecked(self, input_size: int) -> int:
return output_size(output_size(output_size(input_size, self.ksize), 3), 3)


class SimpleResidualBlock(AbsModule):
def __init__(self, channels, activation) -> None:
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=0)
self.activation = activation
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=0)

def forward(self, x: Tensor) -> Tensor:
residual = x
out = self.conv1(x)
out = self.activation(out)
out = self.conv2(out)
return fit_to_smaller_add(residual, out)

def input_size_unchecked(self, output_size: int) -> int:
return input_size(input_size(output_size, 3), 3)

def output_size_unchecked(self, input_size: int) -> int:
return output_size(output_size(input_size, 3), 3)

33 changes: 33 additions & 0 deletions descreen/networks/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@

import torch.nn as nn
from torch import Tensor
from . import AbsModule
from .modules import SimpleResidualBlock
from .utils import input_size, output_size

class RepeatedResidualBlock(AbsModule):
def __init__(self, in_channels, out_channels, inner_channels, activation=nn.ReLU(), n=8) -> None:
super().__init__()
self.in_conv = nn.Conv2d(in_channels, inner_channels, kernel_size=1)
self.out_conv = nn.Conv2d(inner_channels, out_channels, kernel_size=1)
self.blocks = nn.ModuleList([SimpleResidualBlock(inner_channels, activation) for _ in range(n)])

def forward(self, x: Tensor) -> Tensor:
out = self.in_conv(x)
for block in self.blocks:
out = block(out)
out = self.out_conv(out)
return out

def input_size_unchecked(self, output_size: int) -> int:
size = output_size
for _ in range(len(self.blocks)):
size = input_size(size, 3)
return size

def output_size_unchecked(self, input_size: int) -> int:
size = input_size
for _ in range(len(self.blocks)):
size = output_size(size, 3)
return size

4 changes: 2 additions & 2 deletions descreen/training/proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def interrupt(signum, frame):
test_step()

def train_step(x, y):
pred = model(x)
loss = descreen_loss(pred, y, tv=0.01)
pred_mid, pred_full = model.forward_t(x)
loss = descreen_loss(pred_mid, y, tv=0.01) + 1.5 * descreen_loss(pred_full, y, tv=0.01)
optimizer.zero_grad()
loss.backward()
print(f"loss: {loss}")
Expand Down

0 comments on commit 0770ff2

Please sign in to comment.