Skip to content

Commit

Permalink
Merge branch 'master' of github.com:curegit/deep-descreen
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed Apr 26, 2024
2 parents 40ce0b9 + 9f831f1 commit 346086c
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 42 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
!/descreen/**/*.icc
*.ddbin
!/descreen/**/*.ddbin
*.pdf

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
8 changes: 4 additions & 4 deletions descreen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
from pathlib import Path
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from .image import load_image, save_image, magick_wide_png, magick_srgb_png
from .networks.model import DescreenModel, pull, names
from .networks.model import DescreenModel, pull, names, default_name
from .utilities.args import upper, nonempty, file, filelike, backend_devices


def main():
exit_code = 0
parser = ArgumentParser(prog="descreen", allow_abbrev=False, formatter_class=ArgumentDefaultsHelpFormatter, description="")
parser.add_argument("image", metavar="FILE", type=filelike(exist=True), help="describe directory")
parser.add_argument("output", metavar="FILE", type=filelike(exist=False), nargs="?", default=..., help="describe input image files (pass '-' to specify stdin)")
parser.add_argument("image", metavar="IN_FILE", type=filelike(exist=True), help="describe directory")
parser.add_argument("output", metavar="OUT_FILE", type=filelike(exist=False), nargs="?", default=..., help="describe input image files (pass '-' to specify stdin)")
dest_group = parser.add_mutually_exclusive_group()
dest_group.add_argument("-m", "--model", metavar="NAME", choices=names, default=names[0], help=f"send output to standard output {names}")
dest_group.add_argument("-m", "--model", metavar="NAME", choices=names, default=default_name, help=f"send output to standard output {names}")
dest_group.add_argument("-d", "--ddbin", metavar="FILE", type=file(exist=True), help="save output images in DIR directory")
dest_group.add_argument("-x", "--onnx", metavar="FILE", type=file(exist=True), help="save output images in DIR directory")
parser.add_argument("-q", "--quantize", "--depth", type=int, default=8, choices=[8, 16], help="color depth of output PNG")
Expand Down
1 change: 1 addition & 0 deletions descreen/networks/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

names = list(files.keys())

default_name = names[0]

def pull(name: str):
f = files.get(name)
Expand Down
2 changes: 1 addition & 1 deletion descreen/networks/model/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def __init__(self, internal_channels=128, N=8):
super(TopLevelModel, self).__init__()
in_channels = out_channels = 3
self.conv1 = nn.Conv2d(in_channels, internal_channels, kernel_size=3, padding=0)
self.blocks = nn.ModuleList([ResidualBlock(internal_channels, 25, nn.LeakyReLU(0.2)) for _ in range(N)])
self.blocks = nn.ModuleList([ResidualBlock(internal_channels, 25, nn.ReLU()) for _ in range(N)])
self.conv2 = nn.Conv2d(internal_channels, out_channels, kernel_size=3, padding=0)

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion descreen/networks/model/unet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, channels=256, N=4, large_k=13, bottom=False):
# self.conv2 = nn.Conv2d(channels, channels, kernel_size=5, stride=1, padding=0)
# self.a2 = nn.LeakyReLU(0.1)
# self.conv3 = nn.Conv2d(channels, channels, kernel_size=5, stride=1, padding=0)
self.blocks = nn.ModuleList([ResidualBlock(channels, ksize=large_k, activation=nn.LeakyReLU(0.1)) for _ in range(N)])
self.blocks = nn.ModuleList([ResidualBlock(channels, ksize=large_k, activation=nn.ReLU()) for _ in range(N)])
# self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=0)

def forward(self, x, y=None):
Expand Down
6 changes: 3 additions & 3 deletions descreen/training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
batch_size: int = 4
batch_size: int = 8

patch_size: int = 384
patch_size: int = 32

num_images: int = batch_size * 3000
num_images: int = batch_size * 5000
11 changes: 8 additions & 3 deletions descreen/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch
from pathlib import Path
from functools import cache
from collections.abc import Iterator
from numpy import ndarray
from torch import Tensor
Expand Down Expand Up @@ -44,8 +45,7 @@ def __len__(self) -> int:
return len(self.files)

