Spaces:
Sleeping
Sleeping
| from transformers import pipeline, AutoTokenizer, VisionEncoderDecoderModel, AutoProcessor, Qwen2VLForConditionalGeneration | |
| from PIL import Image | |
| from io import BytesIO | |
| import base64 | |
| import json | |
| import torch | |
| from qwen_vl_utils import process_vision_info | |
| import prompt | |
| # Chuyển ảnh thành base64 (tùy chọn nếu bạn cần hiển thị hoặc xuất) | |
| def pil_to_base64(image: Image.Image, format="PNG") -> str: | |
| buffered = BytesIO() | |
| image.save(buffered, format=format) | |
| buffered.seek(0) | |
| return base64.b64encode(buffered.read()).decode("utf-8") | |
| def parse_to_json(result_text): | |
| """ | |
| Nếu output là các dòng 'key: value', parse thành dict. | |
| Nếu không, gói nguyên text vào trường 'text'. | |
| """ | |
| data = {} | |
| lines = [line.strip() for line in result_text.splitlines() if line.strip()] | |
| for line in lines: | |
| if ":" in line: | |
| key, val = line.split(":", 1) | |
| data[key.strip()] = val.strip() | |
| else: | |
| # Nếu không tách được, gom vào list chung | |
| data.setdefault("text", []).append(line) | |
| # Nếu chỉ có list 'text', chuyển về chuỗi | |
| if set(data.keys()) == {"text"}: | |
| data = {"text": "\n".join(data["text"])} | |
| return data | |
| # class TrOCRModel: | |
| # def __init__(self, model_id="microsoft/trocr-base-printed", cache_dir=None, device=None): | |
| # self.model_id = model_id | |
| # self.cache_dir = cache_dir | |
| # self.device = device | |
| # self.processor = TrOCRProcessor.from_pretrained(self.model_id, cache_dir=self.cache_dir) | |
| # self.model = VisionEncoderDecoderModel.from_pretrained(self.model_id, cache_dir=self.cache_dir) | |
| # self.model.to(self.device) | |
| # def predict(self, image: Image.Image) -> str: | |
| # if image is None: | |
| # raise ValueError("No image provided") | |
| # image = image.convert("RGB") | |
| # pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device) | |
| # with torch.no_grad(): | |
| # generated_ids = self.model.generate(pixel_values) | |
| # generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
| # return generated_text | |
| class TrOCRModel: | |
| def __init__(self, model_id="microsoft/trocr-base-printed", cache_dir=None, device=None): | |
| self.pipe = pipeline("image-to-text", model=model_id, device=device) | |
| def predict(self, image: Image.Image) -> str: | |
| if image is None: | |
| raise ValueError("No image provided") | |
| image = image.convert("RGB") | |
| result = self.pipe(image) | |
| return result[0]['generated_text'] if result else "" | |
| class EraXModel: | |
| def __init__(self, model_id="erax-ai/EraX-VL-2B-V1.5", cache_dir=None, device="auto"): | |
| size = { | |
| "shortest_edge": 56 * 56, # đủ chi tiết, dùng phổ biến trong ViT/TrOCR | |
| "longest_edge": 1280 * 28 * 28 # giới hạn chiều dài ảnh nếu là ảnh dọc hoặc ngang dài | |
| } | |
| # with open(config_json_path, 'r', encoding='utf-8') as f: | |
| # self.json_template = json.dumps(json.load(f), ensure_ascii=False) | |
| self.model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| model_id, | |
| cache_dir=cache_dir, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="eager", # replace with "flash_attention_2" if your GPU is Ampere architecture | |
| device_map="auto", | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir) | |
| self.processor = AutoProcessor.from_pretrained( | |
| model_id, | |
| size=size, | |
| cache_dir=cache_dir, | |
| ) | |
| # Generation configs | |
| self.generation_config = self.model.generation_config | |
| self.generation_config.do_sample = True | |
| self.generation_config.temperature = 1.0 | |
| self.generation_config.top_k = 1 | |
| self.generation_config.top_p = 0.9 | |
| self.generation_config.min_p = 0.1 | |
| self.generation_config.best_of = 5 | |
| self.generation_config.max_new_tokens = 784 | |
| self.generation_config.repetition_penalty = 1.06 | |
| def predict(self, image: Image.Image) -> str: | |
| if image is None: | |
| raise ValueError("No image provided") | |
| # image_path = "image.png" | |
| # # Read and encode the image | |
| # with open(image_path, "rb") as f: | |
| # encoded_image = base64.b64encode(f.read()) | |
| # decoded_image_text = encoded_image.decode('utf-8') | |
| # base64_data = f"data:image;base64,{decoded_image_text}" | |
| decoded_image_text = pil_to_base64(image) | |
| base64_data = f"data:image;base64,{decoded_image_text}" | |
| # Prepare messages | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": base64_data, | |
| }, | |
| { | |
| "type": "text", | |
| "text": prompt.CCCD_BOTH_SIDE_PROMPT, | |
| }, | |
| ], | |
| } | |
| ] | |
| # Prepare prompt | |
| tokenized_text = self.processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| print("Tokenized text") | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| print("Processed vision info done") | |
| inputs = self.processor( | |
| text=[tokenized_text], | |
| # images=image_inputs, | |
| images=[image], | |
| # videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(self.model.device) | |
| print("Inputs prepared") | |
| # Inference | |
| print("Generating text...") | |
| generated_ids = self.model.generate(**inputs, generation_config=self.generation_config) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = self.processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| return output_text[0] | |