diff --git a/bin/test_example.py b/bin/test_example.py index 90439f0f..b521597a 100644 --- a/bin/test_example.py +++ b/bin/test_example.py @@ -3,6 +3,7 @@ import time import requests +import truss from tenacity import ( Retrying, retry, @@ -12,7 +13,6 @@ ) from truss.cli.cli import _get_truss_from_directory from truss.remote.remote_factory import RemoteConfig, RemoteFactory -from truss.truss_handle import TrussHandle REMOTE_NAME = "ci" BASETEN_HOST = "https://app.staging.baseten.co" @@ -30,8 +30,8 @@ def write_trussrc_file(api_key: str): RemoteFactory.update_remote_config(ci_user) -@retry(wait=wait_fixed(60), stop=stop_after_attempt(20), reraise=True) -def attempt_inference(truss_handle, model_version_id, api_key): +@retry(wait=wait_fixed(30), stop=stop_after_attempt(8), reraise=True) +def attempt_inference(truss_handle, model_id, model_version_id, api_key): """ Retry every 20 seconds to call inference on the example, using the `example_model_input` from the Truss config to invoke the model. We return success if there is a 200 response, @@ -39,53 +39,49 @@ def attempt_inference(truss_handle, model_version_id, api_key): seconds. """ print("Started attempt inference") - try: - if "example_model_input" in truss_handle.spec.config.model_metadata: - example_model_input = truss_handle.spec.config.model_metadata[ - "example_model_input" - ] - else: - example_model_input = json.loads( - ( - truss_handle._truss_dir - / truss_handle.spec.config.model_metadata[ - "example_model_input_file" - ] - ).read_text() - ) - except KeyError: + + if "example_model_input" in truss_handle.spec.config.model_metadata: + example_model_input = truss_handle.spec.config.model_metadata[ + "example_model_input" + ] + elif "example_model_input_file" in truss_handle.spec.config.model_metadata: + example_model_input = json.loads( + ( + truss_handle._truss_dir + / truss_handle.spec.config.model_metadata["example_model_input_file"] + ).read_text() + ) + else: raise Exception("No example_model_input defined in Truss config") - url = f"{BASETEN_HOST}/model_versions/{model_version_id}/predict" + url = f"https://model-{model_id}.api.staging.baseten.co/deployment/{model_version_id}/predict" + headers = {"Authorization": f"Api-Key {api_key}"} response = requests.post(url, headers=headers, json=example_model_input, timeout=30) print(response.content) - if response.status_code != 200: - raise Exception(f"Request failed with status code {response.status_code}") - - -def deploy_truss(truss_handle: TrussHandle) -> str: - remote_provider = RemoteFactory.create(remote=REMOTE_NAME) - model_name = truss_handle.spec.config.model_name - for _ in Retrying( - wait=wait_random_exponential(multiplier=1, max=120), - stop=stop_after_attempt(5), - reraise=True, - ): - service = remote_provider.push( - truss_handle, model_name, publish=True, trusted=True - ) - return service.model_version_id + response.raise_for_status() + + +@retry( + wait=wait_random_exponential(multiplier=1, max=120), + stop=stop_after_attempt(3), + reraise=True, +) +def deploy_truss(target_directory: str) -> str: + model_deployment = truss.push( + target_directory, remote=REMOTE_NAME, trusted=True, publish=True + ) + model_deployment.wait_for_active() + return model_deployment.model_id, model_deployment.model_deployment_id def main(api_key: str, target_directory: str): write_trussrc_file(api_key) truss_handle = _get_truss_from_directory(target_directory) - model_version_id = deploy_truss(truss_handle) + model_id, model_version_id = deploy_truss(target_directory) print(f"Deployed Truss {model_version_id}") - time.sleep(20) - attempt_inference(truss_handle, model_version_id, api_key) + attempt_inference(truss_handle, model_id, model_version_id, api_key) if __name__ == "__main__":