DSDUDEd commited on
Commit
1959595
·
verified ·
1 Parent(s): 29ba998

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -76
app.py CHANGED
@@ -1,29 +1,15 @@
1
- import gradio as gr
2
- from transformers import (
3
- AutoModelForCausalLM,
4
- AutoTokenizer,
5
- Trainer,
6
- TrainingArguments,
7
- DataCollatorForSeq2Seq,
8
- )
9
  from datasets import load_dataset, Dataset
10
- import random
11
-
12
- # -----------------------------
13
- # Load Base Model
14
- # -----------------------------
15
- model_name = "PerceptronAI/Isaac-0.1"
16
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
17
- model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
18
 
19
  # -----------------------------
20
  # Load Datasets
21
  # -----------------------------
22
- print("📥 Loading datasets...")
23
-
24
  pii_ds = load_dataset("ai4privacy/pii-masking-300k")
25
  cnn_ds = load_dataset("abisee/cnn_dailymail", "1.0.0")
26
-
27
  try:
28
  docqa_ds = load_dataset("vidore/syntheticDocQA_energy_train")
29
  except Exception as e:
@@ -31,90 +17,107 @@ except Exception as e:
31
  docqa_ds = None
32
 
33
  # -----------------------------
34
- # Build Training Samples
35
  # -----------------------------
36
- def make_pairs_pii(example):
37
- return {"input": example["text"], "output": example["masked_text"]}
38
 
39
- def make_pairs_cnn(example):
40
- return {"input": example["article"], "output": example["highlights"]}
 
 
 
 
 
 
 
 
 
41
 
42
- pii_pairs = pii_ds["train"].map(make_pairs_pii).select(range(1000)) # small subset
43
- cnn_pairs = cnn_ds["train"].map(make_pairs_cnn).select(range(1000))
 
44
 
45
- pairs = []
46
- pairs.extend(pii_pairs)
47
- pairs.extend(cnn_pairs)
 
 
 
 
 
 
48
 
49
  if docqa_ds is not None:
50
- def make_pairs_docqa(example):
51
- return {"input": example["question"], "output": example["answer"]}
52
- docqa_pairs = docqa_ds["train"].map(make_pairs_docqa).select(range(1000))
53
- pairs.extend(docqa_pairs)
 
 
 
54
 
55
  dataset = Dataset.from_list(pairs)
56
 
57
  # -----------------------------
58
- # Tokenization
59
  # -----------------------------
60
- def tokenize(batch):
61
- inputs = tokenizer(batch["input"], truncation=True, padding="max_length", max_length=256)
62
- outputs = tokenizer(batch["output"], truncation=True, padding="max_length", max_length=256)
63
- inputs["labels"] = outputs["input_ids"]
64
- return inputs
65
 
66
- tokenized_dataset = dataset.map(tokenize, batched=True)
 
 
 
 
 
 
67
 
68
  # -----------------------------
69
  # Training
70
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
71
  training_args = TrainingArguments(
72
- output_dir="./cass2.0",
73
- overwrite_output_dir=True,
 
 
74
  num_train_epochs=1,
75
- per_device_train_batch_size=2,
76
- save_steps=100,
77
- save_total_limit=2,
78
- logging_steps=20,
79
- learning_rate=5e-5,
80
- fp16=True,
81
  )
82
 
83
- data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
84
-
85
  trainer = Trainer(
86
  model=model,
87
  args=training_args,
88
- train_dataset=tokenized_dataset,
 
89
  tokenizer=tokenizer,
90
- data_collator=data_collator,
91
  )
92
 
93
- print("🚀 Training Cass2.0...")
94
  trainer.train()
95
- print("✅ Training complete!")
96
 
97
  # -----------------------------
98
- # Simple Chat UI
99
  # -----------------------------
100
- from transformers import pipeline
101
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
102
-
103
- def chat(message, history):
104
- prompt = "".join([f"User: {m[0]}\nCass2.0: {m[1]}\n" for m in history])
105
- prompt += f"User: {message}\nCass2.0:"
106
- output = pipe(prompt, max_length=256, do_sample=True, temperature=0.7)[0]["generated_text"]
107
- reply = output.split("Cass2.0:")[-1].strip()
108
- history.append((message, reply))
109
- return history, history
110
-
111
- with gr.Blocks() as demo:
112
- gr.Markdown("# 🤖 Cass2.0 — Trained AI Assistant")
113
- chatbot = gr.Chatbot()
114
- msg = gr.Textbox(label="Type your message")
115
- clear = gr.Button("Clear")
116
-
117
- msg.submit(chat, [msg, chatbot], [chatbot, chatbot])
118
- clear.click(lambda: None, None, chatbot)
119
-
120
- demo.launch()
 
