diff --git a/real-esrgan/README.md b/real-esrgan/README.md new file mode 100644 index 00000000..d6631a28 --- /dev/null +++ b/real-esrgan/README.md @@ -0,0 +1,93 @@ +# Real-ESRGAN Truss + +This is a [Truss](https://truss.baseten.co/) for Real-ESRGAN which is an AI image upscaling model. +Open-source image generation models like Stable Diffusion 1.5 can sometime produce blurry or low resolution images. Using Real-ESRGAN, those low quality images can be upscaled making them look sharper and more detailed. + +## Deployment + +First, clone this repository: + +``` +git clone https://github.com/basetenlabs/truss-examples/ +cd real-esrgan-truss +``` + +Before deployment: + +1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys). +2. Install the latest version of Truss: `pip install --upgrade truss` + +With `real-esrgan-truss` as your working directory, you can deploy the model with: + +``` +truss push +``` + +Paste your Baseten API key if prompted. + +For more information, see [Truss documentation](https://truss.baseten.co). + +## API route: `predict` +The predict route is the primary method for upscaling an image. In order to send the image to our model, the image must first be converted into a base64 string. + +- __image__: The image converted to a base64 string + + +## Invoking the model + +```sh +truss predict -d '{"image": ""}' +``` + +You can also use python to call the model: + +```python +BASE64_PREAMBLE = "data:image/png;base64," + +def pil_to_b64(pil_img): + buffered = BytesIO() + pil_img.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + return img_str + +def b64_to_pil(b64_str): + return Image.open(BytesIO(base64.b64decode(b64_str.replace(BASE64_PREAMBLE, "")))) + +img = Image.open("/path/to/image/ship.jpeg") +b64_img = pil_to_b64(img) + +headers = {"Authorization": f"Api-Key "} +data = {"image": b64_img} +res = requests.post("https://app.baseten.co/model_versions/{model_version}/predict", headers=headers, json=data) +output = res.json() + +result_b64 = output.get("model_output").get("upscaled_image") +pil_img = b64_to_pil(result_b64) +pil_img.save("upscaled_output_img.png") +``` + +The model returns a JSON object containing the key `upscaled_image`, which is the upscaled image as a base64 string. + +## Results + +
+
+ original image +

Original Image Stable Diffusion 1.5

+
+
+ upscaled image +

Upscaled Image

+
+
+ +
+
+ original image +

Original Image SDXL

+
+
+ upscaled image +

Upscaled Image

+
+
diff --git a/real-esrgan/config.yaml b/real-esrgan/config.yaml new file mode 100644 index 00000000..44af1186 --- /dev/null +++ b/real-esrgan/config.yaml @@ -0,0 +1,29 @@ +environment_variables: {} +external_package_dirs: [] +model_metadata: + example_model_input: {"image": "BASE64-STRING-HERE"} +model_name: real-esrgan +python_version: py310 +requirements: + - numpy==1.23.5 + - torch==2.0.1 + - torchvision==0.15.2 + - facexlib==0.3.0 + - gfpgan==1.3.8 + - basicsr==1.4.2 + - opencv-python==4.8.0.76 + - opencv-python-headless==4.8.1.78 + - Pillow==9.4.0 + - tqdm==4.66.1 +resources: + cpu: "3" + memory: 14Gi + use_gpu: true + accelerator: T4 +secrets: {} +system_packages: + - libgl1-mesa-glx + - libglib2.0-0 +external_data: + - url: https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth + local_data_path: weights/RealESRGAN_x4plus.pth diff --git a/real-esrgan/model/__init__.py b/real-esrgan/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/real-esrgan/model/model.py b/real-esrgan/model/model.py new file mode 100644 index 00000000..63be51f1 --- /dev/null +++ b/real-esrgan/model/model.py @@ -0,0 +1,84 @@ +import base64 +import io +import os +import subprocess +import sys +from io import BytesIO +from typing import Dict + +import numpy as np +from PIL import Image + +git_repo_url = "https://github.com/xinntao/Real-ESRGAN.git" +git_clone_command = ["git", "clone", git_repo_url] +commit_hash = "5ca1078535923d485892caee7d7804380bfc87fd" +original_working_directory = os.getcwd() + +try: + subprocess.run(git_clone_command, check=True) + print("Git repository cloned successfully!") + + os.chdir(os.path.join(original_working_directory, "Real-ESRGAN")) + checkout_command = ["git", "checkout", commit_hash] + subprocess.run(checkout_command, check=True) + subprocess.run([sys.executable, "setup.py", "develop"], check=True) + +except Exception as e: + print(e) + raise Exception("Error cloning Real-ESRGAN repo :(") + +sys.path.append(os.path.join(os.getcwd())) + +from basicsr.archs.rrdbnet_arch import RRDBNet +from realesrgan import RealESRGANer + + +class Model: + def __init__(self, **kwargs): + self._data_dir = kwargs["data_dir"] + self.model_checkpoint_path = os.path.join( + original_working_directory, + self._data_dir, + "weights", + "RealESRGAN_x4plus.pth", + ) + self.model = None + + def pil_to_b64(self, pil_img): + buffered = BytesIO() + pil_img.save(buffered, format="PNG") + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") + return img_str + + def load(self): + rrdb_net_model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + netscale = 4 + + self.model = RealESRGANer( + scale=netscale, + model_path=self.model_checkpoint_path, + model=rrdb_net_model, + tile=0, + tile_pad=10, + pre_pad=0, + half=True, + ) + + def predict(self, request: Dict) -> Dict: + image = request.get("image") + scale = 4 + + pil_img = Image.open(io.BytesIO(base64.decodebytes(bytes(image, "utf-8")))) + pil_image_array = np.asarray(pil_img) + + output, _ = self.model.enhance(pil_image_array, outscale=scale) + output = Image.fromarray(output) + output_b64 = self.pil_to_b64(output) + return {"upscaled_image": output_b64} diff --git a/real-esrgan/racecar.jpeg b/real-esrgan/racecar.jpeg new file mode 100644 index 00000000..fb180385 Binary files /dev/null and b/real-esrgan/racecar.jpeg differ diff --git a/real-esrgan/racecar_upscaled.jpeg b/real-esrgan/racecar_upscaled.jpeg new file mode 100644 index 00000000..558e5e16 Binary files /dev/null and b/real-esrgan/racecar_upscaled.jpeg differ diff --git a/real-esrgan/result_image.jpeg b/real-esrgan/result_image.jpeg new file mode 100644 index 00000000..bcfd3439 Binary files /dev/null and b/real-esrgan/result_image.jpeg differ diff --git a/real-esrgan/ship.jpeg b/real-esrgan/ship.jpeg new file mode 100644 index 00000000..7433d986 Binary files /dev/null and b/real-esrgan/ship.jpeg differ