Skip to content

Commit

Permalink
Update GAN to new requirements (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
bolasim authored Nov 15, 2023
1 parent 0926f54 commit c91c354
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 77 deletions.
35 changes: 24 additions & 11 deletions bin/test_example.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
import json
import sys
import time

import requests
from tenacity import retry, stop_after_attempt, wait_fixed
from truss.cli.cli import _get_truss_from_directory
from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory
from truss.remote.remote_factory import RemoteConfig, RemoteFactory
from truss.truss_handle import TrussHandle

REMOTE_NAME = "ci"
BASETEN_HOST = "https://app.staging.baseten.co"


def write_trussrc_file(api_key: str):
file_contents = f"""
[{REMOTE_NAME}]
remote_provider=baseten
remote_url={BASETEN_HOST}
api_key={api_key}"""
with open(USER_TRUSSRC_PATH, "w") as f:
f.write(file_contents)
ci_user = RemoteConfig(
name=REMOTE_NAME,
configs={
"api_key": api_key,
"remote_url": BASETEN_HOST,
"remote_provider": "baseten",
},
)
RemoteFactory.update_remote_config(ci_user)


@retry(wait=wait_fixed(60), stop=stop_after_attempt(20), reraise=True)
Expand All @@ -31,9 +34,19 @@ def attempt_inference(truss_handle, model_version_id, api_key):
"""
print("Started attempt inference")
try:
example_model_input = truss_handle.spec.config.model_metadata[
"example_model_input"
]
if "example_model_input" in truss_handle.spec.config.model_metadata:
example_model_input = truss_handle.spec.config.model_metadata[
"example_model_input"
]
else:
example_model_input = json.loads(
(
truss_handle._truss_dir
/ truss_handle.spec.config.model_metadata[
"example_model_input_file"
]
).read_text()
)
except KeyError:
raise Exception("No example_model_input defined in Truss config")

Expand Down
1 change: 1 addition & 0 deletions ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ tests:
- mistral/mistral-7b-instruct
- jina-embeddings/jina-embeddings-v2-base-en
- whisper/whisper-v3-truss
- gfp-gan
69 changes: 4 additions & 65 deletions gfp-gan/config.yaml
Original file line number Diff line number Diff line change
@@ -1,78 +1,17 @@
data_dir: data
environment_variables: {}
examples_filename: examples.yaml
input_type: Any
description: Restore photos with this image-to-image model.
model_class_filename: restoration_model.py
model_class_name: RestorationModel
model_framework: custom
model_metadata:
avatar_url: https://cdn.baseten.co/production/static/explore/tencent.png
cover_image_url: https://cdn.baseten.co/production/static/explore/gfp-gan.png
example_model_input_file: input.json
tags:
- image-restoration
model_module_dir: model
model_name: GFP-GAN
model_type: custom
python_version: py39
requirements:
- absl-py==1.1.0
- addict==2.4.0
- boto3==1.24.30
- cachetools==5.2.0
- certifi==2022.6.15
- charset-normalizer==2.1.0
- cycler==0.11.0
- facexlib==0.2.4
- filterpy==1.4.5
- fonttools==4.34.4
- future==0.18.2
- google-auth==2.9.1
- google-auth-oauthlib==0.4.6
- grpcio==1.47.0
- idna==3.3
- imageio==2.19.3
- importlib-metadata==4.12.0
- kiwisolver==1.4.3
- llvmlite==0.38.1
- lmdb==1.3.0
- Markdown==3.3.7
- matplotlib==3.5.2
- networkx==2.8.4
- numba==0.55.2
- numpy==1.20.3
- oauthlib==3.2.0
- opencv-python==4.6.0.66
- packaging==21.3
- Pillow==9.2.0
- protobuf==3.19.4
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pyparsing==3.0.9
- python-dateutil==2.8.2
- PyWavelets==1.3.0
- PyYAML==6.0
- requests==2.28.1
- requests-oauthlib==1.3.1
- rsa==4.8
- scikit-image==0.19.3
- scipy==1.8.1
- six==1.16.0
- tb-nightly==2.10.0a20220713
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.1
- tifffile==2022.5.4
- torch==1.12.0
- torchvision==0.13.0
- tqdm==4.64.0
- typing_extensions==4.3.0
- urllib3==1.26.10
- Werkzeug==2.1.2
- yapf==0.32.0
- zipp==3.8.1
- git+https://github.com/basetenlabs/BasicSR.git
- gfpgan==1.3.4
- realesrgan==0.2.5.0
- gfpgan==1.3.8
- realesrgan==0.3.0
- basicsr==1.4.2
resources:
cpu: "3"
memory: 8Gi
Expand Down
3 changes: 3 additions & 0 deletions gfp-gan/input.json

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
RESIZE_DEFAULT_MAX = 1400


class RestorationModel:
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs.get("data_dir")
self._config = kwargs.get("config")
Expand Down

0 comments on commit c91c354

Please sign in to comment.