Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed Mar 13, 2024
1 parent 7ea8b43 commit 6161c97
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 25 deletions.
1 change: 1 addition & 0 deletions descreen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = "0.5.0"
11 changes: 2 additions & 9 deletions descreen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,15 @@

import numpy as np

#from .training import model
# from .training import model

from .networks.models import DescreenModel
from .networks.models import pull


def main():
device = "cpu"

#from torch.optim.swa_utils import AveragedModel

#global model

model = DescreenModel.deserialize(sys.argv[1])

#model.load_state_dict()

model.to(device)
model.eval()
print(model)
Expand Down
11 changes: 5 additions & 6 deletions descreen/networks/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from ...utilities.filesys import resolve_path, self_relpath as rel



files = {
"basic": rel("./unet/model.ddbin.xz"),
}
Expand All @@ -23,8 +22,6 @@ def pull(name: str):
return DescreenModel.load(files[name])




class DescreenModelType(ABCMeta):

aliases = {}
Expand All @@ -33,11 +30,11 @@ def __new__(meta, name, bases, attributes, **kwargs):
cls = super().__new__(meta, name, bases, attributes, **kwargs)
print("new")
try:
#alias = attributes["alias"]()
# alias = attributes["alias"]()
alias = cls.alias()
except Exception:
raise
#return cls
# return cls
if alias in DescreenModelType.aliases:
raise RuntimeError()
DescreenModelType.aliases[alias] = cls
Expand Down Expand Up @@ -131,7 +128,8 @@ def deserialize(filelike: str | Path | bytes | IOBase):
(i,) = struct.unpack(a := "!H", fp.read(struct.calcsize(a)))
alias = fp.read(i).decode()
cls = DescreenModelType.by_alias(alias)
return cls.load(fp)
model: DescreenModel = cls.load(fp)
return model

def serialize_weight(self) -> bytes:
return safetensors.torch.save(self.state_dict(), metadata=None)
Expand All @@ -153,6 +151,7 @@ def serialize(self, filelike: str | Path | IOBase):
fp.write(js)
fp.write(self.serialize_weight())


from .unet import UNetLikeModel

UNetLikeModel()
2 changes: 0 additions & 2 deletions descreen/networks/models/unet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def forward(self, x, y=None):
out = block(out)
return out



def input_size(self, output_size):
for block in self.blocks:
output_size = block.input_size(output_size)
Expand Down
2 changes: 1 addition & 1 deletion descreen/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def clos():
optimizer.step(clos)
else:
pred = model(x)
loss = loss_fn(pred, y) + (0.5 * total_variation(pred)).mean()
loss = loss_fn(pred, y) + 10 * total_variation(pred)
optimizer.zero_grad()
loss.backward()
print(f"loss: {loss}")
Expand Down
19 changes: 12 additions & 7 deletions descreen/training/loss.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@

from torch import Tensor

def total_variation(x: Tensor) -> Tensor:
b, c, h, w = x.shape
pixel_dif1 = x[..., 1:, :] - x[..., :-1, :]
pixel_dif2 = x[..., :, 1:] - x[..., :, :-1]
reduce_axes = (-3, -2, -1)
return (pixel_dif1.abs().sum(dim=reduce_axes) + pixel_dif2.abs().sum(dim=reduce_axes)) / (c * h * w * b)

def total_variation(x: Tensor, mean=True) -> Tensor:
*b, c, h, w = x.shape
assert h >= 2 and w >= 2
diff1 = (x[..., 1:, :] - x[..., :-1, :]).abs()
diff2 = (x[..., :, 1:] - x[..., :, :-1]).abs()
reduce = (-3, -2, -1)
loss = (diff1.sum(dim=reduce) + diff2.sum(dim=reduce)) / (c * (h - 1) * (w - 1))
if mean:
return loss.mean()
else:
return loss.sum()
9 changes: 9 additions & 0 deletions descreen/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import operator as op
from functools import reduce


def prod(xs, start=1):
return reduce(op.mul, xs, start)


from collections.abc import Iterable

from typing import TypeVar


T = TypeVar("T")


Expand Down

0 comments on commit 6161c97

Please sign in to comment.