first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
65
docs/source/en/grad_accumulation.md
Normal file
65
docs/source/en/grad_accumulation.md
Normal file
@@ -0,0 +1,65 @@
|
||||
<!---Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
-->
|
||||
|
||||
# Gradient accumulation
|
||||
|
||||
Large batches produce large activations that exhaust GPU memory. Gradient accumulation lets you train with a larger effective batch size by spreading gradient computation across multiple mini-batches.
|
||||
|
||||
Gradients accumulate across *n* mini-batches before the optimizer updates the weights. For example, with a per-device batch size of 8 and 4 accumulation steps, the effective batch size is 32.
|
||||
|
||||
```text
|
||||
Step 1: mini-batch 1 → forward → backward → grads = G₁
|
||||
Step 2: mini-batch 2 → forward → backward → grads = G₁ + G₂
|
||||
Step 3: mini-batch 3 → forward → backward → grads = G₁ + G₂ + G₃
|
||||
Step 4: mini-batch 4 → forward → backward → grads = G₁ + G₂ + G₃ + G₄
|
||||
→ optimizer.step() ← same update as if batch_size × 4
|
||||
→ zero_grad()
|
||||
```
|
||||
|
||||
Use gradient accumulation only when a larger batch doesn't fit in memory. It doesn't improve throughput over training with a true large batch.
|
||||
|
||||
Accumulate gradients for `gradient_accumulation_steps` across `per_device_train_batch_size`.
|
||||
|
||||
```py
|
||||
from transformers import TrainingArguments
|
||||
|
||||
args = TrainingArguments(
|
||||
...,
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=4,
|
||||
)
|
||||
```
|
||||
|
||||
## Loss scaling
|
||||
|
||||
For a [custom loss function](./trainer_recipes#custom-loss-function), include `num_items_in_batch` so [`Trainer`] divides the loss by the number of prediction targets across all mini-batches. This normalizes by tokens rather than a fixed step count with `gradient_accumulation_steps`. Otherwise, [`Trainer`] divides loss by `gradient_accumulation_steps`.
|
||||
|
||||
```py
|
||||
import torch.nn.functional as F
|
||||
|
||||
def compute_loss(outputs, labels, num_items_in_batch=None):
|
||||
logits = outputs["logits"]
|
||||
loss = F.cross_entropy(logits, labels, reduction="sum")
|
||||
return loss / num_items_in_batch
|
||||
```
|
||||
|
||||
For causal LM models, `num_items_in_batch` counts the *shifted* labels. The loss shifts labels so the prediction at position `i` targets the token at position `i + 1`, which leaves position 0 of every sequence without a target. [`Trainer`] excludes those positions and counts over `labels[..., 1:]`, so the denominator matches the number of prediction targets the loss uses. When a data collator supplies `shift_labels` directly, such as a padding-free collator, [`Trainer`] counts over that tensor instead. Other loss types, like masked LM and classification, count the full label tensor.
|
||||
|
||||
## Next steps
|
||||
|
||||
- Read the [GPU memory usage](./model_memory_anatomy) doc to understand what is driving memory usage on the GPU during training.
|
||||
- See the [Gradient checkpointing](./grad_checkpointing) guide to learn how to reduce activation memory by recomputing activations instead of caching them.
|
||||
- See the [Mixed precision training](./mixed_precision_training) guide to learn how to use lower precision data types to reduce memory and speed up training.
|
||||
- Read the [Gradient Accumulation Fix](https://unsloth.ai/blog/gradient) blog post to learn how gradient accumulation is computed.
|
||||
Reference in New Issue
Block a user