diff --git a/docker/ml-gpu.yml b/docker/ml-gpu.yml index 5a3bfbda5e..b38fffe373 100644 --- a/docker/ml-gpu.yml +++ b/docker/ml-gpu.yml @@ -22,7 +22,7 @@ services: # https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_management.md # for more information entrypoint: "/opt/nvidia/nvidia_entrypoint.sh tritonserver --model-repository=/models --model-control-mode=explicit --load-model=*" - mem_limit: 20g + mem_limit: 30g runtime: nvidia deploy: resources: diff --git a/docker/ml.yml b/docker/ml.yml index c348138b70..3537e2c520 100644 --- a/docker/ml.yml +++ b/docker/ml.yml @@ -14,7 +14,7 @@ services: # https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_management.md # for more information entrypoint: "tritonserver --model-repository=/models --model-control-mode=explicit --load-model=*" - mem_limit: 20g + mem_limit: 30g fasttext: restart: $RESTART_POLICY diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index 7db5bcb579..d4d1f4d599 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -670,6 +670,9 @@ def run_nutrition_extraction( None, help="URI of the Triton Inference Server to use. If not provided, the default value from settings is used.", ), + model_version: Optional[str] = typer.Option( + None, help="Version of the model to use, defaults to the latest" + ), ) -> None: """Run nutrition extraction on a product image. @@ -693,7 +696,9 @@ def run_nutrition_extraction( image = cast(Image.Image, get_image_from_url(image_url)) ocr_result = cast(OCRResult, OCRResult.from_url(image_url.replace(".jpg", ".json"))) - prediction = predict(image, ocr_result, triton_uri=triton_uri) + prediction = predict( + image, ocr_result, model_version=model_version, triton_uri=triton_uri + ) if prediction is not None: pprint(prediction) else: diff --git a/robotoff/prediction/nutrition_extraction.py b/robotoff/prediction/nutrition_extraction.py index 1cc099ceba..1ff628e3b4 100644 --- a/robotoff/prediction/nutrition_extraction.py +++ b/robotoff/prediction/nutrition_extraction.py @@ -60,7 +60,7 @@ class NutritionExtractionPrediction: def predict( image: Image.Image, ocr_result: OCRResult, - model_version: str = "1", + model_version: str | None = None, triton_uri: str | None = None, ) -> NutritionExtractionPrediction | None: """Predict the nutrient values from an image and an OCR result. @@ -77,7 +77,7 @@ def predict( :param image: the *original* image (not resized) :param ocr_result: the OCR result - :param model_version: the version of the model to use, defaults to "1" + :param model_version: the version of the model to use, defaults to None (latest) :param triton_uri: the URI of the Triton Inference Server, if not provided, the default value from settings is used :return: a `NutritionExtractionPrediction` object @@ -619,7 +619,7 @@ def send_infer_request( pixel_values: np.ndarray, model_name: str, triton_stub: GRPCInferenceServiceStub, - model_version: str = "1", + model_version: str | None = None, ) -> np.ndarray: """Send a NER infer request to the Triton inference server. @@ -634,7 +634,7 @@ def send_infer_request( :param pixel_values: pixel values of the image, generated using the transformers tokenizer. :param model_name: the name of the model to use - :param model_version: version of the model model to use, defaults to "1" + :param model_version: version of the model model to use, defaults to None (latest). :return: the predicted logits """ request = build_triton_request( @@ -660,7 +660,7 @@ def build_triton_request( bbox: np.ndarray, pixel_values: np.ndarray, model_name: str, - model_version: str = "1", + model_version: str | None = None, ): """Build a Triton ModelInferRequest gRPC request for LayoutLMv3 models. @@ -672,12 +672,14 @@ def build_triton_request( :param pixel_values: pixel values of the image, generated using the transformers tokenizer. :param model_name: the name of the model to use. - :param model_version: version of the model model to use, defaults to "1". + :param model_version: version of the model model to use, defaults to None (latest). :return: the gRPC ModelInferRequest """ request = service_pb2.ModelInferRequest() request.model_name = model_name - request.model_version = model_version + + if model_version: + request.model_version = model_version add_triton_infer_input_tensor(request, "input_ids", input_ids, "INT64") add_triton_infer_input_tensor(request, "attention_mask", attention_mask, "INT64")