1
+ import os
 
 
 
 
 
 
 
2
  from datasets import load_dataset, Dataset
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer
4
+ import evaluate
5
+ import numpy as np
6
+ import gradio as gr
 
 
 
 
7
 
8
  # -----------------------------
9
  # Load Datasets
10
  # -----------------------------
 
 
11
  pii_ds = load_dataset("ai4privacy/pii-masking-300k")
12
  cnn_ds = load_dataset("abisee/cnn_dailymail", "1.0.0")
 
13
  try:
14
  docqa_ds = load_dataset("vidore/syntheticDocQA_energy_train")
15
  except Exception as e:
 
17
  docqa_ds = None
18
 
19
  # -----------------------------
20
+ # Build Pairs from Datasets (Safe Version)
21
  # -----------------------------
22
+ pairs = []
 
23
 
24
+ def safe_map(dataset, input_keys, output_keys, name, limit=1000):
25
+ """
26
+ dataset: Hugging Face dataset split
27
+ input_keys: list of possible input column names
28
+ output_keys: list of possible output column names
29
+ name: dataset name (for logs)
30
+ limit: number of samples to select
31
+ """
32
+ available = dataset.column_names
33
+ chosen_in = next((k for k in input_keys if k in available), None)
34
+ chosen_out = next((k for k in output_keys if k in available), None)
35
 
36
+ if not chosen_in or not chosen_out:
37
+ print(f"⚠️ Skipping {name} (no matching columns). Available: {available}")
38
+ return []
39
 
40
+ print(f"✅ Using {name}: input='{chosen_in}', output='{chosen_out}'")
41
+
42
+ def make_pairs(example):
43
+ return {"input": example[chosen_in], "output": example[chosen_out]}
44
+
45
+ return dataset.map(make_pairs).select(range(min(limit, len(dataset))))
46
+
47
+ pii_pairs = safe_map(pii_ds["train"], ["original", "text"], ["masked", "masked_text"], "PII")
48
+ cnn_pairs = safe_map(cnn_ds["train"], ["article"], ["highlights", "summary"], "CNN/DailyMail")
49
 
50
  if docqa_ds is not None:
51
+ docqa_pairs = safe_map(docqa_ds["train"], ["question"], ["answer"], "DocQA")
52
+ else:
53
+ docqa_pairs = []
54
+
55
+ pairs.extend(pii_pairs)
56
+ pairs.extend(cnn_pairs)
57
+ pairs.extend(docqa_pairs)
58
 
59
  dataset = Dataset.from_list(pairs)
60
 
61
  # -----------------------------
62
+ # Model + Tokenizer
63
  # -----------------------------
64
+ model_name = "google/flan-t5-small" # small, fast model
65
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
66
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
67
 
68
+ def tokenize_function(example):
69
+ model_inputs = tokenizer(example["input"], max_length=512, truncation=True)
70
+ labels = tokenizer(example["output"], max_length=128, truncation=True)
71
+ model_inputs["labels"] = labels["input_ids"]
72
+ return model_inputs
73
+
74
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
75
 
76
  # -----------------------------
77
  # Training
78
  # -----------------------------
79
+ metric = evaluate.load("rouge")
80
+
81
+ def compute_metrics(eval_pred):
82
+ predictions, labels = eval_pred
83
+ decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
84
+ labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
85
+ decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
86
+ result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
87
+ return {k: round(v * 100, 4) for k, v in result.items()}
88
+
89
  training_args = TrainingArguments(
90
+ output_dir="./results",
91
+ eval_strategy="no",
92
+ learning_rate=2e-5,
93
+ per_device_train_batch_size=8,
94
  num_train_epochs=1,
95
+ weight_decay=0.01,
96
+ logging_dir="./logs",
97
+ logging_steps=10,
98
+ save_strategy="no"
 
 
99
  )
100
 
 
 
101
  trainer = Trainer(
102
  model=model,
103
  args=training_args,
104
+ train_dataset=tokenized_datasets,
105
+ eval_dataset=None,
106
  tokenizer=tokenizer,
107
+ compute_metrics=compute_metrics
108
  )
109
 
 
110
  trainer.train()
 
111
 
112
  # -----------------------------
113
+ # Gradio App
114
  # -----------------------------
115
+ def generate_response(input_text):
116
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512)
117
+ outputs = model.generate(**inputs, max_new_tokens=128)
118
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
119
+
120
+ demo = gr.Interface(fn=generate_response, inputs="text", outputs="text", title="Cass 2.0 Model")
121
+
122
+ if __name__ == "__main__":
123
+ demo.launch()