Skip to content

Commit

Permalink
fix sam session
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgatis committed Oct 26, 2023
1 parent 0d0f6c6 commit c01b1e0
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 76 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,14 @@ rembg i -a path/to/input.png path/to/output.png
Passing extras parameters

```
rembg i -m sam -x '{"input_labels": [1], "input_points": [[100,100]]}' path/to/input.png path/to/output.png
SAM example
rembg i -m sam -x '{ "sam_prompt": [{"type": "point", "data": [724, 740], "label": 1}] }' examples/plants-1.jpg examples/plants-1.out.png
```

```
Custom model example
rembg i -m u2net_custom -x '{"model_path": "~/.u2net/u2net.onnx"}' path/to/input.png path/to/output.png
```

Expand Down
247 changes: 179 additions & 68 deletions rembg/sessions/sam.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
from copy import deepcopy
from typing import List

import cv2
import numpy as np
import onnxruntime as ort
import pooch
from jsonschema import validate
from PIL import Image
from PIL.Image import Image as PILImage

Expand All @@ -15,37 +18,58 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)

return (newh, neww)


def apply_coords(coords: np.ndarray, original_size, target_length) -> np.ndarray:
def apply_coords(coords: np.ndarray, original_size, target_length):
old_h, old_w = original_size
new_h, new_w = get_preprocess_shape(
original_size[0], original_size[1], target_length
)
coords = coords.copy().astype(float)

coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)

return coords


def resize_longes_side(img: PILImage, size=1024):
w, h = img.size
if h > w:
new_h, new_w = size, int(w * size / h)
else:
new_h, new_w = int(h * size / w), size
def get_input_points(prompt):
points = []
labels = []

for mark in prompt:
if mark["type"] == "point":
points.append(mark["data"])
labels.append(mark["label"])
elif mark["type"] == "rectangle":
points.append([mark["data"][0], mark["data"][1]])
points.append([mark["data"][2], mark["data"][3]])
labels.append(2)
labels.append(3)

return img.resize((new_w, new_h))
points, labels = np.array(points), np.array(labels)
return points, labels


def pad_to_square(img: np.ndarray, size=1024):
h, w = img.shape[:2]
padh = size - h
padw = size - w
img = np.pad(img, ((0, padh), (0, padw), (0, 0)), mode="constant")
img = img.astype(np.float32)
return img
def transform_masks(masks, original_size, transform_matrix):
output_masks = []

for batch in range(masks.shape[0]):
batch_masks = []
for mask_id in range(masks.shape[1]):
mask = masks[batch, mask_id]
mask = cv2.warpAffine(
mask,
transform_matrix[:2],
(original_size[1], original_size[0]),
flags=cv2.INTER_LINEAR,
)
batch_masks.append(mask)
output_masks.append(batch_masks)

return np.array(output_masks)


class SamSession(BaseSession):
Expand All @@ -70,7 +94,7 @@ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwar
**kwargs: Arbitrary keyword arguments.
"""
self.model_name = model_name
paths = self.__class__.download_models()
paths = self.__class__.download_models(*args, **kwargs)
self.encoder = ort.InferenceSession(
str(paths[0]),
providers=ort.get_available_providers(),
Expand All @@ -85,9 +109,9 @@ def __init__(self, model_name: str, sess_opts: ort.SessionOptions, *args, **kwar
def normalize(
self,
img: np.ndarray,
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
size=(1024, 1024),
mean=(),
std=(),
size=(),
*args,
**kwargs,
):
Expand All @@ -96,19 +120,16 @@ def normalize(
Args:
img (np.ndarray): The input image.
mean (tuple, optional): The mean values for normalization. Defaults to (123.675, 116.28, 103.53).
std (tuple, optional): The standard deviation values for normalization. Defaults to (58.395, 57.12, 57.375).
size (tuple, optional): The target size of the image. Defaults to (1024, 1024).
mean (tuple, optional): The mean values for normalization. Defaults to ().
std (tuple, optional): The standard deviation values for normalization. Defaults to ().
size (tuple, optional): The target size of the image. Defaults to ().
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Returns:
np.ndarray: The normalized image.
"""
pixel_mean = np.array([*mean]).reshape(1, 1, -1)
pixel_std = np.array([*std]).reshape(1, 1, -1)
x = (img - pixel_mean) / pixel_std
return x
return img

