Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added model kwargs support to AzureChAzureMLChatOnlineEndpoint #313

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/ragas/llms/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import typing as t

from langchain.chat_models import AzureChatOpenAI, BedrockChat, ChatOpenAI, ChatVertexAI
from langchain.chat_models.azureml_endpoint import AzureMLChatOnlineEndpoint
from langchain.chat_models.base import SimpleChatModel
from langchain.chat_models.base import BaseChatModel
from langchain.llms import AmazonAPIGateway, AzureOpenAI, Bedrock, OpenAI, VertexAI
from langchain.llms.base import BaseLLM
Expand All @@ -26,6 +28,9 @@ def isBedrock(llm: BaseLLM | BaseChatModel) -> bool:
return isinstance(llm, Bedrock) or isinstance(llm, BedrockChat)


def isAzureMLEndpoint(llm: BaseLLM | SimpleChatModel) -> bool:
return isinstance(llm, AzureMLChatOnlineEndpoint)

def isAmazonAPIGateway(llm: BaseLLM | BaseChatModel) -> bool:
return isinstance(llm, AmazonAPIGateway)

Expand All @@ -38,9 +43,10 @@ def isAmazonAPIGateway(llm: BaseLLM | BaseChatModel) -> bool:
AzureChatOpenAI,
ChatVertexAI,
VertexAI,
AzureMLChatOnlineEndpoint,
]
MultipleCompletionSupportedLLM = t.Union[
OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI, ChatVertexAI, VertexAI
OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI, ChatVertexAI, VertexAI, AzureMLChatOnlineEndpoint
]


Expand Down Expand Up @@ -146,6 +152,8 @@ async def agenerate(
temperature = 0.2 if n > 1 else 0
if isBedrock(self.llm) and ("model_kwargs" in self.llm.__dict__):
self.llm.model_kwargs = {"temperature": temperature}
elif isAzureMLEndpoint(self.llm) and ("model_kwargs" in self.llm.__dict__):
self.llm.model_kwargs['temperature'] = temperature
else:
self.llm.temperature = temperature

Expand Down Expand Up @@ -200,6 +208,8 @@ def generate(
temperature = 0.2 if n > 1 else 1e-8
if isBedrock(self.llm) and ("model_kwargs" in self.llm.__dict__):
self.llm.model_kwargs = {"temperature": temperature}
elif isAzureMLEndpoint(self.llm) and ("model_kwargs" in self.llm.__dict__):
self.llm.model_kwargs['temperature'] = temperature
elif isAmazonAPIGateway(self.llm) and ("model_kwargs" in self.llm.__dict__):
self.llm.model_kwargs = {"temperature": temperature}
else:
Expand Down
Loading