-
Notifications
You must be signed in to change notification settings - Fork 133
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a fun image telephone app to demonstrate the hub
This corresponds to a streamlit app (image-telephone.streamlit.app). It runs hamilton dataflows in a recursive loop between gpt-4 and dallE-3. It saves the results and displays them on s3.
- Loading branch information
1 parent
2f7236f
commit 681eaa3
Showing
6 changed files
with
1,323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Image Telephone | ||
|
||
See the [streamlit app](https://image-telephone.streamlit.app) for documentation. | ||
|
||
This example uses dataflows from the hub to do something fun with image captioning and generation. | ||
Note that hamilton code is used rather than defined here. | ||
|
||
|
||
# Contents | ||
|
||
There are two files in this: | ||
|
||
1. generate_images.ipynb | ||
2. streamlit.py | ||
|
||
The first is a notebook that generates images and captions, and the second is a streamlit app that displays them. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import dataclasses | ||
import io | ||
import json | ||
import logging | ||
from typing import Any, Collection, Dict, Type | ||
from urllib import parse | ||
|
||
import boto3 | ||
import requests | ||
from PIL import Image | ||
|
||
from hamilton.io.data_adapters import DataSaver | ||
from hamilton.registry import register_adapter | ||
|
||
client = boto3.client("s3") | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclasses.dataclass | ||
class JSONS3DataSaver(DataSaver): | ||
bucket: str | ||
key: str | ||
|
||
def save_data(self, data: dict) -> Dict[str, Any]: | ||
data = json.dumps(data).encode() | ||
client.put_object(Body=data, Bucket=self.bucket, Key=self.key) | ||
|
||
@classmethod | ||
def applicable_types(cls) -> Collection[Type]: | ||
return [dict] | ||
|
||
@classmethod | ||
def name(cls) -> str: | ||
return "json_s3" | ||
|
||
|
||
def _load_image(uri: str, format: str) -> Image: | ||
parsed = parse.urlparse(uri) | ||
if parsed.scheme.strip() == "": # local file to upload | ||
with open(uri, "rb") as f: | ||
data = f.read() | ||
elif parsed.scheme.strip() in ("https", "http"): # URL to copy over | ||
response = requests.get(uri) | ||
data = response.content | ||
image = Image.open(io.BytesIO(data)) | ||
if format in ("jpeg", "jpg"): # TODO -- add more formats if they don't support it | ||
if image.mode in ("RGBA", "P"): | ||
image = image.convert("RGB") | ||
return image | ||
|
||
|
||
@dataclasses.dataclass | ||
class ImageS3DataSaver(DataSaver): | ||
bucket: str | ||
key: str | ||
format: str | ||
# image_convert_params: Optional[Dict[str, Any]] = None | ||
|
||
def save_data(self, data: str) -> Dict[str, Any]: | ||
image = _load_image(data, self.format) | ||
in_mem_file = io.BytesIO() | ||
image.save(in_mem_file, format=self.format) | ||
in_mem_file.seek(0) | ||
client.put_object(Body=in_mem_file, Bucket=self.bucket, Key=self.key) | ||
return {"key": self.key, "bucket": self.bucket} | ||
|
||
@classmethod | ||
def applicable_types(cls) -> Collection[Type]: | ||
return [str] # URL or local path | ||
|
||
@classmethod | ||
def name(cls) -> str: | ||
return "image_s3" | ||
|
||
|
||
@dataclasses.dataclass | ||
class LocalImageSaver(DataSaver): | ||
path: str | ||
format: str | ||
# image_convert_params: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict) | ||
|
||
def save_data(self, data: str) -> Dict[str, Any]: | ||
image = _load_image(data, self.format) | ||
image.save(self.path, format=self.format) | ||
return {"path": self.path} | ||
|
||
@classmethod | ||
def applicable_types(cls) -> Collection[Type]: | ||
return [str] # URL or local path | ||
|
||
@classmethod | ||
def name(cls) -> str: | ||
return "image" | ||
|
||
|
||
adapters = [JSONS3DataSaver, ImageS3DataSaver, LocalImageSaver] | ||
|
||
for adapter in adapters: | ||
register_adapter(adapter) |
246 changes: 246 additions & 0 deletions
246
examples/LLM_Workflows/image_telephone/image_telephone.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3f76f071-0090-4d11-89e9-4f07c73bd405", | ||
"metadata": {}, | ||
"source": [ | ||
"# Image Telephone" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "32fe92a2-cbd6-405b-a15f-b66dd13f2526", | ||
"metadata": {}, | ||
"source": [ | ||
"# Environment management\n", | ||
"- import relevant modules\n", | ||
"- make environment assertions\n", | ||
"- set up variables for dataflow to use\n", | ||
"\n", | ||
"You'll want to tune the variables for later use (the image, the s3 bucket/data directory...)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ea9190a8-bac5-4473-8cc8-7205108746d9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import sys\n", | ||
"import urllib\n", | ||
"from io import BytesIO\n", | ||
"from typing import Tuple\n", | ||
"\n", | ||
"import boto3\n", | ||
"from PIL import Image\n", | ||
"from tenacity import retry, stop_after_delay\n", | ||
"from run import determine_state\n", | ||
"\n", | ||
"from hamilton import driver\n", | ||
"from hamilton.io.materialization import to\n", | ||
"\n", | ||
"assert \"OPENAI_API_KEY\" in os.environ, \"Must have OpenAI key set for this to work!\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "aead7fdb-7544-46bc-ab53-d1e19f54a254", | ||
"metadata": {}, | ||
"source": [ | ||
"# Ensure your state is correct\n", | ||
"\n", | ||
"You'll want your initial image in png format -- change the `INITIAL_IMAGE_PATH` to specify it!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "ca4fbc12-c070-4a38-9ef7-6f13376128e1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"STORAGE_ENGINE = \"local\" # s3 or local\n", | ||
"S3_BUCKET = \"dagworks-image-telephone\" # TODO -- put your bucket\n", | ||
"\n", | ||
"DATA_DIR = \"./results\" # For local mode, unset for now\n", | ||
"\n", | ||
"\n", | ||
"INITIAL_IMAGE_PATH = \"./seed_images/test_wikipedia_image_20231213.png\"\n", | ||
"UNIQUE_IMAGE_NAME = \"test_wikipedia_image_20231213\"\n", | ||
"NUM_ITERATIONS = 3\n", | ||
"\n", | ||
"DESCRIPTIVENESS = \"obsessively\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1dae8592-3f87-4cb4-8cd6-4aeff9eed7b8", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"if STORAGE_ENGINE == \"s3\":\n", | ||
" assert S3_BUCKET is not None, \"Must provide S3_BUCKET to use S3\"\n", | ||
"\n", | ||
"if STORAGE_ENGINE == \"local\":\n", | ||
" assert DATA_DIR is not None, \"Must provide data directory for results when using local mode\"\n", | ||
" \n", | ||
"BASE_SAVE_LOCATION = os.path.join(DATA_DIR, UNIQUE_IMAGE_NAME) if STORAGE_ENGINE == \"local\" else os.path.join(f\"s3://{S3_BUCKET}/{UNIQUE_IMAGE_NAME}\")\n", | ||
"\n", | ||
"if STORAGE_ENGINE == \"local\":\n", | ||
" if not os.path.exists(BASE_SAVE_LOCATION):\n", | ||
" os.makedirs(BASE_SAVE_LOCATION)\n", | ||
" " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bf9aade2-28db-4373-bf51-6409e598965a", | ||
"metadata": {}, | ||
"source": [ | ||
"# Pull dataflows from the Hub\n", | ||
"\n", | ||
"These two dataflows have everything we need to play image telephone. We're going to download two dataflows:\n", | ||
"\n", | ||
"1. `caption_images` -- this has the ability to provide a caption given an image\n", | ||
"2. `generate_images` -- this has the ability to generate an image, given a caption\n", | ||
"\n", | ||
"We use the hub API to download the modules, then do a quick visualization to ensure we're happy with what we've got. We've combined these into the same driver, although one could easily run two drivers. The DAG's are actually independent" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "51534640-431e-432f-a470-0ca604916d7d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from hamilton import dataflows\n", | ||
"caption_images = dataflows.import_module(\"caption_images\", \"elijahbenizzy\")\n", | ||
"generate_images = dataflows.import_module(\"generate_images\", \"elijahbenizzy\")\n", | ||
"import caption_images\n", | ||
"import generate_images\n", | ||
"dr = driver.Driver({\"include_embeddings\" : True}, caption_images, generate_images)\n", | ||
"dr.display_all_functions(orient=\"TB\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f3bc4ef0-b86c-4a91-b5c6-7bc1e4400741", | ||
"metadata": {}, | ||
"source": [ | ||
"# Define our Capabilities (chains)\n", | ||
"\n", | ||
"We define some pretty basic functions that allow us to run components of the DAG. We'll be running these in a loop, displaying the results in-between to track progress. We do two calls to `.materialize(...)` -- this allows us to run/execute the DAG.\n", | ||
"\n", | ||
"1. Generate captions\n", | ||
"2. Generate images\n", | ||
"\n", | ||
"We then update the state, and run again!" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "5b808f9b-f8e5-4128-b64a-60f132c3c8c2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# This allows execution to start where it left off\n", | ||
"iteration, image_url, has_original = determine_state(\n", | ||
" INITIAL_IMAGE_PATH,\n", | ||
" STORAGE_ENGINE,\n", | ||
" UNIQUE_IMAGE_NAME,\n", | ||
" {\"base_dir\": DATA_DIR, \"s3_bucket\": S3_BUCKET}\n", | ||
")\n", | ||
"\n", | ||
"# Loop until we're there\n", | ||
"while iteration < NUM_ITERATIONS:\n", | ||
" print(f\" Beginning iteration: {iteration} with image URL: {image_url}\")\n", | ||
" metadata_save_path = os.path.join(BASE_SAVE_LOCATION, f\"metadata_{iteration}.json\")\n", | ||
" # Run the caption generation step\n", | ||
" _, results = dr.materialize(\n", | ||
" to.json(\n", | ||
" path=metadata_save_path,\n", | ||
" dependencies=[\"metadata\"],\n", | ||
" id=\"save_metadata\",\n", | ||
" ),\n", | ||
" *([] if has_original else [\n", | ||
" to.image(\n", | ||
" path=os.path.join(BASE_SAVE_LOCATION, f\"{UNIQUE_IMAGE_NAME}/original.png\"),\n", | ||
" dependencies=[\"image_url\"],\n", | ||
" id=f\"save_original_image\",\n", | ||
" format=\"png\",\n", | ||
" )\n", | ||
" ]),\n", | ||
" additional_vars=[\"generated_caption\"],\n", | ||
" inputs={\n", | ||
" \"image_url\" : image_url,\n", | ||
" \"descriptiveness\" : DESCRIPTIVENESS,\n", | ||
" \"additional_metadata\" : {\n", | ||
" \"descriptiveness\" : DESCRIPTIVENESS,\n", | ||
" \"iteration\" : iteration,\n", | ||
" }\n", | ||
" }\n", | ||
" )\n", | ||
"\n", | ||
" generated_caption = results[\"generated_caption\"]\n", | ||
" print(f\"Captioned image: {image_url} with caption: {generated_caption}. \\n\\n Saved metadata (caption + embeddings) at: {metadata_save_path}\")\n", | ||
" image_save_path = os.path.join(BASE_SAVE_LOCATION, f\"image_{iteration}.png\")\n", | ||
"\n", | ||
" # Run the image generation step\n", | ||
" _, results = dr.materialize(\n", | ||
" to.image(\n", | ||
" path=image_save_path,\n", | ||
" dependencies=[\"generated_image\"],\n", | ||
" id=f\"save_image\",\n", | ||
" format=\"png\",\n", | ||
" ),\n", | ||
" inputs={\"image_generation_prompt\" : generated_caption},\n", | ||
" additional_vars=[\"generated_image\"]\n", | ||
" )\n", | ||
" generated_caption = results[\"generated_caption\"]\n", | ||
" generated_image = results[\"generated_image\"]\n", | ||
" print(f\"Generated image, saved at: {image_save_path}\")\n", | ||
" iteration += 1\n", | ||
" image_url = image_save_path\n", | ||
" has_original = True\n", | ||
" with open(generated_image) as url:\n", | ||
" img = Image.open(BytesIO(url.read()))\n", | ||
" display(img)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b60679f8-0485-47a7-b7ce-63c3b26b6425", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
altair | ||
boto3 | ||
openai | ||
opentsne | ||
pandas | ||
pillow | ||
s3fs | ||
scikit-learn | ||
sf-hamilton | ||
sf-hamilton-contrib | ||
st-files-connection | ||
streamlit | ||
streamlit-super-slider | ||
tenacity |
Oops, something went wrong.