Hekaya7 / models /image_generation.py
XA7's picture
First
ab3578e
import io
import base64
import os
from PIL import Image
import config
from openai import OpenAI
import warnings
import time
from google.generativeai import GenerativeModel
from datetime import datetime
warnings.filterwarnings("ignore", message="IMAGE_SAFETY is not a valid FinishReason")
global_image_data_url = None
global_image_prompt = None
global_image_description = None
def log_execution(func):
def wrapper(*args, **kwargs):
start_time = time.time()
start_str = datetime.fromtimestamp(start_time).strftime('%Y-%m-%d %H:%M:%S')
result = func(*args, **kwargs)
end_time = time.time()
end_str = datetime.fromtimestamp(end_time).strftime('%Y-%m-%d %H:%M:%S')
duration = end_time - start_time
return result
return wrapper
@log_execution
def generate_image_fn_deprecated (selected_prompt, model="gpt-image-1", output_path="models\benchmark"):
"""
Generate an image from the prompt via the OpenAI API using gpt-image-1.
Convert the image to a data URL and optionally save it to a file.
Args:
selected_prompt (str): The prompt to generate the image from.
model (str): Should be "gpt-image-1". Parameter kept for compatibility.
output_path (str, optional): If provided, saves the image to this path. Defaults to None.
Returns:
PIL.Image.Image or None: The generated image as a PIL Image object, or None on error.
"""
global global_image_data_url, global_image_prompt
MAX_PROMPT_LENGTH = 32000
if len(selected_prompt) > MAX_PROMPT_LENGTH:
selected_prompt = smart_truncate_prompt(selected_prompt, MAX_PROMPT_LENGTH)
print(f"Warning: Prompt was smartly truncated to {len(selected_prompt)} characters while preserving critical details")
global_image_prompt = selected_prompt
model = "gpt-image-1"
try:
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY", config.OPENAI_API_KEY))
api_params = {
"model": model,
"prompt": selected_prompt,
"size": "1024x1536" ,
"quality": "high",
"moderation":"low"
}
result = client.images.generate(**api_params)
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
if output_path:
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "wb") as f:
f.write(image_bytes)
print(f"Successfully saved image to {output_path}")
except Exception as e:
print(f"Error saving image to {output_path}: {str(e)}")
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
global_image_data_url = f"data:image/png;base64,{img_b64}"
print(f"Successfully generated image with prompt: {selected_prompt[:50]}...")
return image
except Exception as e:
print(f"Error generating image: {str(e)}")
return None
@log_execution
def generate_image_fn(selected_prompt, model="gemini-2.5-flash-image-preview", output_path="models/benchmark"):
"""
Generate an image from the prompt via the Google Gemini API using vertexai.
Convert the image to a data URL and optionally save it to a file.
Args:
selected_prompt (str): The prompt to generate the image from.
model (str): The Gemini model to use. Defaults to "gemini-2.5-flash-image-preview".
output_path (str, optional): If provided, saves the image to this path. Defaults to "models/benchmark".
Returns:
PIL.Image.Image or None: The generated image as a PIL Image object, or None on error.
"""
global global_image_data_url, global_image_prompt
MAX_PROMPT_LENGTH = 32000
if len(selected_prompt) > MAX_PROMPT_LENGTH:
selected_prompt = smart_truncate_prompt(selected_prompt, MAX_PROMPT_LENGTH)
print(f"Warning: Prompt was smartly truncated to {len(selected_prompt)} characters while preserving critical details")
global_image_prompt = selected_prompt
try:
from google.generativeai import GenerativeModel
from PIL import Image
import io
import base64
import os
# Initialize the Gemini model
gemini_model = GenerativeModel(model)
# Generate content with the prompt
response = gemini_model.generate_content([selected_prompt])
# Extract the generated image from the response
image = None
image_bytes = None
has_text_response = False
for part in response.candidates[0].content.parts:
# Check for text responses (ignore these)
if hasattr(part, 'text') and part.text:
has_text_response = True
print(f"Ignoring text response from API: {part.text[:100]}...")
continue
# Look for image data
if hasattr(part, 'inline_data') and part.inline_data is not None:
image_bytes = part.inline_data.data
# Verify we have valid data
if not image_bytes or len(image_bytes) == 0:
print("Warning: inline_data.data is empty, skipping...")
continue
# Try to parse the image
try:
img_io = io.BytesIO(image_bytes)
image = Image.open(img_io)
image.load() # Force load to verify it's valid
print(f"Successfully loaded image: {len(image_bytes)} bytes")
break
except Exception as img_error:
print(f"Invalid image data received, skipping: {img_error}")
continue
# If we only got text and no image, return None
if image is None:
if has_text_response:
print("API returned text instead of image - skipping this response")
else:
print("No image data found in response")
return None
# Save image to file if output_path is provided
if output_path:
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Ensure output_path has an image extension
if not output_path.lower().endswith(('.png', '.jpg', '.jpeg')):
output_path = f"{output_path}.png"
image.save(output_path)
print(f"Successfully saved image to {output_path}")
except Exception as e:
print(f"Error saving image to {output_path}: {str(e)}")
# Create data URL for the image
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_b64 = base64.b64encode(img_bytes).decode("utf-8")
global_image_data_url = f"data:image/png;base64,{img_b64}"
print(f"Successfully generated image with prompt: {selected_prompt[:50]}...")
return image
except Exception as e:
print(f"Error generating image: {str(e)}")
import traceback
traceback.print_exc()
return None
@log_execution
def smart_truncate_prompt(prompt, max_length):
"""
Smart truncation that preserves critical details and visual consistency information.
Prioritizes character descriptions, layout specifications, and technical requirements.
"""
if len(prompt) <= max_length:
return prompt
critical_sections = [
"CRITICAL LAYOUT:",
"🎭 CRITICAL CHARACTER CONSISTENCY PROTOCOL:",
"CHARACTER 1",
"CHARACTER 2",
"CHARACTER 3",
"STORY CONTENT:",
"πŸ—οΈ ENVIRONMENTAL CONSISTENCY PROTOCOL:",
"🎨 COMIC BOOK STYLE MASTERY:",
"🎨 AUTHENTIC MANGA STYLE:",
"🎨 PHOTOREALISTIC EXCELLENCE:",
"🎨 CINEMATIC VISUAL MASTERY:",
"🎨 HIGH-QUALITY ILLUSTRATION:",
"πŸ“ PANEL COMPOSITION MASTERY:",
"πŸ” DETAIL PRESERVATION PROTOCOL:",
"⚑ ADVANCED QUALITY REQUIREMENTS:"
]
sections = prompt.split(" || ")
preserved_sections = []
preserved_length = 0
for section in sections:
section_trimmed = section.strip()
if not section_trimmed:
continue
is_critical = any(critical_marker in section_trimmed for critical_marker in critical_sections[:8])
if is_critical or (preserved_length + len(section_trimmed) + 4 < max_length - 200):
preserved_sections.append(section_trimmed)
preserved_length += len(section_trimmed) + 4
elif preserved_length < max_length * 0.7:
available_space = max_length - preserved_length - 200
if available_space > 100:
truncated_section = section_trimmed[:available_space-20] + "..."
preserved_sections.append(truncated_section)
break
preserved_prompt = " || ".join(preserved_sections)
final_mandate = " || FINAL MANDATE: Create a masterpiece with perfect character consistency and narrative clarity"
if len(preserved_prompt) + len(final_mandate) <= max_length:
preserved_prompt += final_mandate
return preserved_prompt