Skip to content

Commit

Permalink
ruff?
Browse files Browse the repository at this point in the history
  • Loading branch information
nsmccandlish committed Oct 23, 2024
1 parent 18e497f commit 5f61360
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
44 changes: 27 additions & 17 deletions computer-use-demo/computer_use_demo/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@

from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
from anthropic.types.beta import (
BetaContentBlock,
BetaCacheControlEphemeralParam,
BetaContentBlockParam,
BetaImageBlockParam,
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaTextBlockParam,
BetaToolResultBlockParam,
BetaToolUseBlockParam,
BetaTextBlock,
BetaToolUseBlock,
BetaCacheControlEphemeralParam,
)

from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult
Expand Down Expand Up @@ -89,9 +87,8 @@ async def sampling_loop(
)

while True:

enable_prompt_caching = False
betas=["computer-use-2024-10-22"]
betas = ["computer-use-2024-10-22"]
image_truncation_threshold = 10
if provider == APIProvider.ANTHROPIC:
client = Anthropic(api_key=api_key)
Expand All @@ -100,16 +97,20 @@ async def sampling_loop(
client = AnthropicVertex()
elif provider == APIProvider.BEDROCK:
client = AnthropicBedrock()

if enable_prompt_caching:
betas.append("prompt-caching-2024-07-31")
_inject_prompt_caching(messages)
# Is it ever worth it to bust the cache with prompt caching?
image_truncation_threshold = 50

if only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images, min_removal_threshold=image_truncation_threshold)

_maybe_filter_to_n_most_recent_images(
messages,
only_n_most_recent_images,
min_removal_threshold=image_truncation_threshold,
)

# Call the API
# we use raw_response to provide debug information to streamlit. Your
# implementation may be able call the SDK directly with:
Expand All @@ -119,10 +120,12 @@ async def sampling_loop(
messages=messages,
model=model,
system=system,
tools=tool_collection.to_params(enable_prompt_caching=enable_prompt_caching),
betas=betas
tools=tool_collection.to_params(
enable_prompt_caching=enable_prompt_caching
),
betas=betas,
)

raw_response = cast(APIResponse[BetaMessage], raw_response)

api_response_callback(raw_response)
Expand Down Expand Up @@ -203,7 +206,10 @@ def _maybe_filter_to_n_most_recent_images(
new_content.append(content)
tool_result["content"] = new_content

def _response_to_params(response: BetaMessage) -> list[BetaTextBlockParam | BetaToolUseBlockParam]:

def _response_to_params(
response: BetaMessage,
) -> list[BetaTextBlockParam | BetaToolUseBlockParam]:
res: list[BetaTextBlockParam | BetaToolUseBlockParam] = []
for block in response.content:
if isinstance(block, BetaTextBlock):
Expand All @@ -212,6 +218,7 @@ def _response_to_params(response: BetaMessage) -> list[BetaTextBlockParam | Beta
res.append(cast(BetaToolUseBlockParam, block.model_dump()))
return res


def _inject_prompt_caching(
messages: list[BetaMessageParam],
):
Expand All @@ -221,20 +228,23 @@ def _inject_prompt_caching(
images in place, with a chunk of min_removal_threshold to reduce the amount we
break the implicit prompt cache.
"""

breakpoints_remaining = 3
for message in reversed(messages):
if message["role"] == "user" and isinstance(content := message["content"], list):
if message["role"] == "user" and isinstance(
content := message["content"], list
):
if breakpoints_remaining:
breakpoints_remaining -= 1
content[-1]["cache_control"] = BetaCacheControlEphemeralParam({"type": "ephemeral"})
content[-1]["cache_control"] = BetaCacheControlEphemeralParam(
{"type": "ephemeral"}
)
else:
content[-1].pop("cache_control", None)
# we'll only every have one extra turn per loop
break



def _make_api_tool_result(
result: ToolResult, tool_use_id: str
) -> BetaToolResultBlockParam:
Expand Down
7 changes: 5 additions & 2 deletions computer-use-demo/computer_use_demo/streamlit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from pathlib import PosixPath
from typing import cast

from anthropic.types.beta.beta_tool_use_block_param import BetaToolUseBlockParam
import streamlit as st
from anthropic import APIResponse
from anthropic.types.beta import BetaContentBlockParam, BetaMessage, BetaTextBlockParam, BetaToolResultBlockParam, BetaToolUseBlock
from anthropic.types.beta import (
BetaContentBlockParam,
BetaMessage,
BetaTextBlockParam,
)
from streamlit.delta_generator import DeltaGenerator

from computer_use_demo.loop import (
Expand Down
7 changes: 5 additions & 2 deletions computer-use-demo/tests/streamlit_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest import mock

from anthropic.types import TextBlockParam
import pytest
from anthropic.types import TextBlockParam
from streamlit.testing.v1 import AppTest

from computer_use_demo.streamlit import Sender
Expand All @@ -19,6 +19,9 @@ def test_streamlit(streamlit_app: AppTest):
streamlit_app.chat_input[0].set_value("Hello").run()
assert patch.called
assert patch.call_args.kwargs["messages"] == [
{"role": Sender.USER, "content": [TextBlockParam(text="Hello", type="text")]}
{
"role": Sender.USER,
"content": [TextBlockParam(text="Hello", type="text")],
}
]
assert not streamlit_app.exception

0 comments on commit 5f61360

Please sign in to comment.