first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
20
examples/modular-transformers/README.md
Normal file
20
examples/modular-transformers/README.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Using the `modular_converter` linter
|
||||
|
||||
`pip install libcst` is a must!
|
||||
|
||||
# `sh examples/modular-transformers/convert_examples.sh` to get the converted outputs
|
||||
|
||||
The modular converter is a new `linter` specific to `transformers`. It allows us to unpack inheritance in python to convert a modular file like `modular_gemma.py` into a `single model single file`.
|
||||
|
||||
Examples of possible usage are available in the `examples/modular-transformers`, or `modular_gemma` for a full model usage.
|
||||
|
||||
`python utils/modular_model_converter.py --files_to_parse "/Users/arthurzucker/Work/transformers/examples/modular-transformers/modular_my_new_model2.py"`
|
||||
|
||||
## How it works
|
||||
We use the `libcst` parser to produce an AST representation of the `modular_xxx.py` file. For any imports that are made from `transformers.models.modeling_xxxx` we parse the source code of that module, and build a class dependency mapping, which allows us to unpack the modularerence dependencies.
|
||||
|
||||
The code from the `modular` file and the class dependency mapping are "merged" to produce the single model single file.
|
||||
We use ruff to automatically remove the potential duplicate imports.
|
||||
|
||||
## Why we use libcst instead of the native AST?
|
||||
AST is super powerful, but it does not keep the `docstring`, `comment` or code formatting. Thus we decided to go with `libcst`
|
||||
@@ -0,0 +1,95 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_duplicated_method.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_duplicated_method.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from huggingface_hub.dataclasses import strict
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...modeling_rope_utils import RopeParameters
|
||||
from ...utils import auto_docstring
|
||||
from ...utils.type_validators import interval
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="meta-duplicated_method/DuplicatedMethod-2-7b-hf")
|
||||
@strict
|
||||
class DuplicatedMethodConfig(PreTrainedConfig):
|
||||
r"""
|
||||
```python
|
||||
>>> from transformers import DuplicatedMethodModel, DuplicatedMethodConfig
|
||||
|
||||
>>> # Initializing a DuplicatedMethod duplicated_method-7b style configuration
|
||||
>>> configuration = DuplicatedMethodConfig()
|
||||
|
||||
>>> # Initializing a model from the duplicated_method-7b style configuration
|
||||
>>> model = DuplicatedMethodModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "duplicated_method"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `DuplicatedMethodModel`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
vocab_size: int = 32000
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int | None = None
|
||||
hidden_act: str = "silu"
|
||||
max_position_embeddings: int = 2048
|
||||
initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
|
||||
rms_norm_eps: float = 1e-6
|
||||
use_cache: bool = True
|
||||
pad_token_id: int | None = None
|
||||
bos_token_id: int | None = 1
|
||||
eos_token_id: int | list[int] | None = 2
|
||||
pretraining_tp: int | None = 1
|
||||
tie_word_embeddings: bool = False
|
||||
rope_parameters: RopeParameters | dict | None = None
|
||||
attention_bias: bool = False
|
||||
attention_dropout: int | float | None = 0.0
|
||||
mlp_bias: bool = False
|
||||
head_dim: int | None = None
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
def validate_architecture(self):
|
||||
"""Part of `@strict`-powered validation. Validates the architecture of the config."""
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({self.num_attention_heads})."
|
||||
)
|
||||
|
||||
@property
|
||||
def vocab_size(self): # noqa: F811 -> we need this at we cannot delete the original for now since config dataclass refactor
|
||||
return 45
|
||||
|
||||
@vocab_size.setter
|
||||
def vocab_size(self, value):
|
||||
self.vocab_size = value
|
||||
192
examples/modular-transformers/configuration_my_new_model.py
Normal file
192
examples/modular-transformers/configuration_my_new_model.py
Normal file
@@ -0,0 +1,192 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_my_new_model.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_my_new_model.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from huggingface_hub.dataclasses import strict
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...modeling_rope_utils import RopeParameters
|
||||
from ...utils import auto_docstring
|
||||
from ...utils.type_validators import interval
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="meta-my_new_model/MyNewModel-2-7b-hf")
|
||||
@strict
|
||||
class MyNewModelConfig(PreTrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MyNewModelModel`]. It is used to instantiate an MyNewModel
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the MyNewModel-7B.
|
||||
e.g. [meta-my_new_model/MyNewModel-2-7b-hf](https://huggingface.co/meta-my_new_model/MyNewModel-2-7b-hf)
|
||||
|
||||
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PreTrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the MyNewModel model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`MyNewModelModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
|
||||
MyNewModel 2 up to 4096, CodeLlama up to 16384.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
||||
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
||||
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'my_new_model3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'my_new_model3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'my_new_model3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'my_new_model3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||
head_dim (`int`, *optional*):
|
||||
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
|
||||
|
||||
```python
|
||||
>>> from transformers import MyNewModelModel, MyNewModelConfig
|
||||
|
||||
>>> # Initializing a MyNewModel my_new_model-7b style configuration
|
||||
>>> configuration = MyNewModelConfig()
|
||||
|
||||
>>> # Initializing a model from the my_new_model-7b style configuration
|
||||
>>> model = MyNewModelModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "my_new_model"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `MyNewModelModel`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
vocab_size: int = 32000
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int | None = None
|
||||
hidden_act: str = "silu"
|
||||
max_position_embeddings: int = 2048
|
||||
initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
|
||||
rms_norm_eps: float = 1e-6
|
||||
use_cache: bool = True
|
||||
pad_token_id: int | None = None
|
||||
bos_token_id: int | None = 1
|
||||
eos_token_id: int | list[int] | None = 2
|
||||
pretraining_tp: int | None = 1
|
||||
tie_word_embeddings: bool = False
|
||||
rope_parameters: RopeParameters | dict | None = None
|
||||
attention_bias: bool = False
|
||||
attention_dropout: int | float | None = 0.0
|
||||
|
||||
mlp_bias: bool = True
|
||||
head_dim: int | None = None
|
||||
new_param: int = 0
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
def validate_architecture(self):
|
||||
"""Part of `@strict`-powered validation. Validates the architecture of the config."""
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({self.num_attention_heads})."
|
||||
)
|
||||
93
examples/modular-transformers/configuration_my_new_model2.py
Normal file
93
examples/modular-transformers/configuration_my_new_model2.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_my_new_model2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_my_new_model2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from huggingface_hub.dataclasses import strict
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...modeling_rope_utils import RopeParameters
|
||||
from ...utils import auto_docstring
|
||||
from ...utils.type_validators import interval
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="meta-my_new_model2/MyNewModel2-2-7b-hf")
|
||||
@strict
|
||||
class MyNewModel2Config(PreTrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma-7B.
|
||||
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
||||
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PreTrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`GemmaModel`]
|
||||
```python
|
||||
>>> from transformers import GemmaModel, GemmaConfig
|
||||
>>> # Initializing a Gemma gemma-7b style configuration
|
||||
>>> configuration = GemmaConfig()
|
||||
>>> # Initializing a model from the gemma-7b style configuration
|
||||
>>> model = GemmaModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "my_new_model2"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
# Default tensor parallel plan for base model `MyNewModel2Model`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
vocab_size: int = 32000
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 11008
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int | None = None
|
||||
hidden_act: str = "silu"
|
||||
max_position_embeddings: int = 2048
|
||||
initializer_range: float = interval(min=0.0, max=1.0)(default=0.02)
|
||||
rms_norm_eps: float = 1e-6
|
||||
use_cache: bool = True
|
||||
pad_token_id: int | None = None
|
||||
bos_token_id: int | None = 1
|
||||
eos_token_id: int | list[int] | None = 2
|
||||
pretraining_tp: int | None = 1
|
||||
tie_word_embeddings: bool = False
|
||||
rope_parameters: RopeParameters | dict | None = None
|
||||
attention_bias: bool = False
|
||||
attention_dropout: int | float | None = 0.0
|
||||
mlp_bias: bool = False
|
||||
head_dim: int | None = None
|
||||
|
||||
def __post_init__(self, **kwargs):
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.num_attention_heads
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
|
||||
super().__post_init__(**kwargs)
|
||||
|
||||
def validate_architecture(self):
|
||||
"""Part of `@strict`-powered validation. Validates the architecture of the config."""
|
||||
if self.hidden_size % self.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({self.num_attention_heads})."
|
||||
)
|
||||
72
examples/modular-transformers/configuration_new_model.py
Normal file
72
examples/modular-transformers/configuration_new_model.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_new_model.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_new_model.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Example where we only want to overwrite the defaults of an init
|
||||
|
||||
from huggingface_hub.dataclasses import strict
|
||||
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...utils import auto_docstring
|
||||
|
||||
|
||||
@auto_docstring(checkpoint="google/new_model-7b")
|
||||
@strict
|
||||
class NewModelConfig(PreTrainedConfig):
|
||||
r"""
|
||||
use_bidirectional_attention (`bool`, *optional*):
|
||||
If True, the model will attend to all text tokens instead of using a causal mask.
|
||||
|
||||
```python
|
||||
>>> from transformers import NewModelModel, NewModelConfig
|
||||
>>> # Initializing a NewModel new_model-7b style configuration
|
||||
>>> configuration = NewModelConfig()
|
||||
>>> # Initializing a model from the new_model-7b style configuration
|
||||
>>> model = NewModelModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "new_model"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
vocab_size: int = 256030
|
||||
hidden_size: int = 64
|
||||
intermediate_size: int = 90
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 16
|
||||
head_dim: int = 256
|
||||
hidden_act: str = "gelu_pytorch_tanh"
|
||||
max_position_embeddings: int = 1500
|
||||
initializer_range: float = 0.02
|
||||
rms_norm_eps: float = 1e-6
|
||||
use_cache: bool = True
|
||||
pad_token_id: int = 0
|
||||
eos_token_id: int = 1
|
||||
bos_token_id: int = 2
|
||||
tie_word_embeddings: bool = True
|
||||
rope_parameters: dict | None = None
|
||||
attention_bias: bool = False
|
||||
attention_dropout: float = 0.0
|
||||
use_bidirectional_attention: bool = False
|
||||
hidden_activation: str | None = None
|
||||
|
||||
@property
|
||||
def num_heads(self):
|
||||
return self.num_attention_heads
|
||||
10
examples/modular-transformers/convert_examples.sh
Normal file
10
examples/modular-transformers/convert_examples.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Iterate over each file in the current directory
|
||||
for file in examples/modular-transformers/modular_*; do
|
||||
# Check if it's a regular file
|
||||
if [ -f "$file" ]; then
|
||||
# Call the Python script with the file name as an argument
|
||||
python utils/modular_model_converter.py --files_to_parse "$file"
|
||||
fi
|
||||
done
|
||||
@@ -0,0 +1,27 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_new_imgproc_model.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_new_imgproc_model.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
import torch
|
||||
|
||||
from ...image_processing_backends import TorchvisionBackend
|
||||
from ...image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, PILImageResampling
|
||||
from ...utils import auto_docstring
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class ImgprocModelImageProcessor(TorchvisionBackend):
|
||||
resample = PILImageResampling.BICUBIC
|
||||
image_mean = OPENAI_CLIP_MEAN
|
||||
image_std = OPENAI_CLIP_STD
|
||||
size = {"height": 384, "width": 384}
|
||||
default_to_square = True
|
||||
do_resize = True
|
||||
do_rescale = True
|
||||
do_normalize = True
|
||||
do_convert_rgb = True
|
||||
|
||||
def new_image_processing_method(self, pixel_values: torch.FloatTensor):
|
||||
return pixel_values / 2
|
||||
66
examples/modular-transformers/modeling_add_function.py
Normal file
66
examples/modular-transformers/modeling_add_function.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_add_function.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_add_function.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Note that zamba does not have the `apply_rotary_pos_emb` function!
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...integrations import use_kernel_func_from_hub
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@use_kernel_func_from_hub("rotary_pos_emb")
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class TestAttention(nn.Module):
|
||||
"""
|
||||
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
||||
and "Generating Long Sequences with Sparse Transformers".
|
||||
|
||||
Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
|
||||
The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
|
||||
The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
|
||||
(see fig. 2 in https://huggingface.co/papers/2405.16712).
|
||||
Additionally, replaced
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def forward(self) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
||||
_ = apply_rotary_pos_emb(1, 1, 1, 1)
|
||||
642
examples/modular-transformers/modeling_dummy_bert.py
Normal file
642
examples/modular-transformers/modeling_dummy_bert.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_dummy_bert.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_dummy_bert.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ... import initialization as init
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...masking_utils import create_bidirectional_mask, create_causal_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import TransformersKwargs, auto_docstring
|
||||
from ...utils.generic import merge_with_config_defaults
|
||||
from ...utils.output_capturing import capture_outputs
|
||||
from .configuration_dummy_bert import DummyBertConfig
|
||||
|
||||
|
||||
class DummyBertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer(
|
||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
else:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
# issue #5664
|
||||
if token_type_ids is None:
|
||||
if hasattr(self, "token_type_ids"):
|
||||
# NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
|
||||
buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
|
||||
buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
|
||||
token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = inputs_embeds + token_type_embeddings
|
||||
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = embeddings + position_embeddings
|
||||
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float | None = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DummyBertSelfAttention(nn.Module):
|
||||
def __init__(self, config, is_causal=False, layer_idx=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# get all proj
|
||||
query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
|
||||
|
||||
if past_key_values is not None:
|
||||
# decoder-only dummy_bert can have a simple dynamic cache for example
|
||||
current_past_key_values = past_key_values
|
||||
if isinstance(past_key_values, EncoderDecoderCache):
|
||||
current_past_key_values = past_key_values.self_attention_cache
|
||||
|
||||
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
||||
key_layer, value_layer = current_past_key_values.update(key_layer, value_layer, self.layer_idx)
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout.p,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DummyBertCrossAttention(nn.Module):
|
||||
def __init__(self, config, is_causal=False, layer_idx=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.FloatTensor | None = None,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: EncoderDecoderCache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
# determine input shapes
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# get query proj
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
|
||||
if past_key_values is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
|
||||
value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
|
||||
else:
|
||||
kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
|
||||
key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
|
||||
value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)
|
||||
|
||||
if past_key_values is not None:
|
||||
# save all states to the cache
|
||||
key_layer, value_layer = past_key_values.cross_attention_cache.update(
|
||||
key_layer, value_layer, self.layer_idx
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
past_key_values.is_updated[self.layer_idx] = True
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout.p,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class DummyBertSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DummyBertAttention(nn.Module):
|
||||
def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
|
||||
super().__init__()
|
||||
self.is_cross_attention = is_cross_attention
|
||||
attention_class = DummyBertCrossAttention if is_cross_attention else DummyBertSelfAttention
|
||||
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
|
||||
self.output = DummyBertSelfOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
encoder_hidden_states: torch.FloatTensor | None = None,
|
||||
encoder_attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
|
||||
attention_output, attn_weights = self.self(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(attention_output, hidden_states)
|
||||
return attention_output, attn_weights
|
||||
|
||||
|
||||
class DummyBertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DummyBertOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DummyBertLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = DummyBertAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = DummyBertAttention(
|
||||
config,
|
||||
is_causal=False,
|
||||
layer_idx=layer_idx,
|
||||
is_cross_attention=True,
|
||||
)
|
||||
self.intermediate = DummyBertIntermediate(config)
|
||||
self.output = DummyBertOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
encoder_hidden_states: torch.FloatTensor | None = None,
|
||||
encoder_attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
self_attention_output, _ = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_output
|
||||
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
cross_attention_output, _ = self.crossattention(
|
||||
self_attention_output,
|
||||
None, # attention_mask
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = cross_attention_output
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class DummyBertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([DummyBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
encoder_hidden_states: torch.FloatTensor | None = None,
|
||||
encoder_attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
use_cache: bool | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
)
|
||||
|
||||
|
||||
class DummyBertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class DummyBertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DummyBertLMPredictionHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.transform = DummyBertPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class DummyBertPreTrainedModel(PreTrainedModel):
|
||||
config_class = DummyBertConfig
|
||||
base_model_prefix = "dummy_bert"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": DummyBertLayer,
|
||||
"attentions": DummyBertSelfAttention,
|
||||
"cross_attentions": DummyBertCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, DummyBertLMPredictionHead):
|
||||
init.zeros_(module.bias)
|
||||
elif isinstance(module, DummyBertEmbeddings):
|
||||
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
||||
init.zeros_(module.token_type_ids)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
||||
all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
||||
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
"""
|
||||
)
|
||||
class DummyBertModel(DummyBertPreTrainedModel):
|
||||
_no_split_modules = ["DummyBertEmbeddings", "DummyBertLayer"]
|
||||
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
r"""
|
||||
add_pooling_layer (bool, *optional*, defaults to `True`):
|
||||
Whether to add a pooling layer
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.embeddings = DummyBertEmbeddings(config)
|
||||
self.encoder = DummyBertEncoder(config)
|
||||
|
||||
self.pooler = DummyBertPooler(config) if add_pooling_layer else None
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
@merge_with_config_defaults
|
||||
@capture_outputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
position_ids: torch.Tensor | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
attention_mask, encoder_attention_mask = self._create_attention_masks(
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
embedding_output=embedding_output,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
position_ids=position_ids,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = encoder_outputs.last_hidden_state
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
)
|
||||
|
||||
def _create_attention_masks(
|
||||
self,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
embedding_output,
|
||||
encoder_hidden_states,
|
||||
past_key_values,
|
||||
):
|
||||
if self.config.is_decoder:
|
||||
attention_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=embedding_output,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
attention_mask = create_bidirectional_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=embedding_output,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = create_bidirectional_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=embedding_output,
|
||||
attention_mask=encoder_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
return attention_mask, encoder_attention_mask
|
||||
145
examples/modular-transformers/modeling_from_uppercase_model.py
Normal file
145
examples/modular-transformers/modeling_from_uppercase_model.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_from_uppercase_model.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_from_uppercase_model.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs
|
||||
from .configuration_from_uppercase_model import FromUppercaseModelTextConfig, FromUppercaseModelVisionConfig
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class FromUppercaseModelAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: FromUppercaseModelVisionConfig | FromUppercaseModelTextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
queries = queries.view(hidden_shape).transpose(1, 2)
|
||||
keys = keys.view(hidden_shape).transpose(1, 2)
|
||||
values = values.view(hidden_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class FromUppercaseModelMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FromUppercaseModelEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: FromUppercaseModelVisionConfig | FromUppercaseModelTextConfig):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = FromUppercaseModelAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = FromUppercaseModelMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
166
examples/modular-transformers/modeling_global_indexing.py
Normal file
166
examples/modular-transformers/modeling_global_indexing.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_global_indexing.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_global_indexing.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.modeling_utils import AttentionInterface
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs
|
||||
from .configuration_global_indexing import GlobalIndexingConfig
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@use_kernel_func_from_hub("rotary_pos_emb")
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def custom_flex(x, **kwargs):
|
||||
"""Dummy function."""
|
||||
return x
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS = AttentionInterface()
|
||||
# This indexing statement and associated function should be exported correctly!
|
||||
ALL_ATTENTION_FUNCTIONS["flex_attention"] = custom_flex
|
||||
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class GlobalIndexingAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: GlobalIndexingConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_values is not None:
|
||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
364
examples/modular-transformers/modeling_multimodal2.py
Normal file
364
examples/modular-transformers/modeling_multimodal2.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_multimodal2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_multimodal2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, torch_int
|
||||
from ...utils.generic import merge_with_config_defaults
|
||||
from ...utils.output_capturing import capture_outputs
|
||||
from .configuration_multimodal2 import Multimodal2Config, Multimodal2TextConfig, Multimodal2VisionConfig
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Multimodal2VisionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: Multimodal2VisionConfig | Multimodal2TextConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
queries = self.q_proj(hidden_states)
|
||||
keys = self.k_proj(hidden_states)
|
||||
values = self.v_proj(hidden_states)
|
||||
|
||||
queries = queries.view(hidden_shape).transpose(1, 2)
|
||||
keys = keys.view(hidden_shape).transpose(1, 2)
|
||||
values = values.view(hidden_shape).transpose(1, 2)
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attention_mask,
|
||||
scaling=self.scale,
|
||||
dropout=0.0 if not self.training else self.dropout,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class Multimodal2VisionMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Multimodal2VisionEncoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Multimodal2VisionAttention(config)
|
||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Multimodal2VisionMLP(config)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Multimodal2VisionEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`Multimodal2VisionEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config: Multimodal2VisionConfig
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutput:
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class Multimodal2VisionPreTrainedModel(PreTrainedModel):
|
||||
config: Multimodal2Config
|
||||
base_model_prefix = "multimodal2_vision"
|
||||
input_modalities = ("image", "text")
|
||||
_no_split_modules = [
|
||||
"Multimodal2VisionTextEmbeddings",
|
||||
"Multimodal2VisionEncoderLayer",
|
||||
"Multimodal2VisionVisionEmbeddings",
|
||||
]
|
||||
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_sdpa = True
|
||||
_supports_flash_attn = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": Multimodal2VisionEncoderLayer,
|
||||
"attentions": Multimodal2VisionAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
if isinstance(module, Multimodal2VisionMLP):
|
||||
pass
|
||||
|
||||
|
||||
class Multimodal2VisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: Multimodal2VisionConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
||||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
||||
"""
|
||||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
|
||||
images. This method is also adapted to support torch.jit tracing.
|
||||
|
||||
Adapted from:
|
||||
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
|
||||
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
|
||||
"""
|
||||
|
||||
num_patches = embeddings.shape[1] - 1
|
||||
position_embedding = self.position_embedding.weight.unsqueeze(0)
|
||||
num_positions = position_embedding.shape[1] - 1
|
||||
|
||||
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
|
||||
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
|
||||
return self.position_embedding(self.position_ids)
|
||||
|
||||
class_pos_embed = position_embedding[:, :1]
|
||||
patch_pos_embed = position_embedding[:, 1:]
|
||||
|
||||
dim = embeddings.shape[-1]
|
||||
|
||||
new_height = height // self.patch_size
|
||||
new_width = width // self.patch_size
|
||||
|
||||
sqrt_num_positions = torch_int(num_positions**0.5)
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed,
|
||||
size=(new_height, new_width),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
|
||||
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
|
||||
batch_size, _, height, width = pixel_values.shape
|
||||
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
|
||||
raise ValueError(
|
||||
f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
|
||||
)
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The vision model from MULTIMODAL2 without any head or projection on top.
|
||||
"""
|
||||
)
|
||||
class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel):
|
||||
config: Multimodal2VisionConfig
|
||||
main_input_name = "pixel_values"
|
||||
input_modalities = ("image",)
|
||||
_input_embed_layer = "patch_embedding"
|
||||
_no_split_modules = ["Multimodal2VisionEncoderLayer"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
self.embeddings = Multimodal2VisionEmbeddings(config)
|
||||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.encoder = Multimodal2VisionEncoder(config)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
self.post_init()
|
||||
|
||||
@merge_with_config_defaults
|
||||
@capture_outputs(tie_last_hidden_states=False)
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
interpolate_pos_encoding: bool | None = False,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> BaseModelOutputWithPooling:
|
||||
r"""
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import httpx
|
||||
>>> from io import BytesIO
|
||||
>>> from transformers import AutoProcessor, Multimodal2VisionModel
|
||||
|
||||
>>> model = Multimodal2VisionModel.from_pretrained("openai/multimodal2-vit-base-patch32")
|
||||
>>> processor = AutoProcessor.from_pretrained("openai/multimodal2-vit-base-patch32")
|
||||
|
||||
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
||||
>>> with httpx.stream("GET", url) as response:
|
||||
... image = Image.open(BytesIO(response.read()))
|
||||
|
||||
>>> inputs = processor(images=image, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**inputs)
|
||||
>>> last_hidden_state = outputs.last_hidden_state
|
||||
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
||||
```"""
|
||||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
||||
hidden_states = self.pre_layrnorm(hidden_states)
|
||||
|
||||
encoder_outputs: BaseModelOutput = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs.last_hidden_state
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
pooled_output = self.post_layernorm(pooled_output)
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
)
|
||||
283
examples/modular-transformers/modeling_my_new_model2.py
Normal file
283
examples/modular-transformers/modeling_my_new_model2.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_my_new_model2.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_my_new_model2.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ... import initialization as init
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
||||
from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring
|
||||
from .configuration_my_new_model2 import MyNewModel2Config
|
||||
|
||||
|
||||
class MyNewModel2TextScaledWordEmbedding(nn.Embedding):
|
||||
"""
|
||||
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx)
|
||||
self.scalar_embed_scale = embed_scale
|
||||
self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor):
|
||||
return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)
|
||||
|
||||
|
||||
class MyNewModel2RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
output = self._norm(x.float())
|
||||
# Llama does x.to(float16) * w whilst MyNewModel2 is (x * w).to(float16)
|
||||
# See https://github.com/huggingface/transformers/pull/29402
|
||||
output = output * (1.0 + self.weight.float())
|
||||
return output.type_as(x)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
||||
|
||||
|
||||
class MyNewModel2MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@use_kernel_func_from_hub("rotary_pos_emb")
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class MyNewModel2Attention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: MyNewModel2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = not getattr(config, "use_bidirectional_attention", False)
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_values is not None:
|
||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class MyNewModel2DecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: MyNewModel2Config, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = MyNewModel2Attention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = MyNewModel2MLP(config)
|
||||
self.input_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
use_cache: bool | None = False,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class MyNewModel2PreTrainedModel(PreTrainedModel):
|
||||
config: MyNewModel2Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MyNewModel2DecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": MyNewModel2DecoderLayer,
|
||||
"attentions": MyNewModel2Attention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
super()._init_weights(module)
|
||||
# We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
|
||||
if "RMSNorm" in module.__class__.__name__:
|
||||
init.zeros_(module.weight)
|
||||
elif isinstance(module, MyNewModel2TextScaledWordEmbedding):
|
||||
init.constant_(module.embed_scale, module.scalar_embed_scale)
|
||||
|
||||
|
||||
class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel):
|
||||
pass
|
||||
516
examples/modular-transformers/modeling_new_task_model.py
Normal file
516
examples/modular-transformers/modeling_new_task_model.py
Normal file
@@ -0,0 +1,516 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_new_task_model.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_new_task_model.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
from ...generation import GenerationMixin
|
||||
from ...masking_utils import create_masks_for_generate
|
||||
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
|
||||
from ...utils.deprecation import deprecate_kwarg
|
||||
from ..auto import AutoModel
|
||||
from .configuration_new_task_model import NewTaskModelConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for NewTaskModel outputs, with hidden states and attentions.
|
||||
"""
|
||||
)
|
||||
class NewTaskModelModelOutputWithPast(BaseModelOutputWithPast):
|
||||
r"""
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
|
||||
"""
|
||||
|
||||
image_hidden_states: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
Base class for NewTaskModel causal language model (or autoregressive) outputs.
|
||||
"""
|
||||
)
|
||||
class NewTaskModelCausalLMOutputWithPast(ModelOutput):
|
||||
r"""
|
||||
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
||||
Language modeling loss (for next-token prediction).
|
||||
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`):
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
|
||||
|
||||
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
||||
`past_key_values` input) to speed up sequential decoding.
|
||||
image_hidden_states (`torch.FloatTensor`, *optional*):
|
||||
A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
|
||||
image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
|
||||
"""
|
||||
|
||||
loss: torch.FloatTensor | None = None
|
||||
logits: torch.FloatTensor | None = None
|
||||
past_key_values: Cache | None = None
|
||||
hidden_states: tuple[torch.FloatTensor] | None = None
|
||||
attentions: tuple[torch.FloatTensor] | None = None
|
||||
image_hidden_states: torch.FloatTensor | None = None
|
||||
|
||||
|
||||
class NewTaskModelMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: NewTaskModelConfig):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear(image_features)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class NewTaskModelPreTrainedModel(PreTrainedModel):
|
||||
config: NewTaskModelConfig
|
||||
base_model_prefix = "model"
|
||||
input_modalities = ("image", "text")
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["NewTaskModelMultiModalProjector"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_can_compile_fullgraph = False
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
|
||||
|
||||
def token_type_ids_mask_function(group_ids: torch.Tensor) -> Callable:
|
||||
"""
|
||||
This function adds the correct offsets to the `q_idx` and `kv_idx` as the torch API can only accept lengths,
|
||||
not start and end indices.
|
||||
Args:
|
||||
group_ids (`torch.Tensor`):
|
||||
A tensor of shape `(bs, len)` assigning each token to a vision group. Tokens with the same group
|
||||
come from the same input image. Text is denoted by `-1`.
|
||||
"""
|
||||
|
||||
def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool:
|
||||
seq_length = group_ids.shape[-1]
|
||||
|
||||
# clamp indices because with static cache they can go beyond `group_ids.shape[-1]`
|
||||
q_idx_clamped = q_idx.clamp(max=seq_length - 1)
|
||||
kv_idx_clamped = kv_idx.clamp(max=seq_length - 1)
|
||||
|
||||
# Unmask if the q and kv come from same group which is not -1 (i.e. non-text)
|
||||
q_group = group_ids[batch_idx, q_idx_clamped]
|
||||
kv_group = group_ids[batch_idx, kv_idx_clamped]
|
||||
q_group = torch.where(q_idx < seq_length, q_group, -1)
|
||||
kv_group = torch.where(kv_idx < seq_length, kv_group, -1)
|
||||
return (q_group == kv_group) & (q_group >= 0)
|
||||
|
||||
return inner_mask
|
||||
|
||||
|
||||
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
|
||||
def create_causal_mask_mapping(
|
||||
config: PreTrainedConfig,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
past_key_values: Cache | None,
|
||||
position_ids: torch.Tensor | None,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
is_training: bool | None = False,
|
||||
is_first_iteration: bool | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""
|
||||
Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
|
||||
for all kinds of forward passes. NewTaskModel uses a bidirectional mask on the prompt tokens.
|
||||
|
||||
Uses `pixel_values` as an optional input to disambiguate edge cases.
|
||||
"""
|
||||
if is_training and token_type_ids is None:
|
||||
raise ValueError("`token_type_ids` is required as a model input when training")
|
||||
|
||||
mask_kwargs = {
|
||||
"config": config.get_text_config(),
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"position_ids": position_ids,
|
||||
}
|
||||
# Infer if prefill or decoding stage, if the flag isn't passed. This happens only when the mask is constructed
|
||||
# from `forward` call. If users run a `forward` call, we have no option to infer `is_first_iteration` because users may be
|
||||
# running generation with custom loop. Thus we need to infer it in a `non-perfect` way
|
||||
# NOTE: Determining prefill in that case requires checking data values, which is not compile-compatible.
|
||||
is_first_iteration = (
|
||||
is_first_iteration
|
||||
if is_first_iteration
|
||||
else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
|
||||
)
|
||||
|
||||
if is_first_iteration or not kwargs.get("use_cache", True):
|
||||
if token_type_ids is not None:
|
||||
# The logic bellow was originally written for Gemma3, where `token_type_ids` is reversed. Let's reverse
|
||||
# it to then use exactly the same logic.
|
||||
token_type_ids = 1 - token_type_ids
|
||||
else:
|
||||
logger.warning_once(
|
||||
"It is a prefill stage but The `token_type_ids` is not provided. We recommend "
|
||||
"passing `token_type_ids` to the model to prevent bad attention masking."
|
||||
)
|
||||
# NOTE: this branch can't be reached when training because `token_type_ids` is required as a model input.
|
||||
token_type_ids = torch.ones_like(inputs_embeds)[:, :, 0]
|
||||
|
||||
# Logic originally copied from Gemma3. It holds up for NewTaskModel as well because NewTaskModel assumes up to one image
|
||||
# per prompt AND we reverse `token_type_ids` above. Gemma3 uses a bidirectional mask for images, tagged through
|
||||
# `token_type_ids` 1s.
|
||||
if token_type_ids is not None and is_first_iteration:
|
||||
# We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
|
||||
# undo the causal masking)
|
||||
|
||||
# First find where a new image block starts: 1 if image and previous not image
|
||||
# The images cannot attend to future images, but can attend to all prev images and to itself bidirectionally
|
||||
is_image = (token_type_ids == 1).to(inputs_embeds.device)
|
||||
is_previous_image = nn.functional.pad(is_image, (1, 0), value=0)[:, :-1]
|
||||
new_image_start = is_image & ~is_previous_image
|
||||
group_ids = torch.cumsum(new_image_start.int(), dim=1) - 1
|
||||
group_ids = torch.where(is_image, group_ids, torch.full_like(token_type_ids, -1))
|
||||
mask_kwargs["or_mask_function"] = token_type_ids_mask_function(group_ids)
|
||||
|
||||
return create_masks_for_generate(**mask_kwargs)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base NewTaskModel model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class NewTaskModelModel(NewTaskModelPreTrainedModel):
|
||||
# we are filtering the logits/labels so we shouldn't divide the loss based on num_items_in_batch
|
||||
accepts_loss_kwargs = False
|
||||
|
||||
def __init__(self, config: NewTaskModelConfig):
|
||||
super().__init__(config)
|
||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||
self.multi_modal_projector = NewTaskModelMultiModalProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
|
||||
language_model = AutoModel.from_config(config=config.text_config)
|
||||
self.language_model = language_model
|
||||
|
||||
self.text_config_dtype = self.config.get_text_config().dtype or self.dtype
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.language_model.set_input_embeddings(value)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring(
|
||||
custom_intro="Obtains image last hidden states from the vision tower and apply multimodal projection."
|
||||
)
|
||||
def get_image_features(
|
||||
self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]
|
||||
) -> tuple | BaseModelOutputWithPooling:
|
||||
image_outputs = self.vision_tower(pixel_values, **kwargs)
|
||||
selected_image_feature = image_outputs.last_hidden_state
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
image_outputs.pooler_output = image_features
|
||||
|
||||
return image_outputs
|
||||
|
||||
def get_placeholder_mask(
|
||||
self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
|
||||
):
|
||||
"""
|
||||
Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
|
||||
equal to the length of multimodal features. If the lengths are different, an error is raised.
|
||||
"""
|
||||
if input_ids is None:
|
||||
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
||||
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = special_image_mask.all(-1)
|
||||
else:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
|
||||
n_image_tokens = special_image_mask.sum()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
||||
torch_compilable_check(
|
||||
inputs_embeds[special_image_mask].numel() == image_features.numel(),
|
||||
f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {n_image_features}",
|
||||
)
|
||||
return special_image_mask
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
pixel_values: torch.FloatTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> tuple | NewTaskModelModelOutputWithPast:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import httpx
|
||||
>>> from io import BytesIO
|
||||
>>> from transformers import AutoProcessor, NewTaskModelForConditionalGeneration
|
||||
|
||||
>>> model = NewTaskModelForConditionalGeneration.from_pretrained("google/new_task_model2-3b-mix-224")
|
||||
>>> processor = AutoProcessor.from_pretrained("google/new_task_model2-3b-mix-224")
|
||||
|
||||
>>> prompt = "Where is the cat standing?"
|
||||
>>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
|
||||
>>> with httpx.stream("GET", url) as response:
|
||||
... image = Image.open(BytesIO(response.read()))
|
||||
|
||||
>>> inputs = processor(images=image, text=prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(**inputs,)
|
||||
>>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Where is the cat standing?\nsnow"
|
||||
```"""
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
# Replace image id with PAD if the image token if OOV, to avoid index-errors
|
||||
if input_ids is not None and self.config.image_token_id >= self.vocab_size:
|
||||
special_image_mask = input_ids == self.config.image_token_id
|
||||
llm_input_ids = input_ids.clone()
|
||||
llm_input_ids[special_image_mask] = 0
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
|
||||
position_ids = position_ids.unsqueeze(0) + 1 # NewTaskModel positions are 1-indexed
|
||||
|
||||
# Merge text and images
|
||||
if pixel_values is not None:
|
||||
image_features = self.get_image_features(pixel_values).pooler_output
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
special_image_mask = self.get_placeholder_mask(
|
||||
input_ids, inputs_embeds=inputs_embeds, image_features=image_features
|
||||
)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
|
||||
# It may already have been prepared by e.g. `generate`
|
||||
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
||||
causal_mask_mapping = create_causal_mask_mapping(
|
||||
self.config,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
past_key_values,
|
||||
position_ids,
|
||||
token_type_ids,
|
||||
pixel_values,
|
||||
is_training=self.training,
|
||||
)
|
||||
|
||||
outputs = self.language_model(
|
||||
attention_mask=causal_mask_mapping,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return NewTaskModelModelOutputWithPast(
|
||||
last_hidden_state=outputs.last_hidden_state,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
image_hidden_states=image_features if pixel_values is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The Base NewTaskModel model which consists of a vision backbone and a language model without language modeling head.,
|
||||
"""
|
||||
)
|
||||
class NewTaskModelForNewTask(NewTaskModelPreTrainedModel, GenerationMixin):
|
||||
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = NewTaskModelModel(config)
|
||||
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
|
||||
self.embedding_dim = self.config.embedding_dim
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
self.post_init()
|
||||
|
||||
@auto_docstring
|
||||
def get_image_features(self, pixel_values: torch.FloatTensor, **kwargs: Unpack[TransformersKwargs]):
|
||||
return self.model.get_image_features(pixel_values, **kwargs)
|
||||
|
||||
@can_return_tuple
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
) -> tuple | NewTaskModelCausalLMOutputWithPast:
|
||||
r"""
|
||||
Returns:
|
||||
"""
|
||||
vlm_outputs = super().forward(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
)
|
||||
last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
||||
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
|
||||
|
||||
# L2 normalization
|
||||
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return (embeddings,) + vlm_outputs
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
position_ids=None,
|
||||
pixel_values=None,
|
||||
attention_mask=None,
|
||||
token_type_ids=None,
|
||||
use_cache=True,
|
||||
logits_to_keep=None,
|
||||
labels=None,
|
||||
is_first_iteration=False,
|
||||
**kwargs,
|
||||
):
|
||||
# Overwritten -- custom `position_ids` and `pixel_values` handling
|
||||
model_inputs = super().prepare_inputs_for_generation(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=use_cache,
|
||||
logits_to_keep=logits_to_keep,
|
||||
token_type_ids=token_type_ids,
|
||||
is_first_iteration=is_first_iteration,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# position_ids in NewTaskModel are 1-indexed
|
||||
if model_inputs.get("position_ids") is not None:
|
||||
# NOTE: we need this op out-of-place, otherwise it modifies the `model_kwargs` dict used in `generate` in-place!
|
||||
model_inputs["position_ids"] = model_inputs["position_ids"] + 1
|
||||
|
||||
# Pixel values are used only in the first iteration if available
|
||||
# In subsequent iterations, they are already merged with text and cached
|
||||
# NOTE: first iteration doesn't have to be prefill, it can be the first
|
||||
# iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
|
||||
if is_first_iteration or not use_cache:
|
||||
model_inputs["pixel_values"] = pixel_values
|
||||
|
||||
return model_inputs
|
||||
|
||||
@staticmethod
|
||||
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
|
||||
def create_masks_for_generate(
|
||||
config: PreTrainedConfig,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
past_key_values: Cache | None,
|
||||
position_ids: torch.Tensor | None,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
is_first_iteration: bool | None = False,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
# Uses the overwritten `create_masks_for_generate` with `token_type_ids` masking
|
||||
return create_causal_mask_mapping(
|
||||
config,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
past_key_values,
|
||||
position_ids,
|
||||
token_type_ids,
|
||||
is_first_iteration=is_first_iteration,
|
||||
**{k: v for k, v in kwargs.items() if k != "pixel_values"},
|
||||
)
|
||||
|
||||
def resize_token_embeddings(
|
||||
self, new_num_tokens: int | None = None, pad_to_multiple_of=None, mean_resizing=True
|
||||
) -> nn.Embedding:
|
||||
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
|
||||
|
||||
# Update vocab size
|
||||
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
||||
self.config.vocab_size = model_embeds.num_embeddings
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
|
||||
return model_embeds
|
||||
642
examples/modular-transformers/modeling_roberta.py
Normal file
642
examples/modular-transformers/modeling_roberta.py
Normal file
@@ -0,0 +1,642 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_roberta.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_roberta.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ... import initialization as init
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...masking_utils import create_bidirectional_mask, create_causal_mask
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...pytorch_utils import apply_chunking_to_forward
|
||||
from ...utils import TransformersKwargs, auto_docstring
|
||||
from ...utils.generic import merge_with_config_defaults
|
||||
from ...utils.output_capturing import capture_outputs
|
||||
from .configuration_roberta import RobertaConfig
|
||||
|
||||
|
||||
class RobertaEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.hidden_size, config.pad_token_id
|
||||
)
|
||||
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
||||
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
||||
self.register_buffer(
|
||||
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
||||
)
|
||||
self.register_buffer(
|
||||
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
|
||||
)
|
||||
self.pad_token_id = config.pad_token_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
past_key_values_length: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
else:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
|
||||
|
||||
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
||||
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
||||
# issue #5664
|
||||
if token_type_ids is None:
|
||||
if hasattr(self, "token_type_ids"):
|
||||
# NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
|
||||
buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
|
||||
buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
|
||||
token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = inputs_embeds + token_type_embeddings
|
||||
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = embeddings + position_embeddings
|
||||
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float | None = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class RobertaSelfAttention(nn.Module):
|
||||
def __init__(self, config, is_causal=False, layer_idx=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
self.is_decoder = config.is_decoder
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# get all proj
|
||||
query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
|
||||
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
|
||||
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
|
||||
|
||||
if past_key_values is not None:
|
||||
# decoder-only roberta can have a simple dynamic cache for example
|
||||
current_past_key_values = past_key_values
|
||||
if isinstance(past_key_values, EncoderDecoderCache):
|
||||
current_past_key_values = past_key_values.self_attention_cache
|
||||
|
||||
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
||||
key_layer, value_layer = current_past_key_values.update(key_layer, value_layer, self.layer_idx)
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout.p,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class RobertaCrossAttention(nn.Module):
|
||||
def __init__(self, config, is_causal=False, layer_idx=None):
|
||||
super().__init__()
|
||||
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
||||
f"heads ({config.num_attention_heads})"
|
||||
)
|
||||
self.config = config
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.scaling = self.attention_head_size**-0.5
|
||||
|
||||
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
||||
|
||||
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
self.is_causal = is_causal
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.FloatTensor | None = None,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: EncoderDecoderCache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
# determine input shapes
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
|
||||
hidden_shape = (*input_shape, -1, self.attention_head_size)
|
||||
|
||||
# get query proj
|
||||
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
|
||||
if past_key_values is not None and is_updated:
|
||||
# reuse k,v, cross_attentions
|
||||
key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
|
||||
value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
|
||||
else:
|
||||
kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
|
||||
key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
|
||||
value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)
|
||||
|
||||
if past_key_values is not None:
|
||||
# save all states to the cache
|
||||
key_layer, value_layer = past_key_values.cross_attention_cache.update(
|
||||
key_layer, value_layer, self.layer_idx
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
past_key_values.is_updated[self.layer_idx] = True
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.dropout.p,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class RobertaSelfOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RobertaAttention(nn.Module):
|
||||
def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
|
||||
super().__init__()
|
||||
self.is_cross_attention = is_cross_attention
|
||||
attention_class = RobertaCrossAttention if is_cross_attention else RobertaSelfAttention
|
||||
self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
|
||||
self.output = RobertaSelfOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
encoder_hidden_states: torch.FloatTensor | None = None,
|
||||
encoder_attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor]:
|
||||
attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
|
||||
attention_output, attn_weights = self.self(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self.output(attention_output, hidden_states)
|
||||
return attention_output, attn_weights
|
||||
|
||||
|
||||
class RobertaIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RobertaOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RobertaLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config, layer_idx=None):
|
||||
super().__init__()
|
||||
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
||||
self.seq_len_dim = 1
|
||||
self.attention = RobertaAttention(config, is_causal=config.is_decoder, layer_idx=layer_idx)
|
||||
self.is_decoder = config.is_decoder
|
||||
self.add_cross_attention = config.add_cross_attention
|
||||
if self.add_cross_attention:
|
||||
if not self.is_decoder:
|
||||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
||||
self.crossattention = RobertaAttention(
|
||||
config,
|
||||
is_causal=False,
|
||||
layer_idx=layer_idx,
|
||||
is_cross_attention=True,
|
||||
)
|
||||
self.intermediate = RobertaIntermediate(config)
|
||||
self.output = RobertaOutput(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
encoder_hidden_states: torch.FloatTensor | None = None,
|
||||
encoder_attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
self_attention_output, _ = self.attention(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = self_attention_output
|
||||
|
||||
if self.is_decoder and encoder_hidden_states is not None:
|
||||
if not hasattr(self, "crossattention"):
|
||||
raise ValueError(
|
||||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
||||
" by setting `config.add_cross_attention=True`"
|
||||
)
|
||||
|
||||
cross_attention_output, _ = self.crossattention(
|
||||
self_attention_output,
|
||||
None, # attention_mask
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
attention_output = cross_attention_output
|
||||
|
||||
layer_output = apply_chunking_to_forward(
|
||||
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
||||
)
|
||||
return layer_output
|
||||
|
||||
def feed_forward_chunk(self, attention_output):
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class RobertaEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer = nn.ModuleList([RobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
encoder_hidden_states: torch.FloatTensor | None = None,
|
||||
encoder_attention_mask: torch.FloatTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
use_cache: bool | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor] | BaseModelOutputWithPastAndCrossAttentions:
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states, # as a positional argument for gradient checkpointing
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
)
|
||||
|
||||
|
||||
class RobertaPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
class RobertaPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.transform_act_fn(hidden_states)
|
||||
hidden_states = self.LayerNorm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class RobertaLMPredictionHead(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.transform = RobertaPredictionHeadTransform(config)
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
# an output-only bias for each token.
|
||||
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
|
||||
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class RobertaPreTrainedModel(PreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
base_model_prefix = "roberta"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": RobertaLayer,
|
||||
"attentions": RobertaSelfAttention,
|
||||
"cross_attentions": RobertaCrossAttention,
|
||||
}
|
||||
|
||||
@torch.no_grad()
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights"""
|
||||
super()._init_weights(module)
|
||||
if isinstance(module, RobertaLMPredictionHead):
|
||||
init.zeros_(module.bias)
|
||||
elif isinstance(module, RobertaEmbeddings):
|
||||
init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
|
||||
init.zeros_(module.token_type_ids)
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
||||
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
|
||||
all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
||||
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
||||
|
||||
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
|
||||
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
||||
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
||||
"""
|
||||
)
|
||||
class RobertaModel(RobertaPreTrainedModel):
|
||||
_no_split_modules = ["RobertaEmbeddings", "RobertaLayer"]
|
||||
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
r"""
|
||||
add_pooling_layer (bool, *optional*, defaults to `True`):
|
||||
Whether to add a pooling layer
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.embeddings = RobertaEmbeddings(config)
|
||||
self.encoder = RobertaEncoder(config)
|
||||
|
||||
self.pooler = RobertaPooler(config) if add_pooling_layer else None
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
|
||||
@merge_with_config_defaults
|
||||
@capture_outputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
position_ids: torch.Tensor | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
use_cache: bool | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.config.is_decoder:
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
else:
|
||||
use_cache = False
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = (
|
||||
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
||||
if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
||||
else DynamicCache(config=self.config)
|
||||
)
|
||||
|
||||
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
attention_mask, encoder_attention_mask = self._create_attention_masks(
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
embedding_output=embedding_output,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
position_ids=position_ids,
|
||||
**kwargs,
|
||||
)
|
||||
sequence_output = encoder_outputs.last_hidden_state
|
||||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||||
|
||||
return BaseModelOutputWithPoolingAndCrossAttentions(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
past_key_values=encoder_outputs.past_key_values,
|
||||
)
|
||||
|
||||
def _create_attention_masks(
|
||||
self,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
embedding_output,
|
||||
encoder_hidden_states,
|
||||
past_key_values,
|
||||
):
|
||||
if self.config.is_decoder:
|
||||
attention_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=embedding_output,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
else:
|
||||
attention_mask = create_bidirectional_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=embedding_output,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = create_bidirectional_mask(
|
||||
config=self.config,
|
||||
inputs_embeds=embedding_output,
|
||||
attention_mask=encoder_attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
return attention_mask, encoder_attention_mask
|
||||
375
examples/modular-transformers/modeling_super.py
Normal file
375
examples/modular-transformers/modeling_super.py
Normal file
@@ -0,0 +1,375 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_super.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_super.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring
|
||||
from ...utils.generic import maybe_autocast, merge_with_config_defaults
|
||||
from ...utils.output_capturing import capture_outputs
|
||||
from .configuration_super import SuperConfig
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class SuperRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
|
||||
"""
|
||||
SuperRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class SuperRotaryEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
||||
|
||||
def __init__(self, config: SuperConfig, device=None):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
|
||||
self.rope_type = self.config.rope_parameters["rope_type"]
|
||||
rope_init_fn: Callable = self.compute_default_rope_parameters
|
||||
if self.rope_type != "default":
|
||||
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def compute_default_rope_parameters(
|
||||
config: SuperConfig | None = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
seq_len: int | None = None,
|
||||
) -> tuple["torch.Tensor", float]:
|
||||
"""
|
||||
Computes the inverse frequencies according to the original RoPE implementation
|
||||
Args:
|
||||
config ([`~transformers.PreTrainedConfig`]):
|
||||
The model configuration.
|
||||
device (`torch.device`):
|
||||
The device to use for initialization of the inverse frequencies.
|
||||
seq_len (`int`, *optional*):
|
||||
The current sequence length. Unused for this type of RoPE.
|
||||
Returns:
|
||||
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
|
||||
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
|
||||
"""
|
||||
base = config.rope_parameters["rope_theta"]
|
||||
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
|
||||
|
||||
attention_factor = 1.0 # Unused in this type of RoPE
|
||||
|
||||
# Compute the inverse frequencies
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
|
||||
)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
@torch.no_grad()
|
||||
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
||||
def forward(self, x, position_ids):
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
|
||||
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
||||
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos() * self.attention_scaling
|
||||
sin = emb.sin() * self.attention_scaling
|
||||
|
||||
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||
|
||||
|
||||
class SuperMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@use_kernel_func_from_hub("rotary_pos_emb")
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class SuperAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: SuperConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_values is not None:
|
||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class SuperDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: SuperConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = SuperAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = SuperMLP(config)
|
||||
self.input_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
use_cache: bool | None = False,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SuperPreTrainedModel(PreTrainedModel):
|
||||
config: SuperConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SuperDecoderLayer"]
|
||||
_skip_keys_device_placement = ["past_key_values"]
|
||||
_supports_flash_attn = True
|
||||
_supports_sdpa = True
|
||||
_supports_flex_attn = True
|
||||
|
||||
_can_compile_fullgraph = True
|
||||
_supports_attention_backend = True
|
||||
_can_record_outputs = {
|
||||
"hidden_states": SuperDecoderLayer,
|
||||
"attentions": SuperAttention,
|
||||
}
|
||||
|
||||
|
||||
@auto_docstring
|
||||
class SuperModel(SuperPreTrainedModel):
|
||||
def __init__(self, config: SuperConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList(
|
||||
[SuperDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
||||
)
|
||||
self.norm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = SuperRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@merge_with_config_defaults
|
||||
@capture_outputs
|
||||
@auto_docstring
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
out = super().forward(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
inputs_embeds,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
)
|
||||
out.logits *= 2**4
|
||||
return out
|
||||
157
examples/modular-transformers/modeling_switch_function.py
Normal file
157
examples/modular-transformers/modeling_switch_function.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_switch_function.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_switch_function.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# Note that llama and cohere have different definitions for rotate_half
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...integrations import use_kernel_func_from_hub, use_kernelized_func
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs
|
||||
from .configuration_switch_function import SwitchFunctionConfig
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
# Split and rotate. Note that this function is different from e.g. Llama.
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
|
||||
return rot_x
|
||||
|
||||
|
||||
@use_kernel_func_from_hub("rotary_pos_emb")
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class SwitchFunctionAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: SwitchFunctionConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_values is not None:
|
||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
1420
examples/modular-transformers/modeling_test_detr.py
Normal file
1420
examples/modular-transformers/modeling_test_detr.py
Normal file
File diff suppressed because it is too large
Load Diff
241
examples/modular-transformers/modeling_test_suffix.py
Normal file
241
examples/modular-transformers/modeling_test_suffix.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
# This file was automatically generated from examples/modular-transformers/modular_test_suffix.py.
|
||||
# Do NOT edit this file manually as any edits will be overwritten by the generation of
|
||||
# the file from the modular. If any change should be done, please apply the change to the
|
||||
# modular_test_suffix.py file directly. One of our CI enforces this.
|
||||
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...cache_utils import Cache
|
||||
from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs
|
||||
from .configuration_test_suffix import TestSuffixLlamaConfig
|
||||
|
||||
|
||||
class TestSuffixDecoderLayer(nn.module):
|
||||
pass
|
||||
|
||||
|
||||
@use_kernel_forward_from_hub("RMSNorm")
|
||||
class TestSuffixLlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
|
||||
"""
|
||||
TestSuffixLlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class TestSuffixLlamaMLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
|
||||
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
@use_kernel_func_from_hub("rotary_pos_emb")
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
||||
"""Applies Rotary Position Embedding to the query and key tensors.
|
||||
|
||||
Args:
|
||||
q (`torch.Tensor`): The query tensor.
|
||||
k (`torch.Tensor`): The key tensor.
|
||||
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
||||
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
||||
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
||||
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
||||
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
||||
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
||||
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
||||
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
||||
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
||||
Returns:
|
||||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
||||
"""
|
||||
cos = cos.unsqueeze(unsqueeze_dim)
|
||||
sin = sin.unsqueeze(unsqueeze_dim)
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None,
|
||||
scaling: float,
|
||||
dropout: float = 0.0,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
key_states = repeat_kv(key, module.num_key_value_groups)
|
||||
value_states = repeat_kv(value, module.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
@use_kernelized_func(apply_rotary_pos_emb)
|
||||
class TestSuffixLlamaAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: TestSuffixLlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
if past_key_values is not None:
|
||||
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
|
||||
|
||||
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
||||
self.config._attn_implementation, eager_attention_forward
|
||||
)
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
class TestSuffixLlamaDecoderLayer(GradientCheckpointingLayer):
|
||||
def __init__(self, config: TestSuffixLlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = TestSuffixLlamaAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = TestSuffixLlamaMLP(config)
|
||||
self.input_layernorm = TestSuffixLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = TestSuffixLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
use_cache: bool | None = False,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self Attention
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
15
examples/modular-transformers/modular_add_function.py
Normal file
15
examples/modular-transformers/modular_add_function.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Note that zamba does not have the `apply_rotary_pos_emb` function!
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
from transformers.models.zamba.modeling_zamba import ZambaAttention
|
||||
|
||||
|
||||
# When following ZambaAttention dependencies, the function `apply_rotary_pos_emb` is not present
|
||||
# by default as it is absent from the class definition (and the file altogether).
|
||||
# Note that this syntax should be able to add both `apply_rotary_pos_emb` as imported directly, but
|
||||
# `rotate_half` as well as a dependency from the imported function!!
|
||||
class TestAttention(ZambaAttention):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def forward(self):
|
||||
_ = apply_rotary_pos_emb(1, 1, 1, 1)
|
||||
27
examples/modular-transformers/modular_dummy_bert.py
Normal file
27
examples/modular-transformers/modular_dummy_bert.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertModel
|
||||
|
||||
from ...modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs
|
||||
|
||||
|
||||
class DummyBertModel(BertModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
token_type_ids: torch.Tensor | None = None,
|
||||
position_ids: torch.Tensor | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
past_key_values: list[torch.FloatTensor] | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
|
||||
return super().forward(input_ids, **kwargs)
|
||||
11
examples/modular-transformers/modular_duplicated_method.py
Normal file
11
examples/modular-transformers/modular_duplicated_method.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
class DuplicatedMethodConfig(LlamaConfig):
|
||||
@property
|
||||
def vocab_size(self): # noqa: F811 -> we need this at we cannot delete the original for now since config dataclass refactor
|
||||
return 45
|
||||
|
||||
@vocab_size.setter
|
||||
def vocab_size(self, value):
|
||||
self.vocab_size = value
|
||||
@@ -0,0 +1,6 @@
|
||||
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
|
||||
|
||||
|
||||
# Check if we can correctly grab dependencies with correct naming from all UPPERCASE old model
|
||||
class FromUppercaseModelEncoderLayer(CLIPEncoderLayer):
|
||||
pass
|
||||
16
examples/modular-transformers/modular_global_indexing.py
Normal file
16
examples/modular-transformers/modular_global_indexing.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from transformers.modeling_utils import AttentionInterface
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
|
||||
def custom_flex(x, **kwargs):
|
||||
"""Dummy function."""
|
||||
return x
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS = AttentionInterface()
|
||||
# This indexing statement and associated function should be exported correctly!
|
||||
ALL_ATTENTION_FUNCTIONS["flex_attention"] = custom_flex
|
||||
|
||||
|
||||
class GlobalIndexingAttention(LlamaAttention):
|
||||
pass
|
||||
62
examples/modular-transformers/modular_multimodal2.py
Normal file
62
examples/modular-transformers/modular_multimodal2.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Here, because clip is not consistent with the use of the "Text" and "Vision" prefixes, we cannot simply use
|
||||
```
|
||||
class Multimodal2VisionModel(CLIPVisionModel):
|
||||
pass
|
||||
```
|
||||
with the hope that all dependencies will be renamed as `Multimodal2VisionClass`. For this reason, if we want consistency and
|
||||
use the "Vision" part everywhere, we need to overwrite the intermediate classes and add the prefix everytime.
|
||||
This adds noise to the modular, but is unfortunately unavoidable.
|
||||
"""
|
||||
|
||||
from torch import nn
|
||||
|
||||
from transformers.models.clip.modeling_clip import (
|
||||
CLIPMLP,
|
||||
CLIPAttention,
|
||||
CLIPEncoder,
|
||||
CLIPEncoderLayer,
|
||||
CLIPPreTrainedModel,
|
||||
CLIPVisionModel,
|
||||
)
|
||||
|
||||
|
||||
class Multimodal2VisionAttention(CLIPAttention):
|
||||
pass
|
||||
|
||||
|
||||
class Multimodal2VisionMLP(CLIPMLP):
|
||||
pass
|
||||
|
||||
|
||||
class Multimodal2VisionEncoderLayer(CLIPEncoderLayer):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.mlp = Multimodal2VisionMLP(config)
|
||||
self.self_attn = Multimodal2VisionAttention(config)
|
||||
|
||||
|
||||
class Multimodal2VisionEncoder(CLIPEncoder):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList([Multimodal2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
|
||||
class Multimodal2VisionPreTrainedModel(CLIPPreTrainedModel):
|
||||
_can_record_outputs = {
|
||||
"hidden_states": Multimodal2VisionEncoderLayer,
|
||||
"attentions": Multimodal2VisionAttention,
|
||||
}
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, Multimodal2VisionMLP):
|
||||
pass
|
||||
|
||||
|
||||
# `CLIPVisionModel` inherits from `CLIPPreTrainedModel`. We need to add the 2nd base here to add the `Vision` part
|
||||
class Multimodal2VisionModel(CLIPVisionModel, Multimodal2VisionPreTrainedModel):
|
||||
_no_split_modules = ["Multimodal2VisionEncoderLayer"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.encoder = Multimodal2VisionEncoder(config)
|
||||
124
examples/modular-transformers/modular_my_new_model.py
Normal file
124
examples/modular-transformers/modular_my_new_model.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
# Example where we only want to only add a new config argument and new arg doc
|
||||
class MyNewModelConfig(LlamaConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`MyNewModelModel`]. It is used to instantiate an MyNewModel
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the MyNewModel-7B.
|
||||
e.g. [meta-my_new_model/MyNewModel-2-7b-hf](https://huggingface.co/meta-my_new_model/MyNewModel-2-7b-hf)
|
||||
|
||||
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PreTrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the MyNewModel model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`MyNewModelModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. MyNewModel 1 supports up to 2048 tokens,
|
||||
MyNewModel 2 up to 4096, CodeLlama up to 16384.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||
document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
|
||||
understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
|
||||
results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'my_new_model3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'my_new_model3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'my_new_model3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'my_new_model3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
|
||||
head_dim (`int`, *optional*):
|
||||
The attention head dimension. If None, it will default to hidden_size // num_attention_heads
|
||||
|
||||
```python
|
||||
>>> from transformers import MyNewModelModel, MyNewModelConfig
|
||||
|
||||
>>> # Initializing a MyNewModel my_new_model-7b style configuration
|
||||
>>> configuration = MyNewModelConfig()
|
||||
|
||||
>>> # Initializing a model from the my_new_model-7b style configuration
|
||||
>>> model = MyNewModelModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
mlp_bias: bool = True
|
||||
new_param: int = 0
|
||||
31
examples/modular-transformers/modular_my_new_model2.py
Normal file
31
examples/modular-transformers/modular_my_new_model2.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from transformers.models.gemma.modeling_gemma import GemmaForSequenceClassification
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
|
||||
|
||||
# Example where we only want to only modify the docstring
|
||||
class MyNewModel2Config(LlamaConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Gemma-7B.
|
||||
e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
|
||||
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PreTrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 256000):
|
||||
Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`GemmaModel`]
|
||||
```python
|
||||
>>> from transformers import GemmaModel, GemmaConfig
|
||||
>>> # Initializing a Gemma gemma-7b style configuration
|
||||
>>> configuration = GemmaConfig()
|
||||
>>> # Initializing a model from the gemma-7b style configuration
|
||||
>>> model = GemmaModel(configuration)
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
|
||||
# Example where alllllll the dependencies are fetched to just copy the entire class
|
||||
class MyNewModel2ForSequenceClassification(GemmaForSequenceClassification):
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from transformers.models.blip.image_processing_blip import BlipImageProcessor
|
||||
|
||||
|
||||
class ImgprocModelImageProcessor(BlipImageProcessor):
|
||||
def new_image_processing_method(self, pixel_values: torch.FloatTensor):
|
||||
return pixel_values / 2
|
||||
31
examples/modular-transformers/modular_new_model.py
Normal file
31
examples/modular-transformers/modular_new_model.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Example where we only want to overwrite the defaults of an init
|
||||
|
||||
from transformers.models.gemma.configuration_gemma import GemmaConfig
|
||||
|
||||
|
||||
class NewModelConfig(GemmaConfig):
|
||||
vocab_size: int = 256030
|
||||
hidden_size: int = 64
|
||||
intermediate_size: int = 90
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 16
|
||||
head_dim: int = 256
|
||||
hidden_act: str = "gelu_pytorch_tanh"
|
||||
hidden_activation: str | None = None
|
||||
max_position_embeddings: int = 1500
|
||||
initializer_range: float = 0.02
|
||||
rms_norm_eps: float = 1e-6
|
||||
use_cache: bool = True
|
||||
pad_token_id: int = 0
|
||||
eos_token_id: int = 1
|
||||
bos_token_id: int = 2
|
||||
tie_word_embeddings: bool = True
|
||||
rope_parameters: dict | None = None
|
||||
attention_bias: bool = False
|
||||
attention_dropout: float = 0.0
|
||||
use_bidirectional_attention: bool = False
|
||||
|
||||
@property
|
||||
def num_heads(self):
|
||||
return self.num_attention_heads
|
||||
78
examples/modular-transformers/modular_new_task_model.py
Normal file
78
examples/modular-transformers/modular_new_task_model.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
||||
|
||||
from ...cache_utils import Cache
|
||||
|
||||
|
||||
class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
|
||||
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config=config)
|
||||
|
||||
self.embedding_dim = self.config.embedding_dim
|
||||
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
token_type_ids: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
labels: torch.LongTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
num_logits_to_keep: int = 0,
|
||||
):
|
||||
r"""
|
||||
Returns:
|
||||
"""
|
||||
vlm_outputs = super().forward(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
token_type_ids=token_type_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
labels=labels,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
num_logits_to_keep=num_logits_to_keep,
|
||||
)
|
||||
last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
||||
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
|
||||
|
||||
# L2 normalization
|
||||
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
||||
|
||||
return (embeddings,) + vlm_outputs
|
||||
|
||||
def resize_token_embeddings(
|
||||
self, new_num_tokens: int | None = None, pad_to_multiple_of=None, mean_resizing=True
|
||||
) -> nn.Embedding:
|
||||
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
|
||||
|
||||
# Update vocab size
|
||||
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
||||
self.config.vocab_size = model_embeds.num_embeddings
|
||||
self.vocab_size = model_embeds.num_embeddings
|
||||
|
||||
return model_embeds
|
||||
17
examples/modular-transformers/modular_roberta.py
Normal file
17
examples/modular-transformers/modular_roberta.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertModel
|
||||
|
||||
|
||||
class RobertaEmbeddings(BertEmbeddings):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.pad_token_id = config.pad_token_id
|
||||
self.position_embeddings = nn.Embedding(
|
||||
config.max_position_embeddings, config.hidden_size, config.pad_token_id
|
||||
)
|
||||
|
||||
|
||||
class RobertaModel(BertModel):
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(self, config)
|
||||
35
examples/modular-transformers/modular_super.py
Normal file
35
examples/modular-transformers/modular_super.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import torch
|
||||
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaModel
|
||||
|
||||
from ...cache_utils import Cache
|
||||
|
||||
|
||||
# example where we need some deps and some functions
|
||||
class SuperModel(LlamaModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values: Cache | None = None,
|
||||
inputs_embeds: torch.FloatTensor | None = None,
|
||||
use_cache: bool | None = None,
|
||||
output_attentions: bool | None = None,
|
||||
output_hidden_states: bool | None = None,
|
||||
return_dict: bool | None = None,
|
||||
) -> tuple | CausalLMOutputWithPast:
|
||||
out = super().forward(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
inputs_embeds,
|
||||
use_cache,
|
||||
output_attentions,
|
||||
output_hidden_states,
|
||||
return_dict,
|
||||
)
|
||||
out.logits *= 2**4
|
||||
return out
|
||||
10
examples/modular-transformers/modular_switch_function.py
Normal file
10
examples/modular-transformers/modular_switch_function.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# Note that llama and cohere have different definitions for rotate_half
|
||||
from transformers.models.cohere.modeling_cohere import rotate_half # noqa
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
|
||||
# When following LlamaAttention dependencies, we will grab the function `rotate_half` defined
|
||||
# in `modeling_llama.py`. But here we imported it explicitly from Cohere, so it should use Cohere's
|
||||
# definition instead
|
||||
class SwitchFunctionAttention(LlamaAttention):
|
||||
pass
|
||||
7
examples/modular-transformers/modular_test_detr.py
Normal file
7
examples/modular-transformers/modular_test_detr.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from transformers.models.deformable_detr.modeling_deformable_detr import DeformableDetrModel
|
||||
|
||||
|
||||
# Here, the old and new model have by essence a common "detr" suffix. Make sure everything is correctly named
|
||||
# in this case (i.e., we do not wrongly detect `Detr` as part of a suffix to remove)
|
||||
class TestDetrModel(DeformableDetrModel):
|
||||
pass
|
||||
12
examples/modular-transformers/modular_test_suffix.py
Normal file
12
examples/modular-transformers/modular_test_suffix.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||||
|
||||
|
||||
class TestSuffixDecoderLayer(nn.module):
|
||||
pass
|
||||
|
||||
|
||||
# Here, we want to add "Llama" as a suffix to the base `TestModel` name for all required dependencies
|
||||
class TestSuffixLlamaDecoderLayer(LlamaDecoderLayer):
|
||||
pass
|
||||
Reference in New Issue
Block a user