Skip to content

Commit

Permalink
feat/docs: deprecate v2 and introduce v3
Browse files Browse the repository at this point in the history
The `v3` route has larger scaling capability. We no longer expose our batching capability at the Python level. Instead, we let `FastAPI` handle batching via multithreading. This works because `ctranslate2` drops the GIL when `translate_iterable` is called, therefore the entire operation can be asynchronous. Now batching is simplified to just making multiple API requests.
  • Loading branch information
winstxnhdw committed May 30, 2024
1 parent cca5e2d commit 42c0a17
Show file tree
Hide file tree
Showing 13 changed files with 66 additions and 108 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

A fast CPU-based API for Meta's [No Language Left Behind](https://huggingface.co/docs/transformers/model_doc/nllb) distilled 1.3B 8-bit quantised variant, hosted on Hugging Face Spaces. To achieve faster executions, we are using [CTranslate2](https://github.com/OpenNMT/CTranslate2) as our inference engine. Requests are cached and then served at the reverse proxy layer to reduce server load.

> [!WARNING]\
> NLLB has a max input length of 1024 tokens. This limit is imposed by the model's architecture and cannot be changed. If you need to translate longer texts, consider splitting your input into smaller chunks.
## Usage

Simply cURL the endpoint like in the following. The `source` and `target` languages must be specified using FLORES-200 codes.
Expand Down Expand Up @@ -230,11 +233,8 @@ Standard Malay | zsm_Latn
Zulu | zul_Latn
</details>

