| | import os |
| | import dataclasses |
| | import torch |
| | import transformers |
| | from transformers import Trainer, TrainingArguments, TrainerCallback |
| | from peft import LoraConfig, get_peft_model, TaskType |
| | from huggingface_hub import HfApi, login |
| | import wandb |
| | from dotenv import load_dotenv |
| | from config import TrainConfig, ModelConfig |
| | from model import MultiModalModel |
| | from data import AudioTextDataset, DataCollator |
| |
|
| |
|
| | class SamplePredictionCallback(TrainerCallback): |
| | """Every N steps, print ground-truth vs model-predicted transcript for a few samples.""" |
| |
|
| | def __init__(self, tokenizer, data_collator, train_dataset, sample_every_n_steps: int = 100, num_samples: int = 2, prompt: str = "Transcribe the following audio:"): |
| | self.tokenizer = tokenizer |
| | self.data_collator = data_collator |
| | self.train_dataset = train_dataset |
| | self.sample_every_n_steps = sample_every_n_steps |
| | self.num_samples = num_samples |
| | self.prompt = prompt |
| | def on_log(self, args, state, control, model=None, **kwargs): |
| | if state.global_step == 0 or state.global_step % self.sample_every_n_steps != 0: |
| | return |
| | if model is None: |
| | return |
| | model.eval() |
| | device = next(model.parameters()).device |
| | try: |
| | indices = [i % len(self.train_dataset) for i in range(self.num_samples)] |
| | batch = self.data_collator([self.train_dataset[i] for i in indices]) |
| | audio_values = batch["audio_values"].to(device) |
| | labels_batch = batch["labels"] |
| | continuations = batch.get("continuation", [""] * audio_values.size(0)) |
| | prompt_ids = self.tokenizer(self.prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device) |
| | prompt_ids = prompt_ids.expand(audio_values.size(0), -1) |
| | with torch.no_grad(): |
| | gen_ids = model.generate( |
| | input_ids=prompt_ids, |
| | audio_values=audio_values, |
| | max_new_tokens=120, |
| | do_sample=False, |
| | pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id, |
| | ) |
| | prompt_len = prompt_ids.size(1) |
| | |
| | |
| | columns = ["Step", "Audio Index", "Ground Truth", "Prediction", "Continuation"] |
| | table = wandb.Table(columns=columns) |
| | |
| | print(f"\n[WandB] Logging sample predictions at step {state.global_step}") |
| | |
| | for i in range(audio_values.size(0)): |
| | gt_tokens = [t for t in labels_batch[i].tolist() if t != -100] |
| | gt_text = self.tokenizer.decode(gt_tokens, skip_special_tokens=True).strip() |
| | pred_text = self.tokenizer.decode(gen_ids[i][prompt_len:], skip_special_tokens=True).strip() |
| | |
| | cont_ref = continuations[i] if i < len(continuations) else "" |
| | |
| | |
| | table.add_data(state.global_step, i, gt_text, pred_text, cont_ref) |
| | |
| | |
| | if wandb.run is not None: |
| | wandb.log({"sample_predictions": table}, step=state.global_step) |
| | else: |
| | print("Warning: WandB run not active, skipping logging.") |
| |
|
| | except Exception as e: |
| | print(f"[SamplePredictionCallback] Error: {e}\n") |
| | finally: |
| | model.train() |
| |
|
| |
|
| | import shutil |
| | import glob |
| | from transformers.trainer_utils import get_last_checkpoint |
| |
|
| | class AggressiveDeleteCallback(TrainerCallback): |
| | """ |
| | Deletes ALL existing checkpoints in output_dir *before* saving a new one |
| | to ensure we don't run out of disk space. |
| | Only keeps the one we are currently training on (in memory) effectively, |
| | but on disk we want 0 checkpoints just before save. |
| | |
| | WARNING: If save fails, we have NO checkpoints on disk. Risk accepted by user. |
| | """ |
| | def __init__(self, output_dir): |
| | self.output_dir = output_dir |
| |
|
| | def on_step_end(self, args, state, control, **kwargs): |
| | |
| | |
| | if args.save_strategy == "steps" and args.save_steps > 0: |
| | if state.global_step > 0 and state.global_step % args.save_steps == 0: |
| | |
| | print(f"\n[AggressiveDeleteCallback] Step {state.global_step}: Deleting old checkpoints to free space before saving...") |
| | |
| | |
| | |
| | |
| | ckpts = glob.glob(os.path.join(self.output_dir, "checkpoint-*")) |
| | for ckpt in ckpts: |
| | try: |
| | shutil.rmtree(ckpt) |
| | print(f" Deleted {ckpt}") |
| | except Exception as e: |
| | print(f" Failed to delete {ckpt}: {e}") |
| |
|
| | def train(): |
| | |
| | load_dotenv() |
| |
|
| | |
| | train_config = TrainConfig() |
| | model_config = ModelConfig() |
| | |
| | |
| | wandb.init( |
| | project=train_config.wandb_project, |
| | entity=train_config.wandb_entity, |
| | name=train_config.wandb_run_name, |
| | config=dataclasses.asdict(train_config), |
| | ) |
| |
|
| | |
| | |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.text_model_id) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | processor = transformers.AutoProcessor.from_pretrained(model_config.audio_model_id) |
| | |
| | |
| | model = MultiModalModel(model_config) |
| | |
| | |
| | if train_config.use_lora: |
| | peft_config = LoraConfig( |
| | task_type=TaskType.CAUSAL_LM, |
| | inference_mode=False, |
| | r=train_config.lora_r, |
| | lora_alpha=train_config.lora_alpha, |
| | lora_dropout=train_config.lora_dropout, |
| | target_modules=["q_proj", "v_proj"] |
| | ) |
| | model.llm = get_peft_model(model.llm, peft_config) |
| | model.llm.print_trainable_parameters() |
| | |
| | |
| | train_dataset = AudioTextDataset(train_config, processor, model_config, tokenizer) |
| | data_collator = DataCollator(processor, tokenizer) |
| | |
| | |
| | training_args = TrainingArguments( |
| | output_dir=train_config.output_dir, |
| | per_device_train_batch_size=train_config.batch_size, |
| | gradient_accumulation_steps=train_config.accum_steps, |
| | learning_rate=train_config.learning_rate, |
| | lr_scheduler_type=train_config.lr_scheduler_type, |
| | num_train_epochs=train_config.num_epochs, |
| | max_steps=train_config.max_steps, |
| | bf16=train_config.use_bf16, |
| | gradient_checkpointing=train_config.gradient_checkpointing, |
| | dataloader_num_workers=train_config.dataloader_num_workers, |
| | dataloader_pin_memory=train_config.dataloader_pin_memory, |
| | logging_steps=train_config.log_steps, |
| | logging_first_step=True, |
| | logging_nan_inf_filter=True, |
| | save_steps=train_config.save_steps, |
| | save_total_limit=train_config.save_total_limit, |
| | eval_strategy="no", |
| | remove_unused_columns=False, |
| | report_to="wandb", |
| | log_level="info", |
| | log_level_replica="info", |
| | ) |
| |
|
| | sample_callback = SamplePredictionCallback( |
| | tokenizer=tokenizer, |
| | data_collator=data_collator, |
| | train_dataset=train_dataset, |
| | sample_every_n_steps=train_config.sample_pred_every_steps, |
| | num_samples=2, |
| | prompt="Transcribe the following audio:", |
| | ) |
| | |
| | aggressive_delete_callback = AggressiveDeleteCallback(train_config.output_dir) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_dataset, |
| | data_collator=data_collator, |
| | callbacks=[sample_callback, aggressive_delete_callback], |
| | ) |
| |
|
| | total_steps = train_config.max_steps |
| | print(f"\n>>> Training: max_steps={total_steps}, batch_size={train_config.batch_size}, " |
| | f"grad_accum={train_config.accum_steps} (effective batch={train_config.batch_size * train_config.accum_steps})") |
| | print(f">>> Sample predictions (GT vs predicted transcript) every {train_config.sample_pred_every_steps} steps.\n") |
| |
|
| | |
| | last_checkpoint = get_last_checkpoint(train_config.output_dir) |
| | if last_checkpoint is not None: |
| | print(f">>> Resuming from checkpoint: {last_checkpoint}") |
| | trainer.train(resume_from_checkpoint=last_checkpoint) |
| | else: |
| | trainer.train() |
| | |
| | |
| | trainer.save_model(train_config.output_dir) |
| | tokenizer.save_pretrained(train_config.output_dir) |
| | processor.save_pretrained(train_config.output_dir) |
| |
|
| | |
| | if train_config.push_to_hub: |
| | print(f"\n>>> Pushing model to Hugging Face Hub: {train_config.hub_model_id}") |
| | if train_config.hub_token: |
| | login(token=train_config.hub_token) |
| | |
| | api = HfApi() |
| | |
| | |
| | |
| | try: |
| | api.create_repo(repo_id=train_config.hub_model_id, private=train_config.hub_private_repo, exist_ok=True) |
| | except Exception as e: |
| | print(f"Warning: Could not create repo {train_config.hub_model_id}. Error: {e}") |
| | |
| | |
| | try: |
| | api.upload_folder( |
| | folder_path=train_config.output_dir, |
| | repo_id=train_config.hub_model_id, |
| | repo_type="model", |
| | ) |
| | |
| | |
| | for file in ["model.py", "config.py", "data.py", "inference.py"]: |
| | if os.path.exists(file): |
| | api.upload_file( |
| | path_or_fileobj=file, |
| | path_in_repo=file, |
| | repo_id=train_config.hub_model_id, |
| | repo_type="model", |
| | ) |
| |
|
| | print(f">>> Successfully pushed to {train_config.hub_model_id}") |
| | except Exception as e: |
| | print(f"Error pushing to hub: {e}") |
| |
|
| | if __name__ == "__main__": |
| | train() |
| |
|