diff --git a/doc/references/api.yml b/doc/references/api.yml index 62b226aca5..ed2ffcfe12 100644 --- a/doc/references/api.yml +++ b/doc/references/api.yml @@ -1021,7 +1021,63 @@ paths: "400": description: "An HTTP 400 is returned if the provided parameters are invalid" - + + /predict/lang: + get: + tags: + - Predict + summary: Predict the language of a text + parameters: + - name: text + in: query + required: true + description: The text to predict language of + schema: + type: string + example: "hello world" + - name: k + in: query + required: false + description: | + the number of predictions to return + schema: + type: integer + default: 10 + minimum: 1 + - name: threshold + in: query + required: false + description: | + the minimum probability for a language to be returned + schema: + type: number + default: 0.01 + minimum: 0 + maximum: 1 + responses: + "200": + description: the predicted languages + content: + application/json: + schema: + type: object + properties: + predictions: + type: array + description: a list of predicted languages, sorted by descending probability + items: + type: object + properties: + lang: + type: string + description: the predicted language (2-letter code) + example: "en" + confidence: + type: number + description: the probability of the predicted language + example: 0.9 + "400": + description: "An HTTP 400 is returned if the provided parameters are invalid" components: schemas: diff --git a/robotoff/app/api.py b/robotoff/app/api.py index 3bf59d3bbd..3964cf0ae3 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -7,7 +7,7 @@ import re import tempfile import uuid -from typing import Literal, Optional +from typing import Literal, Optional, cast import falcon import orjson @@ -36,6 +36,7 @@ get_predictions, save_annotation, update_logo_annotations, + validate_params, ) from robotoff.app.middleware import DBConnectionMiddleware from robotoff.elasticsearch import get_es_client @@ -67,6 +68,7 @@ ) from robotoff.prediction import ingredient_list from robotoff.prediction.category import predict_category +from robotoff.prediction.langid import predict_lang from robotoff.prediction.object_detection import ObjectDetectionModelRegistry from robotoff.products import get_image_id, get_product, get_product_dataset_etag from robotoff.taxonomy import is_prefixed_value, match_taxonomized_value @@ -98,8 +100,6 @@ settings.init_sentry(integrations=[FalconIntegration()]) -es_client = get_es_client() - TRANSLATION_STORE = TranslationStore() TRANSLATION_STORE.load() @@ -650,6 +650,26 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): resp.media = dataclasses.asdict(output) +class LanguagePredictorResource: + def on_get(self, req: falcon.Request, resp: falcon.Response): + """Predict language of a text.""" + params = validate_params( + { + "text": req.get_param("text"), + "k": req.get_param("k"), + "threshold": req.get_param("threshold"), + }, + schema.LanguagePredictorResourceParams, + ) + params = cast(schema.LanguagePredictorResourceParams, params) + language_predictions = predict_lang(params.text, params.k, params.threshold) + resp.media = { + "predictions": [ + dataclasses.asdict(prediction) for prediction in language_predictions + ] + } + + class UpdateDatasetResource: def on_post(self, req: falcon.Request, resp: falcon.Response): """Re-import the Product Opener product dump.""" @@ -1150,6 +1170,7 @@ def on_get( """ count = req.get_param_as_int("count", min_value=1, max_value=500, default=100) server_type = get_server_type_from_req(req) + es_client = get_es_client() if logo_id is None: logo_embeddings = list( @@ -1759,6 +1780,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): api.add_route("/api/v1/predict/ocr_prediction", OCRPredictionPredictorResource()) api.add_route("/api/v1/predict/category", CategoryPredictorResource()) api.add_route("/api/v1/predict/ingredient_list", IngredientListPredictorResource()) +api.add_route("/api/v1/predict/lang", LanguagePredictorResource()) api.add_route("/api/v1/products/dataset", UpdateDatasetResource()) api.add_route("/api/v1/webhook/product", WebhookProductResource()) api.add_route("/api/v1/images", ImageCollection()) diff --git a/robotoff/app/core.py b/robotoff/app/core.py index 6216f6e563..5e8cba3e2c 100644 --- a/robotoff/app/core.py +++ b/robotoff/app/core.py @@ -3,9 +3,11 @@ from enum import Enum from typing import Iterable, Literal, NamedTuple, Optional, Union +import falcon import peewee from openfoodfacts.types import COUNTRY_CODE_TO_NAME, Country from peewee import JOIN, SQL, fn +from pydantic import BaseModel, ValidationError from robotoff.app import events from robotoff.insights.annotate import ( @@ -27,7 +29,7 @@ ) from robotoff.off import OFFAuthentication from robotoff.taxonomy import match_taxonomized_value -from robotoff.types import InsightAnnotation, ServerType +from robotoff.types import InsightAnnotation, JSONType, ServerType from robotoff.utils import get_logger from robotoff.utils.text import get_tag @@ -580,3 +582,23 @@ def filter_question_insight_types(keep_types: Optional[list[str]]): set(keep_types) & set(QuestionFormatterFactory.get_available_types()) ) return keep_types + + +def validate_params(params: JSONType, schema: type) -> BaseModel: + """Validate the parameters passed to a Falcon resource. + + Either returns a validated params object or raises a falcon.HTTPBadRequest. + + :param params: the input parameters to validate, as a dict + :param schema: the pydantic schema to use for validation + :raises falcon.HTTPBadRequest: if the parameters are invalid + """ + # Remove None values from the params dict + params = {k: v for k, v in params.items() if v is not None} + try: + return schema.model_validate(params) # type: ignore + except ValidationError as e: + errors = e.errors(include_url=False) + plural = "s" if len(errors) > 1 else "" + description = f"{len(errors)} validation error{plural}: {errors}" + raise falcon.HTTPBadRequest(description=description) diff --git a/robotoff/app/schema.py b/robotoff/app/schema.py index 6262a4011b..5b9d07db43 100644 --- a/robotoff/app/schema.py +++ b/robotoff/app/schema.py @@ -1,3 +1,7 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + from robotoff.types import JSONType, NeuralCategoryClassifierModel, ServerType IMAGE_PREDICTION_IMPORTER_SCHEMA: JSONType = { @@ -172,3 +176,16 @@ }, "required": ["annotations"], } + + +class LanguagePredictorResourceParams(BaseModel): + text: Annotated[ + str, Field(..., description="the text to predict language of", min_length=1) + ] + k: Annotated[ + int, Field(default=10, description="the number of predictions to return", ge=1) + ] + threshold: Annotated[ + float, + Field(default=0.01, description="the minimum confidence threshold", ge=0, le=1), + ] diff --git a/tests/integration/test_api.py b/tests/integration/test_api.py index 3548a5f4df..18999eb5f4 100644 --- a/tests/integration/test_api.py +++ b/tests/integration/test_api.py @@ -9,6 +9,7 @@ from robotoff.app.api import api from robotoff.models import AnnotationVote, LogoAnnotation, ProductInsight from robotoff.off import OFFAuthentication +from robotoff.prediction.langid import LanguagePrediction from robotoff.types import ProductIdentifier, ServerType from .models_utils import ( @@ -1205,3 +1206,47 @@ def test_logo_annotation_collection_pagination(client, peewee_db): "truffle cake-00", "truffle cake-01", ] + + +def test_predict_lang_invalid_params(client, mocker): + mocker.patch( + "robotoff.app.api.predict_lang", + return_value=[], + ) + # no text + result = client.simulate_get("/api/v1/predict/lang", params={"k": 2}) + assert result.status_code == 400 + assert result.json == { + "description": "1 validation error: [{'type': 'missing', 'loc': ('text',), 'msg': 'Field required', 'input': {'k': '2'}}]", + "title": "400 Bad Request", + } + + # invalid k and threshold parameters + result = client.simulate_get( + "/api/v1/predict/lang", + params={"text": "test", "k": "invalid", "threshold": 1.05}, + ) + assert result.status_code == 400 + assert result.json == { + "description": "2 validation errors: [{'type': 'int_parsing', 'loc': ('k',), 'msg': 'Input should be a valid integer, unable to parse string as an integer', 'input': 'invalid'}, {'type': 'less_than_equal', 'loc': ('threshold',), 'msg': 'Input should be less than or equal to 1', 'input': '1.05', 'ctx': {'le': 1.0}}]", + "title": "400 Bad Request", + } + + +def test_predict_lang(client, mocker): + mocker.patch( + "robotoff.app.api.predict_lang", + return_value=[ + LanguagePrediction("en", 0.9), + LanguagePrediction("fr", 0.1), + ], + ) + expected_predictions = [ + {"lang": "en", "confidence": 0.9}, + {"lang": "fr", "confidence": 0.1}, + ] + result = client.simulate_get( + "/api/v1/predict/lang", params={"text": "hello", "k": 2} + ) + assert result.status_code == 200 + assert result.json == {"predictions": expected_predictions}