Skip to content

Commit

Permalink
Merge branch '11-translation-models' into 5-audio
Browse files Browse the repository at this point in the history
  • Loading branch information
boykovdn committed Jan 27, 2025
2 parents 63908c6 + ed4339e commit 3d03893
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 4 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ dependencies = [
"torchvision>=0.15.2",
"unbabel-comet==2.2.3",
"librosa>=0.10.2.post1",
"sentencepiece",
"transformers"
]

[project.optional-dependencies]
Expand Down
6 changes: 2 additions & 4 deletions src/m4st/ollama/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

import requests # type: ignore[import-untyped]

model = "llama3.2"


def generate(prompt, context):
def generate(prompt, context, model_name="llama3.2"):
r = requests.post(
"http://localhost:11434/api/generate",
json={
"model": model,
"model": model_name,
"prompt": prompt,
"context": context,
},
Expand Down
Empty file added src/m4st/translate/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions src/m4st/translate/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from abc import ABC, abstractmethod

from transformers import T5ForConditionalGeneration, T5Tokenizer

from m4st.ollama.client import generate

language_iso_to_name = {
"eng": "English",
"deu": "German",
"fra": "French",
"jpn": "Japanese",
"spa": "Spanish",
"rou": "Romanian",
"zho": "Chinese",
}


class TranslationModel(ABC):
@abstractmethod
def __call__(self, text: str, source_lang_iso: str, target_lang_iso: str):
pass


class T5TranslateModel(TranslationModel):
def __init__(self, model_tag: str = "google-t5/t5-small"):
self.model = T5ForConditionalGeneration.from_pretrained(model_tag)
self.tokenizer = T5Tokenizer.from_pretrained(model_tag)

self.supported_languages = ["eng", "deu", "fra"]

def __call__(self, text: str, source_lang_iso: str, target_lang_iso: str):
assert target_lang_iso in self.supported_languages, f"This model \
only supports {self.supported_languages}, but got target {target_lang_iso}."
assert source_lang_iso in self.supported_languages, f"This model \
only supports {self.supported_languages}, but got source {source_lang_iso}."

source_lang_name, target_lang_name = (
language_iso_to_name[source_lang_iso],
language_iso_to_name[target_lang_iso],
)
model_input_text = f"translate {source_lang_name} to {target_lang_name}: {text}"
model_input_tokens = self.tokenizer(
model_input_text, return_tensors="pt"
).input_ids
model_output_tokens = self.model.generate(model_input_tokens)

return self.tokenizer.decode(model_output_tokens[0], skip_special_tokens=True)


class OllamaTranslateModel(TranslationModel):
def __init__(self, model_tag: str = "llama3.2"):
self.model_tag = model_tag

def __call__(self, text: str, source_lang_iso: str, target_lang_iso: str):
source_lang_name, target_lang_name = (
language_iso_to_name[source_lang_iso],
language_iso_to_name[target_lang_iso],
)
prompt = f"""
Please translate the following sentence from {source_lang_name}\
to {target_lang_name}, and return only the translated sentence as your\
response: {text}
"""

return generate(prompt, [], model_name=self.model_tag)
18 changes: 18 additions & 0 deletions tests/test_translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from m4st.translate.model import T5TranslateModel

eng_text = "Hi, my name is Bob."


def test_t5():
t5 = T5TranslateModel()
translated_text_t5 = t5(eng_text, source_lang_iso="eng", target_lang_iso="fra")
print(translated_text_t5)


# NOTE: Ollama server must be running for this test to work, skipping it for now.
# def test_llama32():
# llama32translate = OllamaTranslateModel()
# translated_text_llama32 = llama32translate(
# eng_text, source_lang_iso="eng", target_lang_iso="fra"
# )
# print(translated_text_llama32)

0 comments on commit 3d03893

Please sign in to comment.