diff --git a/rembg/bg.py b/rembg/bg.py index 071af09d..4ab03714 100644 --- a/rembg/bg.py +++ b/rembg/bg.py @@ -21,12 +21,12 @@ class ReturnType(Enum): def alpha_matting_cutout( - img: Image, - mask: Image, + img: PILImage, + mask: PILImage, foreground_threshold: int, background_threshold: int, erode_structure_size: int, -) -> Image: +) -> PILImage: img = np.asarray(img) mask = np.asarray(mask) @@ -59,20 +59,20 @@ def alpha_matting_cutout( return cutout -def naive_cutout(img: Image, mask: Image) -> Image: +def naive_cutout(img: PILImage, mask: PILImage) -> PILImage: empty = Image.new("RGBA", (img.size), 0) cutout = Image.composite(img, empty, mask) return cutout -def get_concat_v_multi(imgs: List[Image]) -> Image: +def get_concat_v_multi(imgs: List[PILImage]) -> PILImage: pivot = imgs.pop(0) for im in imgs: pivot = get_concat_v(pivot, im) return pivot -def get_concat_v(img1: Image, img2: Image) -> Image: +def get_concat_v(img1: PILImage, img2: PILImage) -> PILImage: dst = Image.new("RGBA", (img1.width, img1.height + img2.height)) dst.paste(img1, (0, 0)) dst.paste(img2, (0, img1.height)) diff --git a/rembg/session_base.py b/rembg/session_base.py index dceee2dc..aa98693b 100644 --- a/rembg/session_base.py +++ b/rembg/session_base.py @@ -3,6 +3,7 @@ import numpy as np import onnxruntime as ort from PIL import Image +from PIL.Image import Image as PILImage class BaseSession: @@ -12,7 +13,7 @@ def __init__(self, model_name: str, inner_session: ort.InferenceSession): def normalize( self, - img: Image, + img: PILImage, mean: Tuple[float, float, float], std: Tuple[float, float, float], size: Tuple[int, int], @@ -35,5 +36,5 @@ def normalize( .astype(np.float32) } - def predict(self, im: Image) -> List[Image]: + def predict(self, img: PILImage) -> List[PILImage]: raise NotImplementedError diff --git a/rembg/session_cloth.py b/rembg/session_cloth.py index ada82da8..11bcef74 100644 --- a/rembg/session_cloth.py +++ b/rembg/session_cloth.py @@ -2,6 +2,7 @@ import numpy as np from PIL import Image +from PIL.Image import Image as PILImage from scipy.special import log_softmax from .session_base import BaseSession @@ -53,7 +54,7 @@ class ClothSession(BaseSession): - def predict(self, img: Image) -> List[Image]: + def predict(self, img: PILImage) -> List[PILImage]: ort_outs = self.inner_session.run( None, self.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), (768, 768)) ) diff --git a/rembg/session_simple.py b/rembg/session_simple.py index 978c06ec..7ec31813 100644 --- a/rembg/session_simple.py +++ b/rembg/session_simple.py @@ -2,12 +2,13 @@ import numpy as np from PIL import Image +from PIL.Image import Image as PILImage from .session_base import BaseSession class SimpleSession(BaseSession): - def predict(self, img: Image) -> List[Image]: + def predict(self, img: PILImage) -> List[PILImage]: ort_outs = self.inner_session.run( None, self.normalize(