Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -140,9 +140,7 @@ def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
|
|
| 140 |
model_obj = pipe["model"]
|
| 141 |
tokenizer = pipe["tokenizer"]
|
| 142 |
|
| 143 |
-
# Ensure model is in train mode
|
| 144 |
model_obj.train()
|
| 145 |
-
|
| 146 |
target_modules = find_target_modules(model_obj)
|
| 147 |
lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
|
| 148 |
lora_module = get_peft_model(model_obj, lcfg)
|
|
@@ -167,32 +165,31 @@ def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
|
|
| 167 |
short_text = str(texts.get(text_cols[0], "") or "")
|
| 168 |
long_text = str(texts.get(text_cols[1], "") or "")
|
| 169 |
|
| 170 |
-
#
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
-
|
| 175 |
-
labels =
|
| 176 |
|
| 177 |
-
# Forward pass
|
| 178 |
-
outputs = lora_module(**
|
| 179 |
|
| 180 |
-
# Handle loss properly
|
| 181 |
forward_loss = getattr(outputs, "loss", None)
|
| 182 |
if forward_loss is None:
|
| 183 |
-
# Fallback to MSE loss between logits and labels
|
| 184 |
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
|
| 185 |
-
forward_loss = torch.nn.functional.
|
| 186 |
-
logits.
|
| 187 |
)
|
| 188 |
|
| 189 |
-
# Ensure loss requires grad
|
| 190 |
-
if not forward_loss.requires_grad:
|
| 191 |
-
forward_loss = forward_loss.clone().detach().requires_grad_(True)
|
| 192 |
-
|
| 193 |
logs.append(f"[DEBUG] Step {step_counter}, forward_loss: {forward_loss.item():.6f}")
|
| 194 |
|
| 195 |
-
# Backprop
|
| 196 |
optimizer.zero_grad()
|
| 197 |
accelerator.backward(forward_loss)
|
| 198 |
optimizer.step()
|
|
|
|
| 140 |
model_obj = pipe["model"]
|
| 141 |
tokenizer = pipe["tokenizer"]
|
| 142 |
|
|
|
|
| 143 |
model_obj.train()
|
|
|
|
| 144 |
target_modules = find_target_modules(model_obj)
|
| 145 |
lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
|
| 146 |
lora_module = get_peft_model(model_obj, lcfg)
|
|
|
|
| 165 |
short_text = str(texts.get(text_cols[0], "") or "")
|
| 166 |
long_text = str(texts.get(text_cols[1], "") or "")
|
| 167 |
|
| 168 |
+
# --- FIX: Tokenize as text pair to align sequence lengths ---
|
| 169 |
+
enc = tokenizer(
|
| 170 |
+
short_text,
|
| 171 |
+
text_pair=long_text,
|
| 172 |
+
return_tensors="pt",
|
| 173 |
+
padding="max_length",
|
| 174 |
+
truncation=True,
|
| 175 |
+
max_length=512, # enforce same length for both
|
| 176 |
+
)
|
| 177 |
|
| 178 |
+
enc = {k: v.to(accelerator.device) for k, v in enc.items()}
|
| 179 |
+
enc["labels"] = enc["input_ids"].clone()
|
| 180 |
|
| 181 |
+
# --- Forward pass ---
|
| 182 |
+
outputs = lora_module(**enc)
|
| 183 |
|
|
|
|
| 184 |
forward_loss = getattr(outputs, "loss", None)
|
| 185 |
if forward_loss is None:
|
|
|
|
| 186 |
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
|
| 187 |
+
forward_loss = torch.nn.functional.cross_entropy(
|
| 188 |
+
logits.view(-1, logits.size(-1)), enc["labels"].view(-1), ignore_index=tokenizer.pad_token_id
|
| 189 |
)
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
logs.append(f"[DEBUG] Step {step_counter}, forward_loss: {forward_loss.item():.6f}")
|
| 192 |
|
|
|
|
| 193 |
optimizer.zero_grad()
|
| 194 |
accelerator.backward(forward_loss)
|
| 195 |
optimizer.step()
|