first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
325
docs/source/en/generation_strategies.md
Normal file
325
docs/source/en/generation_strategies.md
Normal file
@@ -0,0 +1,325 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Generation strategies
|
||||
|
||||
A decoding strategy informs how a model should select the next generated token. There are many types of decoding strategies, and choosing the appropriate one has a significant impact on the quality of the generated text.
|
||||
|
||||
This guide will help you understand the different decoding strategies available in Transformers and how and when to use them.
|
||||
|
||||
## Basic decoding methods
|
||||
|
||||
These are well established decoding methods, and should be your starting point for text generation tasks.
|
||||
|
||||
### Greedy search
|
||||
|
||||
Greedy search is the default decoding strategy. It selects the next most likely token at each step. Unless specified in [`GenerationConfig`], this strategy generates a maximum of 20 new tokens.
|
||||
|
||||
Greedy search works well for tasks with relatively short outputs where creativity is not a priority. However, it breaks down when generating longer sequences because it begins to repeat itself.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from accelerate import Accelerator
|
||||
|
||||
device = Accelerator().device
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to(device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", dtype=torch.float16).to(device)
|
||||
# explicitly set to default length because Llama2 generation length is 4096
|
||||
outputs = model.generate(**inputs, max_new_tokens=20)
|
||||
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
'Hugging Face is an open-source company that provides a suite of tools and services for building, deploying, and maintaining natural language processing'
|
||||
```
|
||||
|
||||
### Sampling
|
||||
|
||||
Sampling, or multinomial sampling, randomly selects a token based on the probability distribution over the entire model's vocabulary (as opposed to the most likely token, as in greedy search). This means every token with a non-zero probability has a chance to be selected. Sampling strategies reduce repetition and can generate more creative and diverse outputs.
|
||||
|
||||
Enable multinomial sampling with `do_sample=True` and `num_beams=1`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from accelerate import Accelerator
|
||||
|
||||
device = Accelerator().device
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to(device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", dtype=torch.float16).to(device)
|
||||
# explicitly set to 100 because Llama2 generation length is 4096
|
||||
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=True, num_beams=1)
|
||||
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
'Hugging Face is an open-source company 🤗\nWe are open-source and believe that open-source is the best way to build technology. Our mission is to make AI accessible to everyone, and we believe that open-source is the best way to achieve that.'
|
||||
```
|
||||
|
||||
### Beam search
|
||||
|
||||
Beam search keeps track of several generated sequences (beams) at each time step. After a certain number of steps, it selects the sequence with the highest *overall* probability. Unlike greedy search, this strategy can "look ahead" and pick a sequence with a higher probability overall even if the initial tokens have a lower probability. It is best suited for input-grounded tasks, like describing an image or speech recognition. You can also use `do_sample=True` with beam search to sample at each step, but beam search will still greedily prune out low probability sequences between steps.
|
||||
|
||||
> [!TIP]
|
||||
> Check out the [beam search visualizer](https://huggingface.co/spaces/m-ric/beam_search_visualizer) to see how beam search works.
|
||||
|
||||
Enable beam search with the `num_beams` parameter (should be greater than 1 otherwise it's equivalent to greedy search).
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from accelerate import Accelerator
|
||||
|
||||
device = Accelerator().device
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to(device)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", dtype=torch.float16).to(device)
|
||||
# explicitly set to 100 because Llama2 generation length is 4096
|
||||
outputs = model.generate(**inputs, max_new_tokens=50, num_beams=2)
|
||||
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
"['Hugging Face is an open-source company that develops and maintains the Hugging Face platform, which is a collection of tools and libraries for building and deploying natural language processing (NLP) models. Hugging Face was founded in 2018 by Thomas Wolf']"
|
||||
```
|
||||
|
||||
## Custom generation methods
|
||||
|
||||
Custom generation methods enable specialized behavior such as:
|
||||
|
||||
- have the model continue thinking if it is uncertain;
|
||||
- roll back generation if the model gets stuck;
|
||||
- handle special tokens with custom logic;
|
||||
- use specialized KV caches;
|
||||
|
||||
We enable custom generation methods through model repositories, assuming a specific model tag and file structure (see subsection below). This feature is an extension of [custom modeling code](./models.md#custom-models) and, like such, requires setting `trust_remote_code=True`.
|
||||
|
||||
If a model repository holds a custom generation method, the easiest way to try it out is to load the model and generate with it:
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
# `transformers-community/custom_generate_example` holds a copy of `Qwen/Qwen2.5-0.5B-Instruct`, but
|
||||
# with custom generation code -> calling `generate` uses the custom generation method!
|
||||
tokenizer = AutoTokenizer.from_pretrained("transformers-community/custom_generate_example")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"transformers-community/custom_generate_example", device_map="auto", trust_remote_code=True
|
||||
)
|
||||
|
||||
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
||||
# The custom generation method is a minimal greedy decoding implementation. It also prints a custom message at run time.
|
||||
gen_out = model.generate(**inputs)
|
||||
# you should now see its custom message, "✨ using a custom generation method ✨"
|
||||
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True))
|
||||
'The quick brown fox jumps over a lazy dog, and the dog is a type of animal. Is'
|
||||
```
|
||||
|
||||
Model repositories with custom generation methods have a special property: their generation method can be loaded from **any** model through [`~GenerationMixin.generate`]'s `custom_generate` argument. This means anyone can create and share their custom generation method to potentially work with any Transformers model, without requiring users to install additional Python packages.
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", device_map="auto")
|
||||
|
||||
inputs = tokenizer(["The quick brown"], return_tensors="pt").to(model.device)
|
||||
# `custom_generate` replaces the original `generate` by the custom generation method defined in
|
||||
# `transformers-community/custom_generate_example`
|
||||
gen_out = model.generate(**inputs, custom_generate="transformers-community/custom_generate_example", trust_remote_code=True)
|
||||
print(tokenizer.batch_decode(gen_out, skip_special_tokens=True)[0])
|
||||
'The quick brown fox jumps over a lazy dog, and the dog is a type of animal. Is'
|
||||
```
|
||||
|
||||
You should read the `README.md` file of the repository containing the custom generation strategy to see what the new arguments and output type differences are, if they exist. Otherwise, you can assume it works like the base [`~GenerationMixin.generate`] method.
|
||||
|
||||
> [!TIP]
|
||||
> You can find all custom generation methods by [searching for their custom tag.](https://huggingface.co/models?other=custom_generate), `custom_generate`.
|
||||
|
||||
Consider the Hub repository [transformers-community/custom_generate_example](https://huggingface.co/transformers-community/custom_generate_example) as an example. The `README.md` states that it has an additional input argument, `left_padding`, which adds a number of padding tokens before the prompt.
|
||||
|
||||
```py
|
||||
gen_out = model.generate(
|
||||
**inputs, custom_generate="transformers-community/custom_generate_example", trust_remote_code=True, left_padding=5
|
||||
)
|
||||
print(tokenizer.batch_decode(gen_out)[0])
|
||||
'<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>The quick brown fox jumps over the lazy dog.\n\nThe sentence "The quick'
|
||||
```
|
||||
|
||||
If the custom method has pinned Python requirements that your environment doesn't meet, you'll get an exception about missing requirements. For instance, [transformers-community/custom_generate_bad_requirements](https://huggingface.co/transformers-community/custom_generate_bad_requirements) has an impossible set of requirements defined in its `custom_generate/requirements.txt` file, and you'll see the error message below if you try to run it.
|
||||
|
||||
```text
|
||||
ImportError: Missing requirements in your local environment for `transformers-community/custom_generate_bad_requirements`:
|
||||
foo (installed: None)
|
||||
bar==0.0.0 (installed: None)
|
||||
torch>=99.0 (installed: 2.6.0)
|
||||
```
|
||||
|
||||
Updating your Python requirements accordingly will remove this error message.
|
||||
|
||||
### Creating a custom generation method
|
||||
|
||||
To create a new generation method, you need to create a new [**Model**](https://huggingface.co/new) repository and push a few files into it.
|
||||
|
||||
1. The model you've designed your generation method with.
|
||||
2. `custom_generate/generate.py`, which contains all the logic for your custom generation method.
|
||||
3. `custom_generate/requirements.txt`, used to optionally add new Python requirements and/or lock specific versions to correctly use your method.
|
||||
4. `README.md`, where you should add the `custom_generate` tag and document any new arguments or output type differences of your custom method here.
|
||||
|
||||
After you've added all required files, your repository should look like this
|
||||
|
||||
```text
|
||||
your_repo/
|
||||
├── README.md # include the 'custom_generate' tag
|
||||
├── config.json
|
||||
├── ...
|
||||
└── custom_generate/
|
||||
├── generate.py
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
#### Adding the base model
|
||||
|
||||
The starting point for your custom generation method is a model repository just like any other. The model to add to this repository should be the model you've designed your method with, and it is meant to be part of a working self-contained model-generate pair. When the model in this repository is loaded, your custom generation method will override `generate`. Don't worry -- your generation method can still be loaded with any other Transformers model, as explained in the section above.
|
||||
|
||||
If you simply want to copy an existing model, you can do
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("source/model_repo")
|
||||
model = AutoModelForCausalLM.from_pretrained("source/model_repo")
|
||||
tokenizer.save_pretrained("your/generation_method", push_to_hub=True)
|
||||
model.save_pretrained("your/generation_method", push_to_hub=True)
|
||||
```
|
||||
|
||||
#### generate.py
|
||||
|
||||
This is the core of your generation method. It *must* contain a method named `generate`, and this method *must* contain a `model` argument as its first argument. `model` is the model instance, which means you have access to all attributes and methods in the model, including the ones defined in [`GenerationMixin`] (like the base `generate` method).
|
||||
|
||||
> [!WARNING]
|
||||
> `generate.py` must be placed in a folder named `custom_generate`, and not at the root level of the repository. The file paths for this feature are hardcoded.
|
||||
|
||||
Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method, with the exception of the arguments used to trigger the custom generation (`trust_remote_code` and `custom_generate`).
|
||||
|
||||
This means your `generate` can have a mix of original and custom arguments (as well as a different output type) as shown below.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
def generate(model, input_ids, generation_config=None, left_padding=None, **kwargs):
|
||||
generation_config = generation_config or model.generation_config # default to the model generation config
|
||||
cur_length = input_ids.shape[1]
|
||||
max_length = generation_config.max_length or cur_length + generation_config.max_new_tokens
|
||||
|
||||
# Example of custom argument: add `left_padding` (integer) pad tokens before the prompt
|
||||
if left_padding is not None:
|
||||
if not isinstance(left_padding, int) or left_padding < 0:
|
||||
raise ValueError(f"left_padding must be an integer larger than 0, but is {left_padding}")
|
||||
|
||||
pad_token = kwargs.pop("pad_token", None) or generation_config.pad_token_id or model.config.pad_token_id
|
||||
if pad_token is None:
|
||||
raise ValueError("pad_token is not defined")
|
||||
batch_size = input_ids.shape[0]
|
||||
pad_tensor = torch.full(size=(batch_size, left_padding), fill_value=pad_token).to(input_ids.device)
|
||||
input_ids = torch.cat((pad_tensor, input_ids), dim=1)
|
||||
cur_length = input_ids.shape[1]
|
||||
|
||||
# Simple greedy decoding loop
|
||||
while cur_length < max_length:
|
||||
logits = model(input_ids).logits
|
||||
next_token_logits = logits[:, -1, :]
|
||||
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
||||
input_ids = torch.cat((input_ids, next_tokens[:, None]), dim=-1)
|
||||
cur_length += 1
|
||||
|
||||
return input_ids
|
||||
```
|
||||
|
||||
Follow the recommended practices below to ensure your custom generation method works as expected.
|
||||
|
||||
- Feel free to reuse the logic for validation and input preparation in the original [`~GenerationMixin.generate`].
|
||||
- Pin the `transformers` version in the requirements if you use any private method/attribute in `model`.
|
||||
- Consider adding model validation, input validation, or even a separate test file to help users sanity-check your code in their environment.
|
||||
|
||||
Your custom `generate` method can relative import code from the `custom_generate` folder. For example, if you have a `utils.py` file, you can import it like this:
|
||||
|
||||
```py
|
||||
from .utils import some_function
|
||||
```
|
||||
|
||||
Only relative imports from the same-level `custom_generate` folder are supported. Parent/sibling folder imports are not valid. The `custom_generate` argument also works locally with any directory that contains a `custom_generate` structure. This is the recommended workflow for developing your custom generation method.
|
||||
|
||||
#### requirements.txt
|
||||
|
||||
You can optionally specify additional Python requirements in a `requirements.txt` file inside the `custom_generate` folder. These are checked at runtime and an exception will be thrown if they're missing, nudging users to update their environment accordingly.
|
||||
|
||||
#### README.md
|
||||
|
||||
The root level `README.md` in the model repository usually describes the model therein. However, since the focus of the repository is the custom generation method, we highly recommend to shift its focus towards describing the custom generation method. In addition to a description of the method, we recommend documenting any input and/or output differences to the original [`~GenerationMixin.generate`]. This way, users can focus on what's new, and rely on Transformers docs for generic implementation details.
|
||||
|
||||
For discoverability, we highly recommend you to add the `custom_generate` tag to your repository. To do so, the top of your `README.md` file should look like the example below. After you push the file, you should see the tag in your repository!
|
||||
|
||||
```text
|
||||
---
|
||||
library_name: transformers
|
||||
tags:
|
||||
- custom_generate
|
||||
---
|
||||
|
||||
(your markdown content here)
|
||||
```
|
||||
|
||||
Recommended practices:
|
||||
|
||||
- Document input and output differences in [`~GenerationMixin.generate`].
|
||||
- Add self-contained examples to enable quick experimentation.
|
||||
- Describe soft-requirements such as if the method only works well with a certain family of models.
|
||||
|
||||
### Reusing `generate`'s input preparation
|
||||
|
||||
If you're adding a new decoding loop, you might want to preserve the input preparation present in `generate` (batch expansion, attention masks, logits processors, stopping criteria, etc.). You can also pass a **callable** to `custom_generate` to reuse [`~GenerationMixin.generate`]'s full preparation pipeline while overriding only the decoding loop.
|
||||
|
||||
```py
|
||||
def custom_loop(model, input_ids, attention_mask, logits_processor, stopping_criteria, generation_config, **model_kwargs):
|
||||
next_tokens = input_ids
|
||||
while input_ids.shape[1] < stopping_criteria[0].max_length:
|
||||
logits = model(next_tokens, attention_mask=attention_mask, **model_kwargs).logits
|
||||
next_token_logits = logits_processor(input_ids, logits[:, -1, :])
|
||||
next_tokens = torch.argmax(next_token_logits, dim=-1)[:, None]
|
||||
input_ids = torch.cat((input_ids, next_tokens), dim=-1)
|
||||
attention_mask = torch.cat((attention_mask, torch.ones_like(next_tokens)), dim=-1)
|
||||
return input_ids
|
||||
|
||||
output = model.generate(
|
||||
**inputs,
|
||||
custom_generate=custom_loop,
|
||||
max_new_tokens=10,
|
||||
)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> If you publish a `custom_generate` repository, your `generate` implementation can itself define a callable and pass it to `model.generate()`. This lets you customize the decoding loop while still benefiting from Transformers' built-in input preparation logic.
|
||||
|
||||
### Finding custom generation methods
|
||||
|
||||
You can find all custom generation methods by [searching for their custom tag.](https://huggingface.co/models?other=custom_generate), `custom_generate`. In addition to the tag, we curate two collections of `custom_generate` methods:
|
||||
|
||||
- [Custom generation methods - Community](https://huggingface.co/collections/transformers-community/custom-generation-methods-community-6888fb1da0efbc592d3a8ab6) -- a collection of powerful methods contributed by the community;
|
||||
- [Custom generation methods - Tutorials](https://huggingface.co/collections/transformers-community/custom-generation-methods-tutorials-6823589657a94940ea02cfec) -- a collection of reference implementations for methods that previously were part of `transformers`, as well as tutorials for `custom_generate`.
|
||||
|
||||
## Resources
|
||||
|
||||
Read the [How to generate text: using different decoding methods for language generation with Transformers](https://huggingface.co/blog/how-to-generate) blog post for an explanation of how common decoding strategies work.
|
||||
Reference in New Issue
Block a user