Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgatis committed Apr 23, 2022
1 parent 7a4935a commit 77ce4d7
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 10 deletions.
12 changes: 6 additions & 6 deletions rembg/bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class ReturnType(Enum):


def alpha_matting_cutout(
img: Image,
mask: Image,
img: PILImage,
mask: PILImage,
foreground_threshold: int,
background_threshold: int,
erode_structure_size: int,
) -> Image:
) -> PILImage:
img = np.asarray(img)
mask = np.asarray(mask)

Expand Down Expand Up @@ -59,20 +59,20 @@ def alpha_matting_cutout(
return cutout


def naive_cutout(img: Image, mask: Image) -> Image:
def naive_cutout(img: PILImage, mask: PILImage) -> PILImage:
empty = Image.new("RGBA", (img.size), 0)
cutout = Image.composite(img, empty, mask)
return cutout


def get_concat_v_multi(imgs: List[Image]) -> Image:
def get_concat_v_multi(imgs: List[PILImage]) -> PILImage:
pivot = imgs.pop(0)
for im in imgs:
pivot = get_concat_v(pivot, im)
return pivot


def get_concat_v(img1: Image, img2: Image) -> Image:
def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage:
dst = Image.new("RGBA", (img1.width, img1.height + img2.height))
dst.paste(img1, (0, 0))
dst.paste(img2, (0, img1.height))
Expand Down
5 changes: 3 additions & 2 deletions rembg/session_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import onnxruntime as ort
from PIL import Image
from PIL.Image import Image as PILImage


class BaseSession:
Expand All @@ -12,7 +13,7 @@ def __init__(self, model_name: str, inner_session: ort.InferenceSession):

def normalize(
self,
img: Image,
img: PILImage,
mean: Tuple[float, float, float],
std: Tuple[float, float, float],
size: Tuple[int, int],
Expand All @@ -35,5 +36,5 @@ def normalize(
.astype(np.float32)
}

def predict(self, im: Image) -> List[Image]:
def predict(self, img: PILImage) -> List[PILImage]:
raise NotImplementedError
3 changes: 2 additions & 1 deletion rembg/session_cloth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from PIL import Image
from PIL.Image import Image as PILImage
from scipy.special import log_softmax

from .session_base import BaseSession
Expand Down Expand Up @@ -53,7 +54,7 @@


class ClothSession(BaseSession):
def predict(self, img: Image) -> List[Image]:
def predict(self, img: PILImage) -> List[PILImage]:
ort_outs = self.inner_session.run(
None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768))
)
Expand Down
3 changes: 2 additions & 1 deletion rembg/session_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import numpy as np
from PIL import Image
from PIL.Image import Image as PILImage

from .session_base import BaseSession


class SimpleSession(BaseSession):
def predict(self, img: Image) -> List[Image]:
def predict(self, img: PILImage) -> List[PILImage]:
ort_outs = self.inner_session.run(
None,
self.normalize(
Expand Down

0 comments on commit 77ce4d7

Please sign in to comment.