diff --git a/CHANGELOG.md b/CHANGELOG.md index b5b2c9010..cd193c52c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +## Unreleased + +- New "HTTP Retries" display (replacing the "HTTP Rate Limits" display) which counts all retries and does so much more consistently and accurately across providers. +- The `ModelAPI` class now has a `should_retry()` method that replaces the deprecated `is_rate_limit()` method. +- The "Generate..." progress message in the Running Samples view now shows the number of retries for the active call to `generate()`. +- New `inspect trace http` command which will show all HTTP requests for a run. +- More consistent use of `max_retries` and `timeout` configuration options. These options now exclusively control Inspect's outer retry handler; model providers use their default behaviour for the inner request, which is typically 2-4 retries and a service-appropriate timeout. +- Logging: Inspect no longer sets the global log level nor does it allow its own messages to propagate to the global handler (eliminating the possiblity of duplicate display). This should improve compatibility with applications that have their own custom logging configured. + ## v0.3.72 (03 March 2025) - Computer: Updated tool definition to match improvements in Claude Sonnet 3.7. diff --git a/docs/errors-and-limits.qmd b/docs/errors-and-limits.qmd index 600b1936d..fa1f66e99 100644 --- a/docs/errors-and-limits.qmd +++ b/docs/errors-and-limits.qmd @@ -116,7 +116,7 @@ def intercode_ctf(): Working time is computed based on total clock time minus time spent on (a) unsuccessful model generations (e.g. rate limited requests); and (b) waiting on shared resources (e.g. Docker containers or subprocess execution). ::: {.callout-note appearance="simple"} -In order to distinguish successful generate requests from rate limited and retried requests, Inspect installs hooks into the HTTP client of various model packages. This is not possible for some models (`google`, `vertex`, `azureai`, and `goodfire`), and in these cases the `working_time` will include any internal retries that the model client performs. +In order to distinguish successful generate requests from rate limited and retried requests, Inspect installs hooks into the HTTP client of various model packages. This is not possible for some models (`vertex`, `azureai`, and `goodfire`), and in these cases the `working_time` will include any internal retries that the model client performs. ::: diff --git a/docs/options.qmd b/docs/options.qmd index c878fa8a4..d0516877b 100644 --- a/docs/options.qmd +++ b/docs/options.qmd @@ -86,8 +86,8 @@ Below are sections for the various categories of options supported by `inspect e | `--parallel-tool-calls` | Whether to enable calling multiple functions during tool use (defaults to True) OpenAI and Groq only. | | `--max-tool-output` | Maximum size of tool output (in bytes). Defaults to 16 \* 1024. | | `--internal-tools` | Whether to automatically map tools to model internal implementations (e.g. 'computer' for Anthropic). | -| `--max-retries` | Maximum number of times to retry request (defaults to 5) | -| `--timeout` | Request timeout (in seconds). | +| `--max-retries` | Maximum number of times to retry generate request (defaults to unlimited) | +| `--timeout` | Generate timeout in seconds (defaults to no timeout) | ## Tasks and Solvers diff --git a/docs/parallelism.qmd b/docs/parallelism.qmd index 12d8e45c7..7611c7967 100644 --- a/docs/parallelism.qmd +++ b/docs/parallelism.qmd @@ -28,9 +28,13 @@ $ inspect eval --model openai/gpt-4 --max-connections 20 The default value for max connections is 10. By increasing it we might get better performance due to higher parallelism, however we might get *worse* performance if this causes us to frequently hit rate limits (which are retried with exponential backoff). The "correct" max connections for your evaluations will vary based on your actual rate limit and the size and complexity of your evaluations. +::: {.callout-note appearance="simple"} +Note that max connections is applied per-model. This means that if you use a grader model from a provider distinct from the one you are evaluating you will get extra concurrency (as each model will enforce its own max connections). +::: + ### Rate Limits -When you run an eval you'll see information reported on the current active connection usage as well as the number of HTTP rate limit errors that have been encountered (note that Inspect will automatically retry on rate limits and other errors likely to be transient): +When you run an eval you'll see information reported on the current active connection usage as well as the number of HTTP retries that have occurred (Inspect will automatically retry on rate limits and other errors likely to be transient): ![](images/rate-limit.png){fig-alt="The Inspect task results displayed in the terminal. The number of HTTP rate limit errors that have occurred (25) is printed in the bottom right of the task results."} @@ -40,20 +44,41 @@ You should experiment with various values for max connections at different times ### Limiting Retries -By default, Inspect will continue to retry model API calls (with exponential backoff) indefinitely when a rate limit error (HTTP status 429) is returned. You can limit these retries by using the `max_retries` and `timeout` eval options. For example: +By default, Inspect will retry model API calls indefinitely (with exponential backoff) when a recoverable HTTP error occurs. The initial backoff is 3 seconds and exponentiation will result in a 25 minute wait for the 10th request (then 30 minutes for the 11th and subsequent requests). You can limit Inspect's retries using the `--max-retries` option: + +``` bash +inspect eval --model openai/gpt-4 --max-retries 10 +``` + +Note that model interfaces themselves may have internal retry behavior (for example, the `openai` and `anthropic` packages both retry twice by default). + +You can put a limit on the total time for retries using the `--timeout` option: ``` bash -$ inspect eval --model openai/gpt-4 --max-retries 10 --timeout 600 +inspect eval --model openai/gpt-4 --timeout 600 ``` +### Debugging Retries + If you want more insight into Model API connections and retries, specify `log_level=http`. For example: ``` bash -$ inspect eval --model openai/gpt-4 --log-level=http +inspect eval --model openai/gpt-4 --log-level=http +``` + +You can also view all of the HTTP requests for the current (or most recent) evaluation run using the `inspect trace http` command. For example: + +``` bash +inspect trace http # show all http requests +inspect trace http --failed # show only failed requests ``` ::: {.callout-note appearance="simple"} -Note that max connections is applied per-model. This means that if you use a grader model from a provider distinct from the one you are evaluating you will get extra concurrency (as each model will enforce its own max connections). +Note that the `inspect trace http` command is currently available only in the development version of Inspect. To install the development version from GitHub: + +``` bash +pip install git+https://github.com/UKGovernmentBEIS/inspect_ai +``` ::: ## Multiple Models {#sec-multiple-models} diff --git a/docs/tracing.qmd b/docs/tracing.qmd index b68b1fcec..7e1c41cbf 100644 --- a/docs/tracing.qmd +++ b/docs/tracing.qmd @@ -69,6 +69,31 @@ As with the `inspect trace dump` command, you can apply a filter when listing an inspect trace anomalies --filter model ``` +## HTTP Requests + +::: {.callout-note appearance="simple"} +Note that the `inspect trace http` command described below is currently available only in the development version of Inspect. To install the development version from GitHub: + +``` bash +pip install git+https://github.com/UKGovernmentBEIS/inspect_ai +``` +::: + +You can view all of the HTTP requests for the current (or most recent) evaluation run using the `inspect trace http` command. For example: + +``` bash +inspect trace http # show all http requests +inspect trace http --failed # show only failed requests +``` + +The `--filter` parameter also works here, for example: + +```bash +inspect trace http --failed --filter bedrock +``` + + + ## Tracing API {#tracing-api} In addition to the standard set of actions which are trace logged, you can do your own custom trace logging using the `trace_action()` and `trace_message()` APIs. Trace logging is a great way to make sure that logging context is *always captured* (since the last 10 trace logs are always available) without cluttering up the console or eval transcripts. diff --git a/src/inspect_ai/_cli/eval.py b/src/inspect_ai/_cli/eval.py index 6c3c28aea..f89e3eb8a 100644 --- a/src/inspect_ai/_cli/eval.py +++ b/src/inspect_ai/_cli/eval.py @@ -11,7 +11,6 @@ DEFAULT_EPOCHS, DEFAULT_LOG_LEVEL_TRANSCRIPT, DEFAULT_MAX_CONNECTIONS, - DEFAULT_MAX_RETRIES, ) from inspect_ai._util.file import filesystem from inspect_ai._util.samples import parse_sample_id, parse_samples_limit @@ -47,9 +46,9 @@ NO_SCORE_DISPLAY = "Do not display scoring metrics in realtime." MAX_CONNECTIONS_HELP = f"Maximum number of concurrent connections to Model API (defaults to {DEFAULT_MAX_CONNECTIONS})" MAX_RETRIES_HELP = ( - f"Maximum number of times to retry request (defaults to {DEFAULT_MAX_RETRIES})" + "Maximum number of times to retry model API requests (defaults to unlimited)" ) -TIMEOUT_HELP = "Request timeout (in seconds)." +TIMEOUT_HELP = "Model API request timeout in seconds (defaults to no timeout)" def eval_options(func: Callable[..., Any]) -> Callable[..., click.Context]: diff --git a/src/inspect_ai/_cli/trace.py b/src/inspect_ai/_cli/trace.py index 7e6ed947d..c0d2bca4a 100644 --- a/src/inspect_ai/_cli/trace.py +++ b/src/inspect_ai/_cli/trace.py @@ -15,6 +15,7 @@ from inspect_ai._util.error import PrerequisiteError from inspect_ai._util.trace import ( ActionTraceRecord, + TraceRecord, inspect_trace_dir, list_trace_files, read_trace_file, @@ -84,6 +85,41 @@ def dump_command(trace_file: str | None, filter: str | None) -> None: ) +@trace_command.command("http") +@click.argument("trace-file", type=str, required=False) +@click.option( + "--filter", + type=str, + help="Filter (applied to trace message field).", +) +@click.option( + "--failed", + type=bool, + is_flag=True, + default=False, + help="Show only failed HTTP requests (non-200 status)", +) +def http_command(trace_file: str | None, filter: str | None, failed: bool) -> None: + """View all HTTP requests in the trace log.""" + _, traces = _read_traces(trace_file, "HTTP", filter) + + last_timestamp = "" + table = Table(Column(), Column(), box=None) + for trace in traces: + if failed and "200 OK" in trace.message: + continue + timestamp = trace.timestamp.split(".")[0] + if timestamp == last_timestamp: + timestamp = "" + else: + last_timestamp = timestamp + timestamp = f"[{timestamp}]" + table.add_row(timestamp, trace.message) + + if table.row_count > 0: + r_print(table) + + @trace_command.command("anomalies") @click.argument("trace-file", type=str, required=False) @click.option( @@ -99,12 +135,7 @@ def dump_command(trace_file: str | None, filter: str | None) -> None: ) def anomolies_command(trace_file: str | None, filter: str | None, all: bool) -> None: """Look for anomalies in a trace file (never completed or cancelled actions).""" - trace_file_path = _resolve_trace_file_path(trace_file) - traces = read_trace_file(trace_file_path) - - if filter: - filter = filter.lower() - traces = [trace for trace in traces if filter in trace.message.lower()] + trace_file_path, traces = _read_traces(trace_file, None, filter) # Track started actions running_actions: dict[str, ActionTraceRecord] = {} @@ -199,6 +230,22 @@ def print_fn(o: RenderableType) -> None: print(console.export_text(styles=True).strip()) +def _read_traces( + trace_file: str | None, level: str | None = None, filter: str | None = None +) -> tuple[Path, list[TraceRecord]]: + trace_file_path = _resolve_trace_file_path(trace_file) + traces = read_trace_file(trace_file_path) + + if level: + traces = [trace for trace in traces if trace.level == level] + + if filter: + filter = filter.lower() + traces = [trace for trace in traces if filter in trace.message.lower()] + + return (trace_file_path, traces) + + def _print_bucket( print_fn: Callable[[RenderableType], None], label: str, diff --git a/src/inspect_ai/_display/core/footer.py b/src/inspect_ai/_display/core/footer.py index 1d2d466a1..02081cfd9 100644 --- a/src/inspect_ai/_display/core/footer.py +++ b/src/inspect_ai/_display/core/footer.py @@ -1,7 +1,7 @@ from rich.console import RenderableType from rich.text import Text -from inspect_ai._util.logger import http_rate_limit_count +from inspect_ai._util.retry import http_retries_count from inspect_ai.util._concurrency import concurrency_status from inspect_ai.util._throttle import throttle @@ -26,12 +26,12 @@ def task_resources() -> str: def task_counters(counters: dict[str, str]) -> str: - return task_dict(counters | task_http_rate_limits()) + return task_dict(counters | task_http_retries()) -def task_http_rate_limits() -> dict[str, str]: - return {"HTTP rate limits": f"{http_rate_limit_count():,}"} +def task_http_retries() -> dict[str, str]: + return {"HTTP retries": f"{http_retries_count():,}"} -def task_http_rate_limits_str() -> str: - return f"HTTP rate limits: {http_rate_limit_count():,}" +def task_http_retries_str() -> str: + return f"HTTP retries: {http_retries_count():,}" diff --git a/src/inspect_ai/_display/plain/display.py b/src/inspect_ai/_display/plain/display.py index 9c3f199c5..376b46683 100644 --- a/src/inspect_ai/_display/plain/display.py +++ b/src/inspect_ai/_display/plain/display.py @@ -22,7 +22,7 @@ TaskSpec, TaskWithResult, ) -from ..core.footer import task_http_rate_limits_str +from ..core.footer import task_http_retries_str from ..core.panel import task_panel, task_targets from ..core.results import task_metric, tasks_results @@ -182,7 +182,7 @@ def _print_status(self) -> None: status_parts.append(resources) # Add rate limits - rate_limits = task_http_rate_limits_str() + rate_limits = task_http_retries_str() if rate_limits: status_parts.append(rate_limits) diff --git a/src/inspect_ai/_display/textual/widgets/footer.py b/src/inspect_ai/_display/textual/widgets/footer.py index fb2e5f396..10997d9a3 100644 --- a/src/inspect_ai/_display/textual/widgets/footer.py +++ b/src/inspect_ai/_display/textual/widgets/footer.py @@ -36,3 +36,7 @@ def watch_left(self, new_left: RenderableType) -> None: def watch_right(self, new_right: RenderableType) -> None: footer_right = cast(Static, self.query_one("#footer-right")) footer_right.update(new_right) + if footer_right.tooltip is None: + footer_right.tooltip = ( + "Execute 'inspect trace http' for a log of all HTTP requests." + ) diff --git a/src/inspect_ai/_display/textual/widgets/samples.py b/src/inspect_ai/_display/textual/widgets/samples.py index be8a23942..c0a833f60 100644 --- a/src/inspect_ai/_display/textual/widgets/samples.py +++ b/src/inspect_ai/_display/textual/widgets/samples.py @@ -506,6 +506,7 @@ async def sync_sample(self, sample: ActiveSample | None) -> None: # track the sample self.sample = sample + status_group = self.query_one("#" + self.STATUS_GROUP) pending_status = self.query_one("#" + self.PENDING_STATUS) timeout_tool = self.query_one("#" + self.TIMEOUT_TOOL_CALL) clock = self.query_one(Clock) @@ -537,11 +538,19 @@ async def sync_sample(self, sample: ActiveSample | None) -> None: pending_caption = cast( Static, self.query_one("#" + self.PENDING_CAPTION) ) - pending_caption_text = ( - "Generating..." - if isinstance(last_event, ModelEvent) - else "Executing..." - ) + if isinstance(last_event, ModelEvent): + # see if there are retries in play + if sample.retry_count > 0: + suffix = "retry" if sample.retry_count == 1 else "retries" + pending_caption_text = ( + f"Generating ({sample.retry_count:,} {suffix})..." + ) + else: + pending_caption_text = "Generating..." + else: + pending_caption_text = "Executing..." + status_group.styles.width = max(22, len(pending_caption_text)) + pending_caption.update( Text.from_markup(f"[italic]{pending_caption_text}[/italic]") ) diff --git a/src/inspect_ai/_eval/context.py b/src/inspect_ai/_eval/context.py index bbe4c8c60..c9b6337a4 100644 --- a/src/inspect_ai/_eval/context.py +++ b/src/inspect_ai/_eval/context.py @@ -1,6 +1,6 @@ from inspect_ai._util.dotenv import init_dotenv from inspect_ai._util.hooks import init_hooks -from inspect_ai._util.logger import init_http_rate_limit_count, init_logger +from inspect_ai._util.logger import init_logger from inspect_ai.approval._apply import have_tool_approval, init_tool_approval from inspect_ai.approval._human.manager import init_human_approval_manager from inspect_ai.approval._policy import ApprovalPolicy @@ -20,7 +20,6 @@ def init_eval_context( init_logger(log_level, log_level_transcript) init_concurrency() init_max_subprocesses(max_subprocesses) - init_http_rate_limit_count() init_hooks() init_active_samples() init_human_approval_manager() diff --git a/src/inspect_ai/_eval/task/sandbox.py b/src/inspect_ai/_eval/task/sandbox.py index 9f9bd89ae..a0309df12 100644 --- a/src/inspect_ai/_eval/task/sandbox.py +++ b/src/inspect_ai/_eval/task/sandbox.py @@ -15,10 +15,9 @@ from inspect_ai._eval.task.task import Task from inspect_ai._eval.task.util import task_run_dir -from inspect_ai._util.constants import DEFAULT_MAX_RETRIES, DEFAULT_TIMEOUT from inspect_ai._util.file import file, filesystem +from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt from inspect_ai._util.registry import registry_unqualified_name -from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt from inspect_ai._util.url import data_uri_to_base64, is_data_uri, is_http_url from inspect_ai.dataset import Sample from inspect_ai.util._concurrency import concurrency @@ -186,14 +185,14 @@ async def _retrying_httpx_get( url: str, client: httpx.AsyncClient = httpx.AsyncClient(), timeout: int = 30, # per-attempt timeout - max_retries: int = DEFAULT_MAX_RETRIES, - total_timeout: int = DEFAULT_TIMEOUT, # timeout for the whole retry loop. not for an individual attempt + max_retries: int = 10, + total_timeout: int = 120, # timeout for the whole retry loop. not for an individual attempt ) -> bytes: @retry( wait=wait_exponential_jitter(), stop=(stop_after_attempt(max_retries) | stop_after_delay(total_timeout)), retry=retry_if_exception(httpx_should_retry), - before_sleep=log_retry_attempt(url), + before_sleep=log_httpx_retry_attempt(url), ) async def do_get() -> bytes: response = await client.get( diff --git a/src/inspect_ai/_util/constants.py b/src/inspect_ai/_util/constants.py index efb534e26..5bf90c732 100644 --- a/src/inspect_ai/_util/constants.py +++ b/src/inspect_ai/_util/constants.py @@ -6,8 +6,6 @@ PKG_NAME = Path(__file__).parent.parent.stem PKG_PATH = Path(__file__).parent.parent DEFAULT_EPOCHS = 1 -DEFAULT_MAX_RETRIES = 5 -DEFAULT_TIMEOUT = 120 DEFAULT_MAX_CONNECTIONS = 10 DEFAULT_MAX_TOKENS = 2048 DEFAULT_VIEW_PORT = 7575 diff --git a/src/inspect_ai/_util/http.py b/src/inspect_ai/_util/http.py index 430b95026..2826c7fe2 100644 --- a/src/inspect_ai/_util/http.py +++ b/src/inspect_ai/_util/http.py @@ -1,99 +1,3 @@ -import glob -import json -import os -import posixpath -from http import HTTPStatus -from http.server import SimpleHTTPRequestHandler -from io import BytesIO -from typing import Any -from urllib.parse import parse_qs, urlparse - -from .dev import is_dev_mode - - -class InspectHTTPRequestHandler(SimpleHTTPRequestHandler): - def __init__(self, *args: Any, directory: str, **kwargs: Any) -> None: - # note whether we are in dev mode (i.e. developing the package) - self.dev_mode = is_dev_mode() - - # initialize file serving directory - directory = os.path.abspath(directory) - super().__init__(*args, directory=directory, **kwargs) - - def do_GET(self) -> None: - if self.path.startswith("/api/events"): - self.handle_events() - else: - super().do_GET() - - def handle_events(self) -> None: - """Client polls for events (e.g. dev reload) ~ every 1 second.""" - query = parse_qs(urlparse(self.path).query) - params = dict(zip(query.keys(), [value[0] for value in query.values()])) - self.send_json(json.dumps(self.events_response(params))) - - def events_response(self, params: dict[str, str]) -> list[str]: - """Send back a 'reload' event if we have modified source files.""" - loaded_time = params.get("loaded_time", None) - return ( - ["reload"] if loaded_time and self.should_reload(int(loaded_time)) else [] - ) - - def translate_path(self, path: str) -> str: - """Ensure that paths don't escape self.directory.""" - translated = super().translate_path(path) - if not os.path.abspath(translated).startswith(self.directory): - return self.directory - else: - return translated - - def send_json(self, json: str | bytes) -> None: - if isinstance(json, str): - json = json.encode() - self.send_response(HTTPStatus.OK) - self.send_header("Content-type", "application/json") - self.end_headers() - self.copyfile(BytesIO(json), self.wfile) # type: ignore - - def send_response(self, code: int, message: str | None = None) -> None: - """No client side or proxy caches.""" - super().send_response(code, message) - self.send_header("Expires", "Fri, 01 Jan 1990 00:00:00 GMT") - self.send_header("Pragma", "no-cache") - self.send_header( - "Cache-Control", "no-cache, no-store, max-age=0, must-revalidate" - ) - - def guess_type(self, path: str | os.PathLike[str]) -> str: - _, ext = posixpath.splitext(path) - if not ext or ext == ".mjs" or ext == ".js": - return "application/javascript" - elif ext == ".md": - return "text/markdown" - else: - return super().guess_type(path) - - def log_error(self, format: str, *args: Any) -> None: - if self.dev_mode: - super().log_error(format, *args) - - def log_request(self, code: int | str = "-", size: int | str = "-") -> None: - """Don't log status 200 or 404 (too chatty).""" - if code not in [200, 404]: - super().log_request(code, size) - - def should_reload(self, loaded_time: int) -> bool: - if self.dev_mode: - for dir in self.reload_dirs(): - files = [ - os.stat(file).st_mtime - for file in glob.glob(f"{dir}/**/*", recursive=True) - ] - last_modified = max(files) * 1000 - if last_modified > loaded_time: - return True - - return False - - def reload_dirs(self) -> list[str]: - return [self.directory] +# see https://cloud.google.com/storage/docs/retry-strategy +def is_retryable_http_status(status_code: int) -> bool: + return status_code in [408, 429] or (500 <= status_code < 600) diff --git a/src/inspect_ai/_util/httpx.py b/src/inspect_ai/_util/httpx.py new file mode 100644 index 000000000..ec6b78cda --- /dev/null +++ b/src/inspect_ai/_util/httpx.py @@ -0,0 +1,60 @@ +import logging +from typing import Callable + +from httpx import ConnectError, ConnectTimeout, HTTPStatusError, ReadTimeout +from tenacity import RetryCallState + +from inspect_ai._util.constants import HTTP + +logger = logging.getLogger(__name__) + + +def httpx_should_retry(ex: BaseException) -> bool: + """Check whether an exception raised from httpx should be retried. + + Implements the strategy described here: https://cloud.google.com/storage/docs/retry-strategy + + Args: + ex (BaseException): Exception to examine for retry behavior + + Returns: + True if a retry should occur + """ + # httpx status exception + if isinstance(ex, HTTPStatusError): + # request timeout + if ex.response.status_code == 408: + return True + # lock timeout + elif ex.response.status_code == 409: + return True + # rate limit + elif ex.response.status_code == 429: + return True + # internal errors + elif ex.response.status_code >= 500: + return True + else: + return False + + # connection error + elif is_httpx_connection_error(ex): + return True + + # don't retry + else: + return False + + +def log_httpx_retry_attempt(context: str) -> Callable[[RetryCallState], None]: + def log_attempt(retry_state: RetryCallState) -> None: + logger.log( + HTTP, + f"{context} connection retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}", + ) + + return log_attempt + + +def is_httpx_connection_error(ex: BaseException) -> bool: + return isinstance(ex, ConnectTimeout | ConnectError | ConnectionError | ReadTimeout) diff --git a/src/inspect_ai/_util/logger.py b/src/inspect_ai/_util/logger.py index 6d649da65..ed65caf18 100644 --- a/src/inspect_ai/_util/logger.py +++ b/src/inspect_ai/_util/logger.py @@ -1,8 +1,6 @@ import atexit import os -import re from logging import ( - DEBUG, INFO, WARNING, FileHandler, @@ -44,10 +42,12 @@ # log handler that filters messages to stderr and the log file class LogHandler(RichHandler): - def __init__(self, levelno: int, transcript_levelno: int) -> None: - super().__init__(levelno, console=rich.get_console()) + def __init__( + self, capture_levelno: int, display_levelno: int, transcript_levelno: int + ) -> None: + super().__init__(capture_levelno, console=rich.get_console()) self.transcript_levelno = transcript_levelno - self.display_level = WARNING + self.display_level = display_levelno # log into an external file if requested via env var file_logger = os.environ.get("INSPECT_PY_LOGGER_FILE", None) self.file_logger = FileHandler(file_logger) if file_logger else None @@ -77,23 +77,6 @@ def __init__(self, levelno: int, transcript_levelno: int) -> None: @override def emit(self, record: LogRecord) -> None: - # demote httpx and return notifications to log_level http - if ( - record.name == "httpx" - or "http" in record.name - or "Retrying request" in record.getMessage() - ): - record.levelno = HTTP - record.levelname = HTTP_LOG_LEVEL - - # skip httpx event loop is closed errors - if "Event loop is closed" in record.getMessage(): - return - - # skip google-genai AFC message - if "AFC is enabled with max remote calls" in record.getMessage(): - return - # write to stderr if we are at or above the threshold if record.levelno >= self.display_level: super().emit(record) @@ -110,10 +93,9 @@ def emit(self, record: LogRecord) -> None: if self.trace_logger and record.levelno >= self.trace_logger_level: self.trace_logger.emit(record) - # eval log always gets info level and higher records - # eval log only gets debug or http if we opt-in - write = record.levelno >= self.transcript_levelno - notify_logger_record(record, write) + # eval log gets transcript level or higher + if record.levelno >= self.transcript_levelno: + log_to_transcript(record) @override def render_message(self, record: LogRecord, message: str) -> ConsoleRenderable: @@ -122,9 +104,7 @@ def render_message(self, record: LogRecord, message: str) -> ConsoleRenderable: # initialize logging -- this function can be called multiple times # in the lifetime of the process (the levelno will update globally) -def init_logger( - log_level: str | None = None, log_level_transcript: str | None = None -) -> None: +def init_logger(log_level: str | None, log_level_transcript: str | None = None) -> None: # backwards compatibility for 'tools' if log_level == "sandbox" or log_level == "tools": log_level = "trace" @@ -146,7 +126,7 @@ def validate_level(option: str, level: str) -> None: ).upper() validate_level("log level", log_level) - # reolve log file level + # reolve transcript log level log_level_transcript = ( log_level_transcript if log_level_transcript @@ -158,76 +138,40 @@ def validate_level(option: str, level: str) -> None: levelno = getLevelName(log_level) transcript_levelno = getLevelName(log_level_transcript) + # set capture level for our logs (we won't actually display/write all of them) + capture_level = min(TRACE, levelno, transcript_levelno) + # init logging handler on demand global _logHandler - removed_root_handlers = False if not _logHandler: - removed_root_handlers = remove_non_pytest_root_logger_handlers() - _logHandler = LogHandler(min(DEBUG, levelno), transcript_levelno) - getLogger().addHandler(_logHandler) - - # establish default capture level - capture_level = min(TRACE, levelno, transcript_levelno) - - # see all the messages (we won't actually display/write all of them) - getLogger().setLevel(capture_level) - getLogger(PKG_NAME).setLevel(capture_level) - getLogger("httpx").setLevel(capture_level) - getLogger("botocore").setLevel(DEBUG) - - if removed_root_handlers: - getLogger(PKG_NAME).warning( - "Inspect removed pre-existing root logger handlers and replaced them with its own handler." + _logHandler = LogHandler( + capture_levelno=capture_level, + display_levelno=levelno, + transcript_levelno=transcript_levelno, ) - # set the levelno on the global handler - _logHandler.display_level = levelno + # set the log level for our package + getLogger(PKG_NAME).setLevel(capture_level) + getLogger(PKG_NAME).addHandler(_logHandler) + getLogger(PKG_NAME).propagate = False + # add our logger to the global handlers + getLogger().addHandler(_logHandler) -_logHandler: LogHandler | None = None + # httpx currently logs all requests at the INFO level + # this is a bit aggressive and we already do this at + # our own HTTP level + getLogger("httpx").setLevel(WARNING) -def remove_non_pytest_root_logger_handlers() -> bool: - root_logger = getLogger() - non_pytest_handlers = [ - handler - for handler in root_logger.handlers - if handler.__module__ != "_pytest.logging" - ] - for handler in non_pytest_handlers: - root_logger.removeHandler(handler) - return len(non_pytest_handlers) > 0 +_logHandler: LogHandler | None = None -def notify_logger_record(record: LogRecord, write: bool) -> None: +def log_to_transcript(record: LogRecord) -> None: from inspect_ai.log._message import LoggingMessage from inspect_ai.log._transcript import LoggerEvent, transcript - if write: - transcript()._event( - LoggerEvent(message=LoggingMessage._from_log_record(record)) - ) - global _rate_limit_count - if (record.levelno <= INFO and re.search(r"\b429\b", record.getMessage())) or ( - record.levelno == DEBUG - # See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html#validating-retry-attempts - # for boto retry logic / log messages (this is tracking standard or adapative retries) - and "botocore.retries.standard" in record.name - and "Retry needed, retrying request after delay of:" in record.getMessage() - ): - _rate_limit_count = _rate_limit_count + 1 - - -_rate_limit_count = 0 - - -def init_http_rate_limit_count() -> None: - global _rate_limit_count - _rate_limit_count = 0 - - -def http_rate_limit_count() -> int: - return _rate_limit_count + transcript()._event(LoggerEvent(message=LoggingMessage._from_log_record(record))) def warn_once(logger: Logger, message: str) -> None: diff --git a/src/inspect_ai/_util/retry.py b/src/inspect_ai/_util/retry.py index 8cc48ec9f..11b4d8211 100644 --- a/src/inspect_ai/_util/retry.py +++ b/src/inspect_ai/_util/retry.py @@ -1,67 +1,16 @@ -import logging -from typing import Callable +_http_retries_count: int = 0 -from httpx import ConnectError, ConnectTimeout, HTTPStatusError, ReadTimeout -from tenacity import RetryCallState -from inspect_ai._util.constants import HTTP +def report_http_retry() -> None: + from inspect_ai.log._samples import report_active_sample_retry -logger = logging.getLogger(__name__) + # bump global counter + global _http_retries_count + _http_retries_count = _http_retries_count + 1 + # report sample retry + report_active_sample_retry() -def httpx_should_retry(ex: BaseException) -> bool: - """Check whether an exception raised from httpx should be retried. - Implements the strategy described here: https://cloud.google.com/storage/docs/retry-strategy - - Args: - ex (BaseException): Exception to examine for retry behavior - - Returns: - True if a retry should occur - """ - # httpx status exception - if isinstance(ex, HTTPStatusError): - # request timeout - if ex.response.status_code == 408: - return True - # lock timeout - elif ex.response.status_code == 409: - return True - # rate limit - elif ex.response.status_code == 429: - return True - # internal errors - elif ex.response.status_code >= 500: - return True - else: - return False - - # connection error - elif is_httpx_connection_error(ex): - return True - - # don't retry - else: - return False - - -def log_rate_limit_retry(context: str, retry_state: RetryCallState) -> None: - logger.log( - HTTP, - f"{context} rate limit retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}", - ) - - -def log_retry_attempt(context: str) -> Callable[[RetryCallState], None]: - def log_attempt(retry_state: RetryCallState) -> None: - logger.log( - HTTP, - f"{context} connection retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}", - ) - - return log_attempt - - -def is_httpx_connection_error(ex: BaseException) -> bool: - return isinstance(ex, ConnectTimeout | ConnectError | ConnectionError | ReadTimeout) +def http_retries_count() -> int: + return _http_retries_count diff --git a/src/inspect_ai/log/_samples.py b/src/inspect_ai/log/_samples.py index 4088214b8..62cbb16bd 100644 --- a/src/inspect_ai/log/_samples.py +++ b/src/inspect_ai/log/_samples.py @@ -1,15 +1,16 @@ import contextlib from contextvars import ContextVar from datetime import datetime -from typing import AsyncGenerator, Literal +from typing import AsyncGenerator, Iterator, Literal from shortuuid import uuid +from inspect_ai._util.constants import SAMPLE_SUBTASK from inspect_ai.dataset._dataset import Sample from inspect_ai.util._sandbox import SandboxConnection from inspect_ai.util._sandbox.context import sandbox_connections -from ._transcript import Transcript +from ._transcript import Transcript, transcript class ActiveSample: @@ -44,6 +45,7 @@ def __init__( self.total_tokens = 0 self.transcript = transcript self.sandboxes = sandboxes + self.retry_count = 0 self._interrupt_action: Literal["score", "error"] | None = None @property @@ -153,6 +155,29 @@ def set_active_sample_total_messages(total_messages: int) -> None: active.total_messages = total_messages +@contextlib.contextmanager +def track_active_sample_retries() -> Iterator[None]: + reset_active_sample_retries() + try: + yield + finally: + reset_active_sample_retries() + + +def reset_active_sample_retries() -> None: + active = sample_active() + if active: + active.retry_count = 0 + + +def report_active_sample_retry() -> None: + active = sample_active() + if active: + # only do this for the top level subtask + if transcript().name == SAMPLE_SUBTASK: + active.retry_count = active.retry_count + 1 + + _sample_active: ContextVar[ActiveSample | None] = ContextVar( "_sample_active", default=None ) diff --git a/src/inspect_ai/model/_model.py b/src/inspect_ai/model/_model.py index 1407b094d..c5fa6cd05 100644 --- a/src/inspect_ai/model/_model.py +++ b/src/inspect_ai/model/_model.py @@ -13,6 +13,7 @@ from pydantic_core import to_jsonable_python from tenacity import ( + RetryCallState, retry, retry_if_exception, stop_after_attempt, @@ -20,8 +21,9 @@ stop_never, wait_exponential_jitter, ) +from tenacity.stop import StopBaseT -from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS +from inspect_ai._util.constants import DEFAULT_MAX_CONNECTIONS, HTTP from inspect_ai._util.content import ( Content, ContentImage, @@ -30,6 +32,7 @@ ) from inspect_ai._util.hooks import init_hooks, override_api_key, send_telemetry from inspect_ai._util.interrupt import check_sample_interrupt +from inspect_ai._util.logger import warn_once from inspect_ai._util.platform import platform_init from inspect_ai._util.registry import ( RegistryInfo, @@ -37,7 +40,7 @@ registry_info, registry_unqualified_name, ) -from inspect_ai._util.retry import log_rate_limit_retry +from inspect_ai._util.retry import report_http_retry from inspect_ai._util.trace import trace_action from inspect_ai._util.working import report_sample_waiting_time, sample_working_time from inspect_ai.tool import Tool, ToolChoice, ToolFunction, ToolInfo @@ -173,11 +176,11 @@ def connection_key(self) -> str: """Scope for enforcement of max_connections.""" return "default" - def is_rate_limit(self, ex: BaseException) -> bool: - """Is this exception a rate limit error. + def should_retry(self, ex: Exception) -> bool: + """Should this exception be retried? Args: - ex: Exception to check for rate limit. + ex: Exception to check for retry """ return False @@ -331,14 +334,17 @@ async def generate( start_time = datetime.now() working_start = sample_working_time() async with self._connection_concurrency(config): + from inspect_ai.log._samples import track_active_sample_retries + # generate - output = await self._generate( - input=input, - tools=tools, - tool_choice=tool_choice, - config=config, - cache=cache, - ) + with track_active_sample_retries(): + output = await self._generate( + input=input, + tools=tools, + tool_choice=tool_choice, + config=config, + cache=cache, + ) # update the most recent ModelEvent with the actual start/completed # times as well as a computation of working time (events are @@ -418,27 +424,27 @@ async def _generate( if self.api.collapse_assistant_messages(): input = collapse_consecutive_assistant_messages(input) - # retry for rate limit errors (max of 30 minutes) + # retry for transient http errors: + # - no default timeout or max_retries (try forever) + # - exponential backoff starting at 3 seconds (will wait 25 minutes + # on the 10th retry,then will wait no longer than 30 minutes on + # subsequent retries) + if config.max_retries is not None and config.timeout is not None: + stop: StopBaseT = stop_after_attempt(config.max_retries) | stop_after_delay( + config.timeout + ) + elif config.max_retries is not None: + stop = stop_after_attempt(config.max_retries) + elif config.timeout is not None: + stop = stop_after_delay(config.timeout) + else: + stop = stop_never + @retry( - wait=wait_exponential_jitter(max=(30 * 60), jitter=5), - retry=retry_if_exception(self.api.is_rate_limit), - stop=( - ( - stop_after_delay(config.timeout) - | stop_after_attempt(config.max_retries) - ) - if config.timeout and config.max_retries - else ( - stop_after_delay(config.timeout) - if config.timeout - else ( - stop_after_attempt(config.max_retries) - if config.max_retries - else stop_never - ) - ) - ), - before_sleep=functools.partial(log_rate_limit_retry, self.api.model_name), + wait=wait_exponential_jitter(initial=3, max=(30 * 60), jitter=3), + retry=retry_if_exception(self.should_retry), + stop=stop, + before_sleep=functools.partial(log_model_retry, self.api.model_name), ) async def generate() -> ModelOutput: check_sample_interrupt() @@ -555,6 +561,30 @@ async def generate() -> ModelOutput: # return results return model_output + def should_retry(self, ex: BaseException) -> bool: + if isinstance(ex, Exception): + # check standard should_retry() method + retry = self.api.should_retry(ex) + if retry: + report_http_retry() + return True + + # see if the API implements legacy is_rate_limit() method + is_rate_limit = getattr(self.api, "is_rate_limit", None) + if is_rate_limit: + warn_once( + logger, + f"provider '{self.name}' implements deprecated is_rate_limit() method, " + + "please change to should_retry()", + ) + retry = cast(bool, is_rate_limit(ex)) + if retry: + report_http_retry() + return True + + # no retry + return False + # function to verify that its okay to call model apis def verify_model_apis(self) -> None: if ( @@ -1183,6 +1213,13 @@ def combine_messages( ) +def log_model_retry(model_name: str, retry_state: RetryCallState) -> None: + logger.log( + HTTP, + f"-> {model_name} retry {retry_state.attempt_number} after waiting for {retry_state.idle_for}", + ) + + def init_active_model(model: Model, config: GenerateConfig) -> None: active_model_context_var.set(model) set_active_generate_config(config) diff --git a/src/inspect_ai/model/_providers/anthropic.py b/src/inspect_ai/model/_providers/anthropic.py index 1e7c56faf..da6e47990 100644 --- a/src/inspect_ai/model/_providers/anthropic.py +++ b/src/inspect_ai/model/_providers/anthropic.py @@ -6,7 +6,9 @@ from logging import getLogger from typing import Any, Literal, Optional, Tuple, TypedDict, cast -from .util.tracker import HttpxTimeTracker +from inspect_ai._util.http import is_retryable_http_status + +from .util.hooks import HttpxHooks if sys.version_info >= (3, 11): from typing import NotRequired @@ -14,15 +16,13 @@ from typing_extensions import NotRequired from anthropic import ( - APIConnectionError, APIStatusError, + APITimeoutError, AsyncAnthropic, AsyncAnthropicBedrock, AsyncAnthropicVertex, BadRequestError, - InternalServerError, NotGiven, - RateLimitError, ) from anthropic._types import Body from anthropic.types import ( @@ -46,7 +46,6 @@ from inspect_ai._util.constants import ( BASE_64_DATA_REMOVED, - DEFAULT_MAX_RETRIES, NO_CONTENT, ) from inspect_ai._util.content import ( @@ -125,9 +124,6 @@ def collect_model_arg(name: str) -> Any | None: AsyncAnthropic | AsyncAnthropicBedrock | AsyncAnthropicVertex ) = AsyncAnthropicBedrock( base_url=base_url, - max_retries=( - config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES - ), aws_region=aws_region, **model_args, ) @@ -141,9 +137,6 @@ def collect_model_arg(name: str) -> Any | None: region=region, project_id=project_id, base_url=base_url, - max_retries=( - config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES - ), **model_args, ) else: @@ -156,14 +149,11 @@ def collect_model_arg(name: str) -> Any | None: self.client = AsyncAnthropic( base_url=base_url, api_key=self.api_key, - max_retries=( - config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES - ), **model_args, ) # create time tracker - self._time_tracker = HttpxTimeTracker(self.client._client) + self._http_hooks = HttpxHooks(self.client._client) @override async def close(self) -> None: @@ -183,7 +173,7 @@ async def generate( config: GenerateConfig, ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]: # allocate request_id (so we can see it from ModelCall) - request_id = self._time_tracker.start_request() + request_id = self._http_hooks.start_request() # setup request and response for ModelCall request: dict[str, Any] = {} @@ -194,7 +184,7 @@ def model_call() -> ModelCall: request=request, response=response, filter=model_call_filter, - time=self._time_tracker.end_request(request_id), + time=self._http_hooks.end_request(request_id), ) # generate @@ -223,7 +213,7 @@ def model_call() -> ModelCall: request = request | req # extra headers (for time tracker and computer use) - extra_headers = headers | {HttpxTimeTracker.REQUEST_ID_HEADER: request_id} + extra_headers = headers | {HttpxHooks.REQUEST_ID_HEADER: request_id} if computer_use: betas.append("computer-use-2025-01-24") if len(betas) > 0: @@ -291,8 +281,6 @@ def completion_config( betas.append("output-128k-2025-02-19") # config that applies to all models - if config.timeout is not None: - params["timeout"] = float(config.timeout) if config.stop_seqs is not None: params["stop_sequences"] = config.stop_seqs @@ -334,13 +322,13 @@ def connection_key(self) -> str: return str(self.api_key) @override - def is_rate_limit(self, ex: BaseException) -> bool: - # We have observed that anthropic will frequently return InternalServerError - # seemingly in place of RateLimitError (at the very least the errors seem to - # always be transient). Equating this to rate limit errors may occasionally - # result in retrying too many times, but much more often will avert a failed - # eval that just needed to survive a transient error - return isinstance(ex, RateLimitError | InternalServerError | APIConnectionError) + def should_retry(self, ex: Exception) -> bool: + if isinstance(ex, APIStatusError): + return is_retryable_http_status(ex.status_code) + elif isinstance(ex, APITimeoutError): + return True + else: + return False @override def collapse_user_messages(self) -> bool: diff --git a/src/inspect_ai/model/_providers/azureai.py b/src/inspect_ai/model/_providers/azureai.py index 79c8866ac..c7673f8c5 100644 --- a/src/inspect_ai/model/_providers/azureai.py +++ b/src/inspect_ai/model/_providers/azureai.py @@ -27,11 +27,16 @@ UserMessage, ) from azure.core.credentials import AzureKeyCredential -from azure.core.exceptions import AzureError, HttpResponseError +from azure.core.exceptions import ( + AzureError, + HttpResponseError, + ServiceResponseError, +) from typing_extensions import override from inspect_ai._util.constants import DEFAULT_MAX_TOKENS from inspect_ai._util.content import Content, ContentImage, ContentText +from inspect_ai._util.http import is_retryable_http_status from inspect_ai._util.images import file_as_data_uri from inspect_ai.tool import ToolChoice, ToolInfo from inspect_ai.tool._tool_call import ToolCall @@ -232,14 +237,11 @@ def max_tokens(self) -> int | None: return DEFAULT_MAX_TOKENS @override - def is_rate_limit(self, ex: BaseException) -> bool: - if isinstance(ex, HttpResponseError): - return ( - ex.status_code == 408 - or ex.status_code == 409 - or ex.status_code == 429 - or ex.status_code == 500 - ) + def should_retry(self, ex: Exception) -> bool: + if isinstance(ex, HttpResponseError) and ex.status_code is not None: + return is_retryable_http_status(ex.status_code) + elif isinstance(ex, ServiceResponseError): + return True else: return False diff --git a/src/inspect_ai/model/_providers/bedrock.py b/src/inspect_ai/model/_providers/bedrock.py index 852c518cb..a660d4f0e 100644 --- a/src/inspect_ai/model/_providers/bedrock.py +++ b/src/inspect_ai/model/_providers/bedrock.py @@ -1,14 +1,11 @@ import base64 +from logging import getLogger from typing import Any, Literal, Tuple, Union, cast from pydantic import BaseModel, Field from typing_extensions import override -from inspect_ai._util.constants import ( - DEFAULT_MAX_RETRIES, - DEFAULT_MAX_TOKENS, - DEFAULT_TIMEOUT, -) +from inspect_ai._util.constants import DEFAULT_MAX_TOKENS from inspect_ai._util.content import Content, ContentImage, ContentText from inspect_ai._util.error import pip_dependency_error from inspect_ai._util.images import file_as_data @@ -31,7 +28,9 @@ from .util import ( model_base_url, ) -from .util.tracker import BotoTimeTracker +from .util.hooks import ConverseHooks + +logger = getLogger(__name__) # Model for Bedrock Converse API (Response) # generated from: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html#converse @@ -258,7 +257,7 @@ def __init__( self.session = aioboto3.Session() # create time tracker - self._time_tracker = BotoTimeTracker(self.session) + self._http_hooks = ConverseHooks(self.session) except ImportError: raise pip_dependency_error("Bedrock API", ["aioboto3"]) @@ -288,15 +287,25 @@ def max_tokens(self) -> int | None: return DEFAULT_MAX_TOKENS @override - def is_rate_limit(self, ex: BaseException) -> bool: + def should_retry(self, ex: Exception) -> bool: from botocore.exceptions import ClientError # Look for an explicit throttle exception if isinstance(ex, ClientError): - if ex.response["Error"]["Code"] == "ThrottlingException": - return True - - return super().is_rate_limit(ex) + error_code = ex.response.get("Error", {}).get("Code", "") + return error_code in [ + "ThrottlingException", + "RequestLimitExceeded", + "Throttling", + "RequestThrottled", + "TooManyRequestsException", + "ProvisionedThroughputExceededException", + "TransactionInProgressException", + "RequestTimeout", + "ServiceUnavailable", + ] + else: + return False @override def collapse_user_messages(self) -> bool: @@ -317,20 +326,13 @@ async def generate( from botocore.exceptions import ClientError # The bedrock client - request_id = self._time_tracker.start_request() + request_id = self._http_hooks.start_request() async with self.session.client( # type: ignore[call-overload] service_name="bedrock-runtime", endpoint_url=self.base_url, config=Config( - connect_timeout=config.timeout if config.timeout else DEFAULT_TIMEOUT, - read_timeout=config.timeout if config.timeout else DEFAULT_TIMEOUT, - retries=dict( - max_attempts=config.max_retries - if config.max_retries - else DEFAULT_MAX_RETRIES, - mode="adaptive", - ), - user_agent_extra=self._time_tracker.user_agent_extra(request_id), + retries=dict(mode="adaptive"), + user_agent_extra=self._http_hooks.user_agent_extra(request_id), ), **self.model_args, ) as client: @@ -370,7 +372,7 @@ def model_call(response: dict[str, Any] | None = None) -> ModelCall: request.model_dump(exclude_none=True) ), response=response, - time=self._time_tracker.end_request(request_id), + time=self._http_hooks.end_request(request_id), ) try: diff --git a/src/inspect_ai/model/_providers/cloudflare.py b/src/inspect_ai/model/_providers/cloudflare.py index 80f5c5673..6f4516888 100644 --- a/src/inspect_ai/model/_providers/cloudflare.py +++ b/src/inspect_ai/model/_providers/cloudflare.py @@ -16,10 +16,10 @@ chat_api_input, chat_api_request, environment_prerequisite_error, - is_chat_api_rate_limit, model_base_url, + should_retry_chat_api_error, ) -from .util.tracker import HttpxTimeTracker +from .util.hooks import HttpxHooks # https://developers.cloudflare.com/workers-ai/models/#text-generation @@ -51,7 +51,7 @@ def __init__( if not self.api_key: raise environment_prerequisite_error("CloudFlare", CLOUDFLARE_API_TOKEN) self.client = httpx.AsyncClient() - self._time_tracker = HttpxTimeTracker(self.client) + self._http_hooks = HttpxHooks(self.client) base_url = model_base_url(base_url, "CLOUDFLARE_BASE_URL") self.base_url = ( base_url if base_url else "https://api.cloudflare.com/client/v4/accounts" @@ -79,7 +79,7 @@ async def generate( json["messages"] = chat_api_input(input, tools, self.chat_api_handler()) # request_id - request_id = self._time_tracker.start_request() + request_id = self._http_hooks.start_request() # setup response response: dict[str, Any] = {} @@ -88,7 +88,7 @@ def model_call() -> ModelCall: return ModelCall.create( request=json, response=response, - time=self._time_tracker.end_request(request_id), + time=self._http_hooks.end_request(request_id), ) # make the call @@ -98,10 +98,9 @@ def model_call() -> ModelCall: url=f"{chat_url}/{self.model_name}", headers={ "Authorization": f"Bearer {self.api_key}", - HttpxTimeTracker.REQUEST_ID_HEADER: request_id, + HttpxHooks.REQUEST_ID_HEADER: request_id, }, json=json, - config=config, ) # handle response @@ -127,8 +126,8 @@ def model_call() -> ModelCall: raise RuntimeError(f"Error calling {self.model_name}: {error}") @override - def is_rate_limit(self, ex: BaseException) -> bool: - return is_chat_api_rate_limit(ex) + def should_retry(self, ex: Exception) -> bool: + return should_retry_chat_api_error(ex) # cloudflare enforces rate limits by model for each account @override diff --git a/src/inspect_ai/model/_providers/goodfire.py b/src/inspect_ai/model/_providers/goodfire.py index 6327289e4..77bc51529 100644 --- a/src/inspect_ai/model/_providers/goodfire.py +++ b/src/inspect_ai/model/_providers/goodfire.py @@ -3,7 +3,11 @@ from goodfire import AsyncClient from goodfire.api.chat.interfaces import ChatMessage as GoodfireChatMessage -from goodfire.api.exceptions import InvalidRequestException, RateLimitException +from goodfire.api.exceptions import ( + InvalidRequestException, + RateLimitException, + ServerErrorException, +) from goodfire.variants.variants import SUPPORTED_MODELS, Variant from typing_extensions import override @@ -163,9 +167,9 @@ def handle_error(self, ex: Exception) -> ModelOutput | Exception: return ex @override - def is_rate_limit(self, ex: BaseException) -> bool: + def should_retry(self, ex: Exception) -> bool: """Check if exception is due to rate limiting.""" - return isinstance(ex, RateLimitException) + return isinstance(ex, RateLimitException | ServerErrorException) @override def connection_key(self) -> str: diff --git a/src/inspect_ai/model/_providers/google.py b/src/inspect_ai/model/_providers/google.py index 839c86aec..5c23b17ee 100644 --- a/src/inspect_ai/model/_providers/google.py +++ b/src/inspect_ai/model/_providers/google.py @@ -26,6 +26,7 @@ GenerationConfig, HarmBlockThreshold, HarmCategory, + HttpOptions, Part, SafetySetting, SafetySettingDict, @@ -49,6 +50,7 @@ ContentVideo, ) from inspect_ai._util.error import PrerequisiteError +from inspect_ai._util.http import is_retryable_http_status from inspect_ai._util.images import file_as_data from inspect_ai._util.kvstore import inspect_kvstore from inspect_ai._util.trace import trace_message @@ -69,6 +71,7 @@ ) from inspect_ai.model._model_call import ModelCall from inspect_ai.model._providers.util import model_base_url +from inspect_ai.model._providers.util.hooks import HttpHooks, urllib3_hooks from inspect_ai.tool import ( ToolCall, ToolChoice, @@ -199,11 +202,15 @@ async def generate( tool_choice: ToolChoice, config: GenerateConfig, ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]: + # generate request_id + request_id = urllib3_hooks().start_request() + # Create google-genai types. gemini_contents = await as_chat_messages(self.client, input) gemini_tools = chat_tools(tools) if len(tools) > 0 else None gemini_tool_config = chat_tool_config(tool_choice) if len(tools) > 0 else None parameters = GenerateContentConfig( + http_options=HttpOptions(headers={HttpHooks.REQUEST_ID_HEADER: request_id}), temperature=config.temperature, top_p=config.top_p, top_k=config.top_k, @@ -230,10 +237,9 @@ def model_call() -> ModelCall: tools=gemini_tools, tool_config=gemini_tool_config, response=response, + time=urllib3_hooks().end_request(request_id), ) - # TODO: would need to monkey patch AuthorizedSession.request - try: response = await self.client.aio.models.generate_content( model=self.model_name, @@ -252,11 +258,24 @@ def model_call() -> ModelCall: return output, model_call() @override - def is_rate_limit(self, ex: BaseException) -> bool: - # see https://cloud.google.com/storage/docs/retry-strategy - return isinstance(ex, APIError) and ( - ex.code in (408, 429, 429) or ex.code >= 500 - ) + def should_retry(self, ex: Exception) -> bool: + import requests # type: ignore + + # standard http errors + if isinstance(ex, APIError): + return is_retryable_http_status(ex.status) + + # low-level requests exceptions + elif isinstance(ex, requests.exceptions.RequestException): + return isinstance( + ex, + ( + requests.exceptions.ConnectTimeout + | requests.exceptions.ChunkedEncodingError + ), + ) + else: + return False @override def connection_key(self) -> str: @@ -296,6 +315,7 @@ def build_model_call( tools: list[Tool] | None, tool_config: ToolConfig | None, response: GenerateContentResponse | None, + time: float | None, ) -> ModelCall: return ModelCall.create( request=dict( @@ -307,6 +327,7 @@ def build_model_call( ), response=response if response is not None else {}, filter=model_call_filter, + time=time, ) diff --git a/src/inspect_ai/model/_providers/groq.py b/src/inspect_ai/model/_providers/groq.py index f7f44f454..f217507ac 100644 --- a/src/inspect_ai/model/_providers/groq.py +++ b/src/inspect_ai/model/_providers/groq.py @@ -5,8 +5,9 @@ import httpx from groq import ( + APIStatusError, + APITimeoutError, AsyncGroq, - RateLimitError, ) from groq.types.chat import ( ChatCompletion, @@ -25,10 +26,10 @@ from inspect_ai._util.constants import ( BASE_64_DATA_REMOVED, - DEFAULT_MAX_RETRIES, DEFAULT_MAX_TOKENS, ) from inspect_ai._util.content import Content, ContentReasoning, ContentText +from inspect_ai._util.http import is_retryable_http_status from inspect_ai._util.images import file_as_data_uri from inspect_ai._util.url import is_http_url from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo @@ -54,7 +55,7 @@ environment_prerequisite_error, model_base_url, ) -from .util.tracker import HttpxTimeTracker +from .util.hooks import HttpxHooks GROQ_API_KEY = "GROQ_API_KEY" @@ -84,18 +85,12 @@ def __init__( self.client = AsyncGroq( api_key=self.api_key, base_url=model_base_url(base_url, "GROQ_BASE_URL"), - max_retries=( - config.max_retries - if config.max_retries is not None - else DEFAULT_MAX_RETRIES - ), - timeout=config.timeout if config.timeout is not None else 60.0, **model_args, http_client=httpx.AsyncClient(limits=httpx.Limits(max_connections=None)), ) # create time tracker - self._time_tracker = HttpxTimeTracker(self.client._client) + self._http_hooks = HttpxHooks(self.client._client) @override async def close(self) -> None: @@ -109,7 +104,7 @@ async def generate( config: GenerateConfig, ) -> tuple[ModelOutput, ModelCall]: # allocate request_id (so we can see it from ModelCall) - request_id = self._time_tracker.start_request() + request_id = self._http_hooks.start_request() # setup request and response for ModelCall request: dict[str, Any] = {} @@ -120,7 +115,7 @@ def model_call() -> ModelCall: request=request, response=response, filter=model_call_filter, - time=self._time_tracker.end_request(request_id), + time=self._http_hooks.end_request(request_id), ) messages = await as_groq_chat_messages(input) @@ -137,7 +132,7 @@ def model_call() -> ModelCall: request = dict( messages=messages, model=self.model_name, - extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id}, + extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id}, **params, ) @@ -215,8 +210,13 @@ def _chat_choices_from_response( ] @override - def is_rate_limit(self, ex: BaseException) -> bool: - return isinstance(ex, RateLimitError) + def should_retry(self, ex: Exception) -> bool: + if isinstance(ex, APIStatusError): + return is_retryable_http_status(ex.status_code) + elif isinstance(ex, APITimeoutError): + return True + else: + return False @override def connection_key(self) -> str: diff --git a/src/inspect_ai/model/_providers/mistral.py b/src/inspect_ai/model/_providers/mistral.py index c86c73444..d4618640e 100644 --- a/src/inspect_ai/model/_providers/mistral.py +++ b/src/inspect_ai/model/_providers/mistral.py @@ -38,11 +38,9 @@ # TODO: Migration guide: # https://github.com/mistralai/client-python/blob/main/MIGRATION.md -from inspect_ai._util.constants import ( - DEFAULT_TIMEOUT, - NO_CONTENT, -) +from inspect_ai._util.constants import NO_CONTENT from inspect_ai._util.content import Content, ContentImage, ContentText +from inspect_ai._util.http import is_retryable_http_status from inspect_ai._util.images import file_as_data_uri from inspect_ai.tool import ToolCall, ToolChoice, ToolFunction, ToolInfo @@ -61,7 +59,7 @@ StopReason, ) from .util import environment_prerequisite_error, model_base_url -from .util.tracker import HttpxTimeTracker +from .util.hooks import HttpxHooks AZURE_MISTRAL_API_KEY = "AZURE_MISTRAL_API_KEY" AZUREAI_MISTRAL_API_KEY = "AZUREAI_MISTRAL_API_KEY" @@ -127,16 +125,12 @@ async def generate( config: GenerateConfig, ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]: # create client - with Mistral( - api_key=self.api_key, - timeout_ms=(config.timeout if config.timeout else DEFAULT_TIMEOUT) * 1000, - **self.model_args, - ) as client: + with Mistral(api_key=self.api_key, **self.model_args) as client: # create time tracker - time_tracker = HttpxTimeTracker(client.sdk_configuration.async_client) + http_hooks = HttpxHooks(client.sdk_configuration.async_client) # build request - request_id = time_tracker.start_request() + request_id = http_hooks.start_request() request: dict[str, Any] = dict( model=self.model_name, messages=await mistral_chat_messages(input), @@ -144,7 +138,7 @@ async def generate( tool_choice=( mistral_chat_tool_choice(tool_choice) if len(tools) > 0 else None ), - http_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id}, + http_headers={HttpxHooks.REQUEST_ID_HEADER: request_id}, ) if config.temperature is not None: request["temperature"] = config.temperature @@ -169,7 +163,7 @@ def model_call() -> ModelCall: return ModelCall.create( request=req, response=response, - time=time_tracker.end_request(request_id), + time=http_hooks.end_request(request_id), ) # send request @@ -205,12 +199,13 @@ def model_call() -> ModelCall: ), model_call() @override - def is_rate_limit(self, ex: BaseException) -> bool: - return ( - isinstance(ex, SDKError) - and ex.status_code == 429 - or isinstance(ex, ReadTimeout | AsyncReadTimeout) - ) + def should_retry(self, ex: Exception) -> bool: + if isinstance(ex, SDKError): + return is_retryable_http_status(ex.status_code) + elif isinstance(ex, ReadTimeout | AsyncReadTimeout): + return True + else: + return False @override def connection_key(self) -> str: diff --git a/src/inspect_ai/model/_providers/openai.py b/src/inspect_ai/model/_providers/openai.py index 64ab4e10f..9d481e057 100644 --- a/src/inspect_ai/model/_providers/openai.py +++ b/src/inspect_ai/model/_providers/openai.py @@ -7,12 +7,11 @@ from openai import ( DEFAULT_CONNECTION_LIMITS, DEFAULT_TIMEOUT, - APIConnectionError, + APIStatusError, APITimeoutError, AsyncAzureOpenAI, AsyncOpenAI, BadRequestError, - InternalServerError, RateLimitError, ) from openai._types import NOT_GIVEN @@ -21,11 +20,11 @@ ) from typing_extensions import override -from inspect_ai._util.constants import DEFAULT_MAX_RETRIES from inspect_ai._util.error import PrerequisiteError +from inspect_ai._util.http import is_retryable_http_status from inspect_ai._util.logger import warn_once from inspect_ai.model._openai import chat_choices_from_openai -from inspect_ai.model._providers.util.tracker import HttpxTimeTracker +from inspect_ai.model._providers.util.hooks import HttpxHooks from inspect_ai.tool import ToolChoice, ToolInfo from .._chat_message import ChatMessage @@ -130,9 +129,6 @@ def __init__( api_key=self.api_key, azure_endpoint=base_url, azure_deployment=model_name, - max_retries=( - config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES - ), http_client=http_client, **model_args, ) @@ -140,15 +136,12 @@ def __init__( self.client = AsyncOpenAI( api_key=self.api_key, base_url=model_base_url(base_url, "OPENAI_BASE_URL"), - max_retries=( - config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES - ), http_client=http_client, **model_args, ) # create time tracker - self._time_tracker = HttpxTimeTracker(self.client._client) + self._http_hooks = HttpxHooks(self.client._client) def is_azure(self) -> bool: return self.service == "azure" @@ -186,7 +179,7 @@ async def generate( ) # allocate request_id (so we can see it from ModelCall) - request_id = self._time_tracker.start_request() + request_id = self._http_hooks.start_request() # setup request and response for ModelCall request: dict[str, Any] = {} @@ -197,7 +190,7 @@ def model_call() -> ModelCall: request=request, response=response, filter=image_url_filter, - time=self._time_tracker.end_request(request_id), + time=self._http_hooks.end_request(request_id), ) # unlike text models, vision models require a max_tokens (and set it to a very low @@ -216,7 +209,7 @@ def model_call() -> ModelCall: tool_choice=openai_chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN, - extra_headers={HttpxTimeTracker.REQUEST_ID_HEADER: request_id}, + extra_headers={HttpxHooks.REQUEST_ID_HEADER: request_id}, **self.completion_params(config, len(tools) > 0), ) @@ -266,17 +259,21 @@ def _chat_choices_from_response( return chat_choices_from_openai(response, tools) @override - def is_rate_limit(self, ex: BaseException) -> bool: + def should_retry(self, ex: Exception) -> bool: if isinstance(ex, RateLimitError): # Do not retry on these rate limit errors # The quota exceeded one is related to monthly account quotas. - if "You exceeded your current quota" not in ex.message: + if "You exceeded your current quota" in ex.message: + warn_once(logger, f"OpenAI quota exceeded, not retrying: {ex.message}") + return False + else: return True - elif isinstance( - ex, (APIConnectionError | APITimeoutError | InternalServerError) - ): + elif isinstance(ex, APIStatusError): + return is_retryable_http_status(ex.status_code) + elif isinstance(ex, APITimeoutError): return True - return False + else: + return False @override def connection_key(self) -> str: @@ -315,8 +312,6 @@ def completion_params(self, config: GenerateConfig, tools: bool) -> dict[str, An params["temperature"] = 1 if config.top_p is not None: params["top_p"] = config.top_p - if config.timeout is not None: - params["timeout"] = float(config.timeout) if config.num_choices is not None: params["n"] = config.num_choices if config.logprobs is not None: diff --git a/src/inspect_ai/model/_providers/together.py b/src/inspect_ai/model/_providers/together.py index 1f68e2142..6ef167c82 100644 --- a/src/inspect_ai/model/_providers/together.py +++ b/src/inspect_ai/model/_providers/together.py @@ -34,8 +34,8 @@ chat_api_input, chat_api_request, environment_prerequisite_error, - is_chat_api_rate_limit, model_base_url, + should_retry_chat_api_error, ) @@ -186,7 +186,6 @@ async def generate( url=f"{chat_url}", headers={"Authorization": f"Bearer {self.api_key}"}, json=json, - config=config, ) if "error" in response: @@ -215,8 +214,8 @@ async def generate( return ModelOutput(model=model, choices=choices, usage=usage) @override - def is_rate_limit(self, ex: BaseException) -> bool: - return is_chat_api_rate_limit(ex) + def should_retry(self, ex: Exception) -> bool: + return should_retry_chat_api_error(ex) # cloudflare enforces rate limits by model for each account @override diff --git a/src/inspect_ai/model/_providers/util/__init__.py b/src/inspect_ai/model/_providers/util/__init__.py index f8b7aeb69..3bdc68e2c 100644 --- a/src/inspect_ai/model/_providers/util/__init__.py +++ b/src/inspect_ai/model/_providers/util/__init__.py @@ -5,7 +5,7 @@ ChatAPIMessage, chat_api_input, chat_api_request, - is_chat_api_rate_limit, + should_retry_chat_api_error, ) from .hf_handler import HFHandler from .llama31 import Llama31Handler @@ -19,7 +19,7 @@ "as_stop_reason", "chat_api_request", "chat_api_input", - "is_chat_api_rate_limit", + "should_retry_chat_api_error", "model_base_url", "parse_tool_call", "tool_parse_error_message", diff --git a/src/inspect_ai/model/_providers/util/chatapi.py b/src/inspect_ai/model/_providers/util/chatapi.py index 9695e37f2..768ad167c 100644 --- a/src/inspect_ai/model/_providers/util/chatapi.py +++ b/src/inspect_ai/model/_providers/util/chatapi.py @@ -7,17 +7,15 @@ retry, retry_if_exception, stop_after_attempt, - stop_after_delay, wait_exponential_jitter, ) -from inspect_ai._util.constants import DEFAULT_MAX_RETRIES -from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt +from inspect_ai._util.http import is_retryable_http_status +from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt from inspect_ai.model._chat_message import ChatMessageAssistant, ChatMessageTool from inspect_ai.tool._tool_info import ToolInfo from ..._chat_message import ChatMessage -from ..._generate_config import GenerateConfig logger = getLogger(__name__) @@ -75,21 +73,13 @@ async def chat_api_request( url: str, headers: dict[str, Any], json: Any, - config: GenerateConfig, ) -> Any: - # provide default max_retries - max_retries = config.max_retries if config.max_retries else DEFAULT_MAX_RETRIES - # define call w/ retry policy @retry( wait=wait_exponential_jitter(), - stop=( - (stop_after_attempt(max_retries) | stop_after_delay(config.timeout)) - if config.timeout - else stop_after_attempt(max_retries) - ), + stop=(stop_after_attempt(2)), retry=retry_if_exception(httpx_should_retry), - before_sleep=log_retry_attempt(model_name), + before_sleep=log_httpx_retry_attempt(model_name), ) async def call_api() -> Any: response = await client.post(url=url, headers=headers, json=json) @@ -104,14 +94,11 @@ async def call_api() -> Any: # checking for rate limit errors needs to punch through the RetryError and # look at its `__cause__`. we've observed Cloudflare giving transient 500 # status as well as a ReadTimeout, so we count these as rate limit errors -def is_chat_api_rate_limit(ex: BaseException) -> bool: +def should_retry_chat_api_error(ex: BaseException) -> bool: return isinstance(ex, RetryError) and ( ( isinstance(ex.__cause__, httpx.HTTPStatusError) - and ( - ex.__cause__.response.status_code == 429 - or ex.__cause__.response.status_code == 500 - ) + and is_retryable_http_status(ex.__cause__.response.status_code) ) or isinstance(ex.__cause__, httpx.ReadTimeout) ) diff --git a/src/inspect_ai/model/_providers/util/hooks.py b/src/inspect_ai/model/_providers/util/hooks.py new file mode 100644 index 000000000..3326e2b52 --- /dev/null +++ b/src/inspect_ai/model/_providers/util/hooks.py @@ -0,0 +1,165 @@ +import re +import time +from logging import getLogger +from typing import Any, Mapping, NamedTuple, cast + +import httpx +from shortuuid import uuid + +from inspect_ai._util.constants import HTTP +from inspect_ai._util.retry import report_http_retry + +logger = getLogger(__name__) + + +class RequestInfo(NamedTuple): + attempts: int + last_request: float + + +class HttpHooks: + """Class which hooks various HTTP clients for improved tracking/logging. + + A special header is injected into requests which is then read from + a request event hook -- this creates a record of when the request + started. Note that with retries a single request_id could be started + several times; our request hook makes sure we always track the time of + the last request. + + There is an 'end_request()' method which gets the total request time + for a request_id and then purges the request_id from our tracking (so + the dict doesn't grow unbounded) + + Additionally, an http response hook is installed and used for logging + requests for the 'http' log-level + """ + + REQUEST_ID_HEADER = "x-irid" + + def __init__(self) -> None: + # track request start times + self._requests: dict[str, RequestInfo] = {} + + def start_request(self) -> str: + request_id = uuid() + self._requests[request_id] = RequestInfo(0, time.monotonic()) + return request_id + + def end_request(self, request_id: str) -> float: + # read the request info (if available) and purge from dict + request_info = self._requests.pop(request_id, None) + if request_info is None: + raise RuntimeError(f"request_id not registered: {request_id}") + + # return elapsed time + return time.monotonic() - request_info.last_request + + def update_request_time(self, request_id: str) -> None: + request_info = self._requests.get(request_id, None) + if not request_info: + raise RuntimeError(f"No request registered for request_id: {request_id}") + + # update the attempts and last request time + request_info = RequestInfo(request_info.attempts + 1, time.monotonic()) + self._requests[request_id] = request_info + + # trace a retry if this is attempt > 1 + if request_info.attempts > 1: + report_http_retry() + + +class ConverseHooks(HttpHooks): + def __init__(self, session: Any) -> None: + from aiobotocore.session import AioSession + + super().__init__() + + # register hooks + session = cast(AioSession, session._session) + session.register( + "before-send.bedrock-runtime.Converse", self.converse_before_send + ) + session.register( + "after-call.bedrock-runtime.Converse", self.converse_after_call + ) + + def converse_before_send(self, **kwargs: Any) -> None: + user_agent = kwargs["request"].headers["User-Agent"].decode() + match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent) + if match: + request_id = match.group(1) + self.update_request_time(request_id) + + def converse_after_call(self, http_response: Any, **kwargs: Any) -> None: + from botocore.awsrequest import AWSResponse + + response = cast(AWSResponse, http_response) + logger.log(HTTP, f"POST {response.url} - {response.status_code}") + + def user_agent_extra(self, request_id: str) -> str: + return f"{self.USER_AGENT_PREFIX}{request_id}" + + USER_AGENT_PREFIX = "ins/rid#" + + +class HttpxHooks(HttpHooks): + def __init__(self, client: httpx.AsyncClient): + super().__init__() + + # install hooks + client.event_hooks["request"].append(self.request_hook) + client.event_hooks["response"].append(self.response_hook) + + async def request_hook(self, request: httpx.Request) -> None: + # update the last request time for this request id (as there could be retries) + request_id = request.headers.get(self.REQUEST_ID_HEADER, None) + if request_id: + self.update_request_time(request_id) + + async def response_hook(self, response: httpx.Response) -> None: + message = f'{response.request.method} {response.request.url} "{response.http_version} {response.status_code} {response.reason_phrase}" ' + logger.log(HTTP, message) + + +def urllib3_hooks() -> HttpHooks: + import urllib3 + from urllib3.connectionpool import HTTPConnectionPool + from urllib3.response import BaseHTTPResponse + + class Urllib3Hooks(HttpHooks): + def request_hook(self, headers: Mapping[str, str]) -> None: + # update the last request time for this request id (as there could be retries) + request_id = headers.get(self.REQUEST_ID_HEADER, None) + if request_id: + self.update_request_time(request_id) + + def response_hook( + self, method: str, url: str, response: BaseHTTPResponse + ) -> None: + message = f'{method} {url} "{response.version_string} {response.status} {response.reason}" ' + logger.log(HTTP, message) + + global _urlilb3_hooks + if _urlilb3_hooks is None: + # one time patch of urlopen + urlilb3_hooks = Urllib3Hooks() + original_urlopen = urllib3.connectionpool.HTTPConnectionPool.urlopen + + def patched_urlopen( + self: HTTPConnectionPool, method: str, url: str, **kwargs: Any + ) -> BaseHTTPResponse: + headers = kwargs.get("headers", {}) + urlilb3_hooks.request_hook(headers) + response = original_urlopen(self, method, url, **kwargs) + urlilb3_hooks.response_hook(method, f"{self.host}{url}", response) + return response + + urllib3.connectionpool.HTTPConnectionPool.urlopen = patched_urlopen # type: ignore[assignment,method-assign] + + # assign to global hooks instance + _urlilb3_hooks = urlilb3_hooks + + return _urlilb3_hooks + + +_urlilb3_hooks: HttpHooks | None = None diff --git a/src/inspect_ai/model/_providers/util/tracker.py b/src/inspect_ai/model/_providers/util/tracker.py deleted file mode 100644 index 8412862b5..000000000 --- a/src/inspect_ai/model/_providers/util/tracker.py +++ /dev/null @@ -1,92 +0,0 @@ -import re -import time -from typing import Any, cast - -import httpx -from shortuuid import uuid - - -class HttpTimeTracker: - def __init__(self) -> None: - # track request start times - self._requests: dict[str, float] = {} - - def start_request(self) -> str: - request_id = uuid() - self._requests[request_id] = time.monotonic() - return request_id - - def end_request(self, request_id: str) -> float: - # read the request time if (if available) and purge from dict - request_time = self._requests.pop(request_id, None) - if request_time is None: - raise RuntimeError(f"request_id not registered: {request_id}") - - # return elapsed time - return time.monotonic() - request_time - - def update_request_time(self, request_id: str) -> None: - request_time = self._requests.get(request_id, None) - if not request_time: - raise RuntimeError(f"No request registered for request_id: {request_id}") - - # update the request time - self._requests[request_id] = time.monotonic() - - -class BotoTimeTracker(HttpTimeTracker): - def __init__(self, session: Any) -> None: - from aiobotocore.session import AioSession - - super().__init__() - - # register hook - session = cast(AioSession, session._session) - session.register( - "before-send.bedrock-runtime.Converse", self.converse_before_send - ) - - def converse_before_send(self, **kwargs: Any) -> None: - user_agent = kwargs["request"].headers["User-Agent"].decode() - match = re.search(rf"{self.USER_AGENT_PREFIX}(\w+)", user_agent) - if match: - request_id = match.group(1) - self.update_request_time(request_id) - - def user_agent_extra(self, request_id: str) -> str: - return f"{self.USER_AGENT_PREFIX}{request_id}" - - USER_AGENT_PREFIX = "ins/rid#" - - -class HttpxTimeTracker(HttpTimeTracker): - """Class which tracks the duration of successful (200 status) http requests. - - A special header is injected into requests which is then read from - an httpx 'request' event hook -- this creates a record of when the request - started. Note that with retries a single request id could be started - several times; our request hook makes sure we always track the time of - the last request. - - To determine the total time, we also install an httpx response hook. In - this hook we look for 200 responses which have a registered request id. - When we find one, we update the end time of the request. - - There is an 'end_request()' method which gets the total requeset time - for a request_id and then purges the request_id from our tracking (so - the dict doesn't grow unbounded) - """ - - REQUEST_ID_HEADER = "x-irid" - - def __init__(self, client: httpx.AsyncClient): - super().__init__() - - # install httpx request hook - client.event_hooks["request"].append(self.request_hook) - - async def request_hook(self, request: httpx.Request) -> None: - # update the last request time for this request id (as there could be retries) - request_id = request.headers.get(self.REQUEST_ID_HEADER, None) - if request_id: - self.update_request_time(request_id) diff --git a/src/inspect_ai/model/_providers/vertex.py b/src/inspect_ai/model/_providers/vertex.py index c50e9bead..182e63f99 100644 --- a/src/inspect_ai/model/_providers/vertex.py +++ b/src/inspect_ai/model/_providers/vertex.py @@ -4,7 +4,13 @@ from typing import Any, cast import vertexai # type: ignore -from google.api_core.exceptions import TooManyRequests +from google.api_core.exceptions import ( + Aborted, + ClientError, + DeadlineExceeded, + ServiceUnavailable, +) +from google.api_core.retry import if_transient_error from google.protobuf.json_format import MessageToDict from pydantic import JsonValue from typing_extensions import override @@ -31,6 +37,7 @@ ContentText, ContentVideo, ) +from inspect_ai._util.http import is_retryable_http_status from inspect_ai._util.images import file_as_data from inspect_ai.tool import ToolCall, ToolChoice, ToolInfo @@ -169,8 +176,18 @@ async def generate( return output, call @override - def is_rate_limit(self, ex: BaseException) -> bool: - return isinstance(ex, TooManyRequests) + def should_retry(self, ex: Exception) -> bool: + # google API-specific errors + if isinstance(ex, Aborted | DeadlineExceeded | ServiceUnavailable): + return True + # standard HTTP errors + elif isinstance(ex, ClientError) and ex.code is not None: + return is_retryable_http_status(ex.code) + # additional errors flagged by google as transient + elif isinstance(ex, Exception): + return if_transient_error(ex) + else: + return False @override def connection_key(self) -> str: diff --git a/src/inspect_ai/tool/_tools/_web_search.py b/src/inspect_ai/tool/_tools/_web_search.py index c0e3ef15b..3863f85fc 100644 --- a/src/inspect_ai/tool/_tools/_web_search.py +++ b/src/inspect_ai/tool/_tools/_web_search.py @@ -13,7 +13,7 @@ ) from inspect_ai._util.error import PrerequisiteError -from inspect_ai._util.retry import httpx_should_retry, log_retry_attempt +from inspect_ai._util.httpx import httpx_should_retry, log_httpx_retry_attempt from inspect_ai.util._concurrency import concurrency from .._tool import Tool, ToolResult, tool @@ -204,7 +204,7 @@ async def search(query: str, start_idx: int) -> list[SearchLink]: wait=wait_exponential_jitter(), stop=stop_after_attempt(5) | stop_after_delay(60), retry=retry_if_exception(httpx_should_retry), - before_sleep=log_retry_attempt(search_url), + before_sleep=log_httpx_retry_attempt(search_url), ) async def execute_search() -> httpx.Response: return await client.get(search_url) diff --git a/tests/model/test_is_rate_limit.py b/tests/model/test_is_rate_limit.py new file mode 100644 index 000000000..6b7e960ac --- /dev/null +++ b/tests/model/test_is_rate_limit.py @@ -0,0 +1,43 @@ +import logging + +from inspect_ai._util.constants import PKG_NAME +from inspect_ai.model._chat_message import ChatMessage +from inspect_ai.model._generate_config import GenerateConfig +from inspect_ai.model._model import Model, ModelAPI +from inspect_ai.model._model_call import ModelCall +from inspect_ai.model._model_output import ModelOutput +from inspect_ai.tool._tool_choice import ToolChoice +from inspect_ai.tool._tool_info import ToolInfo + + +def test_is_rate_limit_deprecation(caplog): + # class which implements deprecated api + class TestAPI(ModelAPI): + def __init__(self) -> None: + super().__init__("model") + + async def generate( + self, + input: list[ChatMessage], + tools: list[ToolInfo], + tool_choice: ToolChoice, + config: GenerateConfig, + ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]: + return ModelOutput.from_content("model", "foo") + + def is_rate_limit(self, ex: BaseException) -> bool: + return True + + # add our logging handler to the caplog handler + logger = logging.getLogger(PKG_NAME) + logger.propagate = True + try: + # confirm that the deprecated api is called and triggers a warning + model = Model(TestAPI(), GenerateConfig()) + assert model.should_retry(Exception()) is True + assert any( + "deprecated is_rate_limit() method" in record.message + for record in caplog.records + ), "Expected warning not found in logs" + finally: + logger.propagate = False