Transformers documentation
Subclassing Trainer methods
Subclassing Trainer methods
Subclass Trainer methods to change training behavior without rewriting the entire loop. Subclassing modifies the training loop, for example the forward pass or loss computation.
Before subclassing, consider whether you need to change what Trainer computes or when and whether it acts. For timing and conditional logic, use a Callback instead. Callbacks control when things happen (logging, evaluation, early stopping) and subclassing changes what happens (loss computation, data loading, optimization).
See the Trainer API docs for a complete list of methods you can subclass. Private methods (prefixed with
_) like_save_checkpointor_evaluatecan also be overridden, but these may change without notice.
get_train_dataloader
The standard get_train_dataloader() method loads one batch, trains on it, discards it, and loads the next batch.
def get_train_dataloader(self):
return self._get_dataloader(
batch_size=self._train_batch_size,
...
)GRPO is an online reinforcement learning algorithm that generates completions before training on them. Generating completions every step is expensive because it’s autoregressive. A 512-token completion requires ~512 sequential forward passes compared to one forward pass for a training step. GRPOTrainer subclasses get_train_dataloader() to batch generation across multiple steps.
trl.GRPOTrainer.get_train_dataloader loads batches of generation prompts for multiple training steps at once by multiplying batch size by a steps_per_generation argument. If train_batch_size=4 and steps_per_generation=8, the dataloader produces batches of 32, cutting generation cost by 8x.
def get_train_dataloader(self):
dataloader_params = {
"batch_size": self._train_batch_size * self.args.steps_per_generation, # this is the only change
...
}compute_loss
compute_loss() returns the cross-entropy loss calculated by the model.
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
...
outputs = model(**inputs)
...
loss = outputs["loss"] # get loss from model
return (loss, outputs) if return_outputs else lossDPO measures how strongly the policy model prefers a chosen response over a rejected one, relative to a reference model. DPOTrainer subclasses compute_loss() because the loss computation differs from standard cross-entropy in several ways:
- the model never sees labels; it only returns logits for DPO to calculate log-probs from
- chosen and rejected responses are concatenated
- a reference model calculates its own log-probs
- the loss is a function of
π_chosen,π_rejected,π_ref_chosen,π_ref_rejected
None of the above fits the standard Trainer.compute_loss() method.
def compute_loss(
self,
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
...
outputs = model(**inputs)
logits = outputs.logits
logps = get_logps(logits, inputs)
chosen_logps, rejected_logps = logps.chunk(2, dim=0) # batch is [chosen, rejected]
ref_logits = self.ref_model(**inputs).logits
ref_logps = get_logps(ref_logits, inputs)
ref_chosen_logps, ref_rejected_logps = ref_logps.chunk(2, dim=0) # batch is [chosen, rejected]
chosen_scores = chosen_logps - ref_chosen_logps
rejected_scores = rejected_logps - ref_rejected_logps
per_sequence_loss = -F.logsigmoid(self.beta * chosen_scores - rejected_scores)
loss = per_sequence_loss.mean()
return (loss, outputs) if return_outputs else lossNext steps
- For more real-world examples, see how GRPOTrainer and DPOTrainer extend Trainer in TRL, or how Axolotl builds custom trainers on top of it.
- Check the Callbacks guide if you only need to customize what happens during a training event such as logging metrics at the end of a training step.