Files
transformers/docs/source/en/attention_interface.md
陈赣 06f1fd69a6
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
first commit
2026-06-05 16:53:03 +08:00

16 KiB

Attention backends

All attention implementations perform the same computation. Every token is compared to every other token. The difference is how the computation is performed. Basic attention scales poorly because it materializes the full attention matrix in memory, creating bottlenecks that slow down inference. Optimized implementations rearrange the math to reduce memory traffic for faster, more affordable inference.

The [AttentionInterface] provides optimized attention implementations. It decouples the attention implementation from the model implementation to simplify experimentation with different functions. Add new backends easily with this consistent interface.

attention backend description
"flash_attention_3" improves FlashAttention-2 by also overlapping operations and fusing forward and backward passes more tightly
"flash_attention_2" tiles computations into smaller blocks and uses fast on-chip memory
"flex_attention" framework for specifying custom attention patterns (sparse, block-local, sliding window) without writing low-level kernels by hand
"sdpa" built-in PyTorch implementation of scaled dot product attention
"paged|flash_attention_3" Paged version of FlashAttention-3
"paged|flash_attention_2" Paged version of FlashAttention-2
"paged|sdpa" Paged version of SDPA
"paged|eager" Paged version of eager

Set an attention backend

Use the attn_implementation argument in [~PreTrainedModel.from_pretrained] to instantiate a model with a specific attention function.

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_2"
)

Switch between attention backends at runtime without reloading the model using [~PreTrainedModel.set_attn_implementation].

model.set_attn_implementation("sdpa")

Kernels

Download and load compiled compute kernels directly from the Hub at runtime with the Kernels library. This avoids packaging issues from mismatched PyTorch or CUDA versions.

Kernels automatically register to [AttentionInterface] upon detection. You don't need to install the FlashAttention package explicitly.

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B", attn_implementation="kernels-community/flash-attn2"
)

SDPA context manager

PyTorch's scaled dot product attention (SDPA) selects the fastest attention function for CUDA backends automatically. It defaults to the PyTorch C++ implementation for other backends.

Force SDPA to use a specific implementation with the torch.nn.attention.sdpa_kernel context manager.

import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B", attn_implementation="sdpa"
)

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    outputs = model.generate(**inputs)

Backbone-specific attention

Multimodal models use different backbones for each modality. Optimize performance by assigning specific attention functions to each backbone. Some vision backbones perform better in fp32, for example, which FlashAttention does not support.

Map vision backbones to different attention functions with a dict while the text backbone continues to use FlashAttention. Keys in the attention implementation must match sub-config names.

from transformers import AutoModelForImageTextToText

attention_implementation_per_backbone = {"vision_config": "sdpa", "text_config": "flash_attention_2"}

for key in attention_implementation_per_backbone:
    assert key in model.config.sub_configs, f"Invalid key in `attention_implementation`"

model = AutoModelForImageTextToText.from_pretrained(
    "facebook/chameleon-7b", attn_implementation=attention_implementation_per_backbone
)

Omit certain backbones from the dict to use the default attention function (SDPA).

model = AutoModelForImageTextToText.from_pretrained(
    "facebook/chameleon-7b", attn_implementation={"text_config": "flash_attention_2"}
)

Set the same attention function for all backbones with a single string.

model = AutoModelForImageTextToText.from_pretrained(
    "facebook/chameleon-7b", attn_implementation="eager"
)

Set the attention function globally with an empty key.

model = AutoModelForImageTextToText.from_pretrained(
    "facebook/chameleon-7b", attn_implementation={"": "eager"}
)

Create a new attention function

Customize or create new attention functions by adding them to the attention registry with [AttentionInterface.register]. Models use these functions through the attn_implementation argument.

Warning

Register a matching attention mask function when you register a custom attention function. If the custom attn_implementation name is not registered in [AttentionMaskInterface], Transformers skips mask creation and passes attention_mask=None to the attention layers. Your attention function must handle causal, padding, packing, or sliding-window constraints itself, or those constraints can be silently dropped.