def predict(
self,
Expand All @@ -129,36 +150,89 @@ def predict(
Returns:
List[PILImage]: A list of masks generated by the decoder.
"""
# Preprocess image
image = resize_longes_side(img)
image = np.array(image)
image = self.normalize(image)
image = pad_to_square(image)

input_labels = kwargs.get("input_labels")
input_points = kwargs.get("input_points")

if input_labels is None:
raise ValueError("input_labels is required")
if input_points is None:
raise ValueError("input_points is required")

# Transpose
image = image.transpose(2, 0, 1)[None, :, :, :]
# Run encoder (Image embedding)
encoded = self.encoder.run(None, {"x": image})
image_embedding = encoded[0]

# Add a batch index, concatenate a padding point, and transform.
prompt = kwargs.get("sam_prompt", "{}")
schema = {
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"type": "string"},
"label": {"type": "integer"},
"data": {
"type": "array",
"items": {"type": "number"},
},
},
},
}

validate(instance=prompt, schema=schema)

target_size = 1024
input_size = (684, 1024)
encoder_input_name = self.encoder.get_inputs()[0].name

img = img.convert("RGB")
cv_image = np.array(img)
original_size = cv_image.shape[:2]

scale_x = input_size[1] / cv_image.shape[1]
scale_y = input_size[0] / cv_image.shape[0]
scale = min(scale_x, scale_y)

transform_matrix = np.array(
[
[scale, 0, 0],
[0, scale, 0],
[0, 0, 1],
]
)

cv_image = cv2.warpAffine(
cv_image,
transform_matrix[:2],
(input_size[1], input_size[0]),
flags=cv2.INTER_LINEAR,
)

## encoder

encoder_inputs = {
encoder_input_name: cv_image.astype(np.float32),
}

encoder_output = self.encoder.run(None, encoder_inputs)
image_embedding = encoder_output[0]

embedding = {
"image_embedding": image_embedding,
"original_size": original_size,
"transform_matrix": transform_matrix,
}

## decoder

input_points, input_labels = get_input_points(prompt)
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[
None, :, :
]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, :
].astype(np.float32)
onnx_coord = apply_coords(onnx_coord, img.size[::1], 1024).astype(np.float32)
onnx_coord = apply_coords(onnx_coord, input_size, target_size).astype(
np.float32
)

onnx_coord = np.concatenate(
[
onnx_coord,
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
],
axis=2,
)
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)

# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

Expand All @@ -168,17 +242,19 @@ def predict(
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(img.size[::-1], dtype=np.float32),
"orig_im_size": np.array(input_size, dtype=np.float32),
}

masks, _, low_res_logits = self.decoder.run(None, decoder_inputs)
masks = masks > 0.0
masks = [
Image.fromarray((masks[i, 0] * 255).astype(np.uint8))
for i in range(masks.shape[0])
]
masks, _, _ = self.decoder.run(None, decoder_inputs)
inv_transform_matrix = np.linalg.inv(transform_matrix)
masks = transform_masks(masks, original_size, inv_transform_matrix)

mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
for m in masks[0, :, :, :]:
mask[m > 0.0] = [255, 255, 255]

return masks
mask = Image.fromarray(mask).convert("L")
return [mask]

@classmethod
def download_models(cls, *args, **kwargs):
Expand All @@ -195,29 +271,64 @@ def download_models(cls, *args, **kwargs):
Returns:
tuple: A tuple containing the file paths of the downloaded encoder and decoder models.
"""
fname_encoder = f"{cls.name(*args, **kwargs)}_encoder.onnx"
fname_decoder = f"{cls.name(*args, **kwargs)}_decoder.onnx"
model_name = kwargs.get("sam_model", "sam_vit_b_01ec64")
quant = kwargs.get("sam_quant", False)

fname_encoder = f"{model_name}.encoder.onnx"
fname_decoder = f"{model_name}.decoder.onnx"

if quant:
fname_encoder = f"{model_name}.encoder.quant.onnx"
fname_decoder = f"{model_name}.decoder.quant.onnx"

pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx",
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:13d97c5c79ab13ef86d67cbde5f1b250",
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_encoder}",
None,
fname=fname_encoder,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx",
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:fa3d1c36a3187d3de1c8deebf33dd127",
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/{fname_decoder}",
None,
fname=fname_decoder,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

if fname_encoder == "sam_vit_h_4b8939.encoder.onnx" and not os.path.exists(
os.path.join(
cls.u2net_home(*args, **kwargs), "sam_vit_h_4b8939.encoder_data.bin"
)
):
content = bytearray()

for i in range(1, 4):
pooch.retrieve(
f"https://github.com/danielgatis/rembg/releases/download/v0.0.0/sam_vit_h_4b8939.encoder_data.{i}.bin",
None,
fname=f"sam_vit_h_4b8939.encoder_data.{i}.bin",
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

fbin = os.path.join(
cls.u2net_home(*args, **kwargs),
f"sam_vit_h_4b8939.encoder_data.{i}.bin",
)
content.extend(open(fbin, "rb").read())
os.remove(fbin)

with open(
os.path.join(
cls.u2net_home(*args, **kwargs),
"sam_vit_h_4b8939.encoder_data.bin",
),
"wb",
) as fp:
fp.write(content)

return (
os.path.join(cls.u2net_home(*args, **kwargs), fname_encoder),
os.path.join(cls.u2net_home(*args, **kwargs), fname_decoder),
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
long_description = (here / "README.md").read_text(encoding="utf-8")

install_requires = [
"jsonschema",
"numpy",
"onnxruntime",
"opencv-python-headless",
Expand Down
Binary file modified tests/results/anime-girl-1.sam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/results/car-1.sam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/results/cloth-1.sam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit c01b1e0

Please sign in to comment.