Skip to content

Commit

Permalink
Add support for OpenAI custom base url and default headers + RITS Inf…
Browse files Browse the repository at this point in the history
…erence engine (#1385)

* Add support for OpenAI custom base url and default headers

Signed-off-by: Martín Santillán Cooper <[email protected]>

* Add RITSInferenceEngine

This commit also refactor how openai credentials are prepared. The change allows credentials being passed by parameter or via env variables. The mechanism is similar -but simpler- to the way wml inference engine gets the credentials

Signed-off-by: Martín Santillán Cooper <[email protected]>

* Fixed artifact overide to allow ":" in override values

Signed-off-by: Yoav Katz <[email protected]>

* Moved get_api_param to where it was used

Changed to use standard TypedDict typing in WMLCredentials.

Signed-off-by: Yoav Katz <[email protected]>

* Added RITS to CrossProviderInferenceEngine

Signed-off-by: Yoav Katz <[email protected]>

* Added error message when providing wrong format param

Signed-off-by: Yoav Katz <[email protected]>

* Added vision models to RITS

Signed-off-by: Yoav Katz <[email protected]>

* Update inference_tests.yml with RITS_API_KEY_SUPPORT

* Update inference_tests.yml

* Skip testing RITS if RITS_API_KEY not defined.

Signed-off-by: Yoav Katz <[email protected]>

* Removed RITS from xternal tests because an not run in actions

* fixes

* fixes

* fixes

* fixes

---------

Signed-off-by: Martín Santillán Cooper <[email protected]>
Signed-off-by: Yoav Katz <[email protected]>
Co-authored-by: OfirArviv <[email protected]>
Co-authored-by: Yoav Katz <[email protected]>
Co-authored-by: Yoav Katz <[email protected]>
  • Loading branch information
4 people authored Nov 25, 2024
1 parent bf700b8 commit d116a0b
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 37 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/inference_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ jobs:
WX_PROJECT_ID: ${{ secrets.WML_PROJECT_ID }}
WX_API_KEY: ${{ secrets.WML_APIKEY }}
GENAI_KEY: ${{ secrets.GENAI_KEY }}

steps:
- uses: actions/checkout@v4

Expand All @@ -45,4 +44,4 @@ jobs:
- run: huggingface-cli login --token ${{ secrets.UNITXT_READ_HUGGINGFACE_HUB_FOR_TESTS }}

- name: Run Tests
run: python -m unittest discover -s tests/inference -p "test_*.py"
run: python -m unittest discover -s tests/inference -p "test_*.py"
103 changes: 72 additions & 31 deletions src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Optional,
Sequence,
Tuple,
TypedDict,
Union,
)

Expand Down Expand Up @@ -1407,6 +1408,11 @@ def get_options_log_probs(self, dataset):
return dataset


class CredentialsOpenAi(TypedDict, total=False):
api_key: str
api_url: str


class OpenAiInferenceEngineParamsMixin(Artifact):
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
Expand Down Expand Up @@ -1453,27 +1459,40 @@ class OpenAiInferenceEngine(
}
data_classification_policy = ["public"]
parameters: Optional[OpenAiInferenceEngineParams] = None
base_url: Optional[str] = None
default_headers: Dict[str, str] = {}
credentials: CredentialsOpenAi = {}

def get_engine_id(self):
def get_engine_id(self) -> str:
return get_model_and_label_id(self.model_name, self.label)

@classmethod
def get_api_param(cls, inference_engine: str, api_param_env_var_name: str):
api_key = os.environ.get(api_param_env_var_name)
assert api_key is not None, (
f"Error while trying to run {inference_engine}."
f" Please set the environment param '{api_param_env_var_name}'."
def _prepare_credentials(self) -> CredentialsOpenAi:
api_key = self.credentials.get(
"api_key", os.environ.get(f"{self.label.upper()}_API_KEY", None)
)
return api_key
assert api_key, (
f"Error while trying to run {self.label}. "
f"Please set the env variable: '{self.label.upper()}_API_KEY'"
)

api_url = self.credentials.get(
"api_url", os.environ.get(f"{self.label.upper()}_API_URL", None)
)

return {"api_key": api_key, "api_url": api_url}

def get_default_headers(self) -> Dict[str, str]:
return self.default_headers

def create_client(self):
from openai import OpenAI

api_key = self.get_api_param(
inference_engine="OpenAiInferenceEngine",
api_param_env_var_name="OPENAI_API_KEY",
self.credentials = self._prepare_credentials()
return OpenAI(
api_key=self.credentials["api_key"],
base_url=self.base_url or self.credentials["api_url"],
default_headers=self.get_default_headers(),
)
return OpenAI(api_key=api_key)

def prepare_engine(self):
self.client = self.create_client()
Expand Down Expand Up @@ -1553,6 +1572,32 @@ def get_return_object(self, predict_result, response, return_meta_data):
return predict_result


class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
label: str = "vllm"


class RITSInferenceEngine(OpenAiInferenceEngine):
label: str = "rits"

def get_default_headers(self):
return {"RITS_API_KEY": self.credentials["api_key"]}

def prepare_engine(self):
base_url_template = "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/{}/v1"
self.base_url = base_url_template.format(self._get_model_name_for_endpoint())
logger.info(f"Created RITS inference engine with endpoint: {self.base_url}")
super().prepare_engine()

def _get_model_name_for_endpoint(self):
return (
self.model_name.split("/")[-1]
.lower()
.replace("v0.1", "v01")
.replace("vision-", "")
.replace(".", "-")
)


class TogetherAiInferenceEngineParamsMixin(Artifact):
max_tokens: Optional[int] = None
stop: Optional[List[str]] = None
Expand Down Expand Up @@ -1652,23 +1697,6 @@ def _infer(
return outputs


class VLLMRemoteInferenceEngine(OpenAiInferenceEngine):
label: str = "vllm"

def create_client(self):
from openai import OpenAI

api_key = self.get_api_param(
inference_engine="VLLMRemoteInferenceEngine",
api_param_env_var_name="VLLM_API_KEY",
)
api_url = self.get_api_param(
inference_engine="VLLMRemoteInferenceEngine",
api_param_env_var_name="VLLM_API_URL",
)
return OpenAI(api_key=api_key, base_url=api_url)


@deprecation(
version="2.0.0",
msg=" You can specify inference parameters directly when initializing an inference engine.",
Expand Down Expand Up @@ -2667,7 +2695,7 @@ def _infer(


_supported_apis = Literal[
"watsonx", "together-ai", "open-ai", "aws", "ollama", "bam", "watsonx-sdk"
"watsonx", "together-ai", "open-ai", "aws", "ollama", "bam", "watsonx-sdk", "rits"
]


Expand Down Expand Up @@ -2698,6 +2726,8 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
"granite-3-8b-instruct": "watsonx/ibm/granite-3-8b-instruct",
"flan-t5-xxl": "watsonx/google/flan-t5-xxl",
"llama-3-2-1b-instruct": "watsonx/meta-llama/llama-3-2-1b-instruct",
"llama-3-2-11b-vision-instruct": "watsonx/meta-llama/llama-3-2-11b-vision-instruct",
"llama-3-2-90b-vision-instruct": "watsonx/meta-llama/llama-3-2-90b-vision-instruct",
},
"watsonx-sdk": {
"llama-3-8b-instruct": "meta-llama/llama-3-8b-instruct",
Expand All @@ -2723,6 +2753,15 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
"llama-3-2-1b-instruct": "meta-llama/llama-3-2-1b-instruct",
"flan-t5-xxl": "google/flan-t5-xxl",
},
"rits": {
"granite-3-8b-instruct": "ibm-granite/granite-3.0-8b-instruct",
"llama-3-1-8b-instruct": "meta-llama/llama-3-1-8b-instruct",
"llama-3-1-70b-instruct": "meta-llama/llama-3-1-70b-instruct",
"llama-3-2-11b-vision-instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"llama-3-2-90b-vision-instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"mistral-large-instruct": "mistralai/mistral-large-instruct-2407",
"mixtral-8x7b-instruct": "mistralai/mixtral-8x7B-instruct-v0.1",
},
}

_provider_to_base_class = {
Expand All @@ -2733,11 +2772,13 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
"ollama": OllamaInferenceEngine,
"bam": IbmGenAiInferenceEngine,
"watsonx-sdk": WMLInferenceEngine,
"rits": RITSInferenceEngine,
}

_provider_param_renaming = {
"bam": {"max_tokens": "max_new_tokens", "model": "model_name"},
"watsonx-sdk": {"max_tokens": "max_new_tokens", "model": "model_name"},
"rits": {"model": "model_name"},
}

def get_provider_name(self):
Expand All @@ -2747,7 +2788,7 @@ def prepare_engine(self):
provider = self.get_provider_name()
if provider not in self._provider_to_base_class:
raise UnitxtError(
f"{provider} a known API. Supported apis: {','.join(self.provider_model_map.keys())}"
f"{provider} is not a configured API for CrossProviderInferenceEngine. Supported apis: {','.join(self.provider_model_map.keys())}"
)
if self.model not in self.provider_model_map[provider]:
raise UnitxtError(
Expand Down
2 changes: 1 addition & 1 deletion src/unitxt/parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
def consume_name_val(instring: str) -> Tuple[Any, str]:
name_val = ""
for char in instring:
if char in "[],:{}=":
if char in "[],{}=":
break
name_val += char
instring = instring[len(name_val) :].strip()
Expand Down
4 changes: 4 additions & 0 deletions src/unitxt/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ def verify(self):
f"post processors must be a list of post processor. Got postprocessors = {self.postprocessors}"
)

if self.format is not None and not isinstance(self.format, Format):
raise ValueError(
f"format parameter must be a list of of class derived from Format. Got format = {self.format}"
)
if self.template is None:
raise ValueError(
"You must set in the recipe either `template`, `template_card_index`."
Expand Down
38 changes: 38 additions & 0 deletions tests/inference/test_inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
IbmGenAiInferenceEngine,
LiteLLMInferenceEngine,
OptionSelectingByLogProbsInferenceEngine,
RITSInferenceEngine,
TextGenerationInferenceOutput,
WMLInferenceEngineChat,
WMLInferenceEngineGeneration,
)
from unitxt.logging_utils import get_logger
from unitxt.settings_utils import get_settings
from unitxt.text_utils import print_dict
from unitxt.type_utils import isoftype

from tests.utils import UnitxtInferenceTestCase

logger = get_logger()
settings = get_settings()


Expand Down Expand Up @@ -140,6 +143,41 @@ def test_watsonx_inference(self):
result = {**inp, "prediction": prediction}
print_dict(result, keys_to_print=["source", "prediction"])

def test_rits_inference(self):
import os

if os.environ.get("RITS_API_KEY") is None:
logger.warning(
"Skipping test_rits_inference because RITS_API_KEY not defined"
)
return

rits_engine = RITSInferenceEngine(
model_name="meta-llama/llama-3-1-70b-instruct",
max_tokens=128,
)
# The defined rits_engine is equivalent to:
# rits_engine = OpenAiInferenceEngine(
# model_name="meta-llama/llama-3-1-70b-instruct",
# max_tokens=128,
# credentials={"api_key": "<api_key>", "api_url": "https://inference-3scale-apicast-production.apps.rits.fmaas.res.ibm.com/llama-3-1-70b-instruct/v1"},
# default_headers = {"RITS_API_KEY": "<api_key>"}
# )

# Loading dataset:
dataset = load_dataset(
card="cards.go_emotions.simplified",
template="templates.classification.multi_label.empty",
loader_limit=3,
)
test_data = dataset["test"]

# Performing inference:
predictions = rits_engine.infer(test_data)
for inp, prediction in zip(test_data, predictions):
result = {**inp, "prediction": prediction}
print_dict(result, keys_to_print=["source", "prediction"])

def test_option_selecting_by_log_prob_inference_engines(self):
dataset = [
{
Expand Down
6 changes: 3 additions & 3 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@
"filename": "src/unitxt/inference.py",
"hashed_secret": "aa6cd2a77de22303be80e1f632195d62d211a729",
"is_verified": false,
"line_number": 1225,
"line_number": 1226,
"is_secret": false
},
{
"type": "Secret Keyword",
"filename": "src/unitxt/inference.py",
"hashed_secret": "c8f16a194efc59559549c7bd69f7bea038742e79",
"is_verified": false,
"line_number": 1589,
"line_number": 1634,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2024-11-21T09:46:00Z"
"generated_at": "2024-11-25T11:33:41Z"
}

0 comments on commit d116a0b

Please sign in to comment.