Files
transformers/docs/source/en/cache_explanation.md
陈赣 06f1fd69a6
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
first commit
2026-06-05 16:53:03 +08:00

7.7 KiB

Caching

Imagine you're having a conversation with someone, and instead of remembering what they previously said, they have to start from scratch every time you respond. This would be slow and inefficient, right?

You can extend this analogy to transformer models. Autoregressive model generation can be slow because it makes a prediction one token at a time. Each new prediction is dependent on all the previous context.

To predict the 1000th token, the model requires information from the previous 999 tokens. The information is represented as matrix multiplications across the token representations.

To predict the 1001th token, you need the same information from the previous 999 tokens in addition to any information from the 1000th token. This is a lot of matrix multiplications a model has to compute over and over for each token!

A key-value (KV) cache eliminates this inefficiency by storing kv pairs derived from the attention layers of previously processed tokens. The stored kv pairs are retrieved from the cache and reused for subsequent tokens, avoiding the need to recompute.

Warning

Caching should only be used for inference. It may cause unexpected errors if it's enabled during training.

To better understand how and why caching works, let's take a closer look at the structure of the attention matrices.

Attention matrices

The scaled dot-product attention is calculated as shown below for a batch of size b, number of attention heads h, sequence length so far T, and dimension per attention head d_head.


\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \times \text{mask} \right) V

The query (Q), key (K), and value (V) matrices are projections from the input embeddings of shape (b, h, T, d_head).

For causal attention, the mask prevents the model from attending to future tokens. Once a token is processed, its representation never changes with respect to future tokens, which means K_{\text{past}} and V_{\text{past}} can be cached and reused to compute the last token's representation.


\text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}])

At inference time, you only need the last token's query to compute the representation x_t that predicts the next token t+1. At each step, the new key and value vectors are stored in the cache and appended to the past keys and values.


K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t)

Attention is calculated independently in each layer of the model, and caching is done on a per-layer basis.

Refer to the table below to compare how caching improves efficiency.

without caching with caching
for each step, recompute all previous K and V for each step, only compute current K and V
attention cost per step is quadratic with sequence length attention cost per step is linear with sequence length (memory grows linearly, but compute/token remains low)

Cache class

A basic KV cache interface takes a key and value tensor for the current token and returns the updated K and V tensors. This is internally managed by a model's forward method.

new_K, new_V = cache.update(k_t, v_t, layer_idx)
attn_output = attn_layer_idx_fn(q_t, new_K, new_V)

When you use Transformers' [Cache] class, the self-attention module performs several critical steps to integrate past and present information.

  1. The attention module concatenates current kv pairs with past kv pairs stored in the cache. This creates attentions weights with the shape (new_tokens_length, past_kv_length + new_tokens_length). The current and past kv pairs are essentially combined to compute the attention scores, ensuring a model is aware of previous context and the current input.

  2. When the forward method is called iteratively, it's crucial that the attention mask shape matches the combined length of the past and current kv pairs. The attention mask should have the shape (batch_size, past_kv_length + new_tokens_length). This is typically handled internally in [~GenerationMixin.generate], but if you want to implement your own generation loop with [Cache], keep this in mind! The attention mask should hold the past and current token values.

Cache storage implementation

Caches are structured as a list of layers, where each layer contains a key and value cache. The key and value caches are tensors with the shape [batch_size, num_heads, seq_len, head_dim].

Layers can be of different types (e.g. DynamicLayer, StaticLayer, StaticSlidingWindowLayer), which mostly changes how sequence length is handled and how the cache is updated.

The simplest is a DynamicLayer that grows as more tokens are processed. The sequence length dimension (seq_len) increases with each new token:

cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)

Other layer types like StaticLayer and StaticSlidingWindowLayer have a fixed sequence length that is set when the cache is created. This makes them compatible with torch.compile. In the case of StaticSlidingWindowLayer, existing tokens are shifted out of the cache when a new token is added.

The example below demonstrates how to create a generation loop with [DynamicCache]. As discussed, the attention mask is a concatenation of past and current token values.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from accelerate import Accelerator

device = Accelerator().device

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

past_key_values = DynamicCache(config=model.config)
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)

generated_ids = inputs.input_ids
max_new_tokens = 10

for _ in range(max_new_tokens):
    outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
    # Greedily sample one next token
    next_token_ids = outputs.logits[:, -1:].argmax(-1)
    generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
    # Prepare inputs for the next generation step by leaving unprocessed tokens, in our case we have only one new token
    # and expanding attn mask for the new token, as explained above
    attention_mask = inputs["attention_mask"]
    attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
    inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}

print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])
"[INST] Hello, what's your name. [/INST]  Hello! My name is LLaMA,"