From 5f61360dc25104248e8107771b00a5ce69875e61 Mon Sep 17 00:00:00 2001 From: nsmccandlish Date: Wed, 23 Oct 2024 11:32:26 -0700 Subject: [PATCH] ruff? --- computer-use-demo/computer_use_demo/loop.py | 44 ++++++++++++------- .../computer_use_demo/streamlit.py | 7 ++- computer-use-demo/tests/streamlit_test.py | 7 ++- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/computer-use-demo/computer_use_demo/loop.py b/computer-use-demo/computer_use_demo/loop.py index df0b4046..73040e7f 100644 --- a/computer-use-demo/computer_use_demo/loop.py +++ b/computer-use-demo/computer_use_demo/loop.py @@ -10,17 +10,15 @@ from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse from anthropic.types.beta import ( - BetaContentBlock, + BetaCacheControlEphemeralParam, BetaContentBlockParam, BetaImageBlockParam, BetaMessage, BetaMessageParam, + BetaTextBlock, BetaTextBlockParam, BetaToolResultBlockParam, BetaToolUseBlockParam, - BetaTextBlock, - BetaToolUseBlock, - BetaCacheControlEphemeralParam, ) from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult @@ -89,9 +87,8 @@ async def sampling_loop( ) while True: - enable_prompt_caching = False - betas=["computer-use-2024-10-22"] + betas = ["computer-use-2024-10-22"] image_truncation_threshold = 10 if provider == APIProvider.ANTHROPIC: client = Anthropic(api_key=api_key) @@ -100,7 +97,7 @@ async def sampling_loop( client = AnthropicVertex() elif provider == APIProvider.BEDROCK: client = AnthropicBedrock() - + if enable_prompt_caching: betas.append("prompt-caching-2024-07-31") _inject_prompt_caching(messages) @@ -108,8 +105,12 @@ async def sampling_loop( image_truncation_threshold = 50 if only_n_most_recent_images: - _maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images, min_removal_threshold=image_truncation_threshold) - + _maybe_filter_to_n_most_recent_images( + messages, + only_n_most_recent_images, + min_removal_threshold=image_truncation_threshold, + ) + # Call the API # we use raw_response to provide debug information to streamlit. Your # implementation may be able call the SDK directly with: @@ -119,10 +120,12 @@ async def sampling_loop( messages=messages, model=model, system=system, - tools=tool_collection.to_params(enable_prompt_caching=enable_prompt_caching), - betas=betas + tools=tool_collection.to_params( + enable_prompt_caching=enable_prompt_caching + ), + betas=betas, ) - + raw_response = cast(APIResponse[BetaMessage], raw_response) api_response_callback(raw_response) @@ -203,7 +206,10 @@ def _maybe_filter_to_n_most_recent_images( new_content.append(content) tool_result["content"] = new_content -def _response_to_params(response: BetaMessage) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: + +def _response_to_params( + response: BetaMessage, +) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: res: list[BetaTextBlockParam | BetaToolUseBlockParam] = [] for block in response.content: if isinstance(block, BetaTextBlock): @@ -212,6 +218,7 @@ def _response_to_params(response: BetaMessage) -> list[BetaTextBlockParam | Beta res.append(cast(BetaToolUseBlockParam, block.model_dump())) return res + def _inject_prompt_caching( messages: list[BetaMessageParam], ): @@ -221,20 +228,23 @@ def _inject_prompt_caching( images in place, with a chunk of min_removal_threshold to reduce the amount we break the implicit prompt cache. """ - + breakpoints_remaining = 3 for message in reversed(messages): - if message["role"] == "user" and isinstance(content := message["content"], list): + if message["role"] == "user" and isinstance( + content := message["content"], list + ): if breakpoints_remaining: breakpoints_remaining -= 1 - content[-1]["cache_control"] = BetaCacheControlEphemeralParam({"type": "ephemeral"}) + content[-1]["cache_control"] = BetaCacheControlEphemeralParam( + {"type": "ephemeral"} + ) else: content[-1].pop("cache_control", None) # we'll only every have one extra turn per loop break - def _make_api_tool_result( result: ToolResult, tool_use_id: str ) -> BetaToolResultBlockParam: diff --git a/computer-use-demo/computer_use_demo/streamlit.py b/computer-use-demo/computer_use_demo/streamlit.py index 0f29fb76..c0350dd2 100644 --- a/computer-use-demo/computer_use_demo/streamlit.py +++ b/computer-use-demo/computer_use_demo/streamlit.py @@ -12,10 +12,13 @@ from pathlib import PosixPath from typing import cast -from anthropic.types.beta.beta_tool_use_block_param import BetaToolUseBlockParam import streamlit as st from anthropic import APIResponse -from anthropic.types.beta import BetaContentBlockParam, BetaMessage, BetaTextBlockParam, BetaToolResultBlockParam, BetaToolUseBlock +from anthropic.types.beta import ( + BetaContentBlockParam, + BetaMessage, + BetaTextBlockParam, +) from streamlit.delta_generator import DeltaGenerator from computer_use_demo.loop import ( diff --git a/computer-use-demo/tests/streamlit_test.py b/computer-use-demo/tests/streamlit_test.py index 7a76c1fa..d9a42936 100644 --- a/computer-use-demo/tests/streamlit_test.py +++ b/computer-use-demo/tests/streamlit_test.py @@ -1,7 +1,7 @@ from unittest import mock -from anthropic.types import TextBlockParam import pytest +from anthropic.types import TextBlockParam from streamlit.testing.v1 import AppTest from computer_use_demo.streamlit import Sender @@ -19,6 +19,9 @@ def test_streamlit(streamlit_app: AppTest): streamlit_app.chat_input[0].set_value("Hello").run() assert patch.called assert patch.call_args.kwargs["messages"] == [ - {"role": Sender.USER, "content": [TextBlockParam(text="Hello", type="text")]} + { + "role": Sender.USER, + "content": [TextBlockParam(text="Hello", type="text")], + } ] assert not streamlit_app.exception