From 583aad0a263a70ef7dd3b8dc58f07f5154be1c42 Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Fri, 16 Aug 2024 18:08:36 -0400 Subject: [PATCH] feat: add ability to list models from other providers Signed-off-by: Donnie Adams --- gptscript/gptscript.py | 40 ++++++++++++++++++++++++++++++---------- gptscript/opts.py | 41 +++++++++++++++++++++++++++++++++++++---- tests/test_gptscript.py | 40 ++++++++++++++++++++++++++++++++++++---- 3 files changed, 103 insertions(+), 18 deletions(-) diff --git a/gptscript/gptscript.py b/gptscript/gptscript.py index 8f20a2f..5cb3362 100644 --- a/gptscript/gptscript.py +++ b/gptscript/gptscript.py @@ -26,20 +26,22 @@ class GPTScript: def __init__(self, opts: GlobalOptions = None): if opts is None: opts = GlobalOptions() + self.opts = opts + GPTScript.__gptscript_count += 1 if GPTScript.__server_url == "": GPTScript.__server_url = os.environ.get("GPTSCRIPT_URL", "127.0.0.1:0") if GPTScript.__gptscript_count == 1 and os.environ.get("GPTSCRIPT_DISABLE_SERVER", "") != "true": - opts.toEnv() + self.opts.toEnv() GPTScript.__process = Popen( [_get_command(), "--listen-address", GPTScript.__server_url, "sdkserver"], stdin=PIPE, stdout=PIPE, stderr=PIPE, - env=opts.Env, + env=self.opts.Env, text=True, encoding="utf-8", ) @@ -81,18 +83,28 @@ def evaluate( opts: Options = None, event_handlers: list[Callable[[Run, CallFrame | RunFrame | PromptFrame], Awaitable[None]]] = None ) -> Run: - return Run("evaluate", tool, opts, self._server_url, event_handlers=event_handlers).next_chat( - "" if opts is None else opts.input - ) + opts = opts if opts is not None else Options() + return Run( + "evaluate", + tool, + opts.merge_global_opts(self.opts), + self._server_url, + event_handlers=event_handlers, + ).next_chat("" if opts is None else opts.input) def run( self, tool_path: str, opts: Options = None, event_handlers: list[Callable[[Run, CallFrame | RunFrame | PromptFrame], Awaitable[None]]] = None ) -> Run: - return Run("run", tool_path, opts, self._server_url, event_handlers=event_handlers).next_chat( - "" if opts is None else opts.input - ) + opts = opts if opts is not None else Options() + return Run( + "run", + tool_path, + opts.merge_global_opts(self.opts), + self._server_url, + event_handlers=event_handlers, + ).next_chat("" if opts is None else opts.input) async def parse(self, file_path: str, disable_cache: bool = False) -> list[Text | Tool]: out = await self._run_basic_command("parse", {"file": file_path, "disableCache": disable_cache}) @@ -139,8 +151,16 @@ async def version(self) -> str: async def list_tools(self) -> str: return await self._run_basic_command("list-tools") - async def list_models(self) -> list[str]: - return (await self._run_basic_command("list-models")).split("\n") + async def list_models(self, providers: list[str] = None, credential_overrides: list[str] = None) -> list[str]: + if self.opts.DefaultModelProvider != "": + if providers is None: + providers = [] + providers.append(self.opts.DefaultModelProvider) + + return (await self._run_basic_command( + "list-models", + {"providers": providers, "credentialOverrides": credential_overrides} + )).split("\n") def _get_command(): diff --git a/gptscript/opts.py b/gptscript/opts.py index 7b00a99..fb64468 100644 --- a/gptscript/opts.py +++ b/gptscript/opts.py @@ -1,17 +1,33 @@ import os -from typing import Mapping +from typing import Mapping, Self class GlobalOptions: - def __init__(self, - apiKey: str = "", baseURL: str = "", defaultModelProvider: str = "", defaultModel: str = "", - env: Mapping[str, str] = None): + def __init__( + self, + apiKey: str = "", + baseURL: str = "", + defaultModelProvider: str = "", + defaultModel: str = "", + env: Mapping[str, str] = None, + ): self.APIKey = apiKey self.BaseURL = baseURL self.DefaultModel = defaultModel self.DefaultModelProvider = defaultModelProvider self.Env = env + def merge(self, other: Self) -> Self: + cp = self.__class__() + if other is None: + return cp + cp.APIKey = other.APIKey if other.APIKey != "" else self.APIKey + cp.BaseURL = other.BaseURL if other.BaseURL != "" else self.BaseURL + cp.DefaultModel = other.DefaultModel if other.DefaultModel != "" else self.DefaultModel + cp.DefaultModelProvider = other.DefaultModelProvider if other.DefaultModelProvider != "" else self.DefaultModelProvider + cp.Env = (other.Env or []).extend(self.Env or []) + return cp + def toEnv(self): if self.Env is None: self.Env = os.environ.copy() @@ -56,3 +72,20 @@ def __init__(self, self.location = location self.env = env self.forceSequential = forceSequential + + def merge_global_opts(self, other: GlobalOptions) -> Self: + cp = super().merge(other) + if other is None: + return cp + cp.input = self.input + cp.disableCache = self.disableCache + cp.subTool = self.subTool + cp.workspace = self.workspace + cp.chatState = self.chatState + cp.confirm = self.confirm + cp.prompt = self.prompt + cp.credentialOverrides = self.credentialOverrides + cp.location = self.location + cp.env = self.env + cp.forceSequential = self.forceSequential + return cp diff --git a/tests/test_gptscript.py b/tests/test_gptscript.py index 69dcf6e..dc89a94 100644 --- a/tests/test_gptscript.py +++ b/tests/test_gptscript.py @@ -25,9 +25,13 @@ def gptscript(): if os.getenv("OPENAI_API_KEY") is None: pytest.fail("OPENAI_API_KEY not set", pytrace=False) try: + # Start an initial GPTScript instance. + # This one doesn't have any options, but it's there to ensure that using another instance works as expected in all cases. + g_first = GPTScript() gptscript = GPTScript(GlobalOptions(apiKey=os.getenv("OPENAI_API_KEY"))) yield gptscript gptscript.close() + g_first.close() except Exception as e: pytest.fail(e, pytrace=False) @@ -111,6 +115,33 @@ async def test_list_models(gptscript): assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list" +@pytest.mark.asyncio +async def test_list_models_from_provider(gptscript): + models = await gptscript.list_models( + providers=["github.com/gptscript-ai/claude3-anthropic-provider"], + credential_overrides=["github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"], + ) + assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list" + for model in models: + assert model.startswith("claude-3-"), "Unexpected model name" + assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name" + + +@pytest.mark.asyncio +async def test_list_models_from_default_provider(): + g = GPTScript(GlobalOptions(defaultModelProvider="github.com/gptscript-ai/claude3-anthropic-provider")) + try: + models = await g.list_models( + credential_overrides=["github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"], + ) + assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list" + for model in models: + assert model.startswith("claude-3-"), "Unexpected model name" + assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name" + finally: + g.close() + + @pytest.mark.asyncio async def test_list_tools(gptscript): out = await gptscript.list_tools() @@ -472,10 +503,11 @@ async def process_event(r: Run, frame: CallFrame | RunFrame | PromptFrame): event_content += output.content tool = ToolDef(tools=["sys.exec"], instructions="List the files in the current directory as '.'.") - out = await gptscript.evaluate(tool, - Options(confirm=True, disableCache=True), - event_handlers=[process_event], - ).text() + out = await gptscript.evaluate( + tool, + Options(confirm=True, disableCache=True), + event_handlers=[process_event], + ).text() assert confirm_event_found, "No confirm event" # Running the `dir` command in Windows will give the contents of the tests directory