Files
transformers/docs/source/en/trainer_customize.md
陈赣 06f1fd69a6
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
first commit
2026-06-05 16:53:03 +08:00

5.0 KiB

Subclassing Trainer methods

Subclass [Trainer] methods to change training behavior without rewriting the entire loop. Subclassing modifies the training loop, for example the forward pass or loss computation.

Before subclassing, consider whether you need to change what [Trainer] computes or when and whether it acts. For timing and conditional logic, use a Callback instead. Callbacks control when things happen (logging, evaluation, early stopping) and subclassing changes what happens (loss computation, data loading, optimization).

Note

See the [Trainer] API docs for a complete list of methods you can subclass. Private methods (prefixed with _) like _save_checkpoint or _evaluate can also be overridden, but these may change without notice.

get_train_dataloader

The standard [~Trainer.get_train_dataloader] method loads one batch, trains on it, discards it, and loads the next batch.

def get_train_dataloader(self):
    return self._get_dataloader(
        batch_size=self._train_batch_size,
        ...
)

GRPO is an online reinforcement learning algorithm that generates completions before training on them. Generating completions every step is expensive because it's autoregressive. A 512-token completion requires ~512 sequential forward passes compared to one forward pass for a training step. [~trl.GRPOTrainer] subclasses [~Trainer.get_train_dataloader] to batch generation across multiple steps.

[trl.GRPOTrainer.get_train_dataloader] loads batches of generation prompts for multiple training steps at once by multiplying batch size by a steps_per_generation argument. If train_batch_size=4 and steps_per_generation=8, the dataloader produces batches of 32, cutting generation cost by 8x.

def get_train_dataloader(self):
    dataloader_params = {
        "batch_size": self._train_batch_size * self.args.steps_per_generation, # this is the only change
        ...
    }

compute_loss

[~Trainer.compute_loss] returns the cross-entropy loss calculated by the model.

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    ...
    outputs = model(**inputs)
    ...
    loss = outputs["loss"] # get loss from model

    return (loss, outputs) if return_outputs else loss

DPO measures how strongly the policy model prefers a chosen response over a rejected one, relative to a reference model. [~trl.DPOTrainer] subclasses [~Trainer.compute_loss] because the loss computation differs from standard cross-entropy in several ways:

  • the model never sees labels; it only returns logits for DPO to calculate log-probs from
  • chosen and rejected responses are concatenated
  • a reference model calculates its own log-probs
  • the loss is a function of π_chosen, π_rejected, π_ref_chosen, π_ref_rejected

None of the above fits the standard [Trainer.compute_loss] method.

def compute_loss(
    self,
    model: PreTrainedModel | nn.Module,
    inputs: dict[str, torch.Tensor | Any],
    return_outputs=False,
    num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
    ...
    outputs = model(**inputs)
    logits = outputs.logits
    logps = get_logps(logits, inputs)
    chosen_logps, rejected_logps = logps.chunk(2, dim=0)  # batch is [chosen, rejected]
    ref_logits = self.ref_model(**inputs).logits
    ref_logps = get_logps(ref_logits, inputs)
    ref_chosen_logps, ref_rejected_logps = ref_logps.chunk(2, dim=0)  # batch is [chosen, rejected]
    chosen_scores = chosen_logps - ref_chosen_logps
    rejected_scores = rejected_logps - ref_rejected_logps
    per_sequence_loss = -F.logsigmoid(self.beta * chosen_scores - rejected_scores)
    loss = per_sequence_loss.mean()
    return (loss, outputs) if return_outputs else loss

Next steps

  • For more real-world examples, see how [~trl.GRPOTrainer] and [~trl.DPOTrainer] extend [Trainer] in TRL, or how Axolotl builds custom trainers on top of it.
  • Check the Callbacks guide if you only need to customize what happens during a training event such as logging metrics at the end of a training step.