diff --git a/pyproject.toml b/pyproject.toml index 4332bb5..50c4023 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,8 @@ dependencies = [ "torchvision>=0.15.2", "unbabel-comet==2.2.3", "librosa>=0.10.2.post1", + "sentencepiece", + "transformers" ] [project.optional-dependencies] diff --git a/src/m4st/ollama/client.py b/src/m4st/ollama/client.py index d1c326a..b789236 100644 --- a/src/m4st/ollama/client.py +++ b/src/m4st/ollama/client.py @@ -2,14 +2,12 @@ import requests # type: ignore[import-untyped] -model = "llama3.2" - -def generate(prompt, context): +def generate(prompt, context, model_name="llama3.2"): r = requests.post( "http://localhost:11434/api/generate", json={ - "model": model, + "model": model_name, "prompt": prompt, "context": context, }, diff --git a/src/m4st/translate/__init__.py b/src/m4st/translate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/m4st/translate/model.py b/src/m4st/translate/model.py new file mode 100644 index 0000000..988956a --- /dev/null +++ b/src/m4st/translate/model.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod + +from transformers import T5ForConditionalGeneration, T5Tokenizer + +from m4st.ollama.client import generate + +language_iso_to_name = { + "eng": "English", + "deu": "German", + "fra": "French", + "jpn": "Japanese", + "spa": "Spanish", + "rou": "Romanian", + "zho": "Chinese", +} + + +class TranslationModel(ABC): + @abstractmethod + def __call__(self, text: str, source_lang_iso: str, target_lang_iso: str): + pass + + +class T5TranslateModel(TranslationModel): + def __init__(self, model_tag: str = "google-t5/t5-small"): + self.model = T5ForConditionalGeneration.from_pretrained(model_tag) + self.tokenizer = T5Tokenizer.from_pretrained(model_tag) + + self.supported_languages = ["eng", "deu", "fra"] + + def __call__(self, text: str, source_lang_iso: str, target_lang_iso: str): + assert target_lang_iso in self.supported_languages, f"This model \ +only supports {self.supported_languages}, but got target {target_lang_iso}." + assert source_lang_iso in self.supported_languages, f"This model \ +only supports {self.supported_languages}, but got source {source_lang_iso}." + + source_lang_name, target_lang_name = ( + language_iso_to_name[source_lang_iso], + language_iso_to_name[target_lang_iso], + ) + model_input_text = f"translate {source_lang_name} to {target_lang_name}: {text}" + model_input_tokens = self.tokenizer( + model_input_text, return_tensors="pt" + ).input_ids + model_output_tokens = self.model.generate(model_input_tokens) + + return self.tokenizer.decode(model_output_tokens[0], skip_special_tokens=True) + + +class OllamaTranslateModel(TranslationModel): + def __init__(self, model_tag: str = "llama3.2"): + self.model_tag = model_tag + + def __call__(self, text: str, source_lang_iso: str, target_lang_iso: str): + source_lang_name, target_lang_name = ( + language_iso_to_name[source_lang_iso], + language_iso_to_name[target_lang_iso], + ) + prompt = f""" + Please translate the following sentence from {source_lang_name}\ + to {target_lang_name}, and return only the translated sentence as your\ + response: {text} + """ + + return generate(prompt, [], model_name=self.model_tag) diff --git a/tests/test_translate.py b/tests/test_translate.py new file mode 100644 index 0000000..88d209b --- /dev/null +++ b/tests/test_translate.py @@ -0,0 +1,18 @@ +from m4st.translate.model import T5TranslateModel + +eng_text = "Hi, my name is Bob." + + +def test_t5(): + t5 = T5TranslateModel() + translated_text_t5 = t5(eng_text, source_lang_iso="eng", target_lang_iso="fra") + print(translated_text_t5) + + +# NOTE: Ollama server must be running for this test to work, skipping it for now. +# def test_llama32(): +# llama32translate = OllamaTranslateModel() +# translated_text_llama32 = llama32translate( +# eng_text, source_lang_iso="eng", target_lang_iso="fra" +# ) +# print(translated_text_llama32)