Skip to content

Commit

Permalink
Adding truss files for real esrgan
Browse files Browse the repository at this point in the history
  • Loading branch information
htrivedi99 committed Nov 16, 2023
1 parent c91c354 commit c85cde4
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 0 deletions.
93 changes: 93 additions & 0 deletions real-esrgan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Real-ESRGAN Truss

This is a [Truss](https://truss.baseten.co/) for Real-ESRGAN which is an AI image upscaling model.
Open-source image generation models like Stable Diffusion 1.5 can sometime produce blurry or low resolution images. Using Real-ESRGAN, those low quality images can be upscaled making them look sharper and more detailed.

## Deployment

First, clone this repository:

```
git clone https://github.com/basetenlabs/truss-examples/
cd real-esrgan-truss
```

Before deployment:

1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
2. Install the latest version of Truss: `pip install --upgrade truss`

With `real-esrgan-truss` as your working directory, you can deploy the model with:

```
truss push
```

Paste your Baseten API key if prompted.

For more information, see [Truss documentation](https://truss.baseten.co).

## API route: `predict`
The predict route is the primary method for upscaling an image. In order to send the image to our model, the image must first be converted into a base64 string.

- __image__: The image converted to a base64 string


## Invoking the model

```sh
truss predict -d '{"image": "<BASE64-STRING-HERE>"}'
```

You can also use python to call the model:

```python
BASE64_PREAMBLE = "data:image/png;base64,"

def pil_to_b64(pil_img):
buffered = BytesIO()
pil_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str

def b64_to_pil(b64_str):
return Image.open(BytesIO(base64.b64decode(b64_str.replace(BASE64_PREAMBLE, ""))))

img = Image.open("/path/to/image/ship.jpeg")
b64_img = pil_to_b64(img)

headers = {"Authorization": f"Api-Key <BASETEN-API-KEY>"}
data = {"image": b64_img}
res = requests.post("https://app.baseten.co/model_versions/{model_version}/predict", headers=headers, json=data)
output = res.json()

result_b64 = output.get("model_output").get("upscaled_image")
pil_img = b64_to_pil(result_b64)
pil_img.save("upscaled_output_img.png")
```

The model returns a JSON object containing the key `upscaled_image`, which is the upscaled image as a base64 string.

## Results

<div style="display: flex; justify-content: space-between;">
<div style="flex: 1; margin-right: 10px;">
<img src="ship.jpeg" alt="original image" style="width: 100%;">
<p>Original Image Stable Diffusion 1.5</p>
</div>
<div style="flex: 1;">
<img src="result_image.jpeg" alt="upscaled image" style="width: 100%;">
<p>Upscaled Image</p>
</div>
</div>

<div style="display: flex; justify-content: space-between;">
<div style="flex: 1; margin-right: 10px;">
<img src="racecar.jpeg" alt="original image" style="width: 100%;">
<p>Original Image SDXL</p>
</div>
<div style="flex: 1;">
<img src="racecar_upscaled.jpeg" alt="upscaled image" style="width: 100%;">
<p>Upscaled Image</p>
</div>
</div>
29 changes: 29 additions & 0 deletions real-esrgan/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"image": "BASE64-STRING-HERE"}
model_name: real-esrgan
python_version: py310
requirements:
- numpy==1.23.5
- torch==2.0.1
- torchvision==0.15.2
- facexlib==0.3.0
- gfpgan==1.3.8
- basicsr==1.4.2
- opencv-python==4.8.0.76
- opencv-python-headless==4.8.1.78
- Pillow==9.4.0
- tqdm==4.66.1
resources:
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: T4
secrets: {}
system_packages:
- libgl1-mesa-glx
- libglib2.0-0
external_data:
- url: https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth
local_data_path: weights/RealESRGAN_x4plus.pth
Empty file added real-esrgan/model/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions real-esrgan/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import base64
import io
import os
import subprocess
import sys
from io import BytesIO
from typing import Dict

import numpy as np
from PIL import Image

git_repo_url = "https://github.com/xinntao/Real-ESRGAN.git"
git_clone_command = ["git", "clone", git_repo_url]
commit_hash = "5ca1078535923d485892caee7d7804380bfc87fd"
original_working_directory = os.getcwd()

try:
subprocess.run(git_clone_command, check=True)
print("Git repository cloned successfully!")

os.chdir(os.path.join(original_working_directory, "Real-ESRGAN"))
checkout_command = ["git", "checkout", commit_hash]
subprocess.run(checkout_command, check=True)
subprocess.run([sys.executable, "setup.py", "develop"], check=True)

except Exception as e:
print(e)
raise Exception("Error cloning Real-ESRGAN repo :(")

sys.path.append(os.path.join(os.getcwd()))

from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer


class Model:
def __init__(self, **kwargs):
self._data_dir = kwargs["data_dir"]
self.model_checkpoint_path = os.path.join(
original_working_directory,
self._data_dir,
"weights",
"RealESRGAN_x4plus.pth",
)
self.model = None

def pil_to_b64(self, pil_img):
buffered = BytesIO()
pil_img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str

def load(self):
rrdb_net_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4

self.model = RealESRGANer(
scale=netscale,
model_path=self.model_checkpoint_path,
model=rrdb_net_model,
tile=0,
tile_pad=10,
pre_pad=0,
half=True,
)

def predict(self, request: Dict) -> Dict:
image = request.get("image")
scale = 4

pil_img = Image.open(io.BytesIO(base64.decodebytes(bytes(image, "utf-8"))))
pil_image_array = np.asarray(pil_img)

output, _ = self.model.enhance(pil_image_array, outscale=scale)
output = Image.fromarray(output)
output_b64 = self.pil_to_b64(output)
return {"upscaled_image": output_b64}
Binary file added real-esrgan/racecar.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added real-esrgan/racecar_upscaled.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added real-esrgan/result_image.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added real-esrgan/ship.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c85cde4

Please sign in to comment.