vungocthach1112 commited on
Commit
0169392
·
1 Parent(s): 23d2e0c

Create GUI for OCR app

Browse files
Files changed (4) hide show
  1. .gitignore +45 -0
  2. app.py +91 -0
  3. models.py +94 -0
  4. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .env
28
+ .venv
29
+ env.bak/
30
+ venv.bak/
31
+
32
+ # IDE
33
+ .idea/
34
+ .vscode/
35
+ *.swp
36
+ *.swo
37
+
38
+ # Streamlit
39
+ .streamlit/secrets.toml
40
+
41
+ # Logs and local files
42
+ *.log
43
+ .DS_Store
44
+ Thumbs.db
45
+ .env
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import json
4
+ from io import BytesIO
5
+ import base64
6
+ import torch
7
+ from tempfile import gettempdir
8
+ from os import path, makedirs, remove
9
+ import models
10
+
11
+ def get_safe_cache_dir():
12
+ try:
13
+ # Thử ghi vào ~/.cache/huggingface (nếu có)
14
+ default_cache = path.expanduser("~/.cache/huggingface")
15
+ makedirs(default_cache, exist_ok=True)
16
+ test_file = path.join(default_cache, "test_write.txt")
17
+ with open(test_file, "w") as f:
18
+ f.write("ok")
19
+ remove(test_file)
20
+ return default_cache
21
+ except Exception:
22
+ # Nếu lỗi (ví dụ trên HuggingFace Spaces), dùng temp
23
+ return path.join(gettempdir(), "huggingface")
24
+
25
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ CACHE_DIR = get_safe_cache_dir()
27
+ AVAILABLE_MODELS = {
28
+ "TrOCR (Base Printed)": {
29
+ "id": "microsoft/trocr-base-printed",
30
+ "type": "trocr"
31
+ },
32
+ "EraX (VL-2B-V1.5)": {
33
+ "id": "erax-ai/EraX-VL-2B-V1.5",
34
+ "type": "erax"
35
+ }
36
+ }
37
+ _model_cache = {}
38
+
39
+ print("Using device:", DEVICE)
40
+ print("Cache directory:", CACHE_DIR)
41
+
42
+ def load_model(model_key):
43
+ model_id = AVAILABLE_MODELS[model_key]["id"]
44
+ model_type = AVAILABLE_MODELS[model_key]["type"]
45
+
46
+ if model_id in _model_cache:
47
+ return _model_cache[model_key]
48
+
49
+ if "trocr" in model_type:
50
+ model = models.TrOCRModel(model_id, cache_dir=CACHE_DIR, device=DEVICE)
51
+ if "erax" in model_type:
52
+ model = models.EraXModel(model_id, cache_dir=CACHE_DIR, device=DEVICE)
53
+ else:
54
+ raise ValueError("Unknown model")
55
+
56
+ _model_cache[model_key] = model
57
+ print('Load model:', model_id, ' successfully!')
58
+ return model
59
+
60
+ # Hàm xử lý ảnh đầu vào
61
+ def gradio_process(image: Image.Image, model_key: str):
62
+ if image is None:
63
+ return {"error": "No image provided"}
64
+
65
+ model = load_model(model_key)
66
+ result = model.predict(image)
67
+
68
+ return json.dumps({
69
+ "texts": result,
70
+ "image_size": {
71
+ "width": image.width,
72
+ "height": image.height
73
+ },
74
+ "mode": image.mode,
75
+ }, indent=4)
76
+
77
+ # Giao diện Gradio
78
+ demo = gr.Interface(
79
+ fn=gradio_process,
80
+ inputs=[
81
+ gr.Image(type="pil", label="Upload Image"),
82
+ gr.Dropdown(choices=list(AVAILABLE_MODELS.keys()), label="Chọn mô hình", value="TrOCR (Base Printed)"),
83
+ # gr.Textbox(label="Prompt (chỉ dùng cho EraX)", placeholder="Ảnh này có gì?")
84
+ ],
85
+ outputs=gr.JSON(label="Output (Text/JSON Extract)"),
86
+ title="Image to Text/JSON Extractor",
87
+ description="Upload an image and extract structured text using OCR."
88
+ )
89
+
90
+ if __name__ == "__main__":
91
+ demo.launch()
models.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline, AutoTokenizer, VisionEncoderDecoderModel, AutoProcessor
2
+ import torch
3
+ from PIL import Image
4
+ from io import BytesIO
5
+ import base64
6
+
7
+ # Chuyển ảnh thành base64 (tùy chọn nếu bạn cần hiển thị hoặc xuất)
8
+ def pil_to_base64(image: Image.Image, format="PNG") -> str:
9
+ buffered = BytesIO()
10
+ image.save(buffered, format=format)
11
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
12
+
13
+ def parse_to_json(result_text):
14
+ """
15
+ Nếu output là các dòng 'key: value', parse thành dict.
16
+ Nếu không, gói nguyên text vào trường 'text'.
17
+ """
18
+ data = {}
19
+ lines = [line.strip() for line in result_text.splitlines() if line.strip()]
20
+ for line in lines:
21
+ if ":" in line:
22
+ key, val = line.split(":", 1)
23
+ data[key.strip()] = val.strip()
24
+ else:
25
+ # Nếu không tách được, gom vào list chung
26
+ data.setdefault("text", []).append(line)
27
+ # Nếu chỉ có list 'text', chuyển về chuỗi
28
+ if set(data.keys()) == {"text"}:
29
+ data = {"text": "\n".join(data["text"])}
30
+ return data
31
+
32
+ # class TrOCRModel:
33
+ # def __init__(self, model_id="microsoft/trocr-base-printed", cache_dir=None, device=None):
34
+ # self.model_id = model_id
35
+ # self.cache_dir = cache_dir
36
+ # self.device = device
37
+
38
+ # self.processor = TrOCRProcessor.from_pretrained(self.model_id, cache_dir=self.cache_dir)
39
+ # self.model = VisionEncoderDecoderModel.from_pretrained(self.model_id, cache_dir=self.cache_dir)
40
+ # self.model.to(self.device)
41
+
42
+ # def predict(self, image: Image.Image) -> str:
43
+ # if image is None:
44
+ # raise ValueError("No image provided")
45
+
46
+ # image = image.convert("RGB")
47
+ # pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
48
+
49
+ # with torch.no_grad():
50
+ # generated_ids = self.model.generate(pixel_values)
51
+ # generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
52
+
53
+ # return generated_text
54
+
55
+ class TrOCRModel:
56
+ def __init__(self, model_id="microsoft/trocr-base-printed", cache_dir=None, device=None):
57
+ self.pipe = pipeline("image-to-text", model=model_id, device=device)
58
+
59
+ def predict(self, image: Image.Image) -> str:
60
+ if image is None:
61
+ raise ValueError("No image provided")
62
+
63
+ image = image.convert("RGB")
64
+ result = self.pipe(image)
65
+ return result[0]['generated_text'] if result else ""
66
+
67
+ class EraXModel:
68
+ def __init__(self, model_id="erax-ai/EraX-VL-2B-V1.5", cache_dir=None, device=None):
69
+ self.pipe = pipeline("image-to-text", model=model_id, device=device)
70
+
71
+ def predict(self, image: Image.Image) -> str:
72
+ if image is None:
73
+ raise ValueError("No image provided")
74
+
75
+ decoded_image_text = pil_to_base64(image)
76
+ base64_data = f"data:image;base64,{decoded_image_text}"
77
+ messages = [
78
+ {
79
+ "role": "user",
80
+ "content": [
81
+ {
82
+ "type": "image",
83
+ "image": base64_data,
84
+ },
85
+ {
86
+ "type": "text",
87
+ "text": "Trích xuất thông tin nội dung từ hình ảnh được cung cấp."
88
+ },
89
+ ],
90
+ }
91
+ ]
92
+
93
+ result = self.pipe(image)[0]['generated_texts']
94
+ return parse_to_json(result) if result else {}
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Pillow
2
+ transformers
3
+ torch
4
+ torchvision
5
+ gradio