Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed Mar 12, 2024
1 parent 22330f2 commit 7ea8b43
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 138 deletions.
35 changes: 17 additions & 18 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from chainer import serializers

import numpy as np
#import cupy as cp

# import cupy as cp


i = Image.open("test2.bmp")
Expand All @@ -14,7 +14,7 @@
# 学習結果を読み込む
serializers.load_hdf5("nn2.hdf5", model)

#model.to_gpu()
# model.to_gpu()


# 出力画像
Expand All @@ -23,27 +23,26 @@
# 入力画像を分割
cur_x = 0
while cur_x <= i.size[0] - 64:
cur_y = 0
while cur_y <= i.size[1] - 64:
# 画像から切りだし
rect = (cur_x, cur_y, cur_x+64, cur_y+64)
cropimg = i.crop(rect)
cur_y = 0
while cur_y <= i.size[1] - 64:
# 画像から切りだし
rect = (cur_x, cur_y, cur_x + 64, cur_y + 64)
cropimg = i.crop(rect)

x = (np.array(cropimg, dtype=np.uint8).transpose((2, 0, 1)) / 255).astype(np.float32).reshape((1, 3, 64, 64))
x = (np.array(cropimg, dtype=np.uint8).transpose((2, 0, 1)) / 255).astype(np.float32).reshape((1, 3, 64, 64))

print(x.shape)
print(x.shape)

t = model(x)
t = model(x)

print(type(t.data))
arr = np.rint(t.data.astype(np.float64).reshape((3, 64, 64)).transpose((1, 2, 0)) * 255).astype(np.uint8).flatten()

print(type(t.data))
arr = np.rint(t.data.astype(np.float64).reshape((3, 64, 64)).transpose((1, 2, 0)) * 255).astype(np.uint8).flatten()
himg = Image.frombuffer("RGB", (64, 64), arr, "raw", "RGB", 0, 1)
# himg = Image.fromarray(bytes, 'raw')
dst.paste(himg, (cur_x, cur_y))

himg = Image.frombuffer("RGB", (64, 64), arr, "raw", "RGB", 0, 1)
#himg = Image.fromarray(bytes, 'raw')
dst.paste(himg, (cur_x, cur_y))

cur_y += 64
cur_x += 64
cur_y += 64
cur_x += 64

dst.save("dehalfa4.png")
18 changes: 10 additions & 8 deletions descreen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,25 @@

import numpy as np

from .training import model
#from .training import model

from .networks.models import DescreenModel
from .networks.models import pull

def main():
device = "cpu"

from torch.optim.swa_utils import AveragedModel
#from torch.optim.swa_utils import AveragedModel

global model
model1 = AveragedModel(model)
model = model1
model.load_state_dict(torch.load(sys.argv[1]))
#global model

model = DescreenModel.deserialize(sys.argv[1])

#model.load_state_dict()

model.to(device)
model.eval()
print(model)
model = model.module

# img = read_uint16_image(sys.argv[3])

Expand All @@ -39,7 +41,7 @@ def main():
y = z.detach().cpu().numpy()[0]
print(y.shape)
dest[:, k, l] = y
result = dest[:, *crop]
result = dest[:, crop[0], crop[1]]

buf = BytesIO()
save_image(result, buf, prefer16=True)
Expand Down
4 changes: 2 additions & 2 deletions descreen/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ def magick_png(input_img: bytes, args: list[str], *, png48: bool = False) -> byt
e.returncode
match e.stderr:
case str() as stderr:
pass
print(stderr)
case bytes() as bstderr:
bstderr.decode()
print(bstderr.decode())
raise


Expand Down
58 changes: 2 additions & 56 deletions descreen/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
import json
import numpy as np
import torch
from pathlib import Path
from abc import ABC, abstractmethod
from numpy import ndarray
from torch import Tensor
from torch.nn import Module
from ..utilities import range_chunks


class AbsModule(Module, ABC):


@abstractmethod
def forward(self, x: Tensor) -> Tensor:
raise NotImplementedError()

@property
def multiple_of(self) -> int:
return 1

@abstractmethod
def output_size(self, input_size: int) -> int:
raise NotImplementedError()
Expand All @@ -28,55 +17,12 @@ def output_size(self, input_size: int) -> int:
def input_size(self, output_size: int) -> int:
raise NotImplementedError()

def reduced_padding(self, input_size: int):
def reduced_padding(self, input_size: int) -> int:
output_size = self.output_size(input_size)
assert (input_size - output_size) % 2 == 0
return (input_size - output_size) // 2

def required_padding(self, output_size: int):
def required_padding(self, output_size: int) -> int:
input_size = self.input_size(output_size)
assert (input_size - output_size) % 2 == 0
return (input_size - output_size) // 2

def patch(self, x: ndarray, input_patch_size: int, *, pad_mode: str = "symmetric", **kwargs):
if input_patch_size % self.multiple_of != 0:
raise RuntimeError()
out_patch_size = self.output_size(input_patch_size)
p = self.required_padding(out_patch_size)
height, width = x.shape[-2:]

#self.patch_slices(height, width, out_patch_size, p)

qh = input_patch_size - self.patch_slices_remainder(height + 2 * p, input_patch_size, p)
qw = input_patch_size - self.patch_slices_remainder(width + 2 * p, input_patch_size, p)
y = np.pad(x, (*([(0, 0)] * (x.ndim - 2)), (p, qh + p), (p, qw + p)), mode=pad_mode, **kwargs)
h_crop = slice(p, height + p)
w_crop = slice(p, width + p)
return y, self.patch_slices(height + qh, width + qw, out_patch_size, p), (h_crop, w_crop)

def patch_slices(self, height: int, width: int, output_patch_size: int, padding: int):
for h_start, h_stop in range_chunks(height, output_patch_size):
for w_start, w_stop in range_chunks(width, output_patch_size):
h_pad = padding #self.required_padding(h_stop - h_start)
w_pad = padding #self.required_padding(w_stop - w_start)
print(h_stop - h_start, w_stop - w_start, output_patch_size)
assert h_stop - h_start == w_stop - w_start == output_patch_size
h_slice = slice(h_start - h_pad + padding, h_stop + h_pad + padding)
w_slice = slice(w_start - w_pad + padding, w_stop + w_pad + padding)
h_dest_slice = slice(h_start + padding, h_stop + padding)
w_dest_slice = slice(w_start + padding, w_stop + padding)
yield (h_slice, w_slice), (h_dest_slice, w_dest_slice)

def patch_slices_remainder(self, length: int, input_patch_size: int, padding: int):
cur = 0
while length - cur >= input_patch_size:
cur += input_patch_size - padding * 2
return length - cur


#input_size = self.input_size(output_patch_size)
for start, stop in range_chunks(length, output_patch_size):
pad = self.required_padding(stop - start)
size = stop - start + 2 * pad
print(size % input_size)
return size % input_size
99 changes: 74 additions & 25 deletions descreen/networks/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import json
import struct
import safetensors.torch
import numpy as np
from io import IOBase, BytesIO
from pathlib import Path
from abc import ABC, abstractmethod
from .basic import TopLevelModel
from abc import ABCMeta, abstractmethod
from numpy import ndarray
from .. import AbsModule
from ...utilities import range_chunks
from ...utilities.filesys import resolve_path, self_relpath as rel



from .. import AbsModule


files = {
"basic": rel("./unet/model.ddbin.xz"),
}
Expand All @@ -20,49 +20,96 @@


def pull(name: str):
with open(files[name]):
m = cls.load(file)
return m
return DescreenModel.load(files[name])

def compressed

class DescreenModelType(type):


class DescreenModelType(ABCMeta):

aliases = {}

def __new__(meta, name, bases, attributes, **kwargs):
cls = super().__new__(meta, name, bases, attributes, **kwargs)
print("new")
try:
alias = attributes.alias()
#alias = attributes["alias"]()
alias = cls.alias()
except Exception:
return cls
raise
#return cls
if alias in DescreenModelType.aliases:
raise RuntimeError()
DescreenModelType.aliases[alias] = cls
return cls

@staticmethod
def M(alias: str):
return DescreenModelType.aliases["alias"]
def by_alias(alias: str):
print(DescreenModelType.aliases)
return DescreenModelType.aliases[alias]


class DescreenModel(AbsModule, ABC, metaclass=DescreenModelType):
class DescreenModel(AbsModule, metaclass=DescreenModelType):

_params_json: dict = {}

def __new__(cls, **kwargs):
obj = super().__new__(cls)
params_json = json.dumps(kwargs, skipkeys=False, ensure_ascii=True, allow_nan=False)
AbsModule._params_json[id(obj)] = params_json
DescreenModel._params_json[id(obj)] = params_json
return obj

def load_weight(self, bytes):
self.load_state_dict(bytes)
@classmethod
@abstractmethod
def alias(cls) -> str:
print(cls.__name__)
return cls.__name__

@property
def multiple_of(self) -> int:
return 1

def patch(self, x: ndarray, input_patch_size: int, *, pad_mode: str = "symmetric", **kwargs):
if input_patch_size % self.multiple_of != 0:
raise RuntimeError()
out_patch_size = self.output_size(input_patch_size)
p = self.required_padding(out_patch_size)
height, width = x.shape[-2:]

# self.patch_slices(height, width, out_patch_size, p)

qh = input_patch_size - self.patch_slices_remainder(height + 2 * p, input_patch_size, p)
qw = input_patch_size - self.patch_slices_remainder(width + 2 * p, input_patch_size, p)
y = np.pad(x, (*([(0, 0)] * (x.ndim - 2)), (p, qh + p), (p, qw + p)), mode=pad_mode, **kwargs)
h_crop = slice(p, height + p)
w_crop = slice(p, width + p)
return y, self.patch_slices(height + qh, width + qw, out_patch_size, p), (h_crop, w_crop)

def patch_slices(self, height: int, width: int, output_patch_size: int, padding: int):
for h_start, h_stop in range_chunks(height, output_patch_size):
for w_start, w_stop in range_chunks(width, output_patch_size):
h_pad = padding # self.required_padding(h_stop - h_start)
w_pad = padding # self.required_padding(w_stop - w_start)
print(h_stop - h_start, w_stop - w_start, output_patch_size)
assert h_stop - h_start == w_stop - w_start == output_patch_size
h_slice = slice(h_start - h_pad + padding, h_stop + h_pad + padding)
w_slice = slice(w_start - w_pad + padding, w_stop + w_pad + padding)
h_dest_slice = slice(h_start + padding, h_stop + padding)
w_dest_slice = slice(w_start + padding, w_stop + padding)
yield (h_slice, w_slice), (h_dest_slice, w_dest_slice)

def patch_slices_remainder(self, length: int, input_patch_size: int, padding: int):
cur = 0
while length - cur >= input_patch_size:
cur += input_patch_size - padding * 2
return length - cur

def load_weight(self, buffer: bytes):
self.load_state_dict(safetensors.torch.load(buffer))

@classmethod
def load(cls, byteslike: ReadableBuffer):
l, = struct.unpack_from("!I", buffer)
def load(cls, byteslike: bytes | IOBase):
(l,) = struct.unpack(f := "!I", byteslike.read(struct.calcsize(f)))
js = byteslike.read(l).decode()
kwargs = json.loads(js)
model = cls(**kwargs)
Expand All @@ -81,16 +128,14 @@ def deserialize(filelike: str | Path | bytes | IOBase):
case _:
raise TypeError()
with fp:
i, = struct.unpack_from("!H", fp)
(i,) = struct.unpack(a := "!H", fp.read(struct.calcsize(a)))
alias = fp.read(i).decode()
cls = DescreenModelType.by_alias(alias)
return cls.load(fp)


def serialize_weight(self) -> bytes:
return safetensors.torch.save(self.state_dict(), metadata=None)


def serialize(self, filelike: str | Path | IOBase):
match filelike:
case str() | Path() as path:
Expand All @@ -100,10 +145,14 @@ def serialize(self, filelike: str | Path | IOBase):
case _:
raise TypeError()
with fp:
ab = self.alias.encode()
ab = self.alias().encode()
fp.write(struct.pack("!H", len(ab)))
fp.write(ab)
js = json.dumps(self.kwargs).encode()
js = DescreenModel._params_json[id(self)].encode()
fp.write(struct.pack("!I", len(js)))
fp.write(js)
fp.write(self.serialize_weight())

from .unet import UNetLikeModel

UNetLikeModel()
Loading

0 comments on commit 7ea8b43

Please sign in to comment.