Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
1249 lines
50 KiB
Python
1249 lines
50 KiB
Python
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Tests for data collators.
|
|
|
|
Tests are organized by collator type, with each test class containing:
|
|
- Functionality tests (PyTorch and NumPy variants)
|
|
- Immutability tests (verifying inputs are not mutated)
|
|
"""
|
|
|
|
import copy
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from transformers import (
|
|
BertTokenizer,
|
|
BertTokenizerFast,
|
|
DataCollatorForLanguageModeling,
|
|
DataCollatorForPermutationLanguageModeling,
|
|
DataCollatorForSeq2Seq,
|
|
DataCollatorForTokenClassification,
|
|
DataCollatorForWholeWordMask,
|
|
DataCollatorWithFlattening,
|
|
DataCollatorWithPadding,
|
|
default_data_collator,
|
|
is_torch_available,
|
|
set_seed,
|
|
)
|
|
from transformers.testing_utils import require_torch
|
|
from transformers.utils import PaddingStrategy
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
|
|
|
|
class DataCollatorTestMixin:
|
|
"""Mixin providing common setup and utility methods for data collator tests."""
|
|
|
|
def setUp(self):
|
|
self.tmpdirname = tempfile.mkdtemp()
|
|
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"]
|
|
self.vocab_file = os.path.join(self.tmpdirname, "vocab.txt")
|
|
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
|
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
|
|
|
def tearDown(self):
|
|
shutil.rmtree(self.tmpdirname)
|
|
|
|
def _check_immutability(self, collator, features):
|
|
"""Verify that collator does not mutate input data."""
|
|
original = copy.deepcopy(features)
|
|
collator(features)
|
|
|
|
for orig, feat in zip(original, features):
|
|
for key in orig:
|
|
orig_val, feat_val = orig[key], feat[key]
|
|
if isinstance(orig_val, np.ndarray):
|
|
self.assertEqual(orig_val.tolist(), feat_val.tolist())
|
|
elif is_torch_available() and isinstance(orig_val, torch.Tensor):
|
|
self.assertEqual(orig_val.tolist(), feat_val.tolist())
|
|
else:
|
|
self.assertEqual(orig_val, feat_val)
|
|
|
|
|
|
# =============================================================================
|
|
# default_data_collator tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestDefaultDataCollator(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for default_data_collator.
|
|
|
|
The default collator handles basic batching of dict features, converting
|
|
lists and arrays to tensors and properly handling labels.
|
|
"""
|
|
|
|
def test_basic_collation(self):
|
|
"""Test basic dict collation with lists."""
|
|
features = [{"label": i, "inputs": [0, 1, 2, 3, 4, 5]} for i in range(4)]
|
|
batch = default_data_collator(features)
|
|
|
|
self.assertEqual(batch["labels"].tolist(), list(range(4)))
|
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
|
self.assertEqual(batch["inputs"].shape, torch.Size([4, 6]))
|
|
|
|
def test_multi_label(self):
|
|
"""Test collation with multiple labels per sample."""
|
|
features = [{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3]} for _ in range(4)]
|
|
batch = default_data_collator(features)
|
|
|
|
self.assertEqual(batch["labels"].tolist(), [[0, 1, 2]] * 4)
|
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
|
|
|
def test_numpy_array_inputs(self):
|
|
"""Test collation when features are numpy arrays."""
|
|
features = [{"label": i, "inputs": np.random.randint(0, 10, [10])} for i in range(4)]
|
|
batch = default_data_collator(features)
|
|
|
|
self.assertEqual(batch["labels"].tolist(), list(range(4)))
|
|
self.assertEqual(batch["inputs"].shape, torch.Size([4, 10]))
|
|
|
|
def test_tensor_labels(self):
|
|
"""Test collation when labels are already tensors."""
|
|
features = [{"label": torch.tensor(i), "inputs": [0, 1, 2]} for i in range(4)]
|
|
batch = default_data_collator(features)
|
|
|
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
|
self.assertEqual(batch["labels"].tolist(), list(range(4)))
|
|
|
|
def test_dtype_inference(self):
|
|
"""Test that int labels become long, float labels become float."""
|
|
# Classification: int -> long
|
|
features = [{"input_ids": [0, 1, 2], "label": i} for i in range(4)]
|
|
batch = default_data_collator(features)
|
|
self.assertEqual(batch["labels"].dtype, torch.long)
|
|
|
|
# Regression: float -> float
|
|
features = [{"input_ids": [0, 1, 2], "label": float(i)} for i in range(4)]
|
|
batch = default_data_collator(features)
|
|
self.assertEqual(batch["labels"].dtype, torch.float)
|
|
|
|
def test_none_labels_excluded(self):
|
|
"""Test that None labels are excluded from batch."""
|
|
features = [{"label": None, "inputs": [0, 1, 2, 3]} for _ in range(4)]
|
|
batch = default_data_collator(features)
|
|
self.assertNotIn("labels", batch)
|
|
|
|
# With label_ids
|
|
features = [{"label_ids": None, "inputs": [0, 1, 2, 3]} for _ in range(4)]
|
|
batch = default_data_collator(features)
|
|
self.assertNotIn("labels", batch)
|
|
|
|
def test_numpy_output(self):
|
|
"""Test collation with NumPy output."""
|
|
features = [{"label": i, "inputs": [0, 1, 2, 3]} for i in range(4)]
|
|
batch = default_data_collator(features, return_tensors="np")
|
|
|
|
self.assertEqual(batch["labels"].tolist(), list(range(4)))
|
|
self.assertEqual(batch["labels"].dtype, np.int64)
|
|
self.assertEqual(batch["inputs"].shape, (4, 4))
|
|
|
|
def test_numpy_dtype_inference(self):
|
|
"""Test dtype inference with NumPy output."""
|
|
# Int labels
|
|
features = [{"input_ids": [0, 1, 2], "label": i} for i in range(4)]
|
|
batch = default_data_collator(features, return_tensors="np")
|
|
self.assertEqual(batch["labels"].dtype, np.int64)
|
|
|
|
# Float labels
|
|
features = [{"input_ids": [0, 1, 2], "label": float(i)} for i in range(4)]
|
|
batch = default_data_collator(features, return_tensors="np")
|
|
self.assertEqual(batch["labels"].dtype, np.float32)
|
|
|
|
def test_immutability(self):
|
|
"""Test that collation does not mutate input data."""
|
|
|
|
def collator_pt(x):
|
|
return default_data_collator(x, return_tensors="pt")
|
|
|
|
def collator_np(x):
|
|
return default_data_collator(x, return_tensors="np")
|
|
|
|
for collator in [collator_pt, collator_np]:
|
|
# Test with various input types
|
|
for features in [
|
|
[{"label": i, "inputs": [0, 1, 2, 3]} for i in range(4)],
|
|
[{"label": float(i), "inputs": [0, 1, 2, 3]} for i in range(4)],
|
|
[{"label": None, "inputs": [0, 1, 2, 3]} for _ in range(4)],
|
|
[{"label_ids": [0, 1, 2], "inputs": [0, 1, 2, 3]} for _ in range(4)],
|
|
]:
|
|
self._check_immutability(collator, features)
|
|
|
|
|
|
# =============================================================================
|
|
# DataCollatorWithPadding tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestDataCollatorWithPadding(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for DataCollatorWithPadding.
|
|
|
|
This collator pads sequences to the same length within a batch, supporting
|
|
dynamic padding, max_length, and pad_to_multiple_of options.
|
|
"""
|
|
|
|
def test_dynamic_padding(self):
|
|
"""Test padding to longest sequence in batch."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
|
|
|
collator = DataCollatorWithPadding(tokenizer)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
|
|
|
def test_max_length_padding(self):
|
|
"""Test padding to specified max_length."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
|
|
|
collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=10)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
|
|
def test_pad_to_multiple_of(self):
|
|
"""Test padding to multiple of specified value."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
|
|
|
collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
|
|
|
def test_numpy_output(self):
|
|
"""Test padding with NumPy output."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
|
|
|
collator = DataCollatorWithPadding(tokenizer, return_tensors="np")
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (2, 6))
|
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
|
|
|
def test_attention_mask_generated(self):
|
|
"""Test that attention_mask is properly generated."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
|
|
|
collator = DataCollatorWithPadding(tokenizer)
|
|
batch = collator(features)
|
|
|
|
self.assertIn("attention_mask", batch)
|
|
self.assertEqual(batch["attention_mask"][0].tolist(), [1, 1, 1, 0, 0, 0])
|
|
self.assertEqual(batch["attention_mask"][1].tolist(), [1, 1, 1, 1, 1, 1])
|
|
|
|
def test_immutability(self):
|
|
"""Test that padding does not mutate input data."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
for return_tensors in ["pt", "np"]:
|
|
collator = DataCollatorWithPadding(tokenizer, return_tensors=return_tensors)
|
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
|
self._check_immutability(collator, features)
|
|
|
|
def test_4d_attention_mask_preserved(self):
|
|
"""A 4D per-sample `attention_mask` is stacked into the batch unchanged.
|
|
|
|
`tokenizer.pad`'s 1D-per-sample padding logic does not apply to higher-rank
|
|
masks, so the collator must preserve their shape, dtype, and per-sample
|
|
contents.
|
|
"""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
seq_len = 6
|
|
batch_size = 3
|
|
|
|
for sample_factory, expected_dtype in (
|
|
(lambda i: torch.full((1, seq_len, seq_len), float(i)), torch.float32),
|
|
(lambda i: np.full((1, seq_len, seq_len), float(i), dtype=np.float32), torch.float32),
|
|
):
|
|
features = [
|
|
{"input_ids": list(range(seq_len)), "attention_mask": sample_factory(i)} for i in range(batch_size)
|
|
]
|
|
collator = DataCollatorWithPadding(tokenizer)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["attention_mask"].shape, torch.Size([batch_size, 1, seq_len, seq_len]))
|
|
self.assertEqual(batch["attention_mask"].dtype, expected_dtype)
|
|
for i in range(batch_size):
|
|
self.assertTrue(torch.all(batch["attention_mask"][i] == float(i)))
|
|
|
|
def test_4d_attention_mask_asymmetric_kv(self):
|
|
"""A 4D mask with `kv_len != q_len` (e.g. KV-cache style) is preserved."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
q_len, kv_len, batch_size = 4, 9, 2
|
|
features = [
|
|
{
|
|
"input_ids": list(range(q_len)),
|
|
"attention_mask": torch.zeros((1, q_len, kv_len), dtype=torch.float32),
|
|
}
|
|
for _ in range(batch_size)
|
|
]
|
|
batch = DataCollatorWithPadding(tokenizer)(features)
|
|
|
|
self.assertEqual(batch["attention_mask"].shape, torch.Size([batch_size, 1, q_len, kv_len]))
|
|
|
|
|
|
# =============================================================================
|
|
# DataCollatorWithFlattening tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestDataCollatorWithFlattening(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for DataCollatorWithFlattening.
|
|
|
|
This collator concatenates multiple sequences into a single sequence,
|
|
useful for packing multiple examples efficiently. It generates position_ids
|
|
that reset for each original sequence.
|
|
"""
|
|
|
|
def _get_features(self):
|
|
return [
|
|
{"input_ids": [10, 11, 12]},
|
|
{"input_ids": [20, 21, 22, 23, 24, 25]},
|
|
{"input_ids": [30, 31, 32, 33, 34, 35, 36]},
|
|
]
|
|
|
|
def test_basic_flattening(self):
|
|
"""Test that sequences are concatenated with per-sequence position_ids."""
|
|
collator = DataCollatorWithFlattening(return_tensors="pt")
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([1, 16]))
|
|
self.assertEqual(
|
|
batch["input_ids"][0].tolist(),
|
|
[10, 11, 12, 20, 21, 22, 23, 24, 25, 30, 31, 32, 33, 34, 35, 36],
|
|
)
|
|
self.assertEqual(batch["position_ids"][0].tolist(), [0, 1, 2, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6])
|
|
|
|
# Should not include attention_mask or flash attn kwargs by default
|
|
for key in ["attention_mask", "cu_seq_lens_k", "cu_seq_lens_q", "seq_idx"]:
|
|
self.assertNotIn(key, batch)
|
|
|
|
def test_flash_attn_kwargs(self):
|
|
"""Test flattening with Flash Attention kwargs."""
|
|
collator = DataCollatorWithFlattening(return_tensors="pt", return_flash_attn_kwargs=True)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
|
|
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
|
|
self.assertEqual(batch["max_length_k"], 7)
|
|
self.assertEqual(batch["max_length_q"], 7)
|
|
|
|
def test_seq_idx(self):
|
|
"""Test flattening with seq_idx for sequence identification."""
|
|
collator = DataCollatorWithFlattening(return_tensors="pt", return_seq_idx=True)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["seq_idx"][0].tolist(), [0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2])
|
|
|
|
def test_with_labels(self):
|
|
"""Test flattening with tensor and list labels."""
|
|
# Tensor labels
|
|
features = [
|
|
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([10, 11, 12])},
|
|
{"input_ids": torch.tensor([4, 5]), "labels": torch.tensor([13, 14])},
|
|
]
|
|
collator = DataCollatorWithFlattening(return_tensors="pt")
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (1, 5))
|
|
self.assertEqual(batch["labels"].shape, (1, 5))
|
|
|
|
# List labels
|
|
features = [
|
|
{"input_ids": [1, 2, 3], "labels": [10, 11, 12]},
|
|
{"input_ids": [4, 5], "labels": [13, 14]},
|
|
]
|
|
batch = collator(features)
|
|
self.assertEqual(batch["labels"].shape, (1, 5))
|
|
|
|
def test_numpy_output(self):
|
|
"""Test flattening with NumPy output."""
|
|
collator = DataCollatorWithFlattening(return_tensors="np")
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (1, 16))
|
|
self.assertEqual(batch["position_ids"].shape, (1, 16))
|
|
|
|
def test_numpy_flash_attn_kwargs(self):
|
|
"""Test flattening with Flash Attention kwargs and NumPy output."""
|
|
collator = DataCollatorWithFlattening(return_tensors="np", return_flash_attn_kwargs=True)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 3, 9, 16])
|
|
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 3, 9, 16])
|
|
self.assertEqual(batch["max_length_k"], 7)
|
|
self.assertEqual(batch["max_length_q"], 7)
|
|
|
|
def test_immutability(self):
|
|
"""Test that flattening does not mutate input data."""
|
|
for return_tensors in ["pt", "np"]:
|
|
collator = DataCollatorWithFlattening(return_tensors=return_tensors)
|
|
self._check_immutability(collator, self._get_features())
|
|
|
|
|
|
# =============================================================================
|
|
# DataCollatorForTokenClassification tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestDataCollatorForTokenClassification(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for DataCollatorForTokenClassification.
|
|
|
|
This collator pads both input_ids and labels for token classification tasks,
|
|
using -100 as the default label padding value (ignored by loss functions).
|
|
"""
|
|
|
|
def _get_features(self):
|
|
return [
|
|
{"input_ids": [0, 1, 2], "labels": [0, 1, 2]},
|
|
{"input_ids": [0, 1, 2, 3, 4, 5], "labels": [0, 1, 2, 3, 4, 5]},
|
|
]
|
|
|
|
def test_padding(self):
|
|
"""Test that both input_ids and labels are padded correctly."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForTokenClassification(tokenizer)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2, -100, -100, -100])
|
|
|
|
def test_max_length_padding(self):
|
|
"""Test padding to max_length."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))
|
|
|
|
def test_pad_to_multiple_of(self):
|
|
"""Test padding to multiple of 8."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 8]))
|
|
|
|
def test_custom_label_pad_token(self):
|
|
"""Test custom label padding token."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2, -1, -1, -1])
|
|
|
|
def test_without_labels(self):
|
|
"""Test collator works without labels (inference mode)."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}]
|
|
|
|
collator = DataCollatorForTokenClassification(tokenizer)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
|
self.assertNotIn("labels", batch)
|
|
|
|
def test_with_tensor_inputs(self):
|
|
"""Test with PyTorch tensor inputs."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [
|
|
{"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([0, 1, 2])},
|
|
{"input_ids": torch.tensor([0, 1, 2, 3, 4, 5]), "labels": torch.tensor([0, 1, 2, 3, 4, 5])},
|
|
]
|
|
|
|
collator = DataCollatorForTokenClassification(tokenizer)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2, -100, -100, -100])
|
|
|
|
def test_numpy_output(self):
|
|
"""Test with NumPy output."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForTokenClassification(tokenizer, return_tensors="np")
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (2, 6))
|
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2, -100, -100, -100])
|
|
|
|
def test_immutability(self):
|
|
"""Test that collation does not mutate input data."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
for return_tensors in ["pt", "np"]:
|
|
collator = DataCollatorForTokenClassification(tokenizer, return_tensors=return_tensors)
|
|
self._check_immutability(collator, self._get_features())
|
|
|
|
|
|
# =============================================================================
|
|
# DataCollatorForSeq2Seq tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestDataCollatorForSeq2Seq(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for DataCollatorForSeq2Seq.
|
|
|
|
This collator handles encoder-decoder models, padding both input sequences
|
|
and labels (decoder input) appropriately.
|
|
"""
|
|
|
|
def _get_features(self):
|
|
return [
|
|
{"input_ids": list(range(3)), "labels": list(range(3))},
|
|
{"input_ids": list(range(6)), "labels": list(range(6))},
|
|
]
|
|
|
|
def test_padding(self):
|
|
"""Test basic padding behavior."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2, -100, -100, -100])
|
|
|
|
def test_with_tensor_inputs(self):
|
|
"""Test with PyTorch tensor inputs."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [
|
|
{"input_ids": torch.tensor(list(range(3))), "labels": torch.tensor(list(range(3)))},
|
|
{"input_ids": torch.tensor(list(range(6))), "labels": torch.tensor(list(range(6)))},
|
|
]
|
|
|
|
collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
|
|
|
def test_max_length_padding(self):
|
|
"""Test padding to max_length."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.MAX_LENGTH, max_length=10)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))
|
|
|
|
def test_pad_to_multiple_of(self):
|
|
"""Test padding to multiple of 8."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, pad_to_multiple_of=8)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 8]))
|
|
|
|
def test_custom_label_pad_token(self):
|
|
"""Test custom label padding token."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, label_pad_token_id=-1)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2, -1, -1, -1])
|
|
|
|
def test_do_not_pad(self):
|
|
"""Test DO_NOT_PAD raises on unequal lengths, works on equal."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.DO_NOT_PAD)
|
|
|
|
# Unequal lengths should raise
|
|
with self.assertRaises(ValueError):
|
|
collator(self._get_features())
|
|
|
|
# Equal lengths should work
|
|
features = [{"input_ids": list(range(3)), "labels": list(range(3))}] * 2
|
|
batch = collator(features)
|
|
self.assertEqual(batch["input_ids"][0].tolist(), list(range(3)))
|
|
self.assertEqual(batch["labels"][0].tolist(), list(range(3)))
|
|
|
|
def test_without_labels(self):
|
|
"""Test collator works without labels."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(3))}, {"input_ids": list(range(6))}]
|
|
|
|
collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
|
# Labels should either not be present or be None
|
|
self.assertTrue("labels" not in batch or batch["labels"] is None)
|
|
|
|
def test_numpy_output(self):
|
|
"""Test with NumPy output."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForSeq2Seq(tokenizer, padding=PaddingStrategy.LONGEST, return_tensors="np")
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (2, 6))
|
|
self.assertEqual(batch["labels"].shape, (2, 6))
|
|
|
|
def test_immutability(self):
|
|
"""Test that collation does not mutate input data."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
for return_tensors in ["pt", "np"]:
|
|
collator = DataCollatorForSeq2Seq(
|
|
tokenizer, padding=PaddingStrategy.LONGEST, return_tensors=return_tensors
|
|
)
|
|
self._check_immutability(collator, self._get_features())
|
|
|
|
|
|
# =============================================================================
|
|
# DataCollatorForLanguageModeling tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestDataCollatorForLanguageModeling(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for DataCollatorForLanguageModeling.
|
|
|
|
This collator supports both Masked Language Modeling (MLM) and Causal Language
|
|
Modeling (CLM). For MLM, it randomly masks tokens; for CLM, it shifts labels.
|
|
"""
|
|
|
|
def test_clm_mode(self):
|
|
"""Test CLM mode (no masking)."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))
|
|
|
|
def test_clm_with_padding(self):
|
|
"""Test CLM mode with different length sequences."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
|
|
def test_clm_pad_to_multiple_of(self):
|
|
"""Test CLM with pad_to_multiple_of."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, pad_to_multiple_of=8)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 16]))
|
|
|
|
def test_mlm_mode(self):
|
|
"""Test MLM mode with masking."""
|
|
set_seed(42)
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=True)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))
|
|
|
|
# Check that masking occurred and non-masked tokens have -100 labels
|
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
|
self.assertTrue(torch.any(masked_tokens))
|
|
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
|
|
|
def test_mlm_with_padding(self):
|
|
"""Test MLM mode with different-length sequences requiring padding."""
|
|
set_seed(42)
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=True)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
|
self.assertTrue(torch.any(masked_tokens))
|
|
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
|
|
|
def test_mlm_pad_to_multiple_of(self):
|
|
"""Test MLM mode with pad_to_multiple_of."""
|
|
set_seed(42)
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=True, pad_to_multiple_of=8)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 16]))
|
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
|
self.assertTrue(torch.any(masked_tokens))
|
|
self.assertTrue(all(x == -100 for x in batch["labels"][~masked_tokens].tolist()))
|
|
|
|
def test_with_raw_list_features(self):
|
|
"""Test LM collator with raw list features (not dicts)."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
# CLM with raw lists
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
|
features = [list(range(10)), list(range(10))]
|
|
batch = collator(features)
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
|
|
# CLM with raw lists requiring padding
|
|
features = [list(range(5)), list(range(10))]
|
|
batch = collator(features)
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
|
|
def test_mlm_seed_reproducibility(self):
|
|
"""Test that masking is reproducible with seed."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(1000))}, {"input_ids": list(range(1000))}]
|
|
|
|
collator1 = DataCollatorForLanguageModeling(tokenizer, seed=42)
|
|
batch1 = collator1(features)
|
|
|
|
collator2 = DataCollatorForLanguageModeling(tokenizer, seed=42)
|
|
batch2 = collator2(features)
|
|
|
|
self.assertTrue(torch.all(batch1["input_ids"] == batch2["input_ids"]))
|
|
self.assertTrue(torch.all(batch1["labels"] == batch2["labels"]))
|
|
|
|
# Different seed -> different results
|
|
collator3 = DataCollatorForLanguageModeling(tokenizer, seed=43)
|
|
batch3 = collator3(features)
|
|
self.assertFalse(torch.all(batch1["input_ids"] == batch3["input_ids"]))
|
|
|
|
def test_mlm_multiworker_dataloader(self):
|
|
"""Test seed works with multi-worker DataLoader."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(1000))} for _ in range(10)]
|
|
|
|
dataloader = torch.utils.data.DataLoader(
|
|
features,
|
|
batch_size=2,
|
|
num_workers=2,
|
|
generator=torch.Generator().manual_seed(42),
|
|
collate_fn=DataCollatorForLanguageModeling(tokenizer, seed=42),
|
|
)
|
|
|
|
batches = [batch["input_ids"] for batch in dataloader]
|
|
result = torch.stack(batches)
|
|
self.assertEqual(result.shape, torch.Size([5, 2, 1000]))
|
|
|
|
def test_missing_pad_token_error(self):
|
|
"""Test error when pad token is missing and padding is needed."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
tokenizer.pad_token = None
|
|
features = [{"input_ids": list(range(5))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
|
with self.assertRaises(ValueError):
|
|
collator(features)
|
|
|
|
def test_numpy_output(self):
|
|
"""Test with NumPy output."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors="np")
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
|
|
|
def test_numpy_mlm(self):
|
|
"""Test MLM mode with NumPy output."""
|
|
set_seed(42)
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=True, return_tensors="np")
|
|
batch = collator(features)
|
|
|
|
masked_tokens = batch["input_ids"] == tokenizer.mask_token_id
|
|
self.assertTrue(np.any(masked_tokens))
|
|
|
|
def test_immutability(self):
|
|
"""Test that collation does not mutate input data."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
for return_tensors in ["pt", "np"]:
|
|
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors=return_tensors)
|
|
self._check_immutability(collator, copy.deepcopy(features))
|
|
|
|
# -------------------- Unit tests for internal methods --------------------
|
|
|
|
def test_calc_word_ids_and_prob_mask(self):
|
|
"""Test word ID assignment and probability mask generation."""
|
|
offsets = np.array(
|
|
[
|
|
[(0, 0), (0, 3), (3, 4), (5, 6), (6, 7), (8, 9)],
|
|
[(0, 0), (0, 3), (3, 4), (5, 6), (6, 7), (0, 0)],
|
|
[(0, 0), (0, 3), (3, 4), (0, 0), (6, 7), (0, 0)],
|
|
[(1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7)],
|
|
[(1, 1), (2, 2), (3, 4), (5, 6), (7, 8), (9, 10)],
|
|
[(0, 0), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)],
|
|
]
|
|
)
|
|
|
|
special_tokens_mask = np.array(
|
|
[
|
|
[1, 0, 0, 0, 0, 0],
|
|
[1, 0, 0, 0, 0, 1],
|
|
[1, 0, 0, 1, 0, 1],
|
|
[0, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0],
|
|
[1, 1, 1, 1, 1, 1],
|
|
]
|
|
)
|
|
|
|
word_ids, prob_mask = DataCollatorForLanguageModeling._calc_word_ids_and_prob_mask(
|
|
offsets, special_tokens_mask
|
|
)
|
|
|
|
expected_word_ids = np.array(
|
|
[
|
|
[-1, 1, 1, 2, 2, 3],
|
|
[-1, 1, 1, 2, 2, -1],
|
|
[-1, 1, 1, -1, 2, -1],
|
|
[1, 1, 1, 1, 1, 1],
|
|
[1, 2, 3, 4, 5, 6],
|
|
[-1, -1, -1, -1, -1, -1],
|
|
]
|
|
)
|
|
|
|
expected_prob_mask = np.array(
|
|
[
|
|
[1, 0, 1, 0, 1, 0],
|
|
[1, 0, 1, 0, 1, 1],
|
|
[1, 0, 1, 1, 0, 1],
|
|
[0, 1, 1, 1, 1, 1],
|
|
[0, 0, 0, 0, 0, 0],
|
|
[1, 1, 1, 1, 1, 1],
|
|
]
|
|
)
|
|
|
|
np.testing.assert_array_equal(word_ids, expected_word_ids)
|
|
np.testing.assert_array_equal(prob_mask, expected_prob_mask)
|
|
|
|
def test_whole_word_mask_internal(self):
|
|
"""Test mask expansion to cover all subword tokens of masked words."""
|
|
word_ids = np.array(
|
|
[
|
|
[-1, 1, 1, 2, 2, 3],
|
|
[-1, 1, 1, 2, 2, -1],
|
|
[-1, 1, 1, -1, 2, -1],
|
|
[1, 1, 1, 1, 1, 1],
|
|
[1, 2, 3, 4, 5, 6],
|
|
[1, 2, 3, 4, 5, 6],
|
|
[-1, -1, -1, -1, -1, -1],
|
|
]
|
|
)
|
|
|
|
mask = np.array(
|
|
[
|
|
[0, 1, 0, 0, 0, 0],
|
|
[0, 1, 0, 1, 0, 0],
|
|
[0, 0, 0, 0, 1, 0],
|
|
[1, 0, 0, 0, 0, 0],
|
|
[0, 0, 0, 0, 0, 0],
|
|
[0, 1, 0, 1, 0, 1],
|
|
[0, 0, 0, 0, 0, 0],
|
|
]
|
|
).astype(bool)
|
|
|
|
output = DataCollatorForLanguageModeling._whole_word_mask(word_ids, mask)
|
|
|
|
expected = np.array(
|
|
[
|
|
[0, 1, 1, 0, 0, 0],
|
|
[0, 1, 1, 1, 1, 0],
|
|
[0, 0, 0, 0, 1, 0],
|
|
[1, 1, 1, 1, 1, 1],
|
|
[0, 0, 0, 0, 0, 0],
|
|
[0, 1, 0, 1, 0, 1],
|
|
[0, 0, 0, 0, 0, 0],
|
|
]
|
|
).astype(bool)
|
|
|
|
np.testing.assert_array_equal(output, expected)
|
|
|
|
|
|
# =============================================================================
|
|
# DataCollatorForWholeWordMask tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestDataCollatorForWholeWordMask(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for DataCollatorForWholeWordMask.
|
|
|
|
This collator extends MLM to ensure that when a token is masked, all other
|
|
tokens from the same word are also masked (whole word masking).
|
|
"""
|
|
|
|
def _get_tokenizer_and_features(self):
|
|
tokenizer = BertTokenizerFast(self.vocab_file)
|
|
input_tokens = [f"token_{i}" for i in range(8)]
|
|
tokenizer.add_tokens(input_tokens)
|
|
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
|
return tokenizer, features
|
|
|
|
def test_basic(self):
|
|
"""Test whole word masking masks complete words."""
|
|
tokenizer, features = self._get_tokenizer_and_features()
|
|
collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
|
|
|
def test_with_numpy_inputs(self):
|
|
"""Test with numpy array inputs."""
|
|
tokenizer, _ = self._get_tokenizer_and_features()
|
|
input_tokens = [f"token_{i}" for i in range(8)]
|
|
features = [
|
|
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("np") for _ in range(2)
|
|
]
|
|
|
|
collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
|
|
|
def test_with_tensor_inputs(self):
|
|
"""Test with PyTorch tensor inputs."""
|
|
tokenizer, _ = self._get_tokenizer_and_features()
|
|
input_tokens = [f"token_{i}" for i in range(8)]
|
|
features = [
|
|
tokenizer(" ".join(input_tokens), return_offsets_mapping=True).convert_to_tensors("pt") for _ in range(2)
|
|
]
|
|
|
|
collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="pt")
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 10)))
|
|
|
|
def test_seed_reproducibility(self):
|
|
"""Test reproducibility with seed."""
|
|
tokenizer = BertTokenizerFast(self.vocab_file)
|
|
input_tokens = [f"token_{i}" for i in range(998)]
|
|
tokenizer.add_tokens(input_tokens)
|
|
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(2)]
|
|
|
|
collator1 = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
|
|
batch1 = collator1(features)
|
|
|
|
collator2 = DataCollatorForWholeWordMask(tokenizer, seed=42, return_tensors="np")
|
|
batch2 = collator2(features)
|
|
|
|
np.testing.assert_array_equal(batch1["input_ids"], batch2["input_ids"])
|
|
|
|
# Different seed -> different results
|
|
collator3 = DataCollatorForWholeWordMask(tokenizer, seed=43, return_tensors="np")
|
|
batch3 = collator3(features)
|
|
self.assertFalse(np.all(batch1["input_ids"] == batch3["input_ids"]))
|
|
|
|
def test_seed_multiworker_dataloader(self):
|
|
"""Test seed reproducibility with multi-worker DataLoader."""
|
|
tokenizer = BertTokenizerFast(self.vocab_file)
|
|
input_tokens = [f"token_{i}" for i in range(998)]
|
|
tokenizer.add_tokens(input_tokens)
|
|
features = [tokenizer(" ".join(input_tokens), return_offsets_mapping=True) for _ in range(10)]
|
|
|
|
dataloader1 = torch.utils.data.DataLoader(
|
|
features,
|
|
batch_size=2,
|
|
num_workers=2,
|
|
generator=torch.Generator().manual_seed(42),
|
|
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
|
)
|
|
batches1 = torch.stack([batch["input_ids"] for batch in dataloader1])
|
|
|
|
dataloader2 = torch.utils.data.DataLoader(
|
|
features,
|
|
batch_size=2,
|
|
num_workers=2,
|
|
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=42),
|
|
)
|
|
batches2 = torch.stack([batch["input_ids"] for batch in dataloader2])
|
|
|
|
self.assertTrue(torch.all(batches1 == batches2))
|
|
|
|
# Different seed -> different results
|
|
dataloader3 = torch.utils.data.DataLoader(
|
|
features,
|
|
batch_size=2,
|
|
num_workers=2,
|
|
collate_fn=DataCollatorForWholeWordMask(tokenizer, seed=43),
|
|
)
|
|
batches3 = torch.stack([batch["input_ids"] for batch in dataloader3])
|
|
self.assertFalse(torch.all(batches1 == batches3))
|
|
|
|
def test_numpy_output(self):
|
|
"""Test with NumPy output."""
|
|
tokenizer, features = self._get_tokenizer_and_features()
|
|
collator = DataCollatorForWholeWordMask(tokenizer, return_tensors="np")
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
|
self.assertEqual(batch["labels"].shape, (2, 10))
|
|
|
|
def test_immutability(self):
|
|
"""Test that collation does not mutate input data."""
|
|
tokenizer, features = self._get_tokenizer_and_features()
|
|
features = [dict(f) for f in features]
|
|
|
|
for return_tensors in ["pt", "np"]:
|
|
collator = DataCollatorForWholeWordMask(tokenizer, return_tensors=return_tensors)
|
|
self._check_immutability(collator, copy.deepcopy(features))
|
|
|
|
|
|
# =============================================================================
|
|
# DataCollatorForPermutationLanguageModeling tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestDataCollatorForPermutationLanguageModeling(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for DataCollatorForPermutationLanguageModeling.
|
|
|
|
This collator implements XLNet-style permutation language modeling,
|
|
generating perm_mask and target_mapping for each batch.
|
|
"""
|
|
|
|
def test_basic(self):
|
|
"""Test basic PLM collation."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForPermutationLanguageModeling(tokenizer)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
self.assertEqual(batch["perm_mask"].shape, torch.Size([2, 10, 10]))
|
|
self.assertEqual(batch["target_mapping"].shape, torch.Size([2, 10, 10]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))
|
|
|
|
def test_with_padding(self):
|
|
"""Test PLM with different length sequences."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(4))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForPermutationLanguageModeling(tokenizer)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
|
|
|
def test_odd_sequence_error(self):
|
|
"""Test that odd sequence lengths raise an error."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForPermutationLanguageModeling(tokenizer)
|
|
|
|
features = [{"input_ids": list(range(5))}]
|
|
with self.assertRaises(ValueError):
|
|
collator(features)
|
|
|
|
def test_numpy_output(self):
|
|
"""Test with NumPy output."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
collator = DataCollatorForPermutationLanguageModeling(tokenizer, return_tensors="np")
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (2, 10))
|
|
self.assertEqual(batch["perm_mask"].shape, (2, 10, 10))
|
|
self.assertEqual(batch["target_mapping"].shape, (2, 10, 10))
|
|
|
|
def test_immutability(self):
|
|
"""Test that collation does not mutate input data."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [{"input_ids": list(range(10))}, {"input_ids": list(range(10))}]
|
|
|
|
for return_tensors in ["pt", "np"]:
|
|
collator = DataCollatorForPermutationLanguageModeling(tokenizer, return_tensors=return_tensors)
|
|
self._check_immutability(collator, copy.deepcopy(features))
|
|
|
|
|
|
# =============================================================================
|
|
# Next Sentence Prediction tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestNextSentencePrediction(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for Next Sentence Prediction (NSP) with DataCollatorForLanguageModeling.
|
|
|
|
NSP is used in BERT pretraining where the model predicts if two sentences
|
|
follow each other in the original text.
|
|
"""
|
|
|
|
def _get_features(self):
|
|
return [
|
|
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "next_sentence_label": i}
|
|
for i in range(2)
|
|
]
|
|
|
|
def test_nsp(self):
|
|
"""Test NSP labels are preserved during collation."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForLanguageModeling(tokenizer)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 5]))
|
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size([2, 5]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 5]))
|
|
self.assertEqual(batch["next_sentence_label"].shape, torch.Size([2]))
|
|
|
|
def test_nsp_with_padding(self):
|
|
"""Test NSP with pad_to_multiple_of."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
|
self.assertEqual(batch["next_sentence_label"].shape, torch.Size([2]))
|
|
|
|
def test_numpy_output(self):
|
|
"""Test NSP with NumPy output."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="np")
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (2, 5))
|
|
self.assertEqual(batch["next_sentence_label"].shape, (2,))
|
|
|
|
def test_immutability(self):
|
|
"""Test that NSP collation does not mutate input data."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
for return_tensors in ["pt", "np"]:
|
|
collator = DataCollatorForLanguageModeling(tokenizer, return_tensors=return_tensors)
|
|
self._check_immutability(collator, self._get_features())
|
|
|
|
|
|
# =============================================================================
|
|
# Sentence Order Prediction tests
|
|
# =============================================================================
|
|
|
|
|
|
@require_torch
|
|
class TestSentenceOrderPrediction(DataCollatorTestMixin, unittest.TestCase):
|
|
"""
|
|
Tests for Sentence Order Prediction (SOP) with DataCollatorForLanguageModeling.
|
|
|
|
SOP is used in ALBERT pretraining where the model predicts if two sentences
|
|
are in the correct order.
|
|
"""
|
|
|
|
def _get_features(self):
|
|
return [
|
|
{"input_ids": [0, 1, 2, 3, 4], "token_type_ids": [0, 1, 2, 3, 4], "sentence_order_label": i}
|
|
for i in range(2)
|
|
]
|
|
|
|
def test_sop(self):
|
|
"""Test SOP labels are preserved during collation."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForLanguageModeling(tokenizer)
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 5]))
|
|
self.assertEqual(batch["token_type_ids"].shape, torch.Size([2, 5]))
|
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 5]))
|
|
self.assertEqual(batch["sentence_order_label"].shape, torch.Size([2]))
|
|
|
|
def test_sop_with_tensor_inputs(self):
|
|
"""Test SOP with PyTorch tensor inputs."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [
|
|
{
|
|
"input_ids": torch.tensor([0, 1, 2, 3, 4]),
|
|
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
|
|
"sentence_order_label": i,
|
|
}
|
|
for i in range(2)
|
|
]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["sentence_order_label"].shape, torch.Size([2]))
|
|
|
|
def test_sop_with_padding(self):
|
|
"""Test SOP with pad_to_multiple_of."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
features = [
|
|
{
|
|
"input_ids": torch.tensor([0, 1, 2, 3, 4]),
|
|
"token_type_ids": torch.tensor([0, 1, 2, 3, 4]),
|
|
"sentence_order_label": i,
|
|
}
|
|
for i in range(2)
|
|
]
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer, pad_to_multiple_of=8)
|
|
batch = collator(features)
|
|
|
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
|
self.assertEqual(batch["sentence_order_label"].shape, torch.Size([2]))
|
|
|
|
def test_numpy_output(self):
|
|
"""Test SOP with NumPy output."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
collator = DataCollatorForLanguageModeling(tokenizer, return_tensors="np")
|
|
batch = collator(self._get_features())
|
|
|
|
self.assertEqual(batch["input_ids"].shape, (2, 5))
|
|
self.assertEqual(batch["sentence_order_label"].shape, (2,))
|
|
|
|
def test_immutability(self):
|
|
"""Test that SOP collation does not mutate input data."""
|
|
tokenizer = BertTokenizer(self.vocab_file)
|
|
|
|
for return_tensors in ["pt", "np"]:
|
|
collator = DataCollatorForLanguageModeling(tokenizer, return_tensors=return_tensors)
|
|
self._check_immutability(collator, self._get_features())
|