> [!TIP]\
> You can translate multiple texts in a single batch by separating the texts with a `\n` character. See the [tests](tests/test_translate.py) for an example.
```bash
curl -N 'https://winstxnhdw-nllb-api.hf.space/api/v2/translate' \
curl -N 'https://winstxnhdw-nllb-api.hf.space/api/v3/translate' \
-H 'Content-Type: application/json' \
-d \
'{
Expand Down
3 changes: 2 additions & 1 deletion server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from server.api import v2
from server.api import v2, v3
from server.config import Config
from server.lifespans import lifespans
from server.middlewares import LoggingMiddleware
Expand Down Expand Up @@ -87,6 +87,7 @@ def initialise() -> Framework:
app = Framework(lifespan=lifespans, root_path=Config.server_root_path)
app.initialise_routes(join('server', 'api'))
app.include_router(v2)
app.include_router(v3)
app.add_middleware(LoggingMiddleware)
app.add_middleware(
CORSMiddleware,
Expand Down
1 change: 1 addition & 0 deletions server/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from server.api.v2 import v2 as v2
from server.api.v3 import v3 as v3
18 changes: 11 additions & 7 deletions server/api/v2/translate.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from typing import Annotated, Generator
from asyncio import gather

from fastapi import Depends
from starlette.responses import StreamingResponse
from starlette.responses import PlainTextResponse

from server.api.v2 import v2
from server.dependencies import translation
from server.features import Translator
from server.schemas.v1 import Translation


@v2.post('/translate')
def translate(result: Annotated[Generator[str, None, None], Depends(translation)]):
@v2.post('/translate', deprecated=True)
async def translate(request: Translation) -> PlainTextResponse:
"""
Summary
-------
the `/translate` route translates an input from a source language to a target language
"""
return StreamingResponse(result, media_type='text/event-stream')
results = await gather(
*(Translator.translate(line, request.source, request.target) for line in request.text.splitlines() if line)
)

return PlainTextResponse('\n'.join(results))
3 changes: 3 additions & 0 deletions server/api/v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from fastapi import APIRouter

v3 = APIRouter(prefix='/v3', tags=["v3"])
15 changes: 15 additions & 0 deletions server/api/v3/index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Literal

from starlette.responses import PlainTextResponse

from server.api.v3 import v3


@v3.get('/', response_model=Literal['Welcome to v3 of the API!'])
def index():
"""
Summary
-------
the `/` route
"""
return PlainTextResponse('Welcome to v2 of the API!')
13 changes: 13 additions & 0 deletions server/api/v3/translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from server.api.v3 import v3
from server.features import Translator
from server.schemas.v1 import Translated, Translation


@v3.post('/translate')
async def translate(request: Translation) -> Translated:
"""
Summary
-------
the `/translate` route translates an input from a source language to a target language
"""
return Translated(result=await Translator.translate(request.text, request.source, request.target))
1 change: 0 additions & 1 deletion server/dependencies/__init__.py

This file was deleted.

35 changes: 0 additions & 35 deletions server/dependencies/translation.py

This file was deleted.

22 changes: 9 additions & 13 deletions server/features/translator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Generator

from ctranslate2 import Translator as CTranslator
from transformers.models.nllb.tokenization_nllb_fast import NllbTokenizerFast

Expand Down Expand Up @@ -50,11 +48,11 @@ def load(cls):


@classmethod
def translate(cls, text: str, source_language: str, target_language: str) -> Generator[str, None, None]:
async def translate(cls, text: str, source_language: str, target_language: str) -> str:
"""
Summary
-------
translate the input from the source language to the target language
translate the input from the source language to the target language without the Python GIL
Parameters
----------
Expand All @@ -68,13 +66,11 @@ def translate(cls, text: str, source_language: str, target_language: str) -> Gen
"""
cls.tokeniser.src_lang = source_language

lines = [line for line in text.splitlines() if line]
result = next(cls.translator.translate_iterable(
(cls.tokeniser(text).tokens(),),
([target_language],),
batch_type='tokens',
beam_size=1,
))

return (
f'{cls.tokeniser.decode(cls.tokeniser.convert_tokens_to_ids(result.hypotheses[0][1:]))}\n'
for result in cls.translator.translate_iterable(
(cls.tokeniser(line).tokens() for line in lines),
([target_language] for _ in lines),
beam_size=1
)
)
return cls.tokeniser.decode(cls.tokeniser.convert_tokens_to_ids(result.hypotheses[0][1:]))
4 changes: 2 additions & 2 deletions server/schemas/v1/translated.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ class Translated(BaseModel):
Attributes
----------
text (str) : the translated text
result (str) : the translated text
"""
text: str = Field(examples=['¡Hola, mundo!'])
result: str = Field(examples=['¡Hola, mundo!'])
43 changes: 2 additions & 41 deletions server/typings/ctranslate2/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pylint: skip-file

from typing import Callable, Iterable, Literal, overload
from typing import Callable, Generator, Iterable, Literal, overload

ComputeTypes = Literal[
'default',
Expand Down Expand Up @@ -129,52 +129,13 @@ class Translator:
) -> list[AsyncTranslationResult]: ...


@overload
def translate_iterable(
self,
source: Iterable[list[str]],
target_prefix: Iterable[list[str]] | None = None,
max_batch_size: int = 32,
batch_type: str = 'examples',
*,
asynchronous: Literal[False] = False,
beam_size: int = 2,
patience: float = 1,
num_hypotheses: int = 1,
length_penalty: float = 1,
coverage_penalty: float = 0,
repetition_penalty: float = 1,
no_repeat_ngram_size: int = 0,
disable_unks: bool = False,
supress_sequences: list[list[str]] | None = None,
end_token: str | list[str] | list[int] | None = None,
return_end_token: bool = False,
prefix_bias_beta: float = 0,
max_input_length: int = 1024,
max_decoding_length: int = 256,
min_decoding_length: int = 1,
use_vmap: bool = False,
return_scores: bool = False,
return_attention: bool = False,
return_alternatives: bool = False,
min_alternative_expansion_prob: float = 0,
sampling_topk: int = 1,
sampling_topp: float = 1,
sampling_temperature: float = 1,
replace_unknowns: bool = False,
callback: Callable[[GenerationStepResult], bool] | None = None
) -> Iterable[TranslationResult]: ...


@overload
def translate_iterable(
self,
source: Iterable[list[str]],
target_prefix: Iterable[list[str]] | None = None,
max_batch_size: int = 32,
batch_type: str = 'examples',
*,
asynchronous: Literal[True],
beam_size: int = 2,
patience: float = 1,
num_hypotheses: int = 1,
Expand All @@ -200,4 +161,4 @@ class Translator:
sampling_temperature: float = 1,
replace_unknowns: bool = False,
callback: Callable[[GenerationStepResult], bool] | None = None
) -> Iterable[AsyncTranslationResult]: ...
) -> Generator[TranslationResult, None, None]: ...
8 changes: 4 additions & 4 deletions tests/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ def client() -> Generator[TestClient, None, None]:


def test_generate(client: TestClient):
response = client.post('/v2/translate', json={
response = client.post('/v3/translate', json={
'text': 'Hello, world!',
'source': 'eng_Latn',
'target': 'spa_Latn'
})

assert response.text == '¡Hola, mundo!\n'
assert response.json()['result'] == '¡Hola, mundo!'


def test_generate_from_chinese(client: TestClient):
response = client.post('/v2/translate', json={
response = client.post('/v3/translate', json={
'text': '我是一名软件工程师!',
'source': 'zho_Hans',
'target': 'spa_Latn'
})

assert response.text == '¡Soy ingeniero de software!\n'
assert response.json()['result'] == '¡Soy ingeniero de software!'


def test_generate_stream(client: TestClient):
Expand Down

0 comments on commit 42c0a17

Please sign in to comment.