Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import json | |
| from io import BytesIO | |
| import base64 | |
| import torch | |
| from tempfile import gettempdir | |
| from os import path, makedirs, remove | |
| import models | |
| import time | |
| def get_safe_cache_dir(): | |
| try: | |
| # Thử ghi vào ~/.cache/huggingface (nếu có) | |
| default_cache = path.expanduser("~/.cache/huggingface") | |
| makedirs(default_cache, exist_ok=True) | |
| test_file = path.join(default_cache, "test_write.txt") | |
| with open(test_file, "w") as f: | |
| f.write("ok") | |
| remove(test_file) | |
| return default_cache | |
| except Exception: | |
| # Nếu lỗi (ví dụ trên HuggingFace Spaces), dùng temp | |
| return path.join(gettempdir(), "huggingface") | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CACHE_DIR = get_safe_cache_dir() | |
| AVAILABLE_MODELS = { | |
| # "TrOCR (Base Printed)": { | |
| # "id": "microsoft/trocr-base-printed", | |
| # "type": "trocr" | |
| # }, | |
| "EraX (VL-2B-V1.5)": { | |
| "id": "erax-ai/EraX-VL-2B-V1.5", | |
| "type": "erax" | |
| } | |
| } | |
| _model_cache = {} | |
| print("Using device:", DEVICE) | |
| print("Cache directory:", CACHE_DIR) | |
| def load_model(model_key): | |
| print("Processing image with model:", model_key) | |
| model_id = AVAILABLE_MODELS[model_key]["id"] | |
| model_type = AVAILABLE_MODELS[model_key]["type"] | |
| print("Model ID:", model_id, "Type:", model_type) | |
| if model_id in _model_cache: | |
| return _model_cache[model_key] | |
| if "trocr" == model_type: | |
| model = models.TrOCRModel(model_id, cache_dir=CACHE_DIR, device=DEVICE) | |
| elif "erax" == model_type: | |
| model = models.EraXModel(model_id, cache_dir=CACHE_DIR, device=DEVICE) | |
| else: | |
| raise ValueError("Unknown model") | |
| _model_cache[model_key] = model | |
| print('Load model:', model_id, ' successfully!') | |
| return model | |
| # Hàm xử lý ảnh đầu vào | |
| def gradio_process(image: Image.Image, model_key: str): | |
| if image is None: | |
| return {"error": "No image provided"} | |
| print('Received image size:', image.size) | |
| start = time.time() | |
| model = load_model(model_key) | |
| result = model.predict(image) | |
| print('Model predicted successfully!') | |
| print('Result:', result) | |
| print('Time taken for prediction:', time.time() - start) | |
| return json.dumps({ | |
| "texts": result, | |
| "image_size": { | |
| "width": image.width, | |
| "height": image.height | |
| }, | |
| "mode": image.mode, | |
| }, indent=4) | |
| # Giao diện Gradio | |
| demo = gr.Interface( | |
| fn=gradio_process, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Image"), | |
| gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), label="Chọn mô hình", value="TrOCR (Base Printed)"), | |
| # gr.Textbox(label="Prompt (chỉ dùng cho EraX)", placeholder="Ảnh này có gì?") | |
| ], | |
| outputs=gr.JSON(label="Output (Text/JSON Extract)"), | |
| title="Image to Text/JSON Extractor", | |
| description="Upload an image and extract structured text using OCR." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |