From 6a99cc8843461d106ed6e4d2f6ab534b76d12a39 Mon Sep 17 00:00:00 2001 From: Elron Bandel Date: Wed, 22 Jan 2025 12:19:39 +0200 Subject: [PATCH] Add Replicate inference support (#1544) Signed-off-by: elronbandel --- src/unitxt/inference.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 6e77b67bb3..02de4a48dd 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -2898,6 +2898,7 @@ def get_return_object(self, responses, return_meta_data): "rits", "azure", "vertex-ai", + "replicate", ] @@ -3026,6 +3027,28 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "llama-3-1-70b-instruct": "vertex_ai/meta/llama-3.1-70b-instruct-maas", "llama-3-1-405b-instruct": "vertex_ai/meta/llama-3.1-405b-instruct-maas", }, + "replicate": { + "granite-20b-code-instruct-8k": "replicate/ibm-granite/granite-20b-code-instruct-8k", + "granite-3-2b-instruct": "replicate/ibm-granite/granite-3.0-2b-instruct", + "granite-3-8b-instruct": "replicate/ibm-granite/granite-3.0-8b-instruct", + "granite-3-1-2b-instruct": "replicate/ibm-granite/granite-3.1-2b-instruct", + "granite-3-1-8b-instruct": "replicate/ibm-granite/granite-3.1-8b-instruct", + "granite-8b-code-instruct-128k": "replicate/ibm-granite/granite-8b-code-instruct-128k", + "llama-2-13b": "replicate/meta/llama-2-13b", + "llama-2-13b-chat": "replicate/meta/llama-2-13b-chat", + "llama-2-70b": "replicate/meta/llama-2-70b", + "llama-2-70b-chat": "replicate/meta/llama-2-70b-chat", + "llama-2-7b": "replicate/meta/llama-2-7b", + "llama-2-7b-chat": "replicate/meta/llama-2-7b-chat", + "llama-3-1-405b-instruct": "replicate/meta/meta-llama-3.1-405b-instruct", + "llama-3-70b": "replicate/meta/meta-llama-3-70b", + "llama-3-70b-instruct": "replicate/meta/meta-llama-3-70b-instruct", + "llama-3-8b": "replicate/meta/meta-llama-3-8b", + "llama-3-8b-instruct": "replicate/meta/meta-llama-3-8b-instruct", + "mistral-7b-instruct-v0.2": "replicate/mistralai/mistral-7b-instruct-v0.2", + "mistral-7b-v0.1": "replicate/mistralai/mistral-7b-v0.1", + "mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1", + }, } _provider_to_base_class = { @@ -3039,6 +3062,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "rits": RITSInferenceEngine, "azure": LiteLLMInferenceEngine, "vertex-ai": LiteLLMInferenceEngine, + "replicate": LiteLLMInferenceEngine, } _provider_param_renaming = {