diff --git a/pyproject.toml b/pyproject.toml index 60a646f..a2d3458 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "asyncio", + "openai", "tiktoken", "aiolimiter", "tqdm", diff --git a/src/coverup/coverup.py b/src/coverup/coverup.py index 4da6e80..c173059 100644 --- a/src/coverup/coverup.py +++ b/src/coverup/coverup.py @@ -2,6 +2,7 @@ import json import litellm # type: ignore import logging +import openai import subprocess import re import sys @@ -9,13 +10,6 @@ from pathlib import Path from datetime import datetime -from openai import ( - NotFoundError, - RateLimitError, - APITimeoutError, - OpenAIError, - BadRequestError, -) from .llm import * from .segment import * @@ -23,7 +17,7 @@ PREFIX = 'coverup' -DEFAULT_MODEL='gpt-4-1106-preview' +DEFAULT_MODEL='' # Model logic now in main() # Turn off most logging litellm.set_verbose = False @@ -118,6 +112,7 @@ def positive_int(value): def test_file_path(test_seq: int) -> Path: """Returns the Path for a test's file, given its sequence number.""" + global args return args.tests_dir / f"test_{PREFIX}_{test_seq}.py" @@ -432,7 +427,7 @@ async def do_chat(seg: CodeSegment, completion: dict) -> str: return await litellm.acreate(**completion) - except (RateLimitError, TimeoutError) as e: + except (openai.RateLimitError, openai.APITimeoutError) as e: # This message usually indicates out of money in account if 'You exceeded your current quota' in str(e): @@ -447,7 +442,7 @@ async def do_chat(seg: CodeSegment, completion: dict) -> str: state.inc_counter('R') await asyncio.sleep(sleep_time) - except BadRequestError as e: + except openai.BadRequestError as e: # usually "maximum context length" XXX check for this? log_write(seg, f"Error: {type(e)} {e}") return None # gives up this segment @@ -653,13 +648,16 @@ def main(): return 1 if 'OPENAI_API_KEY' in os.environ: - args.model = "openai/gpt-4" # FIXME - openai.key=os.environ['OPENAI_API_KEY'] - if 'OPENAI_ORGANIZATION' in os.environ: - openai.organization=os.environ['OPENAI_ORGANIZATION'] + if not args.model: + # args.model = "openai/gpt-4" + args.model = "openai/gpt-4-1106-preview" + # openai.key=os.environ['OPENAI_API_KEY'] + #if 'OPENAI_ORGANIZATION' in os.environ: + # openai.organization=os.environ['OPENAI_ORGANIZATION'] else: - # args.model = "bedrock/anthropic.claude-v2:1" - args.model = "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" + # args.model = "bedrock/anthropic.claude-v2:1" + if not args.model: + args.model = "bedrock/anthropic.claude-3-sonnet-20240229-v1:0" log_write('startup', f"Command: {' '.join(sys.argv)}") # --- (1) load or measure initial coverage, figure out segmentation --- diff --git a/src/coverup/utils.py b/src/coverup/utils.py index e00fb2b..47b0601 100644 --- a/src/coverup/utils.py +++ b/src/coverup/utils.py @@ -85,7 +85,11 @@ async def subprocess_run(args: str, check: bool = False, timeout: T.Optional[int except asyncio.TimeoutError: process.terminate() await process.wait() - raise subprocess.TimeoutExpired(args, timeout) from None + if timeout: + timeout_f = float(timeout) + else: + timeout_f = 0.0 + raise subprocess.TimeoutExpired(args, timeout_f) from None if check and process.returncode != 0: raise subprocess.CalledProcessError(process.returncode, args, output=output)