Skip to content

Commit

Permalink
update prompt image resizing and improve example app
Browse files Browse the repository at this point in the history
  • Loading branch information
tahouse committed Nov 22, 2024
1 parent ec54609 commit 1aae049
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 78 deletions.
70 changes: 36 additions & 34 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
import base64
from dataclasses import dataclass
from io import BytesIO
import time
from typing import Any, Dict, List
from typing import List

import streamlit as st
from streamlit_chat_prompt import PromptReturn, prompt, ImageData
from PIL import Image

from streamlit_chat_prompt import ImageData, PromptReturn, prompt

st.title("streamlit-chat-prompt")


@dataclass
class ChatMessage:
role: str
content: str | PromptReturn


if "messages" not in st.session_state:
messages: List[Dict[str, str | PromptReturn]] = [
{"role": "assistant", "content": "Hi there! What should we chat about?"}
messages: List[ChatMessage] = [
ChatMessage(role="assistant", content="Hi there! What should we chat about?")
]
st.session_state.messages = messages

if "new_default_input" not in st.session_state:
st.session_state.new_default_input = None


@st.dialog("Prompt in dialog")
def test_dg(default_input: str | PromptReturn | None = None, key="default_dialog_key"):
def dialog(default_input: str | PromptReturn | None = None, key="default_dialog_key"):
dialog_input = prompt(
"dialog_prompt",
key=key,
Expand All @@ -30,16 +35,13 @@ def test_dg(default_input: str | PromptReturn | None = None, key="default_dialog
)
if dialog_input:
st.write(dialog_input)
st.session_state.new_default_input = dialog_input
time.sleep(2)
st.rerun()


with st.sidebar:
st.header("Sidebar")

if st.button("Dialog Prompt", key=f"dialog_prompt_button"):
test_dg()
dialog()

if st.button(
"Dialog Prompt with Default Value", key=f"dialog_prompt_with_default_button"
Expand All @@ -48,49 +50,49 @@ def test_dg(default_input: str | PromptReturn | None = None, key="default_dialog
image_data = f.read()
image = Image.open(BytesIO(image_data))
base64_image = base64.b64encode(image_data).decode("utf-8")
test_dg(
dialog(
default_input=PromptReturn(
message="This is a test message with an image",
text="This is a test message with an image",
images=[
ImageData(data=base64_image, type="image/png", format="base64")
],
),
key="dialog_with_default",
)

for message in st.session_state.messages:
message: Dict[str, str | PromptReturn]
role: str = message["role"]
content: str | PromptReturn = message["content"]
for chat_message in st.session_state.messages:
chat_message: ChatMessage

with st.chat_message(role):
if isinstance(content, PromptReturn):
st.markdown(content.message)
if content.images:
for image in content.images:
with st.chat_message(chat_message.role):
if isinstance(chat_message.content, PromptReturn):
st.markdown(chat_message.content.text)
if chat_message.content.images:
for image_data in chat_message.content.images:
st.divider()
image_data: bytes = base64.b64decode(image.data)
st.markdown("Ussng `st.image`")
st.image(Image.open(BytesIO(image_data)))
st.markdown("Using `st.markdown`")
st.markdown(
f"![Image example](data:{image_data.type};{image_data.format},{image_data.data})"
)

# or use markdown
# or use PIL
st.divider()
st.markdown("Using `st.markdown`")
st.markdown(f"![Hello World](data:image/png;base64, {image.data})")
st.markdown("Using `st.image`")
image = Image.open(BytesIO(base64.b64decode(image_data.data)))
st.image(image)

else:
st.markdown(content)
st.markdown(chat_message.content)

prompt_return: PromptReturn | None = prompt(
name="foo",
key="chat_prompt",
placeholder="Hi there! What should we chat about?",
main_bottom=True,
default=st.session_state.new_default_input,
)

if prompt_return:
st.session_state.messages.append({"role": "user", "content": prompt_return})
st.session_state.messages.append(ChatMessage(role="user", content=prompt_return))
st.session_state.messages.append(
{"role": "assistant", "content": f"Echo: {prompt_return.message}"}
ChatMessage(role="assistant", content=f"Echo: {prompt_return.text}")
)
st.rerun()
12 changes: 6 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@

setuptools.setup(
name="streamlit-chat-prompt",
version="0.1.7",
version="0.2.0",
author="Tyler House",
author_email="[email protected]",
description="Streamlit component that allows you to create a chat prompt with paste and image attachment support",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/tahouse/streamlit-chat-prompt",
project_urls={
"Documentation": "https://github.com/tahouse/streamlit-chat-prompt/blob/main/README.md",
"Issue Tracker": "https://github.com/tahouse/streamlit-chat-prompt/issues",
},
project_urls={
"Documentation": "https://github.com/tahouse/streamlit-chat-prompt/blob/main/README.md",
"Issue Tracker": "https://github.com/tahouse/streamlit-chat-prompt/issues",
},
packages=setuptools.find_packages(),
include_package_data=True,
license="Apache-2.0",
Expand All @@ -31,7 +31,7 @@
"Intended Audience :: Developers",
"Topic :: Desktop Environment",
"Topic :: Multimedia :: Graphics",
"Topic :: Software Development :: User Interfaces"
"Topic :: Software Development :: User Interfaces",
],
python_requires=">=3.7",
install_requires=[
Expand Down
13 changes: 7 additions & 6 deletions streamlit_chat_prompt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ImageData(BaseModel):


class PromptReturn(BaseModel):
message: Optional[str] = None
text: Optional[str] = None
images: Optional[List[ImageData]] = None


Expand Down Expand Up @@ -98,7 +98,8 @@ def prompt(
"""
# Convert string default to PromptReturn if needed
if isinstance(default, str):
default = PromptReturn(message=default)
default = PromptReturn(text=default)


if f"chat_prompt_{key}_prev_uuid" not in st.session_state:
st.session_state[f"chat_prompt_{key}_prev_uuid"] = None
Expand All @@ -112,7 +113,7 @@ def prompt(
f"data:{img.type};{img.format},{img.data}" for img in default.images
]
default_value = {
"message": default.message or "",
"text": default.text or "",
"images": images,
"uuid": None, # No UUID for default value
}
Expand Down Expand Up @@ -184,7 +185,7 @@ def prompt(
component_value
and component_value["uuid"] != st.session_state[f"chat_prompt_{key}_prev_uuid"]
):
# we have a new message
# we have a new prompt return
st.session_state[f"chat_prompt_{key}_prev_uuid"] = component_value["uuid"]
images = []
# Process any images
Expand All @@ -199,11 +200,11 @@ def prompt(
)
# print(len(image_data) / 1024 / 1024)

if not images and not component_value.get("message"):
if not images and not component_value.get("text"):
return None

return PromptReturn(
message=component_value.get("message"),
text=component_value.get("text"),
images=images,
)
else:
Expand Down
96 changes: 64 additions & 32 deletions streamlit_chat_prompt/frontend/src/Prompt.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import CloseIcon from "@mui/icons-material/Close"

interface State {
uuid: string
message: string
text: string
images: File[]
isFocused: boolean
disabled: boolean
Expand All @@ -40,7 +40,7 @@ class ChatInput extends StreamlitComponentBase<State> {
this.maxImageSize = this.props.args?.max_image_size || 1024 * 1024 * 5
this.state = {
uuid: "",
message: this.props.args?.default?.message || "",
text: this.props.args?.default?.text || "",
images: [],
isFocused: false,
disabled: this.props.args?.disabled || false,
Expand All @@ -53,7 +53,7 @@ class ChatInput extends StreamlitComponentBase<State> {

// Initialize state with default values if provided
const defaultValue = this.props.args["default"] || {
message: "",
text: "",
images: [],
}

Expand Down Expand Up @@ -143,14 +143,22 @@ class ChatInput extends StreamlitComponentBase<State> {
console.log(`Processing file: ${file.name}`)
const processedImage = await this.processImage(file)
if (processedImage) {
console.log("Successfully processed image:", {
name: file.name,
originalSize: `${(file.size / 1024 / 1024).toFixed(2)}MB`,
processedSize: `${(processedImage.size / 1024 / 1024).toFixed(2)}MB`,
})
this.setState((prevState) => ({
images: [...prevState.images, processedImage],
}))
const sizeCheck = await this.checkFileSize(processedImage)
if (sizeCheck.isValid) {
console.log("Successfully processed image:", {
name: file.name,
originalSize: `${(file.size / 1024 / 1024).toFixed(2)}MB`,
finalFileSize: `${(sizeCheck.fileSize / 1024 / 1024).toFixed(2)}MB`,
finalBase64Size: `${(sizeCheck.base64Size / 1024 / 1024).toFixed(
2
)}MB`,
})
this.setState((prevState) => ({
images: [...prevState.images, processedImage],
}))
} else {
throw new Error("Processed image still exceeds size limits")
}
} else {
console.log(`Failed to process image: ${file.name}`)
this.showNotification(
Expand Down Expand Up @@ -192,13 +200,48 @@ class ChatInput extends StreamlitComponentBase<State> {
this.focusTextField()
}

private async checkFileSize(file: File): Promise<{
isValid: boolean
fileSize: number
base64Size: number
}> {
// Get base64 size
const base64Size = await new Promise<number>((resolve) => {
const reader = new FileReader()
reader.onloadend = () => {
const base64String = reader.result as string
resolve(base64String.length)
}
reader.readAsDataURL(file)
})

const fileSize = file.size
const isValid =
base64Size <= this.maxImageSize && fileSize <= this.maxImageSize

console.log("File size check:", {
fileName: file.name,
fileSize: `${(fileSize / 1024 / 1024).toFixed(2)}MB`,
base64Size: `${(base64Size / 1024 / 1024).toFixed(2)}MB`,
maxSize: `${(this.maxImageSize / 1024 / 1024).toFixed(2)}MB`,
isValid,
})

return {
isValid,
fileSize,
base64Size,
}
}
async processImage(file: File): Promise<File | null> {
console.log(`Processing image: ${file.name}`, {
originalSize: `${(file.size / 1024 / 1024).toFixed(2)}MB`,
type: file.type,
})

if (file.size <= this.maxImageSize) {
// Check if the original file is already small enough
const initialSizeCheck = await this.checkFileSize(file)
if (initialSizeCheck.isValid) {
console.log("Image already under size limit, returning original")
return file
}
Expand All @@ -221,12 +264,8 @@ class ChatInput extends StreamlitComponentBase<State> {
for (const quality of [1.0, 0.9, 0.8, 0.7]) {
console.log(`Trying compression only with quality=${quality}`)
const result = await this.compressImage(img, quality, 1.0)
console.log("Compression result:", {
quality,
size: `${(result.size / 1024 / 1024).toFixed(2)}MB`,
})

if (result.size <= this.maxImageSize) {
const sizeCheck = await this.checkFileSize(result)
if (sizeCheck.isValid) {
console.log("Successfully compressed without scaling")
return result
}
Expand All @@ -239,15 +278,8 @@ class ChatInput extends StreamlitComponentBase<State> {
`Trying scaling with scale=${scale.toFixed(2)} and quality=0.8`
)
const result = await this.compressImage(img, 0.8, scale)
console.log("Scaling result:", {
scale: scale.toFixed(2),
dimensions: `${Math.round(img.width * scale)}x${Math.round(
img.height * scale
)}`,
size: `${(result.size / 1024 / 1024).toFixed(2)}MB`,
})

if (result.size <= this.maxImageSize) {
const sizeCheck = await this.checkFileSize(result)
if (sizeCheck.isValid) {
console.log("Successfully compressed with scaling")
return result
}
Expand Down Expand Up @@ -327,7 +359,7 @@ class ChatInput extends StreamlitComponentBase<State> {
handleSubmit() {
if (this.state.disabled) return

if (!this.state.message && this.state.images.length === 0) return
if (!this.state.text && this.state.images.length === 0) return

const imagePromises = this.state.images.map((image) => {
return new Promise((resolve) => {
Expand All @@ -340,12 +372,12 @@ class ChatInput extends StreamlitComponentBase<State> {
Promise.all(imagePromises).then((imageData) => {
Streamlit.setComponentValue({
uuid: crypto.randomUUID(),
message: this.state.message,
text: this.state.text,
images: imageData,
})
this.setState({
uuid: "",
message: "",
text: "",
images: [],
})
})
Expand Down Expand Up @@ -463,8 +495,8 @@ class ChatInput extends StreamlitComponentBase<State> {
maxRows={11}
fullWidth
disabled={disabled}
value={this.state.message}
onChange={(e) => this.setState({ message: e.target.value })}
value={this.state.text}
onChange={(e) => this.setState({ text: e.target.value })}
onKeyDown={this.handleKeyDown}
placeholder={this.props.args["placeholder"]}
variant="standard"
Expand Down

0 comments on commit 1aae049

Please sign in to comment.