Skip to content

Commit

Permalink
fixed generate image
Browse files Browse the repository at this point in the history
  • Loading branch information
yoheinakajima committed Jan 24, 2025
1 parent f1ba4f7 commit 639e878
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
4 changes: 2 additions & 2 deletions my_digital_being/activities/activity_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
@activity(
name="draw",
energy_cost=0.6,
cooldown=7200, # 2 hours
cooldown=1000, # 2 hours
required_skills=["image_generation"],
)
class DrawActivity(ActivityBase):
def __init__(self):
super().__init__()
self.default_size = (512, 512)
self.default_size = (1024, 1024)
self.default_format = "png"

async def execute(self, shared_data) -> ActivityResult:
Expand Down
47 changes: 41 additions & 6 deletions my_digital_being/skills/skill_generate_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import logging
from typing import Dict, Any, Tuple
import random
import os
import openai
from openai import OpenAI
import asyncio
from framework.api_management import api_manager

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,7 +42,7 @@ async def can_generate(self) -> bool:
return True

async def generate_image(
self, prompt: str, size: Tuple[int, int] = (512, 512), format: str = "png"
self, prompt: str, size: Tuple[int, int] = (1024, 1024), format: str = "png"
) -> Dict[str, Any]:
"""Generate an image based on the prompt."""
if not await self.can_generate():
Expand All @@ -59,19 +63,48 @@ async def generate_image(
logger.error(error_msg)
return {"success": False, "error": error_msg}

# In a real implementation, this would use the API key to call the image generation service
logger.info(f"Generating image for prompt: {prompt}")
# Configure OpenAI with the retrieved API key
os.environ["OPENAI_API_KEY"] = api_key

client = OpenAI()

# Map the size tuple to OpenAI's expected string format
size_str = f"{size[0]}x{size[1]}"

logger.info(f"Generating image for prompt: {prompt} with size {size_str}")

# As OpenAI's library is synchronous, run it in a separate thread to avoid blocking
loop = asyncio.get_event_loop()
print(prompt)
print(size_str)
response = await loop.run_in_executor(
None,
lambda: client.images.generate(
model="dall-e-3",
prompt=prompt,
n=1,
size=size_str,
response_format="url", # You can change to "b64_json" if needed
),
)

# Extract the image URL from the response
image_url = response.data[0].url

# Increment counter only on successful generation
self.generations_count += 1

# Simulate generation result
# Generate a seed and generation_id for consistency with previous structure
seed = random.randint(1000, 9999)
generation_id = f"gen_{self.generations_count}"

image_data = {
"width": size[0],
"height": size[1],
"format": format,
"seed": random.randint(1000, 9999),
"generation_id": f"gen_{self.generations_count}",
"seed": seed,
"generation_id": generation_id,
"url": image_url, # Including the actual image URL from OpenAI
}

return {
Expand All @@ -82,6 +115,8 @@ async def generate_image(
"generation_number": self.generations_count,
},
}


except Exception as e:
logger.error(f"Failed to generate image: {e}")
return {"success": False, "error": str(e)}
Expand Down

0 comments on commit 639e878

Please sign in to comment.