Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from copy import deepcopy | |
| import shutil | |
| import os | |
| from datetime import datetime | |
| import time | |
| import uuid | |
| import subprocess | |
| import gradio as gr | |
| import yaml | |
| import torch.multiprocessing as mp | |
| mp.set_start_method('spawn', force=True) | |
| from mm_story_agent import MMStoryAgent | |
| os.system("cp policy.xml /etc/ImageMagick-6/") | |
| with open("configs/mm_story_agent.yaml", "r") as reader: | |
| config = yaml.load(reader, Loader=yaml.FullLoader) | |
| default_story_setting = config["story_setting"] | |
| default_story_gen_config = config["story_gen_config"] | |
| default_slideshow_effect = config["slideshow_effect"] | |
| default_image_config = config["image_generation"] | |
| default_sound_config = config["sound_generation"] | |
| default_music_config = config["music_generation"] | |
| def set_generating_progress_text(text): | |
| return gr.update(visible=True, value=f"<h3>{text}</h3>") | |
| def set_text_invisible(): | |
| return gr.update(visible=False) | |
| def deep_update(original, updates): | |
| for key, value in updates.items(): | |
| if isinstance(value, dict): | |
| original[key] = deep_update(original.get(key, {}), value) | |
| else: | |
| original[key] = value | |
| return original | |
| def update_page(direction, page, story_data): | |
| orig_page = page | |
| if direction == 'next' and page < len(story_data) - 1: | |
| page = orig_page + 1 | |
| elif direction == 'prev' and page > 0: | |
| page = orig_page - 1 | |
| return page, story_data[page], story_data | |
| def write_story_fn(story_topic, main_role, scene, | |
| num_outline, temperature, | |
| current_page, | |
| config, | |
| progress=gr.Progress(track_tqdm=True)): | |
| config["story_dir"] = f"generated_stories/{time.strftime('%Y%m%d-%H%M%S') + '-' + str(uuid.uuid1().hex)}" | |
| current_date = datetime.now() | |
| if Path("generated_stories").exists(): | |
| for story_dir in Path("generated_stories").iterdir(): | |
| story_date = story_dir.name[:8] | |
| story_date = datetime.strptime(story_date, '%Y%m%d') | |
| date_difference = current_date - story_date | |
| if date_difference.days >= 2: | |
| shutil.rmtree(story_dir) | |
| deep_update(config, { | |
| "story_setting": { | |
| "story_topic": story_topic, | |
| "main_role": main_role, | |
| "scene": scene, | |
| }, | |
| "story_gen_config": { | |
| "num_outline": num_outline, | |
| "temperature": temperature | |
| }, | |
| }) | |
| story_gen_agent = MMStoryAgent() | |
| pages = story_gen_agent.write_story(config) | |
| # story_data, story_accordion, story_content | |
| return pages, gr.update(visible=True), pages[current_page], gr.update() | |
| def modality_assets_generation_fn( | |
| height, width, image_seed, sound_guidance_scale, sound_seed, | |
| n_candidate_per_text, | |
| config, | |
| story_data): | |
| deep_update(config, { | |
| "image_generation": { | |
| "obj_cfg": { | |
| "height": height, | |
| "width": width, | |
| }, | |
| "call_cfg": { | |
| "seed": image_seed | |
| } | |
| }, | |
| "sound_generation": { | |
| "call_cfg": { | |
| "guidance_scale": sound_guidance_scale, | |
| "seed": sound_seed, | |
| "n_candidate_per_text": n_candidate_per_text | |
| } | |
| }, | |
| }) | |
| story_gen_agent = MMStoryAgent() | |
| images = story_gen_agent.generate_modality_assets(config, story_data) | |
| # image gallery | |
| return gr.update(visible=True, value=images, columns=[len(images)], rows=[1], height="auto") | |
| def compose_storytelling_video_fn( | |
| fade_duration, slide_duration, zoom_speed, move_ratio, | |
| sound_volume, music_volume, bg_speech_ratio, fps, | |
| config, | |
| story_data, | |
| progress=gr.Progress(track_tqdm=True)): | |
| deep_update(config, { | |
| "slideshow_effect": { | |
| "fade_duration": fade_duration, | |
| "slide_duration": slide_duration, | |
| "zoom_speed": zoom_speed, | |
| "move_ratio": move_ratio, | |
| "sound_volume": sound_volume, | |
| "music_volume": music_volume, | |
| "bg_speech_ratio": bg_speech_ratio, | |
| "fps": fps | |
| }, | |
| }) | |
| story_gen_agent = MMStoryAgent() | |
| story_gen_agent.compose_storytelling_video(config, story_data) | |
| # video_output | |
| return Path(config["story_dir"]) / "output.mp4" | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <h1 style="text-align: center;">MM-StoryAgent</h1> | |
| <div style="padding: 10px; background-color: #fffbcc; border: 1px solid #ffe564; border-radius:4px;"> | |
| <strong>Note: </strong>If generated images can be previewed but the video generation fails, it is due to AliYun SDK token expiration issue. Please contact <a href="mailto:lcl193798@alibaba-inc.com">lcl193798@alibaba-inc.com</a>. | |
| </div> | |
| <p style="font-size: 16px;">This is a demo for generating attractive storytelling videos based on the given story setting.</p> | |
| <p style="font-size: 16px;">Depending on the chapter number, the generation may take a long time. Please be patient.</p> | |
| """) | |
| config = gr.State(deepcopy(config)) | |
| with gr.Row(): | |
| with gr.Column(): | |
| story_topic = gr.Textbox(label="Story Topic", value=default_story_setting["story_topic"]) | |
| main_role = gr.Textbox(label="Main Role", value=default_story_setting["main_role"]) | |
| scene = gr.Textbox(label="Scene", value=default_story_setting["scene"]) | |
| chapter_num = gr.Number(label="Chapter Number", value=default_story_gen_config["num_outline"]) | |
| temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Temperature", value=default_story_gen_config["temperature"]) | |
| with gr.Accordion("Detailed Image Configuration (Optional)", open=False): | |
| height = gr.Slider(label="Height", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['height']) | |
| width = gr.Slider(label="Width", minimum=256, maximum=1024, step=32, value=default_image_config["obj_cfg"]['width']) | |
| image_seed = gr.Number(label="Image Seed", value=default_image_config["call_cfg"]['seed']) | |
| with gr.Accordion("Detailed Sound Configuration (Optional)", open=False): | |
| sound_guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=7.0, step=0.5, value=default_sound_config["call_cfg"]['guidance_scale']) | |
| sound_seed = gr.Number(label="Sound Seed", value=default_sound_config["call_cfg"]['seed']) | |
| n_candidate_per_text = gr.Slider(label="Number of Candidates per Text", minimum=0, maximum=5, step=1, value=default_sound_config["call_cfg"]['n_candidate_per_text']) | |
| with gr.Accordion("Detailed Slideshow Effect (Optional)", open=False): | |
| fade_duration = gr.Slider(label="Fade Duration", minimum=0.1, maximum=1.5, step=0.1, value=default_slideshow_effect['fade_duration']) | |
| slide_duration = gr.Slider(label="Slide Duration", minimum=0.1, maximum=1.0, step=0.1, value=default_slideshow_effect['slide_duration']) | |
| zoom_speed = gr.Slider(label="Zoom Speed", minimum=0.1, maximum=2.0, step=0.1, value=default_slideshow_effect['zoom_speed']) | |
| move_ratio = gr.Slider(label="Move Ratio", minimum=0.8, maximum=1.0, step=0.05, value=default_slideshow_effect['move_ratio']) | |
| sound_volume = gr.Slider(label="Sound Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['sound_volume']) | |
| music_volume = gr.Slider(label="Music Volume", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['music_volume']) | |
| bg_speech_ratio = gr.Slider(label="Background / Speech Ratio", minimum=0.0, maximum=1.0, step=0.1, value=default_slideshow_effect['bg_speech_ratio']) | |
| fps = gr.Slider(label="FPS", minimum=1, maximum=30, step=1, value=default_slideshow_effect['fps']) | |
| with gr.Column(): | |
| story_data = gr.State([]) | |
| story_generation_information = gr.Markdown( | |
| label="Story Generation Status", | |
| value="<h3>Generating Story Script ......</h3>", | |
| visible=False) | |
| with gr.Accordion(label="Story Content", open=False, visible=False) as story_accordion: | |
| with gr.Row(): | |
| prev_button = gr.Button("Previous Page",) | |
| next_button = gr.Button("Next Page",) | |
| story_content = gr.Textbox(label="Page Content") | |
| video_generation_information = gr.Markdown(label="Generation Status", value="<h3>Generating Video ......</h3>", visible=False) | |
| image_gallery = gr.Gallery(label="Images", show_label=False, visible=False) | |
| video_generation_btn = gr.Button("Generate Video") | |
| video_output = gr.Video(label="Generated Story", interactive=False) | |
| current_page = gr.State(0) | |
| prev_button.click( | |
| fn=update_page, | |
| inputs=[gr.State("prev"), current_page, story_data], | |
| outputs=[current_page, story_content] | |
| ) | |
| next_button.click( | |
| fn=update_page, | |
| inputs=[gr.State("next"), current_page, story_data], | |
| outputs=[current_page, story_content,]) | |
| # (possibly) update role description and scripts | |
| video_generation_btn.click( | |
| fn=set_generating_progress_text, | |
| inputs=[gr.State("Generating Story ...")], | |
| outputs=video_generation_information | |
| ).then( | |
| fn=write_story_fn, | |
| inputs=[story_topic, main_role, scene, | |
| chapter_num, temperature, | |
| current_page, | |
| config | |
| ], | |
| outputs=[story_data, story_accordion, story_content, video_output] | |
| ).then( | |
| fn=set_generating_progress_text, | |
| inputs=[gr.State("Generating Modality Assets ...")], | |
| outputs=video_generation_information | |
| ).then( | |
| fn=modality_assets_generation_fn, | |
| inputs=[height, width, image_seed, sound_guidance_scale, sound_seed, | |
| n_candidate_per_text, | |
| config, | |
| story_data], | |
| outputs=[image_gallery] | |
| ).then( | |
| fn=set_generating_progress_text, | |
| inputs=[gr.State("Composing Video ...")], | |
| outputs=video_generation_information | |
| ).then( | |
| fn=compose_storytelling_video_fn, | |
| inputs=[fade_duration, slide_duration, zoom_speed, move_ratio, | |
| sound_volume, music_volume, bg_speech_ratio, fps, | |
| config, | |
| story_data], | |
| outputs=[video_output] | |
| ).then( | |
| fn=lambda : gr.update(visible=False), | |
| inputs=[], | |
| outputs=[image_gallery] | |
| ).then( | |
| fn=set_generating_progress_text, | |
| inputs=[gr.State("Generation Finished!")], | |
| outputs=video_generation_information | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |