DonWare commited on
Commit
cefd326
·
verified ·
1 Parent(s): bdfaca3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -91
app.py CHANGED
@@ -1,91 +1,36 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- from langdetect import detect, DetectorFactory
4
- import torch
5
-
6
- # Fijar semilla para consistencia en langdetect
7
- DetectorFactory.seed = 0
8
-
9
- # Usar todos los núcleos CPU disponibles
10
- torch.set_num_threads(torch.get_num_threads())
11
-
12
- # Modelo NLLB-200 pequeño (distilled 600M)
13
- MODEL_NAME = "facebook/nllb-200-distilled-600M"
14
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
16
-
17
- # Idiomas disponibles (subset de los soportados por NLLB-200)
18
- LANGUAGES = {
19
- "Inglés": "eng_Latn",
20
- "Español": "spa_Latn",
21
- "Francés": "fra_Latn",
22
- "Alemán": "deu_Latn",
23
- "Italiano": "ita_Latn",
24
- "Portugués": "por_Latn",
25
- "Holandés": "nld_Latn",
26
- "Ruso": "rus_Cyrl",
27
- "Japonés": "jpn_Jpan",
28
- "Chino": "zho_Hans",
29
- }
30
-
31
- # Mapeo ISO 639-1 a NLLB
32
- ISO_TO_NLLB = {
33
- "en": "eng_Latn",
34
- "es": "spa_Latn",
35
- "fr": "fra_Latn",
36
- "de": "deu_Latn",
37
- "it": "ita_Latn",
38
- "pt": "por_Latn",
39
- "nl": "nld_Latn",
40
- "ru": "rus_Cyrl",
41
- "ja": "jpn_Jpan",
42
- "zh": "zho_Hans",
43
- }
44
-
45
- def translate_text(text, target_lang):
46
- if not text.strip():
47
- return "⚠️ No hay texto para traducir."
48
-
49
- # Detectar idioma de entrada
50
- try:
51
- src_lang = detect(text)
52
- src_lang_code = ISO_TO_NLLB.get(src_lang, "eng_Latn")
53
- except Exception as e:
54
- print(f"Error al detectar idioma: {e}")
55
- src_lang_code = "eng_Latn"
56
-
57
- tgt_lang_code = LANGUAGES[target_lang]
58
-
59
- # Si ya está en el idioma deseado
60
- if src_lang_code == tgt_lang_code:
61
- return text
62
-
63
- try:
64
- # Obtener el token ID del idioma destino
65
- tgt_lang_id = tokenizer.convert_tokens_to_ids(f"<{tgt_lang_code}>")
66
-
67
- # Tokenizar y generar traducción
68
- inputs = tokenizer(text, return_tensors="pt", truncation=True)
69
- generated_tokens = model.generate(
70
- **inputs,
71
- forced_bos_token_id=tgt_lang_id,
72
- max_length=512
73
- )
74
- translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
75
- return translated_text
76
-
77
- except Exception as e:
78
- print(f"Error durante la traducción: {e}")
79
- return "❌ Error en la traducción. Intenta con otro texto o idioma."
80
-
81
- # Interfaz Gradio
82
- with gr.Blocks() as demo:
83
- gr.Markdown("## Traductor Profesional - NLLB-200 Distilled 600M (CPU)")
84
-
85
- text_input = gr.Textbox(label="Ingresa el texto", lines=10)
86
- text_target = gr.Dropdown(list(LANGUAGES.keys()), label="Idioma de destino")
87
- text_output = gr.Textbox(label="Traducción", lines=10)
88
-
89
- gr.Button("Traducir").click(translate_text, inputs=[text_input, text_target], outputs=text_output)
90
-
91
- demo.launch()
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
4
+
5
+ app = FastAPI(title="Multilingual Translation API", description="API REST con M2M100 para traducción texto a texto 🌍", version="1.0")
6
+
7
+ # 🔤 Cargar modelo y tokenizer al iniciar
8
+ model_name = "facebook/m2m100_418M"
9
+ tokenizer = M2M100Tokenizer.from_pretrained(model_name)
10
+ model = M2M100ForConditionalGeneration.from_pretrained(model_name)
11
+
12
+ # 📩 Modelo del body de la solicitud
13
+ class TranslationRequest(BaseModel):
14
+ text: str
15
+ source_lang: str
16
+ target_lang: str
17
+
18
+ @app.post("/translate")
19
+ async def translate_text(req: TranslationRequest):
20
+ tokenizer.src_lang = req.source_lang
21
+ encoded = tokenizer(req.text, return_tensors="pt")
22
+ generated_tokens = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(req.target_lang))
23
+ translated_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
24
+ return {
25
+ "source": req.text,
26
+ "translation": translated_text,
27
+ "source_lang": req.source_lang,
28
+ "target_lang": req.target_lang
29
+ }
30
+
31
+ @app.get("/")
32
+ async def root():
33
+ return {
34
+ "message": "🌍 Bienvenido a la API de traducción M2M100",
35
+ "usage": "POST /translate con {text, source_lang, target_lang}"
36
+ }