From 0d069d1c95cd31dd96d19ea61c85d7d6fd706e02 Mon Sep 17 00:00:00 2001 From: Kevin Ji Date: Tue, 22 Oct 2024 17:16:37 -0700 Subject: [PATCH] [computer-use-demo] Add prompt caching --- computer-use-demo/computer_use_demo/loop.py | 30 ++++++++++++++++--- .../computer_use_demo/streamlit.py | 2 +- computer-use-demo/tests/loop_test.py | 19 +++++++----- computer-use-demo/tests/streamlit_test.py | 5 ++-- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/computer-use-demo/computer_use_demo/loop.py b/computer-use-demo/computer_use_demo/loop.py index bb959e4b..733fe0f6 100644 --- a/computer-use-demo/computer_use_demo/loop.py +++ b/computer-use-demo/computer_use_demo/loop.py @@ -8,7 +8,13 @@ from enum import StrEnum from typing import Any, cast -from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse +from anthropic import ( + Anthropic, + AnthropicBedrock, + AnthropicVertex, + APIResponse, + BaseModel, +) from anthropic.types import ( ToolResultBlockParam, ) @@ -24,8 +30,6 @@ from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult -BETA_FLAG = "computer-use-2024-10-22" - class APIProvider(StrEnum): ANTHROPIC = "anthropic" @@ -74,6 +78,7 @@ async def sampling_loop( api_key: str, only_n_most_recent_images: int | None = None, max_tokens: int = 4096, + prompt_caching: bool = True, ): """ Agentic sampling loop for the assistant/tool interaction of computer use. @@ -98,6 +103,23 @@ async def sampling_loop( elif provider == APIProvider.BEDROCK: client = AnthropicBedrock() + betas = ["computer-use-2024-10-22"] + if prompt_caching: + betas.append("prompt-caching-2024-07-31") + for message in messages: + if isinstance(message["content"], str): + continue + + params: list[BetaContentBlockParam] = [] + for content_block in message["content"]: + if isinstance(content_block, BaseModel): + content_block_param = content_block.to_dict() + else: + content_block_param = content_block + params.append(content_block_param) + content_block_param["cache_control"] = {"type": "ephemeral"} + message["content"] = params + # Call the API # we use raw_response to provide debug information to streamlit. Your # implementation may be able call the SDK directly with: @@ -108,7 +130,7 @@ async def sampling_loop( model=model, system=system, tools=tool_collection.to_params(), - betas=["computer-use-2024-10-22"], + betas=betas, ) api_response_callback(cast(APIResponse[BetaMessage], raw_response)) diff --git a/computer-use-demo/computer_use_demo/streamlit.py b/computer-use-demo/computer_use_demo/streamlit.py index 6750029c..97bb4fcb 100644 --- a/computer-use-demo/computer_use_demo/streamlit.py +++ b/computer-use-demo/computer_use_demo/streamlit.py @@ -194,7 +194,7 @@ def _reset_api_provider(): st.session_state.messages.append( { "role": Sender.USER, - "content": [TextBlock(type="text", text=new_message)], + "content": [BetaTextBlock(type="text", text=new_message)], } ) _render_message(Sender.USER, new_message) diff --git a/computer-use-demo/tests/loop_test.py b/computer-use-demo/tests/loop_test.py index 4985dbee..acce00ae 100644 --- a/computer-use-demo/tests/loop_test.py +++ b/computer-use-demo/tests/loop_test.py @@ -1,7 +1,11 @@ from unittest import mock -from anthropic.types import TextBlock, ToolUseBlock -from anthropic.types.beta import BetaMessage, BetaMessageParam +from anthropic.types.beta import ( + BetaMessage, + BetaMessageParam, + BetaTextBlock, + BetaToolUseBlock, +) from computer_use_demo.loop import APIProvider, sampling_loop @@ -13,13 +17,13 @@ async def test_loop(): mock.Mock( spec=BetaMessage, content=[ - TextBlock(type="text", text="Hello"), - ToolUseBlock( + BetaTextBlock(type="text", text="Hello"), + BetaToolUseBlock( type="tool_use", id="1", name="computer", input={"action": "test"} ), ], ), - mock.Mock(spec=BetaMessage, content=[TextBlock(type="text", text="Done!")]), + mock.Mock(spec=BetaMessage, content=[BetaTextBlock(type="text", text="Done!")]), ] tool_collection = mock.AsyncMock() @@ -49,7 +53,8 @@ async def test_loop(): ) assert len(result) == 4 - assert result[0] == {"role": "user", "content": "Test message"} + assert result[0]["role"] == "user" + assert result[0]["content"] == "Test message" assert result[1]["role"] == "assistant" assert result[2]["role"] == "user" assert result[3]["role"] == "assistant" @@ -58,7 +63,7 @@ async def test_loop(): tool_collection.run.assert_called_once_with( name="computer", tool_input={"action": "test"} ) - output_callback.assert_called_with(TextBlock(text="Done!", type="text")) + output_callback.assert_called_with(BetaTextBlock(text="Done!", type="text")) assert output_callback.call_count == 3 assert tool_output_callback.call_count == 1 assert api_response_callback.call_count == 2 diff --git a/computer-use-demo/tests/streamlit_test.py b/computer-use-demo/tests/streamlit_test.py index 25cd586b..8235a9dd 100644 --- a/computer-use-demo/tests/streamlit_test.py +++ b/computer-use-demo/tests/streamlit_test.py @@ -1,9 +1,10 @@ from unittest import mock import pytest +from anthropic.types.beta import BetaTextBlock from streamlit.testing.v1 import AppTest -from computer_use_demo.streamlit import Sender, TextBlock +from computer_use_demo.streamlit import Sender @pytest.fixture @@ -18,6 +19,6 @@ def test_streamlit(streamlit_app: AppTest): streamlit_app.chat_input[0].set_value("Hello").run() assert patch.called assert patch.call_args.kwargs["messages"] == [ - {"role": Sender.USER, "content": [TextBlock(text="Hello", type="text")]} + {"role": Sender.USER, "content": [BetaTextBlock(text="Hello", type="text")]} ] assert not streamlit_app.exception