import base64 import copy import mimetypes from pathlib import Path from typing import Any, Dict, List, Optional import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoProcessor MODEL_ID = "baidu/ERNIE-4.5-VL-28B-A3B-Thinking" torch.backends.cuda.matmul.allow_tf32 = True processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, ) model.add_image_preprocess(processor) model.eval() DEVICE = next(model.parameters()).device def encode_image_to_data_uri(image_path: str) -> str: """ Convert an image file path to a base64 data URI string. Args: image_path (str): Path to the local image file. Returns: str: Base64 data URI representation of the image. """ mime_type, _ = mimetypes.guess_type(image_path) mime_type = mime_type or "image/png" data = Path(image_path).read_bytes() encoded = base64.b64encode(data).decode("utf-8") return f"data:{mime_type};base64,{encoded}" def _move_batch_to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: """ Move tensors (and lists of tensors) within a batch to the specified device. Args: batch (Dict[str, Any]): Processor output batch. device (torch.device): Target device. Returns: Dict[str, Any]: Batch with tensors moved to the device. """ for key, value in batch.items(): if isinstance(value, torch.Tensor): batch[key] = value.to(device) elif isinstance(value, list): batch[key] = [ v.to(device) if isinstance(v, torch.Tensor) else v for v in value ] return batch def _build_display_content( text: Optional[str], image_path: Optional[str] ) -> List[Dict[str, Any]]: """ Create the Chatbot-friendly content list for a message. Args: text (Optional[str]): User text. image_path (Optional[str]): Path to attached image. Returns: List[Dict[str, Any]]: Content list for Chatbot display. """ content: List[Dict[str, Any]] = [] if text: content.append({"type": "text", "text": text}) if image_path: content.append({"type": "image", "image": image_path}) return content def _build_model_content( text: Optional[str], image_path: Optional[str] ) -> List[Dict[str, Any]]: """ Create the model-friendly content list for a message. Args: text (Optional[str]): User text. image_path (Optional[str]): Path to attached image. Returns: List[Dict[str, Any]]: Content list for model consumption. """ content: List[Dict[str, Any]] = [] if text: content.append({"type": "text", "text": text}) if image_path: data_uri = encode_image_to_data_uri(image_path) content.append({"type": "image_url", "image_url": {"url": data_uri}}) return content def handle_user_message( user_text: str, user_image: Optional[str], chat_history: List[Dict[str, Any]], convo_state: List[Dict[str, Any]], ) -> tuple[str, Optional[str], List[Dict[str, Any]], List[Dict[str, Any]]]: """ Append the user's message (text/image) to both the display chat history and model state. Args: user_text (str): Text entered by the user. user_image (Optional[str]): Optional path to uploaded image. chat_history (List[Dict[str, Any]]): Current Chatbot history. convo_state (List[Dict[str, Any]]): Current model conversation state. Returns: tuple: Cleared text, cleared image, updated chat history, updated conversation state. """ if not user_text and not user_image: raise gr.Error("Please provide a question or attach an image before sending.") display_content = _build_display_content(user_text, user_image) model_content = _build_model_content(user_text, user_image) chat_history = chat_history + [{"role": "user", "content": display_content}] convo_state = convo_state + [{"role": "user", "content": model_content}] return "", None, chat_history, convo_state @spaces.GPU(duration=90) def generate_model_reply( messages: List[Dict[str, Any]], temperature: float, top_p: float, max_new_tokens: int, ) -> str: """ Run the ERNIE-4.5-VL model to obtain a response. Args: messages (List[Dict[str, Any]]): Conversation in model format. temperature (float): Sampling temperature. top_p (float): Nucleus sampling probability. max_new_tokens (int): Maximum tokens to generate. Returns: str: Generated assistant response text. """ convo = copy.deepcopy(messages) text_prompt = processor.tokenizer.apply_chat_template( convo, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = processor.process_vision_info(convo) batch = processor( text=[text_prompt], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) batch = _move_batch_to_device(batch, DEVICE) generated_ids = model.generate( inputs=batch["input_ids"], **batch, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, use_cache=True, pad_token_id=processor.tokenizer.pad_token_id, ) prompt_length = batch["input_ids"].shape[-1] response_ids = generated_ids[0][prompt_length:] response_text = processor.decode(response_ids, skip_special_tokens=True).strip() if not response_text: response_text = "[No response returned by the model.]" return response_text def generate_bot_response( chat_history: List[Dict[str, Any]], convo_state: List[Dict[str, Any]], temperature: float, top_p: float, max_new_tokens: int, ) -> tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """ Generate the assistant reply and append it to chat history/state. Args: chat_history (List[Dict[str, Any]]): Current display history. convo_state (List[Dict[str, Any]]): Current model conversation state. temperature (float): Sampling temperature. top_p (float): Nucleus sampling probability. max_new_tokens (int): Generation token cap. Returns: tuple: Updated chat history and conversation state. """ if not convo_state or convo_state[-1]["role"] != "user": return chat_history, convo_state response = generate_model_reply(convo_state, temperature, top_p, max_new_tokens) assistant_message = {"role": "assistant", "content": [{"type": "text", "text": response}]} chat_history = chat_history + [assistant_message] convo_state = convo_state + [assistant_message] return chat_history, convo_state def clear_conversation() -> tuple[List[Dict[str, Any]], List[Dict[str, Any]], str, Optional[str]]: """ Reset the chat history, model state, and input widgets. Returns: tuple: Cleared chat history, cleared model state, empty text, empty image. """ return [], [], "", None with gr.Blocks(fill_height=True) as demo: gr.Markdown( """ # ERNIE-4.5-VL Vision-Language Chat [Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder) Ask the model questions about images or pure text. Attach an image to explore grounded visual reasoning. """ ) chatbot = gr.Chatbot(type="messages", label="ERNIE-4.5-VL", height=520) conversation_state = gr.State([]) with gr.Row(): user_text = gr.Textbox( label="Your message", placeholder="Ask a question or describe what you want to know...", lines=2, ) user_image = gr.Image( label="Optional image", type="filepath", height=200, tool=None, ) with gr.Row(): send_button = gr.Button("Send", variant="primary") clear_button = gr.Button("Clear Conversation") with gr.Accordion("Generation Settings", open=False): temperature = gr.Slider( minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="Higher values produce more diverse outputs.", ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)", ) max_tokens = gr.Slider( minimum=64, maximum=1024, value=512, step=32, label="Max new tokens", ) send_event = send_button.click( handle_user_message, inputs=[user_text, user_image, chatbot, conversation_state], outputs=[user_text, user_image, chatbot, conversation_state], show_progress=False, ) send_event.then( generate_bot_response, inputs=[chatbot, conversation_state, temperature, top_p, max_tokens], outputs=[chatbot, conversation_state], ) user_text.submit( lambda *args: None, None, None, _js="() => document.querySelector('button.primary').click()", ) clear_button.click( clear_conversation, inputs=None, outputs=[chatbot, conversation_state, user_text, user_image], ) demo.queue().launch()