diff --git a/.github/workflows/inference_tests.yml b/.github/workflows/inference_tests.yml index 9ba07690cb..00d7125c8f 100644 --- a/.github/workflows/inference_tests.yml +++ b/.github/workflows/inference_tests.yml @@ -32,7 +32,6 @@ jobs: WX_PROJECT_ID: ${{ secrets.WML_PROJECT_ID }} WX_API_KEY: ${{ secrets.WML_APIKEY }} GENAI_KEY: ${{ secrets.GENAI_KEY }} - steps: - uses: actions/checkout@v4 @@ -45,4 +44,4 @@ jobs: - run: huggingface-cli login --token ${{ secrets.UNITXT_READ_HUGGINGFACE_HUB_FOR_TESTS }} - name: Run Tests - run: python -m unittest discover -s tests/inference -p "test_*.py" \ No newline at end of file + run: python -m unittest discover -s tests/inference -p "test_*.py" diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index cc528ee771..c0be052c95 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -19,6 +19,7 @@ Optional, Sequence, Tuple, + TypedDict, Union, ) @@ -1407,6 +1408,11 @@ def get_options_log_probs(self, dataset): return dataset +class CredentialsOpenAi(TypedDict, total=False): + api_key: str + api_url: str + + class OpenAiInferenceEngineParamsMixin(Artifact): frequency_penalty: Optional[float] = None presence_penalty: Optional[float] = None @@ -1453,27 +1459,40 @@ class OpenAiInferenceEngine( } data_classification_policy = ["public"] parameters: Optional[OpenAiInferenceEngineParams] = None + base_url: Optional[str] = None + default_headers: Dict[str, str] = {} + credentials: CredentialsOpenAi = {} - def get_engine_id(self): + def get_engine_id(self) -> str: return get_model_and_label_id(self.model_name, self.label) - @classmethod - def get_api_param(cls, inference_engine: str, api_param_env_var_name: str): - api_key = os.environ.get(api_param_env_var_name) - assert api_key is not None, ( - f"Error while trying to run {inference_engine}." - f" Please set the environment param '{api_param_env_var_name}'." + def _prepare_credentials(self) -> CredentialsOpenAi: + api_key = self.credentials.get( + "api_key", os.environ.get(f"{self.label.upper()}_API_KEY", None) ) - return api_key + assert api_key, ( + f"Error while trying to run {self.label}. " + f"Please set the env variable: '{self.label.upper()}_API_KEY'" + ) + + api_url = self.credentials.get( + "api_url", os.environ.get(f"{self.label.upper()}_API_URL", None) + ) + + return {"api_key": api_key, "api_url": api_url} + + def get_default_headers(self) -> Dict[str, str]: + return self.default_headers def create_client(self): from openai import OpenAI - api_key = self.get_api_param( - inference_engine="OpenAiInferenceEngine", - api_param_env_var_name="OPENAI_API_KEY", + self.credentials = self._prepare_credentials() + return OpenAI( + api_key=self.credentials["api_key"], + base_url=self.base_url or self.credentials["api_url"], + default_headers=self.get_default_headers(), ) - return OpenAI(api_key=api_key) def prepare_engine(self): self.client = self.create_client() @@ -1553,6 +1572,32 @@ def get_return_object(self, predict_result, response, return_meta_data): return predict_result +class VLLMRemoteInferenceEngine(OpenAiInferenceEngine): + label: str = "vllm" + + +class RITSInferenceEngine(OpenAiInferenceEngine): + label: str = "rits" + + def get_default_headers(self): + return {"RITS_API_KEY": self.credentials["api_key"]} + + def prepare_engine(self): + base_url_template = "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/{}/v1" + self.base_url = base_url_template.format(self._get_model_name_for_endpoint()) + logger.info(f"Created RITS inference engine with endpoint: {self.base_url}") + super().prepare_engine() + + def _get_model_name_for_endpoint(self): + return ( + self.model_name.split("/")[-1] + .lower() + .replace("v0.1", "v01") + .replace("vision-", "") + .replace(".", "-") + ) + + class TogetherAiInferenceEngineParamsMixin(Artifact): max_tokens: Optional[int] = None stop: Optional[List[str]] = None @@ -1652,23 +1697,6 @@ def _infer( return outputs -class VLLMRemoteInferenceEngine(OpenAiInferenceEngine): - label: str = "vllm" - - def create_client(self): - from openai import OpenAI - - api_key = self.get_api_param( - inference_engine="VLLMRemoteInferenceEngine", - api_param_env_var_name="VLLM_API_KEY", - ) - api_url = self.get_api_param( - inference_engine="VLLMRemoteInferenceEngine", - api_param_env_var_name="VLLM_API_URL", - ) - return OpenAI(api_key=api_key, base_url=api_url) - - @deprecation( version="2.0.0", msg=" You can specify inference parameters directly when initializing an inference engine.", @@ -2667,7 +2695,7 @@ def _infer( _supported_apis = Literal[ - "watsonx", "together-ai", "open-ai", "aws", "ollama", "bam", "watsonx-sdk" + "watsonx", "together-ai", "open-ai", "aws", "ollama", "bam", "watsonx-sdk", "rits" ] @@ -2698,6 +2726,8 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "granite-3-8b-instruct": "watsonx/ibm/granite-3-8b-instruct", "flan-t5-xxl": "watsonx/google/flan-t5-xxl", "llama-3-2-1b-instruct": "watsonx/meta-llama/llama-3-2-1b-instruct", + "llama-3-2-11b-vision-instruct": "watsonx/meta-llama/llama-3-2-11b-vision-instruct", + "llama-3-2-90b-vision-instruct": "watsonx/meta-llama/llama-3-2-90b-vision-instruct", }, "watsonx-sdk": { "llama-3-8b-instruct": "meta-llama/llama-3-8b-instruct", @@ -2723,6 +2753,15 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "llama-3-2-1b-instruct": "meta-llama/llama-3-2-1b-instruct", "flan-t5-xxl": "google/flan-t5-xxl", }, + "rits": { + "granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct", + "llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct", + "llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct", + "llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct", + "mistral-large-instruct": "mistralai/mistral-large-instruct-2407", + "mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1", + }, } _provider_to_base_class = { @@ -2733,11 +2772,13 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "ollama": OllamaInferenceEngine, "bam": IbmGenAiInferenceEngine, "watsonx-sdk": WMLInferenceEngine, + "rits": RITSInferenceEngine, } _provider_param_renaming = { "bam": {"max_tokens": "max_new_tokens", "model": "model_name"}, "watsonx-sdk": {"max_tokens": "max_new_tokens", "model": "model_name"}, + "rits": {"model": "model_name"}, } def get_provider_name(self): @@ -2747,7 +2788,7 @@ def prepare_engine(self): provider = self.get_provider_name() if provider not in self._provider_to_base_class: raise UnitxtError( - f"{provider} a known API. Supported apis: {','.join(self.provider_model_map.keys())}" + f"{provider} is not a configured API for CrossProviderInferenceEngine. Supported apis: {','.join(self.provider_model_map.keys())}" ) if self.model not in self.provider_model_map[provider]: raise UnitxtError( diff --git a/src/unitxt/parsing_utils.py b/src/unitxt/parsing_utils.py index d9bcc028b2..a5487d455c 100644 --- a/src/unitxt/parsing_utils.py +++ b/src/unitxt/parsing_utils.py @@ -45,7 +45,7 @@ def consume_name_val(instring: str) -> Tuple[Any, str]: name_val = "" for char in instring: - if char in "[],:{}=": + if char in "[],{}=": break name_val += char instring = instring[len(name_val) :].strip() diff --git a/src/unitxt/standard.py b/src/unitxt/standard.py index b3af5e530f..b4373c3fc1 100644 --- a/src/unitxt/standard.py +++ b/src/unitxt/standard.py @@ -140,6 +140,10 @@ def verify(self): f"post processors must be a list of post processor. Got postprocessors = {self.postprocessors}" ) + if self.format is not None and not isinstance(self.format, Format): + raise ValueError( + f"format parameter must be a list of of class derived from Format. Got format = {self.format}" + ) if self.template is None: raise ValueError( "You must set in the recipe either `template`, `template_card_index`." diff --git a/tests/inference/test_inference_engine.py b/tests/inference/test_inference_engine.py index 331bb078e9..ca63a6c925 100644 --- a/tests/inference/test_inference_engine.py +++ b/tests/inference/test_inference_engine.py @@ -12,16 +12,19 @@ IbmGenAiInferenceEngine, LiteLLMInferenceEngine, OptionSelectingByLogProbsInferenceEngine, + RITSInferenceEngine, TextGenerationInferenceOutput, WMLInferenceEngineChat, WMLInferenceEngineGeneration, ) +from unitxt.logging_utils import get_logger from unitxt.settings_utils import get_settings from unitxt.text_utils import print_dict from unitxt.type_utils import isoftype from tests.utils import UnitxtInferenceTestCase +logger = get_logger() settings = get_settings() @@ -140,6 +143,41 @@ def test_watsonx_inference(self): result = {**inp, "prediction": prediction} print_dict(result, keys_to_print=["source", "prediction"]) + def test_rits_inference(self): + import os + + if os.environ.get("RITS_API_KEY") is None: + logger.warning( + "Skipping test_rits_inference because RITS_API_KEY not defined" + ) + return + + rits_engine = RITSInferenceEngine( + model_name="meta-llama/llama-3-1-70b-instruct", + max_tokens=128, + ) + # The defined rits_engine is equivalent to: + # rits_engine = OpenAiInferenceEngine( + # model_name="meta-llama/llama-3-1-70b-instruct", + # max_tokens=128, + # credentials={"api_key": "", "api_url": "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/llama-3-1-70b-instruct/v1"}, + # default_headers = {"RITS_API_KEY": ""} + # ) + + # Loading dataset: + dataset = load_dataset( + card="cards.go_emotions.simplified", + template="templates.classification.multi_label.empty", + loader_limit=3, + ) + test_data = dataset["test"] + + # Performing inference: + predictions = rits_engine.infer(test_data) + for inp, prediction in zip(test_data, predictions): + result = {**inp, "prediction": prediction} + print_dict(result, keys_to_print=["source", "prediction"]) + def test_option_selecting_by_log_prob_inference_engines(self): dataset = [ { diff --git a/utils/.secrets.baseline b/utils/.secrets.baseline index d4f8067c29..91de0fa545 100644 --- a/utils/.secrets.baseline +++ b/utils/.secrets.baseline @@ -133,7 +133,7 @@ "filename": "src/unitxt/inference.py", "hashed_secret": "aa6cd2a77de22303be80e1f632195d62d211a729", "is_verified": false, - "line_number": 1225, + "line_number": 1226, "is_secret": false }, { @@ -141,7 +141,7 @@ "filename": "src/unitxt/inference.py", "hashed_secret": "c8f16a194efc59559549c7bd69f7bea038742e79", "is_verified": false, - "line_number": 1589, + "line_number": 1634, "is_secret": false } ], @@ -184,5 +184,5 @@ } ] }, - "generated_at": "2024-11-21T09:46:00Z" + "generated_at": "2024-11-25T11:33:41Z" }