This example customizes the attention function to print a statement for each layer. It keeps the mask in the original implementation by registering masking_utils.sdpa_mask as the attention mask function.

import torch
from transformers import AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
from transformers.masking_utils import sdpa_mask

def my_new_sdpa(*args, **kwargs):
    print("I just entered the attention computation")
    return sdpa_attention_forward(*args, **kwargs)

AttentionInterface.register("my_new_sdpa", my_new_sdpa)
AttentionMaskInterface.register("my_new_sdpa", sdpa_mask)  # must have the same name as the registered attention function

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="my_new_sdpa")
model(torch.ones(1, 5, dtype=int))

You can also add new arguments to the attention function. Models supporting [AttentionInterface] propagate kwargs to attention layers and the attention function. Pass arguments as kwargs in the model's forward function. Custom attention functions must follow this signature and return format.

import torch
from transformers import AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
from transformers.masking_utils import sdpa_mask

def custom_attention(
    module: torch.nn.Module,  # required arg
    query: torch.Tensor,  # required arg
    key: torch.Tensor,  # required arg
    value: torch.Tensor,  # required arg
    attention_mask: Optional[torch.Tensor],  # required arg
    a_new_kwargs = None,  # You can now add as many kwargs as you need
    another_new_kwargs = None,  # You can now add as many kwargs as you need
    **kwargs,  # You need to accept **kwargs as models will pass other args
) -> tuple[torch.Tensor, Optional[torch.Tensor]]
    ...  # do your magic!
    return attn_output, attn_weights  # attn_weights are optional here

AttentionInterface.register("custom", custom_attention)
AttentionMaskInterface.register("custom", sdpa_mask)  # to leave the existing mask untouched

model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)

Check a model's modeling code to confirm what arguments and kwargs it sends to the attention function.

AttentionMaskInterface

[AttentionMaskInterface] is the registry the create_*_mask functions consult to convert a mask into the format the active attention backend expects. FlexAttention needs a BlockMask, SDPA needs a 4D tensor, and FlashAttention needs the base 2D padding mask. Register a custom backend, or override the formatter for an existing one, with [AttentionMaskInterface.register].

import torch
from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask

def my_new_sdpa_mask(*args, **kwargs):
    print("I just entered the attention mask computation")
    return sdpa_mask(*args, **kwargs)

AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)

Without a registered formatter for the active attn_implementation, mask creation is skipped and attention_mask=None passes to the attention layers.

Registered functions must match this signature.

def custom_attention_mask(
    batch_size: int,  # required arg
    q_length: int,  # required arg
    kv_length: int,  # required arg
    q_offset: int = 0,  # required arg
    kv_offset: int = 0,  # required arg
    mask_function: Callable = causal_mask_function,  # required arg
    attention_mask: Optional[torch.Tensor] = None,  # required arg
    **kwargs,  # a few additional args may be passed as kwargs, especially the model's config is always passed
) -> Optional[torch.Tensor]:

The mask_function argument is a Callable that mimics PyTorch's mask_mod functions. It takes 4 indices (batch_idx, head_idx, q_idx, kv_idx) and returns a boolean indicating whether that position contributes to the attention computation. This is the same primitive shape used by or_mask_function and and_mask_function in Build an attention mask.

Tip

Use this workaround for torch.export if mask_function fails to create a mask.

Build an attention mask

Build attention masks with the create_*_mask functions in transformers.masking_utils. Each function reads the active attention backend from the model config, looks up the backend's mask formatter in [AttentionMaskInterface], and returns the format that backend expects. You don't need to invert, expand, or cast the mask yourself.

Pick the function that matches the attention pattern.

