Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenLLM support #81

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion alfred/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
"google",
"groq",
"torch",
"openllm",
"dummy",
], f"Invalid model type: {self.model_type}"
else:
Expand All @@ -99,7 +100,7 @@ def __init__(
self.run = self.cache.cached_query(self.run)

self.grpcClient = None
if end_point:
if end_point and model_type not in ["dummy", "openllm", ]:
end_point_pieces = end_point.split(":")
self.end_point_ip, self.end_point_port = (
"".join(end_point_pieces[:-1]),
Expand Down Expand Up @@ -180,6 +181,11 @@ def __init__(
from ..fm.openai import OpenAIModel

self.model = OpenAIModel(self.model, **kwargs)
elif self.model_type == "openllm":
from ..fm.openllm import OpenLLMModel

base_url = kwargs.get("base_url", end_point)
self.model = OpenLLMModel(self.model, base_url=base_url, **kwargs)
elif self.model_type == "cohere":
from ..fm.cohere import CohereModel

Expand Down
148 changes: 148 additions & 0 deletions alfred/fm/openllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import json
import logging
from typing import Optional, List, Any, Union, Tuple

import openai
from openai._exceptions import (
AuthenticationError,
APIError,
APITimeoutError,
RateLimitError,
BadRequestError,
APIConnectionError,
APIStatusError,
)

from .model import APIAccessFoundationModel
from .response import CompletionResponse, RankedResponse
from .utils import retry

logger = logging.getLogger(__name__)

class OpenLLMModel(APIAccessFoundationModel):
"""
A wrapper for the OpenLLM Models using OpenAI's Python package
"""

@retry(
num_retries=3,
wait_time=0.1,
exceptions=(
AuthenticationError,
APIConnectionError,
APITimeoutError,
RateLimitError,
APIError,
BadRequestError,
APIStatusError,
),
)
def _api_query(
self,
query: Union[str, List, Tuple],
temperature: float = 0.0,
max_tokens: int = 64,
**kwargs: Any,
) -> str:
"""
Run a single query through the foundation model using OpenAI's Python package

:param query: The prompt to be used for the query
:type query: Union[str, List, Tuple]
:param temperature: The temperature of the model
:type temperature: float
:param max_tokens: The maximum number of tokens to be returned
:type max_tokens: int
:param kwargs: Additional keyword arguments
:type kwargs: Any
:return: The generated completion
:rtype: str
"""
chat = kwargs.get("chat", False)

if chat:
messages = query if isinstance(query, list) else [{"role": "user", "content": query}]
response = self.openai_client.chat.completions.create(
model=self.model_string,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
return response.choices[0].message.content
else:
prompt = query[0]['content'] if isinstance(query, list) else query
response = self.openai_client.completions.create(
model=self.model_string,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
)
return response.choices[0].text

def __init__(
self, model_string: str = "", api_key: Optional[str] = None, **kwargs: Any
):
"""
Initialize the OpenLLM API wrapper.

:param model_string: The model to be used for generating completions.
:type model_string: str
:param api_key: The API key to be used for the OpenAI API.
:type api_key: Optional[str]
"""
self.model_string = model_string
base_url = kwargs.get("base_url", None)
api_key = api_key or "na"
self.openai_client = openai.OpenAI(base_url=base_url, api_key=api_key)
super().__init__(model_string, {"api_key": api_key, "base_url": base_url})

def _generate_batch(
self,
batch_instance: Union[List[str], Tuple],
**kwargs,
) -> List[CompletionResponse]:
"""
Generate completions for a batch of prompts using the OpenAI API.

:param batch_instance: A list of prompts for which to generate completions.
:type batch_instance: List[str] or List[Tuple]
:param kwargs: Additional keyword arguments to pass to the API.
:type kwargs: Any
:return: A list of `CompletionResponse` objects containing the generated completions.
:rtype: List[CompletionResponse]
"""
output = []
for query in batch_instance:
output.append(
CompletionResponse(prediction=self._api_query(query, **kwargs))
)
return output

def _score_batch(
self,
batch_instance: Union[List[Tuple[str, str]], List[str]],
scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n",
**kwargs,
) -> List[RankedResponse]:
"""
Score candidates using the OpenAI API.

:param batch_instance: A list of prompts for which to generate candidate preferences.
:type batch_instance: List[str] or List[Tuple]
:param scoring_instruction: The instruction prompt for scoring
:type scoring_instruction: str
"""
output = []
for query in batch_instance:
_scoring_prompt = (
scoring_instruction.replace(
"[[label_space]]", ",".join(query.candidates)
)
+ query.prompt
)
output.append(
RankedResponse(
prediction=self._api_query(_scoring_prompt, **kwargs), scores={}
)
)
return output
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ A full list of `Alfred` project modules.
- [Model](alfred/fm/model.md#model)
- [Onnx](alfred/fm/onnx.md#onnx)
- [Openai](alfred/fm/openai.md#openai)
- [Openllm](alfred/fm/openllm.md#openllm)
- [Query](alfred/fm/query/index.md#query)
- [CompletionQuery](alfred/fm/query/completion_query.md#completionquery)
- [Query](alfred/fm/query/query.md#query)
Expand Down
16 changes: 8 additions & 8 deletions docs/alfred/client/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Client:

### Client().__call__

[Show source in client.py:313](../../../alfred/client/client.py#L313)
[Show source in client.py:319](../../../alfred/client/client.py#L319)

__call__() function to run the model on the queries.
Equivalent to run() function.
Expand All @@ -71,7 +71,7 @@ def __call__(

### Client().calibrate

[Show source in client.py:329](../../../alfred/client/client.py#L329)
[Show source in client.py:335](../../../alfred/client/client.py#L335)

calibrate are used to calibrate foundation models contextually given the template.
A voter class may be passed to calibrate the model with a specific voter.
Expand Down Expand Up @@ -115,7 +115,7 @@ def calibrate(

### Client().chat

[Show source in client.py:427](../../../alfred/client/client.py#L427)
[Show source in client.py:433](../../../alfred/client/client.py#L433)

Chat with the model APIs.
Currently, Alfred supports Chat APIs from Anthropic and OpenAI
Expand All @@ -133,7 +133,7 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any): ...

### Client().encode

[Show source in client.py:401](../../../alfred/client/client.py#L401)
[Show source in client.py:407](../../../alfred/client/client.py#L407)

embed() function to embed the queries.

Expand All @@ -155,7 +155,7 @@ def encode(

### Client().generate

[Show source in client.py:272](../../../alfred/client/client.py#L272)
[Show source in client.py:278](../../../alfred/client/client.py#L278)

Wrapper function to generate the response(s) from the model. (For completion)

Expand Down Expand Up @@ -183,7 +183,7 @@ def generate(

### Client().remote_run

[Show source in client.py:246](../../../alfred/client/client.py#L246)
[Show source in client.py:252](../../../alfred/client/client.py#L252)

Wrapper function for running the model on the queries thru a gRPC Server.

Expand All @@ -209,7 +209,7 @@ def remote_run(

### Client().run

[Show source in client.py:226](../../../alfred/client/client.py#L226)
[Show source in client.py:232](../../../alfred/client/client.py#L232)

Run the model on the queries.

Expand All @@ -235,7 +235,7 @@ def run(

### Client().score

[Show source in client.py:289](../../../alfred/client/client.py#L289)
[Show source in client.py:295](../../../alfred/client/client.py#L295)

Wrapper function to score the response(s) from the model. (For ranking)

Expand Down
1 change: 1 addition & 0 deletions docs/alfred/fm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- [Model](./model.md)
- [Onnx](./onnx.md)
- [Openai](./openai.md)
- [Openllm](./openllm.md)
- [Query](query/index.md)
- [Remote](remote/index.md)
- [Response](response/index.md)
Expand Down
Loading