Skip to content

Commit

Permalink
Improve truss examples CI script (#373)
Browse files Browse the repository at this point in the history
Improves the truss examples to do the following:
* Fail immediately if there are problems (instead of having to wait 20
minutes)
* Doesn't have a 20 minute timeout if the model is still building

# Testing

Ran
https://github.com/basetenlabs/truss-examples/actions/runs/11633933987/job/32400112601
squidarth authored Nov 4, 2024
1 parent 13325c0 commit 7786b53
Showing 1 changed file with 34 additions and 38 deletions.
72 changes: 34 additions & 38 deletions bin/test_example.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import time

import requests
import truss
from tenacity import (
Retrying,
retry,
@@ -12,7 +13,6 @@
)
from truss.cli.cli import _get_truss_from_directory
from truss.remote.remote_factory import RemoteConfig, RemoteFactory
from truss.truss_handle import TrussHandle

REMOTE_NAME = "ci"
BASETEN_HOST = "https://app.staging.baseten.co"
@@ -30,62 +30,58 @@ def write_trussrc_file(api_key: str):
RemoteFactory.update_remote_config(ci_user)


@retry(wait=wait_fixed(60), stop=stop_after_attempt(20), reraise=True)
def attempt_inference(truss_handle, model_version_id, api_key):
@retry(wait=wait_fixed(30), stop=stop_after_attempt(8), reraise=True)
def attempt_inference(truss_handle, model_id, model_version_id, api_key):
"""
Retry every 20 seconds to call inference on the example, using the `example_model_input`
from the Truss config to invoke the model. We return success if there is a 200 response,
and if not retry, ultimately throwing an exception if we don't get a response after 200
seconds.
"""
print("Started attempt inference")
try:
if "example_model_input" in truss_handle.spec.config.model_metadata:
example_model_input = truss_handle.spec.config.model_metadata[
"example_model_input"
]
else:
example_model_input = json.loads(
(
truss_handle._truss_dir
/ truss_handle.spec.config.model_metadata[
"example_model_input_file"
]
).read_text()
)
except KeyError:

if "example_model_input" in truss_handle.spec.config.model_metadata:
example_model_input = truss_handle.spec.config.model_metadata[
"example_model_input"
]
elif "example_model_input_file" in truss_handle.spec.config.model_metadata:
example_model_input = json.loads(
(
truss_handle._truss_dir
/ truss_handle.spec.config.model_metadata["example_model_input_file"]
).read_text()
)
else:
raise Exception("No example_model_input defined in Truss config")

url = f"{BASETEN_HOST}/model_versions/{model_version_id}/predict"
url = f"https://model-{model_id}.api.staging.baseten.co/deployment/{model_version_id}/predict"

headers = {"Authorization": f"Api-Key {api_key}"}
response = requests.post(url, headers=headers, json=example_model_input, timeout=30)

print(response.content)
if response.status_code != 200:
raise Exception(f"Request failed with status code {response.status_code}")


def deploy_truss(truss_handle: TrussHandle) -> str:
remote_provider = RemoteFactory.create(remote=REMOTE_NAME)
model_name = truss_handle.spec.config.model_name
for _ in Retrying(
wait=wait_random_exponential(multiplier=1, max=120),
stop=stop_after_attempt(5),
reraise=True,
):
service = remote_provider.push(
truss_handle, model_name, publish=True, trusted=True
)
return service.model_version_id
response.raise_for_status()


@retry(
wait=wait_random_exponential(multiplier=1, max=120),
stop=stop_after_attempt(3),
reraise=True,
)
def deploy_truss(target_directory: str) -> str:
model_deployment = truss.push(
target_directory, remote=REMOTE_NAME, trusted=True, publish=True
)
model_deployment.wait_for_active()
return model_deployment.model_id, model_deployment.model_deployment_id


def main(api_key: str, target_directory: str):
write_trussrc_file(api_key)
truss_handle = _get_truss_from_directory(target_directory)
model_version_id = deploy_truss(truss_handle)
model_id, model_version_id = deploy_truss(target_directory)
print(f"Deployed Truss {model_version_id}")
time.sleep(20)
attempt_inference(truss_handle, model_version_id, api_key)
attempt_inference(truss_handle, model_id, model_version_id, api_key)


if __name__ == "__main__":

0 comments on commit 7786b53

Please sign in to comment.