function use case
[create_causal_mask] decoder-only models where each token attends to itself and earlier tokens
[create_bidirectional_mask] encoder models, or cross-attention from a decoder to encoder states
[create_sliding_window_causal_mask] decoder models with a sliding-window attention pattern
[create_chunked_causal_mask] decoder models that chunk the sequence into fixed-size blocks
[create_bidirectional_sliding_window_mask] encoder models with a sliding-window attention pattern

Warning

The legacy callable mask helpers - get_extended_attention_mask, create_extended_attention_mask_for_decoder, invert_attention_mask - emit a deprecation warning and will be removed in a future release. Use the create_*_mask functions instead.

Call [create_causal_mask] inside a decoder forward pass. Pass the config, the input embeddings, the user-provided 2D attention_mask, and the cache. The function uses the embeddings to read the batch size, query length, dtype, and device, and uses the cache to compute the key length.

from transformers.masking_utils import create_causal_mask

attention_mask = create_causal_mask(
    config=self.config,
    inputs_embeds=inputs_embeds,
    attention_mask=attention_mask,
    past_key_values=past_key_values,
)

Call [create_bidirectional_mask] for encoder self-attention. Drop past_key_values because encoders don't cache.

from transformers.masking_utils import create_bidirectional_mask

attention_mask = create_bidirectional_mask(
    config=self.config,
    inputs_embeds=embedding_output,
    attention_mask=attention_mask,
)

For cross-attention, pass the encoder states as encoder_hidden_states so the mask uses the encoder's key and value length instead of the decoder's query length.

encoder_attention_mask = create_bidirectional_mask(
    config=self.config,
    inputs_embeds=embedding_output,
    attention_mask=encoder_attention_mask,
    encoder_hidden_states=encoder_hidden_states,
)

Add extra constraints on top of the base mask with the or_mask_function and and_mask_function arguments. Use or_mask_function to let additional positions attend, and and_mask_function to restrict the base pattern further. Both follow the 4-index mask_function signature described in AttentionMaskInterface. They take (batch_idx, head_idx, q_idx, kv_idx) and return a boolean.

Warning

or_mask_function and and_mask_function can express any attention pattern, but they're slower than the built-in patterns and are not compatible with ExecuTorch. The overhead is most noticeable on smaller models (~200M parameters), where mask creation takes a larger share of forward-pass time. Reach for them only when the standard create_*_mask functions can't express what you need.

For example, overlay a function that returns True everywhere on a causal mask to turn it into a fully bidirectional one. The union with the causal pattern lets every token attend to every other token.

mask_kwargs = {
    "config": self.config,
    "inputs_embeds": inputs_embeds,
    "attention_mask": attention_mask,
    "past_key_values": past_key_values,
    "position_ids": position_ids,
    "or_mask_function": lambda *args: torch.tensor(True, dtype=torch.bool),
}

attention_mask = create_causal_mask(**mask_kwargs)

During generation, [~GenerationMixin.generate] builds masks through [create_masks_for_generate], which dispatches to the right create_*_mask based on the model config. Override it on a model class to plug in a custom masking strategy for generation.

Bidirectional attention

Decoder-only models use causal (unidirectional) attention by default, where each token only attends to itself and previous tokens. Set is_causal=False to switch to bidirectional attention, where every token attends to every other token. This lets you use decoder-only models as text encoders, for example, to generate embeddings.

Note

This only works for causal (decoder) models. It does not turn encoder models into decoder models.

Set is_causal=False in the model config to make bidirectional attention the default for every forward pass.

from transformers import AutoModel, AutoConfig

config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config.is_causal = False

model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", config=config)

# all forward passes now use bidirectional attention
outputs = model(**inputs)

Pass is_causal in the forward call instead of the model config to switch between causal and bidirectional attention without loading the model twice. The kwarg temporarily overrides the config and is restored after the call.

from transformers import AutoModel

model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B")

# run with bidirectional attention
outputs = model(**inputs, is_causal=False)

# run with default causal attention
outputs = model(**inputs)