Skip to content

Commit

Permalink
Update pnn
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed Feb 20, 2024
1 parent ac1177d commit 4b55f50
Show file tree
Hide file tree
Showing 16 changed files with 584 additions and 190 deletions.
Empty file added descreen/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions descreen/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys
from .cli import main

if __name__ == "__main__":
sys.exit(main())
53 changes: 53 additions & 0 deletions descreen/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import typer

def main():

except KeyboardInterrupt:
eprint("KeyboardInterrupt")
exit_code = 130
return exit_code


def con():
device = "cpu"


#from .models import UNetLikeModel
model = UNetLikeModel()
model.load_state_dict(torch.load(sys.argv[2]))


model.to(device)
model.eval()
print(model)

patch_size = model.output_size(512)
#img = read_uint16_image(sys.argv[3])

with open(sys.argv[3], "rb") as fp:
i = load_image(magickpng(fp.read(), png48=True), assert16=True)


height, width = img.shape[1:3]
# TODO: 4倍数にあわせる
ppp_h = h % 512
ppp_w = w % 512
a_h = h + ppp_h
a_w = w + ppp_w
img = img.reshape((1, 3, h, w))
res = np.zeros((3, a_h, a_w), dtype="float32")
p = model.required_padding(patch_size)

img = np.pad(img, ((0, 0), (0, 0), (p, p + ppp_h), (p, p + ppp_w)), mode="symmetric")
for (j, i), (k, l) in model.patch_slices(a_h, a_w, patch_size):
print(k)
x = img[:, :, j, i]
t = torch.from_numpy(x.astype("float32"))
t = t.to(device)
y = model(t)
yy = y.detach().cpu().numpy()
print(y.shape)
res[:, k, l] = yy[0]
res = res[:, :h, :w]
save_image(res, sys.argv[4])
#save_wide_gamut_uint16_array_as_srgb(res, sys.argv[4])
53 changes: 53 additions & 0 deletions descreen/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import cv2
import numpy as np
from pathlib import Path
from numpy import ndarray
from .utilities.filesys import resolve_path


# ファイルパス、パスオブジェクト、またはバイトを受け取り画像を配列としてロードする
def load_image(filelike: str | Path | bytes, *, transpose: bool = True, normalize: bool = True, orient: bool = True, assert16: bool = False) -> ndarray:
match filelike:
case str() | Path() as path:
with open(resolve_path(path), "rb") as fp:
buffer = fp.read()
case bytes() as buffer:
pass
case _:
raise ValueError()
# OpenCV が ASCII パスしか扱えない問題を回避するためにバッファを経由する
bin = np.frombuffer(buffer, np.uint8)
flags = cv2.IMREAD_COLOR | cv2.IMREAD_ANYDEPTH
if not orient:
flags |= cv2.IMREAD_IGNORE_ORIENTATION
img = cv2.imdecode(bin, flags)
if img.ndim != 3 or img.shape[2] != 3:
raise RuntimeError()
if transpose:
# RGBxHxW にする
img = img[:, :, [2, 1, 0]].transpose(2, 0, 1)
match img.dtype:
case np.uint8:
if assert16:
raise RuntimeError()
if normalize:
return (img / (2**8 - 1)).astype(np.float32)
else:
return img
case np.uint16:
if normalize:
return (img / (2**16 - 1)).astype(np.float32)
else:
return img
case _:
raise RuntimeError()



def to_pil_image(array):
srgb = rint(array * 255).clip(0, 255).astype(uint8)
return Image.fromarray(srgb.transpose(1, 2, 0), "RGB")

def save_image(array, filepath):
to_pil_image(array).save(filepath)

Empty file added descreen/models/__init__.py
Empty file.
39 changes: 39 additions & 0 deletions descreen/models/abs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from abc import ABC, abstractmethod

from ..utilities import range_chunks

from torch import nn

class AbsModel(ABC, nn.Module):

@abstractmethod
def forward(self, x):
raise NotImplementedError()

@abstractmethod
def output_size(self, input_size):
raise NotImplementedError()

@abstractmethod
def input_size(self, output_size):
raise NotImplementedError()

def reduced_padding(self, input_size):
output_size = self.output_size(input_size)
return (input_size - output_size) // 2

def required_padding(self, output_size):
input_size = self.input_size(output_size)
return (input_size - output_size) // 2


def patch_slices(self, height, width, patch_size):
for h_start, h_stop in range_chunks(height, patch_size):
for w_start, w_stop in range_chunks(width, patch_size):
h_pad = self.required_padding(h_stop - h_start)
w_pad = self.required_padding(w_stop - w_start)
h_s = slice(h_start, h_stop)
w_s = slice(w_start, w_stop)
h_slice = slice(h_start - h_pad + h_pad, h_stop + h_pad * 2)
w_slice = slice(w_start - w_pad + w_pad, w_stop + w_pad * 2)
yield (h_slice, w_slice), (h_s, w_s)
154 changes: 154 additions & 0 deletions descreen/models/unet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from ..abs import AbsModel

from torch import nn


def input_size(output_size, kernel_size, stride=1, padding=0):
return ((output_size - 1) * stride) + kernel_size - 2 * padding


def output_size(input_size, kernel_size, stride=1, padding=0):
return (input_size - kernel_size + 2 * padding) // stride + 1


def fit_to_smaller(x, y):
b, c, h1, w1 = x.shape
_, _, h2, w2 = y.shape

h = min(h1, h2)
w = min(w1, w2)

