Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ability to list models from other providers #46

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions gptscript/gptscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,22 @@ class GPTScript:
def __init__(self, opts: GlobalOptions = None):
if opts is None:
opts = GlobalOptions()
self.opts = opts

GPTScript.__gptscript_count += 1

if GPTScript.__server_url == "":
GPTScript.__server_url = os.environ.get("GPTSCRIPT_URL", "127.0.0.1:0")

if GPTScript.__gptscript_count == 1 and os.environ.get("GPTSCRIPT_DISABLE_SERVER", "") != "true":
opts.toEnv()
self.opts.toEnv()

GPTScript.__process = Popen(
[_get_command(), "--listen-address", GPTScript.__server_url, "sdkserver"],
stdin=PIPE,
stdout=PIPE,
stderr=PIPE,
env=opts.Env,
env=self.opts.Env,
text=True,
encoding="utf-8",
)
Expand Down Expand Up @@ -81,18 +83,28 @@ def evaluate(
opts: Options = None,
event_handlers: list[Callable[[Run, CallFrame | RunFrame | PromptFrame], Awaitable[None]]] = None
) -> Run:
return Run("evaluate", tool, opts, self._server_url, event_handlers=event_handlers).next_chat(
"" if opts is None else opts.input
)
opts = opts if opts is not None else Options()
return Run(
"evaluate",
tool,
opts.merge_global_opts(self.opts),
self._server_url,
event_handlers=event_handlers,
).next_chat("" if opts is None else opts.input)

def run(
self, tool_path: str,
opts: Options = None,
event_handlers: list[Callable[[Run, CallFrame | RunFrame | PromptFrame], Awaitable[None]]] = None
) -> Run:
return Run("run", tool_path, opts, self._server_url, event_handlers=event_handlers).next_chat(
"" if opts is None else opts.input
)
opts = opts if opts is not None else Options()
return Run(
"run",
tool_path,
opts.merge_global_opts(self.opts),
self._server_url,
event_handlers=event_handlers,
).next_chat("" if opts is None else opts.input)

async def parse(self, file_path: str, disable_cache: bool = False) -> list[Text | Tool]:
out = await self._run_basic_command("parse", {"file": file_path, "disableCache": disable_cache})
Expand Down Expand Up @@ -139,8 +151,16 @@ async def version(self) -> str:
async def list_tools(self) -> str:
return await self._run_basic_command("list-tools")

async def list_models(self) -> list[str]:
return (await self._run_basic_command("list-models")).split("\n")
async def list_models(self, providers: list[str] = None, credential_overrides: list[str] = None) -> list[str]:
if self.opts.DefaultModelProvider != "":
if providers is None:
providers = []
providers.append(self.opts.DefaultModelProvider)

return (await self._run_basic_command(
"list-models",
{"providers": providers, "credentialOverrides": credential_overrides}
)).split("\n")


def _get_command():
Expand Down
41 changes: 37 additions & 4 deletions gptscript/opts.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
import os
from typing import Mapping
from typing import Mapping, Self


class GlobalOptions:
def __init__(self,
apiKey: str = "", baseURL: str = "", defaultModelProvider: str = "", defaultModel: str = "",
env: Mapping[str, str] = None):
def __init__(
self,
apiKey: str = "",
baseURL: str = "",
defaultModelProvider: str = "",
defaultModel: str = "",
env: Mapping[str, str] = None,
):
self.APIKey = apiKey
self.BaseURL = baseURL
self.DefaultModel = defaultModel
self.DefaultModelProvider = defaultModelProvider
self.Env = env

def merge(self, other: Self) -> Self:
cp = self.__class__()
if other is None:
return cp
cp.APIKey = other.APIKey if other.APIKey != "" else self.APIKey
cp.BaseURL = other.BaseURL if other.BaseURL != "" else self.BaseURL
cp.DefaultModel = other.DefaultModel if other.DefaultModel != "" else self.DefaultModel
cp.DefaultModelProvider = other.DefaultModelProvider if other.DefaultModelProvider != "" else self.DefaultModelProvider
cp.Env = (other.Env or []).extend(self.Env or [])
return cp

def toEnv(self):
if self.Env is None:
self.Env = os.environ.copy()
Expand Down Expand Up @@ -56,3 +72,20 @@ def __init__(self,
self.location = location
self.env = env
self.forceSequential = forceSequential

def merge_global_opts(self, other: GlobalOptions) -> Self:
cp = super().merge(other)
if other is None:
return cp
cp.input = self.input
cp.disableCache = self.disableCache
cp.subTool = self.subTool
cp.workspace = self.workspace
cp.chatState = self.chatState
cp.confirm = self.confirm
cp.prompt = self.prompt
cp.credentialOverrides = self.credentialOverrides
cp.location = self.location
cp.env = self.env
cp.forceSequential = self.forceSequential
return cp
40 changes: 36 additions & 4 deletions tests/test_gptscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ def gptscript():
if os.getenv("OPENAI_API_KEY") is None:
pytest.fail("OPENAI_API_KEY not set", pytrace=False)
try:
# Start an initial GPTScript instance.
# This one doesn't have any options, but it's there to ensure that using another instance works as expected in all cases.
g_first = GPTScript()
gptscript = GPTScript(GlobalOptions(apiKey=os.getenv("OPENAI_API_KEY")))
yield gptscript
gptscript.close()
g_first.close()
except Exception as e:
pytest.fail(e, pytrace=False)

Expand Down Expand Up @@ -111,6 +115,33 @@ async def test_list_models(gptscript):
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"


@pytest.mark.asyncio
async def test_list_models_from_provider(gptscript):
models = await gptscript.list_models(
providers=["github.com/gptscript-ai/claude3-anthropic-provider"],
credential_overrides=["github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"],
)
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"
for model in models:
assert model.startswith("claude-3-"), "Unexpected model name"
assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"


@pytest.mark.asyncio
async def test_list_models_from_default_provider():
g = GPTScript(GlobalOptions(defaultModelProvider="github.com/gptscript-ai/claude3-anthropic-provider"))
try:
models = await g.list_models(
credential_overrides=["github.com/gptscript-ai/claude3-anthropic-provider/credential:ANTHROPIC_API_KEY"],
)
assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list"
for model in models:
assert model.startswith("claude-3-"), "Unexpected model name"
assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name"
finally:
g.close()


@pytest.mark.asyncio
async def test_list_tools(gptscript):
out = await gptscript.list_tools()
Expand Down Expand Up @@ -472,10 +503,11 @@ async def process_event(r: Run, frame: CallFrame | RunFrame | PromptFrame):
event_content += output.content

tool = ToolDef(tools=["sys.exec"], instructions="List the files in the current directory as '.'.")
out = await gptscript.evaluate(tool,
Options(confirm=True, disableCache=True),
event_handlers=[process_event],
).text()
out = await gptscript.evaluate(
tool,
Options(confirm=True, disableCache=True),
event_handlers=[process_event],
).text()

assert confirm_event_found, "No confirm event"
# Running the `dir` command in Windows will give the contents of the tests directory
Expand Down