def __getitem__(self, idx: int) -> tuple[ndarray, ndarray]:
path = self.files[idx]
img = load_image(path, transpose=False, normalize=False)
img = self.load_image_cached(idx)
assert img.ndim == 3
assert img.shape[2] == 3
height, width = img.shape[:2]
Expand Down Expand Up @@ -112,6 +112,11 @@ def __getitem__(self, idx: int) -> tuple[ndarray, ndarray]:
assert y.shape[1] == y.shape[2]
return x, y

@cache
def load_image_cached(self, idx: int) -> ndarray:
path = self.files[idx]
return load_image(path, transpose=False, normalize=False)

@once
def save_example_pair(self, idx: int, x_png: bytes, y_png: bytes) -> None:
with open(self.debug_dir / f"example-{idx}-x.png", "wb") as fp:
Expand Down Expand Up @@ -146,7 +151,7 @@ def enumerate_loader[T: tuple[Tensor, ...]](data_loader: DataLoader[T], *, devic
for batch in data_loader:
counts = epoch, iters, samples
n = len(batch[0])
yield counts, (batch if device is None else tuple(x.to(device) for x in batch))
yield counts, (batch if device is None else tuple(x.to(device, non_blocking=True) for x in batch))
samples += n
iters += 1
epoch += 1
46 changes: 16 additions & 30 deletions descreen/training/proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,16 @@ def train[

optimizer = torch.optim.RAdam(model.parameters(), lr=0.001)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.002)
# optimizer = torch.optim.LBFGS(model.parameters(), lr=0.1)

p = model.reduced_padding(patch_size)
print(p)
input_size = model.input_size(patch_size)
padding = model.reduced_padding(input_size)
assert model.output_size(input_size) == patch_size

print("OutputSize:", model.output_size(patch_size))
print("InputSize:", input_size)
print("OutputSize:", patch_size)

training_data = HalftonePairDataset(train_data_dir, profile, patch_size, p, augment=True, debug=True, debug_dir=output_dir).as_tensor()
valid_data = HalftonePairDataset(valid_data_dir, profile, patch_size, p).as_tensor()
training_data = HalftonePairDataset(train_data_dir, profile, input_size, padding, augment=True, debug=True, debug_dir=output_dir).as_tensor()
valid_data = HalftonePairDataset(valid_data_dir, profile, input_size, padding).as_tensor()

def train_loop(max_samples: int, *, graceful: bool = True):
interrupted = False
Expand All @@ -55,8 +56,8 @@ def interrupt(signum, frame):
default_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, signal.SIG_IGN)

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8, prefetch_factor=4, persistent_workers=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8, prefetch_factor=4, persistent_workers=True)
train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8, prefetch_factor=4, persistent_workers=True, pin_memory=True)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=8, prefetch_factor=4, persistent_workers=True, pin_memory=True)
signal.signal(signal.SIGINT, interrupt)
last_epoch = 0
for (epoch, iters, samples), (x, y) in enumerate_loader(train_dataloader, device=device):
Expand All @@ -79,28 +80,13 @@ def interrupt(signum, frame):
test_step()

def train_step(x, y):
if False:
loss = None

def clos():
global loss
pred = model(x)
loss = loss_fn(pred, y)
print(f"loss: {loss}")
# Backpropagation
optimizer.zero_grad()
loss.backward()
return loss

optimizer.step(clos)
else:
pred = model(x)
loss = descreen_loss(pred, y, tv=0.15)
optimizer.zero_grad()
loss.backward()
print(f"loss: {loss}")
optimizer.step()
ema_model.update_parameters(model)
pred = model(x)
loss = descreen_loss(pred, y, tv=0.01)
optimizer.zero_grad()
loss.backward()
print(f"loss: {loss}")
optimizer.step()
ema_model.update_parameters(model)

def valid_step(dataloader):
# Set the model to evaluation mode - important for batch normalization and dropout layers
Expand Down

0 comments on commit 346086c

Please sign in to comment.