Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bedrock support [WIP] #2

Merged
merged 6 commits into from
Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ classifiers = [
requires-python = ">=3.10"
dependencies = [
"asyncio",
"openai==0.28",
"openai",
"tiktoken",
"aiolimiter",
"tqdm",
"llm_utils",
"slipcover>=1.0.3"
"slipcover>=1.0.3",
"litellm>=1.33.1"
]

[project.scripts]
Expand Down
71 changes: 53 additions & 18 deletions src/coverup/coverup.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import asyncio
import openai
import json
import litellm # type: ignore
import logging
import openai
import subprocess
from pathlib import Path
import typing as T
import re
import sys
import typing as T

from pathlib import Path
from datetime import datetime

from .llm import *
from .segment import *
from .testrunner import *


PREFIX = 'coverup'
DEFAULT_MODEL='gpt-4-1106-preview'

# Turn off most logging
litellm.set_verbose = False
logging.getLogger().setLevel(logging.ERROR)
# Ignore unavailable parameters
litellm.drop_params=True

def parse_args(args=None):
import argparse
Expand All @@ -40,7 +48,7 @@ def Path_dir(value):
ap.add_argument('--no-checkpoint', action='store_const', const=None, dest='checkpoint', default=argparse.SUPPRESS,
help=f'disables checkpoint')

ap.add_argument('--model', type=str, default=DEFAULT_MODEL,
ap.add_argument('--model', type=str,
help='OpenAI model to use')

ap.add_argument('--model-temperature', type=str, default=0,
Expand Down Expand Up @@ -103,6 +111,7 @@ def positive_int(value):

def test_file_path(test_seq: int) -> Path:
"""Returns the Path for a test's file, given its sequence number."""
global args
return args.tests_dir / f"test_{PREFIX}_{test_seq}.py"


Expand Down Expand Up @@ -413,11 +422,9 @@ async def do_chat(seg: CodeSegment, completion: dict) -> str:
log_write(seg, f"Error: too many tokens for rate limit ({e})")
return None # gives up this segment

return await openai.ChatCompletion.acreate(**completion)
return await litellm.acreate(**completion)

except (openai.error.ServiceUnavailableError,
openai.error.RateLimitError,
openai.error.Timeout) as e:
except (openai.RateLimitError, openai.APITimeoutError) as e:

# This message usually indicates out of money in account
if 'You exceeded your current quota' in str(e):
Expand All @@ -432,13 +439,12 @@ async def do_chat(seg: CodeSegment, completion: dict) -> str:
state.inc_counter('R')
await asyncio.sleep(sleep_time)

except openai.error.InvalidRequestError as e:
except openai.BadRequestError as e:
# usually "maximum context length" XXX check for this?
log_write(seg, f"Error: {type(e)} {e}")
return None # gives up this segment

except (openai.error.APIConnectionError,
openai.error.APIError) as e:
except (ConnectionError) as e:
log_write(seg, f"Error: {type(e)} {e}")
# usually a server-side error... just retry right away
state.inc_counter('R')
Expand Down Expand Up @@ -589,6 +595,7 @@ def add_to_pythonpath(source_dir: Path):


def main():

from collections import defaultdict
import os

Expand All @@ -612,14 +619,42 @@ def main():
token_rate_limit = AsyncLimiter(*limit)
# TODO also add request limit, and use 'await asyncio.gather(t.acquire(tokens), r.acquire())' to acquire both

if 'OPENAI_API_KEY' not in os.environ:
print("Please place your OpenAI key in an environment variable named OPENAI_API_KEY and try again.")
return 1

openai.key=os.environ['OPENAI_API_KEY']
if 'OPENAI_ORGANIZATION' in os.environ:
openai.organization=os.environ['OPENAI_ORGANIZATION']
# Check for an API key for OpenAI or Amazon Bedrock.
if 'OPENAI_API_KEY' not in os.environ:
if not all(x in os.environ for x in ['AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'AWS_REGION_NAME']):
print("You need a key (or keys) from an AI service to use CoverUp.")
print()
print("OpenAI:")
print(" You can get a key here: https://platform.openai.com/api-keys")
print(" Set the environment variable OPENAI_API_KEY to your key value:")
print(" export OPENAI_API_KEY=<your key>")
print()
print()
print("Bedrock:")
print(" To use Bedrock, you need an AWS account.")
print(" Set the following environment variables:")
print(" export AWS_ACCESS_KEY_ID=<your key id>")
print(" export AWS_SECRET_ACCESS_KEY=<your secret key>")
print(" export AWS_REGION_NAME=us-west-2")
print(" You also need to request access to Claude:")
print(
" https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html#manage-model-access"
)
print()
return 1

if 'OPENAI_API_KEY' in os.environ:
if not args.model:
# args.model = "openai/gpt-4"
args.model = "openai/gpt-4-1106-preview"
# openai.key=os.environ['OPENAI_API_KEY']
#if 'OPENAI_ORGANIZATION' in os.environ:
# openai.organization=os.environ['OPENAI_ORGANIZATION']
else:
# args.model = "bedrock/anthropic.claude-v2:1"
if not args.model:
args.model = "bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
log_write('startup', f"Command: {' '.join(sys.argv)}")

# --- (1) load or measure initial coverage, figure out segmentation ---
Expand Down
6 changes: 5 additions & 1 deletion src/coverup/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ async def subprocess_run(args: str, check: bool = False, timeout: T.Optional[int
except asyncio.TimeoutError:
process.terminate()
await process.wait()
raise subprocess.TimeoutExpired(args, timeout) from None
if timeout:
timeout_f = float(timeout)
else:
timeout_f = 0.0
raise subprocess.TimeoutExpired(args, timeout_f) from None
emeryberger marked this conversation as resolved.
Show resolved Hide resolved

if check and process.returncode != 0:
raise subprocess.CalledProcessError(process.returncode, args, output=output)
Expand Down
Loading