h1_start = (h1 - h) // 2
h1_end = h1_start + h

w1_start = (w1 - w) // 2
w1_end = w1_start + w

h2_start = (h2 - h) // 2
h2_end = h2_start + h

w2_start = (w2 - w) // 2
w2_end = w2_start + w

x = x[:, :, h1_start:h1_end, w1_start:w1_end]
y = y[:, :, h2_start:h2_end, w2_start:w2_end]

return x, y



class UNetLikeModelLevel(AbsModel):
def __init__(self, channels=256, bottom=False):
super().__init__()
self.bottom = bottom
if not bottom:
self.up = Lanczos2xUpsampler(n=2, pad=False)
self.conv1 = nn.Conv2d(3 if bottom else 3 + channels, channels, kernel_size=3, stride=1, padding=0)
self.a1 = nn.LeakyReLU(0.1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=5, stride=1, padding=0)
self.a2 = nn.LeakyReLU(0.1)
self.bn2 = nn.BatchNorm2d(channels)
self.conv3 = nn.Conv2d(channels, channels, kernel_size=5, stride=1, padding=0)

def forward(self, x, y=None):
if not self.bottom:
h1 = torch.cat(fit_to_smaller(x, self.up(y)), dim=1)
else:
h1 = x
assert y is None
h2 = self.a1(self.conv1(h1))
h3 = self.a2(self.conv2(self.bn1(h2)))
h4 = self.conv3(self.bn2(h3))
h2_, h4_ = fit_to_smaller(h2, h4)
return h2_ + h4_

def input_size(self, output_size):
if self.bottom:
return input_size(input_size(input_size(output_size, 5), 5), 3)
else:
return input_size(input_size(input_size(output_size, 5), 5), 3) // 2 + 4

def output_size(self, input_size):
if self.bottom:
return output_size(output_size(output_size(input_size, 3), 5), 5)
else:
return output_size(output_size(output_size((input_size - 4) * 2, 3), 5), 5)






class UNetLikeModel(AbsModel):
def __init__(self, channels=128, residual=False):
super().__init__()
self.residual = residual
self.block3 = UNetLikeModelLevel(channels)
self.block2 = UNetLikeModelLevel(channels)
self.block1 = UNetLikeModelLevel(channels, bottom=True)
self.av1 = nn.AvgPool2d(kernel_size=2, stride=2)
self.av2 = nn.AvgPool2d(kernel_size=2, stride=2)
#self.av3 = nn.AvgPool2d(kernel_size=2, stride=2)
#self.up1 = Lanczos2xUpsampler(n=2, pad=False)
#self.up2 = Lanczos2xUpsampler(n=2, pad=False)
#self.up3 = Lanczos2xUpsampler(n=2, pad=False)
self.layer4 = nn.Conv2d(channels, 3, kernel_size=3, stride=1, padding=0)


def forward(self, x):
#m1 = self.av1(x)
m1 = x
m2 = self.av1(m1)
m3 = self.av2(m2)
h1 = self.block1(m3)
h2 = self.block2(m2, h1)
h3 = self.block3(m1, h2)
_h4, r = fit_to_smaller(self.layer4(h3), x)
return _h4 + r

def input_size(self, s):
return self.block1.input_size(self.block2.input_size(self.block3.input_size(input_size(s, 3)))) * 4

def output_size(self, s):
s = s // 4
return output_size(self.block3.output_size(self.block2.output_size(self.block1.output_size(s))), 3)


import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def lanczos(x, n):
return 0.0 if abs(x) > n else np.sinc(x) * np.sinc(x / n)

class Lanczos2xUpsampler(nn.Module):
def __init__(self, n=3, pad=True):
super().__init__()
start = np.array([lanczos(i + 0.25, n) for i in range(-n, n)])
end = np.array([lanczos(i + 0.75, n) for i in range(-n, n)])
s = start / np.sum(start)
e = end / np.sum(end)
k1 = np.pad(s.reshape(1, n * 2) * s.reshape(n * 2, 1), ((0, 1), (0, 1)))
k2 = np.pad(e.reshape(1, n * 2) * s.reshape(n * 2, 1), ((0, 1), (1, 0)))
k3 = np.pad(s.reshape(1, n * 2) * e.reshape(n * 2, 1), ((1, 0), (0, 1)))
k4 = np.pad(e.reshape(1, n * 2) * e.reshape(n * 2, 1), ((1, 0), (1, 0)))
w = torch.tensor(np.array([[k1], [k2], [k3], [k4]], dtype=np.float32))
self.register_buffer('w', w)
self.n = n
self.pad = pad

def forward(self, x):
b, c, h, w = x.shape
h1 = x.view(b * c, 1, h, w)
if self.pad:
h2 = F.pad(h1, (self.n, self.n, self.n, self.n), mode="reflect")
else:
h2 = h1
h3 = F.conv2d(h2, self.w)
h4 = F.pixel_shuffle(h3, 2)
if self.pad:
return h4.view(b, c, h * 2, w * 2)
else:
return h4.view(b, c, (h - 2 * self.n) * 2, (w - 2 * self.n) * 2)
Empty file added descreen/networks/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions descreen/networks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def input_size(output_size, kernel_size, stride=1, padding=0):
return ((output_size - 1) * stride) + kernel_size - 2 * padding


def output_size(input_size, kernel_size, stride=1, padding=0):
return (input_size - kernel_size + 2 * padding) // stride + 1
Loading

0 comments on commit 4b55f50

Please sign in to comment.