Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| import os | |
| import sys | |
| import tempfile | |
| import zipfile | |
| import json | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import uuid | |
| from datetime import datetime | |
| import traceback | |
| import gradio as gr | |
| project_root = Path(__file__).parent | |
| sys.path.insert(0, str(project_root)) | |
| try: | |
| from src.state.poster_state import create_state | |
| from src.workflow.pipeline import create_workflow_graph | |
| except ImportError as e: | |
| print(f"Error importing modules: {e}") | |
| sys.exit(1) | |
| def set_temp_api_keys(anthropic_key, openai_key, anthropic_base_url=None, openai_base_url=None): | |
| """Temporarily set API keys and base URLs in environment, returns cleanup function""" | |
| original_values = {} | |
| # Save original values and set new ones | |
| if anthropic_key and anthropic_key.strip(): | |
| original_values["ANTHROPIC_API_KEY"] = os.environ.get("ANTHROPIC_API_KEY") | |
| os.environ["ANTHROPIC_API_KEY"] = anthropic_key.strip() | |
| if openai_key and openai_key.strip(): | |
| original_values["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY") | |
| os.environ["OPENAI_API_KEY"] = openai_key.strip() | |
| if anthropic_base_url and anthropic_base_url.strip(): | |
| original_values["ANTHROPIC_BASE_URL"] = os.environ.get("ANTHROPIC_BASE_URL") | |
| os.environ["ANTHROPIC_BASE_URL"] = anthropic_base_url.strip() | |
| if openai_base_url and openai_base_url.strip(): | |
| original_values["OPENAI_BASE_URL"] = os.environ.get("OPENAI_BASE_URL") | |
| os.environ["OPENAI_BASE_URL"] = openai_base_url.strip() | |
| def cleanup(): | |
| """Restore original environment values""" | |
| for key, original_value in original_values.items(): | |
| if original_value is None: | |
| os.environ.pop(key, None) | |
| else: | |
| os.environ[key] = original_value | |
| return cleanup | |
| AVAILABLE_MODELS = [ | |
| "claude-sonnet-4-20250514", | |
| "gpt-4o-2024-08-06", | |
| "gpt-4.1-2025-04-14", | |
| "gpt-4.1-mini-2025-04-14" | |
| ] | |
| def create_job_directory() -> Path: | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| job_id = str(uuid.uuid4())[:8] | |
| dir_name = f"job_{timestamp}_{job_id}" | |
| job_dir = Path(tempfile.mkdtemp(prefix=f"{dir_name}_")) | |
| return job_dir | |
| def validate_inputs(pdf_file, logo_file, aff_logo_file, text_model, vision_model, poster_width, poster_height, anthropic_key, openai_key, anthropic_base_url, openai_base_url): | |
| if not pdf_file: | |
| return "Please upload PDF paper" | |
| if not logo_file: | |
| return "Please upload conference logo" | |
| if not aff_logo_file: | |
| return "Please upload affiliation logo" | |
| if text_model not in AVAILABLE_MODELS: | |
| return f"Invalid text model: {text_model}" | |
| if vision_model not in AVAILABLE_MODELS: | |
| return f"Invalid vision model: {vision_model}" | |
| # Check API keys | |
| has_anthropic = bool(anthropic_key and anthropic_key.strip()) | |
| has_openai = bool(openai_key and openai_key.strip()) | |
| if not has_anthropic and not has_openai: | |
| return "Please provide at least one API key (Anthropic or OpenAI)" | |
| # Check if selected models have corresponding API keys | |
| if text_model.startswith("claude") and not has_anthropic: | |
| return "Anthropic API key required for Claude models" | |
| if text_model.startswith("gpt") and not has_openai: | |
| return "OpenAI API key required for GPT models" | |
| if vision_model.startswith("claude") and not has_anthropic: | |
| return "Anthropic API key required for Claude models" | |
| if vision_model.startswith("gpt") and not has_openai: | |
| return "OpenAI API key required for GPT models" | |
| ratio = poster_width / poster_height | |
| if ratio < 1.4 or ratio > 2.0: | |
| return f"Poster ratio {ratio:.2f} out of range (1.4-2.0)" | |
| # Check file type - Gradio returns file object with name attribute | |
| if hasattr(pdf_file, 'name') and not pdf_file.name.lower().endswith('.pdf'): | |
| return "Paper must be PDF format" | |
| return None | |
| def generate_poster(pdf_file, logo_file, aff_logo_file, text_model, vision_model, poster_width, poster_height, anthropic_key, openai_key, anthropic_base_url, openai_base_url, progress=gr.Progress()): | |
| try: | |
| # Set API keys temporarily | |
| cleanup_api_keys = set_temp_api_keys(anthropic_key, openai_key, anthropic_base_url, openai_base_url) | |
| error_msg = validate_inputs(pdf_file, logo_file, aff_logo_file, text_model, vision_model, poster_width, poster_height, anthropic_key, openai_key, anthropic_base_url, openai_base_url) | |
| if error_msg: | |
| cleanup_api_keys() | |
| return None, f"β {error_msg}" | |
| progress(0.1, desc="Initializing...") | |
| job_dir = create_job_directory() | |
| pdf_path = job_dir / "paper.pdf" | |
| logo_path = job_dir / "logo.png" | |
| aff_logo_path = job_dir / "aff_logo.png" | |
| # Handle file writing - check if it's file object or bytes | |
| if hasattr(pdf_file, 'read'): | |
| pdf_content = pdf_file.read() | |
| else: | |
| pdf_content = pdf_file | |
| if hasattr(logo_file, 'read'): | |
| logo_content = logo_file.read() | |
| else: | |
| logo_content = logo_file | |
| if hasattr(aff_logo_file, 'read'): | |
| aff_logo_content = aff_logo_file.read() | |
| else: | |
| aff_logo_content = aff_logo_file | |
| with open(pdf_path, "wb") as f: | |
| f.write(pdf_content) | |
| with open(logo_path, "wb") as f: | |
| f.write(logo_content) | |
| with open(aff_logo_path, "wb") as f: | |
| f.write(aff_logo_content) | |
| progress(0.2, desc="Setting up workflow...") | |
| state = create_state( | |
| pdf_path=str(pdf_path), | |
| text_model=text_model, | |
| vision_model=vision_model, | |
| width=int(poster_width), | |
| height=int(poster_height), | |
| url="", | |
| logo_path=str(logo_path), | |
| aff_logo_path=str(aff_logo_path) | |
| ) | |
| progress(0.3, desc="Compiling workflow...") | |
| graph = create_workflow_graph() | |
| workflow = graph.compile() | |
| progress(0.5, desc="Processing paper...") | |
| final_state = workflow.invoke(state) | |
| progress(0.8, desc="Generating outputs...") | |
| if final_state.get("errors"): | |
| error_details = "; ".join(final_state["errors"]) | |
| cleanup_api_keys() | |
| return None, f"β Generation errors: {error_details}" | |
| output_dir = Path(final_state["output_dir"]) | |
| poster_name = final_state.get("poster_name", "poster") | |
| progress(0.9, desc="Packaging results...") | |
| zip_path = job_dir / f"{poster_name}_output.zip" | |
| with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: | |
| if output_dir.exists(): | |
| for file_path in output_dir.rglob("*"): | |
| if file_path.is_file(): | |
| arcname = file_path.relative_to(output_dir) | |
| zipf.write(file_path, arcname) | |
| progress(1.0, desc="β Completed!") | |
| success_msg = f"""β Poster generation successful! | |
| Poster: {poster_name} | |
| Output: {output_dir.name} | |
| Package: {zip_path.name} | |
| Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}""" | |
| cleanup_api_keys() | |
| return str(zip_path), success_msg | |
| except Exception as e: | |
| cleanup_api_keys() | |
| error_msg = f"β Error: {str(e)}\n\n{traceback.format_exc()}" | |
| return None, error_msg | |
| with gr.Blocks(title="PosterGen", css=""" | |
| .gradio-column { | |
| margin-left: 10px !important; | |
| margin-right: 10px !important; | |
| } | |
| .gradio-column:first-child { | |
| margin-left: 0 !important; | |
| } | |
| .gradio-column:last-child { | |
| margin-right: 0 !important; | |
| } | |
| """) as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 30px;"> | |
| <h1 style="margin-bottom: 10px;"> | |
| PosterGen | |
| </h1> | |
| <p style="font-size: 18px; color: #666;">π¨ Generate design-aware academic posters from PDF papers</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, variant="panel"): | |
| gr.Markdown("### π Upload Files") | |
| pdf_file = gr.File(label="PDF Paper", file_types=[".pdf"], type="binary") | |
| with gr.Row(): | |
| logo_file = gr.File(label="Conference Logo", file_types=["image"], type="binary") | |
| aff_logo_file = gr.File(label="Affiliation Logo", file_types=["image"], type="binary") | |
| with gr.Column(scale=1, variant="panel"): | |
| with gr.Group(): | |
| gr.Markdown("### π API Keys") | |
| gr.Markdown("β οΈ Keys are processed securely and not stored") | |
| with gr.Row(): | |
| anthropic_key = gr.Textbox( | |
| label="Anthropic API Key", | |
| type="password", | |
| placeholder="sk-ant-...", | |
| info="Required for Claude models" | |
| ) | |
| openai_key = gr.Textbox( | |
| label="OpenAI API Key", | |
| type="password", | |
| placeholder="sk-...", | |
| info="Required for GPT models" | |
| ) | |
| with gr.Row(): | |
| anthropic_base_url = gr.Textbox( | |
| label="Anthropic Base URL (Optional)", | |
| placeholder="https://api.anthropic.com", | |
| info="Set the base url for compatible API services" | |
| ) | |
| openai_base_url = gr.Textbox( | |
| label="OpenAI Base URL (Optional)", | |
| placeholder="https://api.openai.com/v1", | |
| info="Set the base url for compatible API services" | |
| ) | |
| gr.Markdown("### π€ Model Settings") | |
| with gr.Row(): | |
| text_model = gr.Dropdown(choices=AVAILABLE_MODELS, value=AVAILABLE_MODELS[0], label="Text Model") | |
| vision_model = gr.Dropdown(choices=AVAILABLE_MODELS, value=AVAILABLE_MODELS[0], label="Vision Model") | |
| gr.Markdown("### π Dimensions") | |
| with gr.Row(): | |
| poster_width = gr.Number(value=54, minimum=20, maximum=100, step=0.1, label="Width (inches)") | |
| poster_height = gr.Number(value=36, minimum=10, maximum=60, step=0.1, label="Height (inches)") | |
| with gr.Column(scale=1, variant="panel"): | |
| gr.Markdown("### π Status") | |
| status_output = gr.Textbox(label="Generation Status", placeholder="Click 'Generate Poster' to start...", lines=6) | |
| gr.Markdown("### π₯ Download") | |
| download_file = gr.File(label="Download Package") | |
| # Generate button spanning full width | |
| with gr.Row(): | |
| generate_btn = gr.Button("π Generate Poster", variant="primary", size="lg") | |
| def generate_and_display(*args): | |
| download_file_result, status_result = generate_poster(*args) | |
| return download_file_result, status_result | |
| generate_btn.click( | |
| fn=generate_and_display, | |
| inputs=[pdf_file, logo_file, aff_logo_file, text_model, vision_model, poster_width, poster_height, anthropic_key, openai_key, anthropic_base_url, openai_base_url], | |
| outputs=[download_file, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |