diff --git a/README.md b/README.md index b5f695cc..34591005 100644 --- a/README.md +++ b/README.md @@ -159,10 +159,14 @@ rembg i -a path/to/input.png path/to/output.png Passing extras parameters ``` -rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png +SAM example + +rembg i -m sam -x '{ "sam_prompt": [{"type": "point", "data": [724, 740], "label": 1}] }' examples/plants-1.jpg examples/plants-1.out.png ``` ``` +Custom model example + rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png ``` diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index 5e524b0f..8f2d4ca4 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -1,9 +1,12 @@ import os +from copy import deepcopy from typing import List +import cv2 import numpy as np import onnxruntime as ort import pooch +from jsonschema import validate from PIL import Image from PIL.Image import Image as PILImage @@ -15,37 +18,58 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int): newh, neww = oldh * scale, oldw * scale neww = int(neww + 0.5) newh = int(newh + 0.5) + return (newh, neww) -def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray: +def apply_coords(coords: np.ndarray, original_size, target_length): old_h, old_w = original_size new_h, new_w = get_preprocess_shape( original_size[0], original_size[1], target_length ) - coords = coords.copy().astype(float) + + coords = deepcopy(coords).astype(float) coords[..., 0] = coords[..., 0] * (new_w / old_w) coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords -def resize_longes_side(img: PILImage, size=1024): - w, h = img.size - if h > w: - new_h, new_w = size, int(w * size / h) - else: - new_h, new_w = int(h * size / w), size +def get_input_points(prompt): + points = [] + labels = [] + + for mark in prompt: + if mark["type"] == "point": + points.append(mark["data"]) + labels.append(mark["label"]) + elif mark["type"] == "rectangle": + points.append([mark["data"][0], mark["data"][1]]) + points.append([mark["data"][2], mark["data"][3]]) + labels.append(2) + labels.append(3) - return img.resize((new_w, new_h)) + points, labels = np.array(points), np.array(labels) + return points, labels -def pad_to_square(img: np.ndarray, size=1024): - h, w = img.shape[:2] - padh = size - h - padw = size - w - img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant") - img = img.astype(np.float32) - return img +def transform_masks(masks, original_size, transform_matrix): + output_masks = [] + + for batch in range(masks.shape[0]): + batch_masks = [] + for mask_id in range(masks.shape[1]): + mask = masks[batch, mask_id] + mask = cv2.warpAffine( + mask, + transform_matrix[:2], + (original_size[1], original_size[0]), + flags=cv2.INTER_LINEAR, + ) + batch_masks.append(mask) + output_masks.append(batch_masks) + + return np.array(output_masks) class SamSession(BaseSession): @@ -70,7 +94,7 @@ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwar **kwargs: Arbitrary keyword arguments. """ self.model_name = model_name - paths = self.__class__.download_models() + paths = self.__class__.download_models(*args, **kwargs) self.encoder = ort.InferenceSession( str(paths[0]), providers=ort.get_available_providers(), @@ -85,9 +109,9 @@ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwar def normalize( self, img: np.ndarray, - mean=(123.675, 116.28, 103.53), - std=(58.395, 57.12, 57.375), - size=(1024, 1024), + mean=(), + std=(), + size=(), *args, **kwargs, ): @@ -96,19 +120,16 @@ def normalize( Args: img (np.ndarray): The input image. - mean (tuple, optional): The mean values for normalization. Defaults to (123.675, 116.28, 103.53). - std (tuple, optional): The standard deviation values for normalization. Defaults to (58.395, 57.12, 57.375). - size (tuple, optional): The target size of the image. Defaults to (1024, 1024). + mean (tuple, optional): The mean values for normalization. Defaults to (). + std (tuple, optional): The standard deviation values for normalization. Defaults to (). + size (tuple, optional): The target size of the image. Defaults to (). *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. Returns: np.ndarray: The normalized image. """ - pixel_mean = np.array([*mean]).reshape(1, 1, -1) - pixel_std = np.array([*std]).reshape(1, 1, -1) - x = (img - pixel_mean) / pixel_std - return x + return img def predict( self, @@ -129,36 +150,89 @@ def predict( Returns: List[PILImage]: A list of masks generated by the decoder. """ - # Preprocess image - image = resize_longes_side(img) - image = np.array(image) - image = self.normalize(image) - image = pad_to_square(image) - - input_labels = kwargs.get("input_labels") - input_points = kwargs.get("input_points") - - if input_labels is None: - raise ValueError("input_labels is required") - if input_points is None: - raise ValueError("input_points is required") - - # Transpose - image = image.transpose(2, 0, 1)[None, :, :, :] - # Run encoder (Image embedding) - encoded = self.encoder.run(None, {"x": image}) - image_embedding = encoded[0] - - # Add a batch index, concatenate a padding point, and transform. + prompt = kwargs.get("sam_prompt", "{}") + schema = { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"type": "string"}, + "label": {"type": "integer"}, + "data": { + "type": "array", + "items": {"type": "number"}, + }, + }, + }, + } + + validate(instance=prompt, schema=schema) + + target_size = 1024 + input_size = (684, 1024) + encoder_input_name = self.encoder.get_inputs()[0].name + + img = img.convert("RGB") + cv_image = np.array(img) + original_size = cv_image.shape[:2] + + scale_x = input_size[1] / cv_image.shape[1] + scale_y = input_size[0] / cv_image.shape[0] + scale = min(scale_x, scale_y) + + transform_matrix = np.array( + [ + [scale, 0, 0], + [0, scale, 0], + [0, 0, 1], + ] + ) + + cv_image = cv2.warpAffine( + cv_image, + transform_matrix[:2], + (input_size[1], input_size[0]), + flags=cv2.INTER_LINEAR, + ) + + ## encoder + + encoder_inputs = { + encoder_input_name: cv_image.astype(np.float32), + } + + encoder_output = self.encoder.run(None, encoder_inputs) + image_embedding = encoder_output[0] + + embedding = { + "image_embedding": image_embedding, + "original_size": original_size, + "transform_matrix": transform_matrix, + } + + ## decoder + + input_points, input_labels = get_input_points(prompt) onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[ None, :, : ] onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[ None, : ].astype(np.float32) - onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32) + onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype( + np.float32 + ) + + onnx_coord = np.concatenate( + [ + onnx_coord, + np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32), + ], + axis=2, + ) + onnx_coord = np.matmul(onnx_coord, transform_matrix.T) + onnx_coord = onnx_coord[:, :, :2].astype(np.float32) - # Create an empty mask input and an indicator for no mask. onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32) @@ -168,17 +242,19 @@ def predict( "point_labels": onnx_label, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, - "orig_im_size": np.array(img.size[::-1], dtype=np.float32), + "orig_im_size": np.array(input_size, dtype=np.float32), } - masks, _, low_res_logits = self.decoder.run(None, decoder_inputs) - masks = masks > 0.0 - masks = [ - Image.fromarray((masks[i, 0] * 255).astype(np.uint8)) - for i in range(masks.shape[0]) - ] + masks, _, _ = self.decoder.run(None, decoder_inputs) + inv_transform_matrix = np.linalg.inv(transform_matrix) + masks = transform_masks(masks, original_size, inv_transform_matrix) + + mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8) + for m in masks[0, :, :, :]: + mask[m > 0.0] = [255, 255, 255] - return masks + mask = Image.fromarray(mask).convert("L") + return [mask] @classmethod def download_models(cls, *args, **kwargs): @@ -195,29 +271,64 @@ def download_models(cls, *args, **kwargs): Returns: tuple: A tuple containing the file paths of the downloaded encoder and decoder models. """ - fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx" - fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx" + model_name = kwargs.get("sam_model", "sam_vit_b_01ec64") + quant = kwargs.get("sam_quant", False) + + fname_encoder = f"{model_name}.encoder.onnx" + fname_decoder = f"{model_name}.decoder.onnx" + + if quant: + fname_encoder = f"{model_name}.encoder.quant.onnx" + fname_decoder = f"{model_name}.decoder.quant.onnx" pooch.retrieve( - "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx", - None - if cls.checksum_disabled(*args, **kwargs) - else "md5:13d97c5c79ab13ef86d67cbde5f1b250", + f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}", + None, fname=fname_encoder, path=cls.u2net_home(*args, **kwargs), progressbar=True, ) pooch.retrieve( - "https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx", - None - if cls.checksum_disabled(*args, **kwargs) - else "md5:fa3d1c36a3187d3de1c8deebf33dd127", + f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}", + None, fname=fname_decoder, path=cls.u2net_home(*args, **kwargs), progressbar=True, ) + if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists( + os.path.join( + cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin" + ) + ): + content = bytearray() + + for i in range(1, 4): + pooch.retrieve( + f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin", + None, + fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin", + path=cls.u2net_home(*args, **kwargs), + progressbar=True, + ) + + fbin = os.path.join( + cls.u2net_home(*args, **kwargs), + f"sam_vit_h_4b8939.encoder_data.{i}.bin", + ) + content.extend(open(fbin, "rb").read()) + os.remove(fbin) + + with open( + os.path.join( + cls.u2net_home(*args, **kwargs), + "sam_vit_h_4b8939.encoder_data.bin", + ), + "wb", + ) as fp: + fp.write(content) + return ( os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder), os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder), diff --git a/setup.py b/setup.py index dccc6b61..7b44aa0f 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ long_description = (here / "README.md").read_text(encoding="utf-8") install_requires = [ + "jsonschema", "numpy", "onnxruntime", "opencv-python-headless", diff --git a/tests/results/anime-girl-1.sam.png b/tests/results/anime-girl-1.sam.png index daf576f0..46ae2635 100644 Binary files a/tests/results/anime-girl-1.sam.png and b/tests/results/anime-girl-1.sam.png differ diff --git a/tests/results/car-1.sam.png b/tests/results/car-1.sam.png index f36b6695..cbfa1f92 100644 Binary files a/tests/results/car-1.sam.png and b/tests/results/car-1.sam.png differ diff --git a/tests/results/cloth-1.sam.png b/tests/results/cloth-1.sam.png index 664a7dcf..42d853fe 100644 Binary files a/tests/results/cloth-1.sam.png and b/tests/results/cloth-1.sam.png differ diff --git a/tests/test_remove.py b/tests/test_remove.py index 7f9901c4..b35caa38 100644 --- a/tests/test_remove.py +++ b/tests/test_remove.py @@ -12,18 +12,19 @@ def test_remove(): kwargs = { "sam": { "anime-girl-1" : { - "input_points": [[400, 165]], - "input_labels": [1], + "sam_prompt" :[{"type": "point", "data": [400, 165], "label": 1}], }, "car-1" : { - "input_points": [[250, 200]], - "input_labels": [1], + "sam_prompt" :[{"type": "point", "data": [250, 200], "label": 1}], }, "cloth-1" : { - "input_points": [[370, 495]], - "input_labels": [1], + "sam_prompt" :[{"type": "point", "data": [370, 495], "label": 1}], + }, + + "plants-1" : { + "sam_prompt" :[{"type": "point", "data": [724, 740], "label": 1}], }, } } @@ -38,7 +39,7 @@ def test_remove(): "isnet-anime", "sam" ]: - for picture in ["anime-girl-1", "car-1", "cloth-1"]: + for picture in ["anime-girl-1", "car-1", "cloth-1", "plants-1"]: image_path = Path(here / "fixtures" / f"{picture}.jpg") image = image_path.read_bytes()