rahul7star commited on
Commit
67d08a5
·
verified ·
1 Parent(s): 7088fb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -18
app.py CHANGED
@@ -140,9 +140,7 @@ def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
140
  model_obj = pipe["model"]
141
  tokenizer = pipe["tokenizer"]
142
 
143
- # Ensure model is in train mode
144
  model_obj.train()
145
-
146
  target_modules = find_target_modules(model_obj)
147
  lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
148
  lora_module = get_peft_model(model_obj, lcfg)
@@ -167,32 +165,31 @@ def train_lora_stream(base_model, dataset_src, csv_name, text_cols, output_dir,
167
  short_text = str(texts.get(text_cols[0], "") or "")
168
  long_text = str(texts.get(text_cols[1], "") or "")
169
 
170
- # Encode both short and long as supervised pairs
171
- inputs = tokenizer(short_text, return_tensors="pt", truncation=True, padding=True, max_length=1024)
172
- labels = tokenizer(long_text, return_tensors="pt", truncation=True, padding=True, max_length=1024)
 
 
 
 
 
 
173
 
174
- inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
175
- labels = labels["input_ids"].to(DEVICE)
176
 
177
- # Forward pass with labels — ensures gradient flow
178
- outputs = lora_module(**inputs, labels=labels)
179
 
180
- # Handle loss properly
181
  forward_loss = getattr(outputs, "loss", None)
182
  if forward_loss is None:
183
- # Fallback to MSE loss between logits and labels
184
  logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
185
- forward_loss = torch.nn.functional.mse_loss(
186
- logits.float(), torch.nn.functional.one_hot(labels, num_classes=logits.size(-1)).float()
187
  )
188
 
189
- # Ensure loss requires grad
190
- if not forward_loss.requires_grad:
191
- forward_loss = forward_loss.clone().detach().requires_grad_(True)
192
-
193
  logs.append(f"[DEBUG] Step {step_counter}, forward_loss: {forward_loss.item():.6f}")
194
 
195
- # Backprop
196
  optimizer.zero_grad()
197
  accelerator.backward(forward_loss)
198
  optimizer.step()
 
140
  model_obj = pipe["model"]
141
  tokenizer = pipe["tokenizer"]
142
 
 
143
  model_obj.train()
 
144
  target_modules = find_target_modules(model_obj)
145
  lcfg = LoraConfig(r=r, lora_alpha=alpha, target_modules=target_modules, lora_dropout=0.0)
146
  lora_module = get_peft_model(model_obj, lcfg)
 
165
  short_text = str(texts.get(text_cols[0], "") or "")
166
  long_text = str(texts.get(text_cols[1], "") or "")
167
 
168
+ # --- FIX: Tokenize as text pair to align sequence lengths ---
169
+ enc = tokenizer(
170
+ short_text,
171
+ text_pair=long_text,
172
+ return_tensors="pt",
173
+ padding="max_length",
174
+ truncation=True,
175
+ max_length=512, # enforce same length for both
176
+ )
177
 
178
+ enc = {k: v.to(accelerator.device) for k, v in enc.items()}
179
+ enc["labels"] = enc["input_ids"].clone()
180
 
181
+ # --- Forward pass ---
182
+ outputs = lora_module(**enc)
183
 
 
184
  forward_loss = getattr(outputs, "loss", None)
185
  if forward_loss is None:
 
186
  logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
187
+ forward_loss = torch.nn.functional.cross_entropy(
188
+ logits.view(-1, logits.size(-1)), enc["labels"].view(-1), ignore_index=tokenizer.pad_token_id
189
  )
190
 
 
 
 
 
191
  logs.append(f"[DEBUG] Step {step_counter}, forward_loss: {forward_loss.item():.6f}")
192
 
 
193
  optimizer.zero_grad()
194
  accelerator.backward(forward_loss)
195
  optimizer.step()