Skip to content

Commit

Permalink
Adds a fun image telephone app to demonstrate the hub
Browse files Browse the repository at this point in the history
This corresponds to a streamlit app (image-telephone.streamlit.app).
It runs hamilton dataflows in a recursive loop between gpt-4 and
dallE-3. It saves the results and displays them on s3.
  • Loading branch information
elijahbenizzy committed Dec 17, 2023
1 parent 2f7236f commit 681eaa3
Show file tree
Hide file tree
Showing 6 changed files with 1,323 additions and 0 deletions.
16 changes: 16 additions & 0 deletions examples/LLM_Workflows/image_telephone/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Image Telephone

See the [streamlit app](https://image-telephone.streamlit.app) for documentation.

This example uses dataflows from the hub to do something fun with image captioning and generation.
Note that hamilton code is used rather than defined here.


# Contents

There are two files in this:

1. generate_images.ipynb
2. streamlit.py

The first is a notebook that generates images and captions, and the second is a streamlit app that displays them.
100 changes: 100 additions & 0 deletions examples/LLM_Workflows/image_telephone/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import dataclasses
import io
import json
import logging
from typing import Any, Collection, Dict, Type
from urllib import parse

import boto3
import requests
from PIL import Image

from hamilton.io.data_adapters import DataSaver
from hamilton.registry import register_adapter

client = boto3.client("s3")

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class JSONS3DataSaver(DataSaver):
bucket: str
key: str

def save_data(self, data: dict) -> Dict[str, Any]:
data = json.dumps(data).encode()
client.put_object(Body=data, Bucket=self.bucket, Key=self.key)

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [dict]

@classmethod
def name(cls) -> str:
return "json_s3"


def _load_image(uri: str, format: str) -> Image:
parsed = parse.urlparse(uri)
if parsed.scheme.strip() == "": # local file to upload
with open(uri, "rb") as f:
data = f.read()
elif parsed.scheme.strip() in ("https", "http"): # URL to copy over
response = requests.get(uri)
data = response.content
image = Image.open(io.BytesIO(data))
if format in ("jpeg", "jpg"): # TODO -- add more formats if they don't support it
if image.mode in ("RGBA", "P"):
image = image.convert("RGB")
return image


@dataclasses.dataclass
class ImageS3DataSaver(DataSaver):
bucket: str
key: str
format: str
# image_convert_params: Optional[Dict[str, Any]] = None

def save_data(self, data: str) -> Dict[str, Any]:
image = _load_image(data, self.format)
in_mem_file = io.BytesIO()
image.save(in_mem_file, format=self.format)
in_mem_file.seek(0)
client.put_object(Body=in_mem_file, Bucket=self.bucket, Key=self.key)
return {"key": self.key, "bucket": self.bucket}

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [str] # URL or local path

@classmethod
def name(cls) -> str:
return "image_s3"


@dataclasses.dataclass
class LocalImageSaver(DataSaver):
path: str
format: str
# image_convert_params: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)

def save_data(self, data: str) -> Dict[str, Any]:
image = _load_image(data, self.format)
image.save(self.path, format=self.format)
return {"path": self.path}

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [str] # URL or local path

@classmethod
def name(cls) -> str:
return "image"


adapters = [JSONS3DataSaver, ImageS3DataSaver, LocalImageSaver]

for adapter in adapters:
register_adapter(adapter)
246 changes: 246 additions & 0 deletions examples/LLM_Workflows/image_telephone/image_telephone.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "3f76f071-0090-4d11-89e9-4f07c73bd405",
"metadata": {},
"source": [
"# Image Telephone"
]
},
{
"cell_type": "markdown",
"id": "32fe92a2-cbd6-405b-a15f-b66dd13f2526",
"metadata": {},
"source": [
"# Environment management\n",
"- import relevant modules\n",
"- make environment assertions\n",
"- set up variables for dataflow to use\n",
"\n",
"You'll want to tune the variables for later use (the image, the s3 bucket/data directory...)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea9190a8-bac5-4473-8cc8-7205108746d9",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import sys\n",
"import urllib\n",
"from io import BytesIO\n",
"from typing import Tuple\n",
"\n",
"import boto3\n",
"from PIL import Image\n",
"from tenacity import retry, stop_after_delay\n",
"from run import determine_state\n",
"\n",
"from hamilton import driver\n",
"from hamilton.io.materialization import to\n",
"\n",
"assert \"OPENAI_API_KEY\" in os.environ, \"Must have OpenAI key set for this to work!\""
]
},
{
"cell_type": "markdown",
"id": "aead7fdb-7544-46bc-ab53-d1e19f54a254",
"metadata": {},
"source": [
"# Ensure your state is correct\n",
"\n",
"You'll want your initial image in png format -- change the `INITIAL_IMAGE_PATH` to specify it!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ca4fbc12-c070-4a38-9ef7-6f13376128e1",
"metadata": {},
"outputs": [],
"source": [
"STORAGE_ENGINE = \"local\" # s3 or local\n",
"S3_BUCKET = \"dagworks-image-telephone\" # TODO -- put your bucket\n",
"\n",
"DATA_DIR = \"./results\" # For local mode, unset for now\n",
"\n",
"\n",
"INITIAL_IMAGE_PATH = \"./seed_images/test_wikipedia_image_20231213.png\"\n",
"UNIQUE_IMAGE_NAME = \"test_wikipedia_image_20231213\"\n",
"NUM_ITERATIONS = 3\n",
"\n",
"DESCRIPTIVENESS = \"obsessively\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1dae8592-3f87-4cb4-8cd6-4aeff9eed7b8",
"metadata": {},
"outputs": [],
"source": [
"if STORAGE_ENGINE == \"s3\":\n",
" assert S3_BUCKET is not None, \"Must provide S3_BUCKET to use S3\"\n",
"\n",
"if STORAGE_ENGINE == \"local\":\n",
" assert DATA_DIR is not None, \"Must provide data directory for results when using local mode\"\n",
" \n",
"BASE_SAVE_LOCATION = os.path.join(DATA_DIR, UNIQUE_IMAGE_NAME) if STORAGE_ENGINE == \"local\" else os.path.join(f\"s3://{S3_BUCKET}/{UNIQUE_IMAGE_NAME}\")\n",
"\n",
"if STORAGE_ENGINE == \"local\":\n",
" if not os.path.exists(BASE_SAVE_LOCATION):\n",
" os.makedirs(BASE_SAVE_LOCATION)\n",
" "
]
},
{
"cell_type": "markdown",
"id": "bf9aade2-28db-4373-bf51-6409e598965a",
"metadata": {},
"source": [
"# Pull dataflows from the Hub\n",
"\n",
"These two dataflows have everything we need to play image telephone. We're going to download two dataflows:\n",
"\n",
"1. `caption_images` -- this has the ability to provide a caption given an image\n",
"2. `generate_images` -- this has the ability to generate an image, given a caption\n",
"\n",
"We use the hub API to download the modules, then do a quick visualization to ensure we're happy with what we've got. We've combined these into the same driver, although one could easily run two drivers. The DAG's are actually independent"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51534640-431e-432f-a470-0ca604916d7d",
"metadata": {},
"outputs": [],
"source": [
"from hamilton import dataflows\n",
"caption_images = dataflows.import_module(\"caption_images\", \"elijahbenizzy\")\n",
"generate_images = dataflows.import_module(\"generate_images\", \"elijahbenizzy\")\n",
"import caption_images\n",
"import generate_images\n",
"dr = driver.Driver({\"include_embeddings\" : True}, caption_images, generate_images)\n",
"dr.display_all_functions(orient=\"TB\")"
]
},
{
"cell_type": "markdown",
"id": "f3bc4ef0-b86c-4a91-b5c6-7bc1e4400741",
"metadata": {},
"source": [
"# Define our Capabilities (chains)\n",
"\n",
"We define some pretty basic functions that allow us to run components of the DAG. We'll be running these in a loop, displaying the results in-between to track progress. We do two calls to `.materialize(...)` -- this allows us to run/execute the DAG.\n",
"\n",
"1. Generate captions\n",
"2. Generate images\n",
"\n",
"We then update the state, and run again!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5b808f9b-f8e5-4128-b64a-60f132c3c8c2",
"metadata": {},
"outputs": [],
"source": [
"# This allows execution to start where it left off\n",
"iteration, image_url, has_original = determine_state(\n",
" INITIAL_IMAGE_PATH,\n",
" STORAGE_ENGINE,\n",
" UNIQUE_IMAGE_NAME,\n",
" {\"base_dir\": DATA_DIR, \"s3_bucket\": S3_BUCKET}\n",
")\n",
"\n",
"# Loop until we're there\n",
"while iteration < NUM_ITERATIONS:\n",
" print(f\" Beginning iteration: {iteration} with image URL: {image_url}\")\n",
" metadata_save_path = os.path.join(BASE_SAVE_LOCATION, f\"metadata_{iteration}.json\")\n",
" # Run the caption generation step\n",
" _, results = dr.materialize(\n",
" to.json(\n",
" path=metadata_save_path,\n",
" dependencies=[\"metadata\"],\n",
" id=\"save_metadata\",\n",
" ),\n",
" *([] if has_original else [\n",
" to.image(\n",
" path=os.path.join(BASE_SAVE_LOCATION, f\"{UNIQUE_IMAGE_NAME}/original.png\"),\n",
" dependencies=[\"image_url\"],\n",
" id=f\"save_original_image\",\n",
" format=\"png\",\n",
" )\n",
" ]),\n",
" additional_vars=[\"generated_caption\"],\n",
" inputs={\n",
" \"image_url\" : image_url,\n",
" \"descriptiveness\" : DESCRIPTIVENESS,\n",
" \"additional_metadata\" : {\n",
" \"descriptiveness\" : DESCRIPTIVENESS,\n",
" \"iteration\" : iteration,\n",
" }\n",
" }\n",
" )\n",
"\n",
" generated_caption = results[\"generated_caption\"]\n",
" print(f\"Captioned image: {image_url} with caption: {generated_caption}. \\n\\n Saved metadata (caption + embeddings) at: {metadata_save_path}\")\n",
" image_save_path = os.path.join(BASE_SAVE_LOCATION, f\"image_{iteration}.png\")\n",
"\n",
" # Run the image generation step\n",
" _, results = dr.materialize(\n",
" to.image(\n",
" path=image_save_path,\n",
" dependencies=[\"generated_image\"],\n",
" id=f\"save_image\",\n",
" format=\"png\",\n",
" ),\n",
" inputs={\"image_generation_prompt\" : generated_caption},\n",
" additional_vars=[\"generated_image\"]\n",
" )\n",
" generated_caption = results[\"generated_caption\"]\n",
" generated_image = results[\"generated_image\"]\n",
" print(f\"Generated image, saved at: {image_save_path}\")\n",
" iteration += 1\n",
" image_url = image_save_path\n",
" has_original = True\n",
" with open(generated_image) as url:\n",
" img = Image.open(BytesIO(url.read()))\n",
" display(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b60679f8-0485-47a7-b7ce-63c3b26b6425",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
14 changes: 14 additions & 0 deletions examples/LLM_Workflows/image_telephone/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
altair
boto3
openai
opentsne
pandas
pillow
s3fs
scikit-learn
sf-hamilton
sf-hamilton-contrib
st-files-connection
streamlit
streamlit-super-slider
tenacity
Loading

0 comments on commit 681eaa3

Please sign in to comment.