Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
2099 lines
206 KiB
Python
2099 lines
206 KiB
Python
# Copyright 2022 The HuggingFace Inc. team.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import inspect
|
||
import json
|
||
import os
|
||
import re
|
||
import shutil
|
||
import tempfile
|
||
import unittest
|
||
|
||
from parameterized import parameterized
|
||
|
||
from transformers import (
|
||
AddedToken,
|
||
MarkupLMTokenizerFast,
|
||
PreTrainedTokenizerBase,
|
||
is_mlx_available,
|
||
is_torch_available,
|
||
logging,
|
||
)
|
||
from transformers.models.markuplm.tokenization_markuplm import VOCAB_FILES_NAMES, MarkupLMTokenizer
|
||
from transformers.testing_utils import require_tokenizers, slow
|
||
|
||
from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizersExtractor, TokenizerTesterMixin
|
||
|
||
|
||
logger = logging.get_logger(__name__)
|
||
|
||
|
||
@require_tokenizers
|
||
class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||
from_pretrained_id = "microsoft/markuplm-base"
|
||
tokenizer_class = MarkupLMTokenizer
|
||
rust_tokenizer_class = MarkupLMTokenizerFast
|
||
test_rust_tokenizer = True
|
||
from_pretrained_kwargs = {"cls_token": "<s>"}
|
||
test_seq2seq = False
|
||
|
||
input_text = "Hello😊 <s>intro</s> falsé-world! 生活的真谛"
|
||
integration_expected_tokens = ['Hello', 'ðŁĺ', 'Ĭ', 'Ġ', '<s>', 'int', 'ro', '</s>', 'Ġfals', 'é', '-', 'world', '!', 'Ġç', 'Ķ', 'Ł', 'æ', '´', '»', 'çļĦ', 'çľ', 'Ł', 'è', '°', 'Ľ'] # fmt: skip
|
||
integration_expected_token_ids = [31414, 18636, 27969, 0, 2544, 1001, 2, 506, 1536, 1140, 12, 8331, 328, 48998, 37127, 20024, 2023, 44574, 49122, 4333, 36484, 7487, 3726] # fmt: skip
|
||
expected_tokens_from_ids = ['Hello', 'ðŁĺ', 'Ĭ', '<s>', 'int', 'ro', '</s>', 'f', 'als', 'é', '-', 'world', '!', 'çĶŁ', 'æ', '´', '»', 'çļĦ', 'çľ', 'Ł', 'è', '°', 'Ľ'] # fmt: skip
|
||
integration_expected_decoded_text = "Hello😊<s>intro</s>falsé-world!生活的真谛"
|
||
text_from_tokens = "Hello😊 <s>intro</s> falsé-world! 生活的真谛"
|
||
|
||
@classmethod
|
||
def setUpClass(cls):
|
||
super().setUpClass()
|
||
|
||
# Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
|
||
vocab = ["l", "o", "w", "e", "r", "s", "t", "i", "d", "n", "\u0120", "\u0120l", "\u0120n", "\u0120lo", "\u0120low", "er", "\u0120lowest", "\u0120newer", "\u0120wider", "\u0120hello", "\u0120world", "<unk>",] # fmt: skip
|
||
vocab_tokens = dict(zip(vocab, range(len(vocab))))
|
||
merges = ["#version: 0.2", "\u0120 l", "\u0120l o", "\u0120lo w", "e r", ""]
|
||
cls.tags_dict = {"a": 0, "abbr": 1, "acronym": 2, "address": 3}
|
||
cls.special_tokens_map = {"unk_token": "<unk>"}
|
||
|
||
cls.vocab_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
|
||
cls.merges_file = os.path.join(cls.tmpdirname, VOCAB_FILES_NAMES["merges_file"])
|
||
cls.tokenizer_config_file = os.path.join(cls.tmpdirname, "tokenizer_config.json")
|
||
|
||
with open(cls.vocab_file, "w", encoding="utf-8") as fp:
|
||
fp.write(json.dumps(vocab_tokens) + "\n")
|
||
with open(cls.merges_file, "w", encoding="utf-8") as fp:
|
||
fp.write("\n".join(merges))
|
||
with open(cls.tokenizer_config_file, "w", encoding="utf-8") as fp:
|
||
fp.write(json.dumps({"tags_dict": cls.tags_dict}))
|
||
|
||
def _run_integration_checks(self, tokenizer, tokenizer_type):
|
||
tokens = tokenizer.tokenize(self.input_text)
|
||
self.assertEqual(
|
||
tokens,
|
||
self.integration_expected_tokens,
|
||
f"Tokenized tokens don't match expected for {tokenizer.__class__.__name__} ({tokenizer_type})",
|
||
)
|
||
ids = tokenizer.encode(self.input_text, add_special_tokens=False)
|
||
self.assertEqual(
|
||
ids,
|
||
self.integration_expected_token_ids,
|
||
f"Encoded IDs don't match expected for {tokenizer.__class__.__name__} ({tokenizer_type})",
|
||
)
|
||
decoded_text = tokenizer.decode(self.integration_expected_token_ids, clean_up_tokenization_spaces=False)
|
||
self.assertEqual(
|
||
decoded_text,
|
||
self.integration_expected_decoded_text,
|
||
f"Decoded text doesn't match expected for {tokenizer.__class__.__name__} ({tokenizer_type})",
|
||
)
|
||
tokens_from_ids = tokenizer.convert_ids_to_tokens(self.integration_expected_token_ids)
|
||
self.assertEqual(
|
||
tokens_from_ids,
|
||
self.expected_tokens_from_ids,
|
||
f"Tokens from IDs don't match expected for {tokenizer.__class__.__name__} ({tokenizer_type})",
|
||
)
|
||
|
||
def get_nodes_and_xpaths(self):
|
||
nodes = ["hello", "world"]
|
||
xpaths = ["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"]
|
||
|
||
return nodes, xpaths
|
||
|
||
def get_nodes_and_xpaths_batch(self):
|
||
nodes = [["hello world", "running"], ["hello my name is bob"]]
|
||
xpaths = [
|
||
["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"],
|
||
["/html/body/div/li[2]/div/span"],
|
||
]
|
||
|
||
return nodes, xpaths
|
||
|
||
def get_question_nodes_and_xpaths(self):
|
||
question = "what's his name?"
|
||
nodes = ["hello world"]
|
||
xpaths = ["/html/body/div/li[1]/div/span"] # , "/html/body/div/li[1]/div/span"]
|
||
|
||
return question, nodes, xpaths
|
||
|
||
def get_extracted_tokenizer(self, reference_tokenizer=None):
|
||
if reference_tokenizer is None:
|
||
reference_tokenizer = self.get_tokenizer()
|
||
|
||
try:
|
||
tokenizer_json_path = os.path.join(self.tmpdirname, "tokenizer.json")
|
||
if not os.path.exists(tokenizer_json_path):
|
||
return None
|
||
|
||
extractor = TokenizersExtractor(tokenizer_json_path)
|
||
vocab_ids, vocab_scores, merges, added_tokens_decoder = extractor.extract()
|
||
if _type := getattr(self.tokenizer_class, "model", None):
|
||
if _type.__name__ == "BPE" or _type.__name__ == "WordPiece":
|
||
vocab = vocab_ids
|
||
else:
|
||
vocab = vocab_scores
|
||
|
||
init_kwargs = {
|
||
"vocab": vocab,
|
||
"merges": merges,
|
||
"do_lower_case": False,
|
||
"keep_accents": True,
|
||
"added_tokens_decoder": dict(added_tokens_decoder.items()),
|
||
}
|
||
|
||
tags_dict = getattr(reference_tokenizer, "tags_dict", None)
|
||
if tags_dict is None:
|
||
raise ValueError("MarkupLMTokenizer requires a tags_dict for initialization.")
|
||
init_kwargs["tags_dict"] = tags_dict
|
||
|
||
if self.from_pretrained_kwargs is not None:
|
||
init_kwargs.update(self.from_pretrained_kwargs)
|
||
|
||
return self.tokenizer_class(**init_kwargs)
|
||
except (TypeError, Exception):
|
||
raise
|
||
|
||
def get_question_nodes_and_xpaths_batch(self):
|
||
questions = ["what's his name?", "how is he called?"]
|
||
nodes = [["hello world", "running"], ["hello my name is bob"]]
|
||
xpaths = [
|
||
["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"],
|
||
["/html/body/div/li[2]/div/span"],
|
||
]
|
||
|
||
return questions, nodes, xpaths
|
||
|
||
def get_empty_nodes_and_xpaths(self):
|
||
nodes = ["test", "empty", ""]
|
||
xpaths = ["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"]
|
||
|
||
return nodes, xpaths
|
||
|
||
def get_empty_nodes_and_xpaths_batch(self):
|
||
nodes = [["test", "empty", ""], ["one", "more", "empty", ""]]
|
||
xpaths = [
|
||
["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"],
|
||
[
|
||
"/html/body/div/li[2]/div/span",
|
||
"/html/body/div/li[2]/div/span",
|
||
"/html/body/div/li[2]/div/span",
|
||
"/html/body/div/li[2]/div/span",
|
||
],
|
||
]
|
||
|
||
return nodes, xpaths
|
||
|
||
def get_empty_question_nodes_and_xpaths(self):
|
||
question = ""
|
||
nodes = ["test", "empty", ""]
|
||
xpaths = ["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"]
|
||
|
||
return question, nodes, xpaths
|
||
|
||
def get_empty_question_nodes_and_xpaths_batch(self):
|
||
questions = ["what's his name?", ""]
|
||
nodes = [["test", "empty", ""], ["one", "more", "empty", ""]]
|
||
xpaths = [
|
||
["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"],
|
||
[
|
||
"/html/body/div/li[2]/div/span",
|
||
"/html/body/div/li[2]/div/span",
|
||
"/html/body/div/li[2]/div/span",
|
||
"/html/body/div/li[2]/div/span",
|
||
],
|
||
]
|
||
|
||
return questions, nodes, xpaths
|
||
|
||
@unittest.skip(reason="Chat template tests don't play well with table/layout models.")
|
||
def test_chat_template_batched(self):
|
||
pass
|
||
|
||
def get_input_output_texts(self, tokenizer):
|
||
input_text = "UNwant\u00e9d,running"
|
||
output_text = "unwanted, running"
|
||
return input_text, output_text
|
||
|
||
def convert_batch_encode_plus_format_to_encode_plus(self, batch_encode_plus_sequences):
|
||
first_key = next(iter(batch_encode_plus_sequences))
|
||
batch_size = len(batch_encode_plus_sequences[first_key])
|
||
encode_plus_sequences = []
|
||
for i in range(batch_size):
|
||
single = {}
|
||
for key, value in batch_encode_plus_sequences.items():
|
||
if key != "encodings":
|
||
single[key] = value[i]
|
||
encode_plus_sequences.append(single)
|
||
return encode_plus_sequences
|
||
|
||
def test_add_special_tokens(self):
|
||
tokenizers: list[MarkupLMTokenizer] = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
special_token = "[SPECIAL_TOKEN]"
|
||
special_token_xpath = "/html/body/div/li[1]/div/span"
|
||
|
||
tokenizer.add_special_tokens({"cls_token": special_token})
|
||
encoded_special_token = tokenizer.encode(
|
||
[special_token], xpaths=[special_token_xpath], add_special_tokens=False
|
||
)
|
||
self.assertEqual(len(encoded_special_token), 1)
|
||
|
||
decoded = tokenizer.decode(encoded_special_token, skip_special_tokens=True)
|
||
self.assertTrue(special_token not in decoded)
|
||
|
||
def test_add_tokens_tokenizer(self):
|
||
tokenizers: list[MarkupLMTokenizer] = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
vocab_size = tokenizer.vocab_size
|
||
all_size = len(tokenizer)
|
||
|
||
self.assertNotEqual(vocab_size, 0)
|
||
|
||
# We usually have added tokens from the start in tests because our vocab fixtures are
|
||
# smaller than the original vocabs - let's not assert this
|
||
# self.assertEqual(vocab_size, all_size)
|
||
|
||
new_toks = [
|
||
AddedToken("aaaaa", rstrip=True, lstrip=True),
|
||
AddedToken("bbbbbb", rstrip=True, lstrip=True),
|
||
AddedToken("cccccccccdddddddd", rstrip=True, lstrip=True),
|
||
]
|
||
added_toks = tokenizer.add_tokens(new_toks)
|
||
vocab_size_2 = tokenizer.vocab_size
|
||
all_size_2 = len(tokenizer)
|
||
|
||
self.assertNotEqual(vocab_size_2, 0)
|
||
self.assertEqual(vocab_size + 3, vocab_size_2 + 3)
|
||
self.assertEqual(added_toks, len(new_toks))
|
||
self.assertEqual(all_size_2, all_size + len(new_toks))
|
||
|
||
nodes = "aaaaa bbbbbb low cccccccccdddddddd l".split()
|
||
xpaths = ["/html/body/div/li[1]/div/span" for _ in range(len(nodes))]
|
||
|
||
tokens = tokenizer.encode(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
|
||
self.assertGreaterEqual(len(tokens), 4)
|
||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||
|
||
new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"}
|
||
added_toks_2 = tokenizer.add_special_tokens(new_toks_2)
|
||
vocab_size_3 = tokenizer.vocab_size
|
||
all_size_3 = len(tokenizer)
|
||
|
||
self.assertNotEqual(vocab_size_3, 0)
|
||
self.assertEqual(vocab_size, vocab_size_3)
|
||
self.assertEqual(added_toks_2, len(new_toks_2))
|
||
self.assertEqual(all_size_3, all_size_2 + len(new_toks_2))
|
||
|
||
nodes = ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l".split()
|
||
xpaths = ["/html/body/div/li[1]/div/span" for _ in range(len(nodes))]
|
||
|
||
tokens = tokenizer.encode(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
add_special_tokens=False,
|
||
)
|
||
|
||
self.assertGreaterEqual(len(tokens), 6)
|
||
self.assertGreater(tokens[0], tokenizer.vocab_size - 1)
|
||
self.assertGreater(tokens[0], tokens[1])
|
||
self.assertGreater(tokens[-2], tokenizer.vocab_size - 1)
|
||
self.assertGreater(tokens[-2], tokens[-3])
|
||
self.assertEqual(tokens[0], tokenizer.eos_token_id)
|
||
self.assertEqual(tokens[-2], tokenizer.pad_token_id)
|
||
|
||
@require_tokenizers
|
||
def test_encode_decode_with_spaces(self):
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
|
||
new_toks = [AddedToken("[ABC]", normalized=False), AddedToken("[DEF]", normalized=False)]
|
||
tokenizer.add_tokens(new_toks)
|
||
input = "[ABC][DEF][ABC][DEF]"
|
||
if self.space_between_special_tokens:
|
||
output = "[ABC] [DEF] [ABC] [DEF]"
|
||
else:
|
||
output = input
|
||
encoded = tokenizer.encode(input.split(), xpaths=xpaths, add_special_tokens=False)
|
||
decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens)
|
||
self.assertIn(decoded, [output, output.lower()])
|
||
|
||
@unittest.skip(reason="Not implemented")
|
||
def test_right_and_left_truncation(self):
|
||
pass
|
||
|
||
@parameterized.expand([(True,), (False,)])
|
||
def test_encode_plus_with_padding(self, use_padding_as_call_kwarg: bool):
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
|
||
# check correct behaviour if no pad_token_id exists and add it eventually
|
||
self._check_no_pad_token_padding(tokenizer, nodes)
|
||
|
||
padding_size = 10
|
||
padding_idx = tokenizer.pad_token_id
|
||
|
||
encoded_sequence = tokenizer.encode_plus(nodes, xpaths=xpaths, return_special_tokens_mask=True)
|
||
input_ids = encoded_sequence["input_ids"]
|
||
special_tokens_mask = encoded_sequence["special_tokens_mask"]
|
||
sequence_length = len(input_ids)
|
||
|
||
# Test 'longest' and 'no_padding' don't do anything
|
||
tokenizer.padding_side = "right"
|
||
|
||
not_padded_sequence = tokenizer.encode_plus(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
padding=False,
|
||
return_special_tokens_mask=True,
|
||
)
|
||
not_padded_input_ids = not_padded_sequence["input_ids"]
|
||
|
||
not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"]
|
||
not_padded_sequence_length = len(not_padded_input_ids)
|
||
|
||
self.assertTrue(sequence_length == not_padded_sequence_length)
|
||
self.assertTrue(input_ids == not_padded_input_ids)
|
||
self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask)
|
||
|
||
not_padded_sequence = tokenizer.encode_plus(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
padding=False,
|
||
return_special_tokens_mask=True,
|
||
)
|
||
not_padded_input_ids = not_padded_sequence["input_ids"]
|
||
|
||
not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"]
|
||
not_padded_sequence_length = len(not_padded_input_ids)
|
||
|
||
self.assertTrue(sequence_length == not_padded_sequence_length)
|
||
self.assertTrue(input_ids == not_padded_input_ids)
|
||
self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask)
|
||
|
||
# Test right padding
|
||
tokenizer_kwargs_right = {
|
||
"max_length": sequence_length + padding_size,
|
||
"padding": "max_length",
|
||
"return_special_tokens_mask": True,
|
||
}
|
||
|
||
if not use_padding_as_call_kwarg:
|
||
tokenizer.padding_side = "right"
|
||
else:
|
||
tokenizer_kwargs_right["padding_side"] = "right"
|
||
|
||
right_padded_sequence = tokenizer.encode_plus(nodes, xpaths=xpaths, **tokenizer_kwargs_right)
|
||
right_padded_input_ids = right_padded_sequence["input_ids"]
|
||
|
||
right_padded_special_tokens_mask = right_padded_sequence["special_tokens_mask"]
|
||
right_padded_sequence_length = len(right_padded_input_ids)
|
||
|
||
self.assertTrue(sequence_length + padding_size == right_padded_sequence_length)
|
||
self.assertTrue(input_ids + [padding_idx] * padding_size == right_padded_input_ids)
|
||
self.assertTrue(special_tokens_mask + [1] * padding_size == right_padded_special_tokens_mask)
|
||
|
||
# Test left padding
|
||
tokenizer_kwargs_left = {
|
||
"max_length": sequence_length + padding_size,
|
||
"padding": "max_length",
|
||
"return_special_tokens_mask": True,
|
||
}
|
||
|
||
if not use_padding_as_call_kwarg:
|
||
tokenizer.padding_side = "left"
|
||
else:
|
||
tokenizer_kwargs_left["padding_side"] = "left"
|
||
|
||
left_padded_sequence = tokenizer.encode_plus(nodes, xpaths=xpaths, **tokenizer_kwargs_left)
|
||
left_padded_input_ids = left_padded_sequence["input_ids"]
|
||
left_padded_special_tokens_mask = left_padded_sequence["special_tokens_mask"]
|
||
left_padded_sequence_length = len(left_padded_input_ids)
|
||
|
||
self.assertTrue(sequence_length + padding_size == left_padded_sequence_length)
|
||
self.assertTrue([padding_idx] * padding_size + input_ids == left_padded_input_ids)
|
||
self.assertTrue([1] * padding_size + special_tokens_mask == left_padded_special_tokens_mask)
|
||
|
||
if "token_type_ids" in tokenizer.model_input_names:
|
||
token_type_ids = encoded_sequence["token_type_ids"]
|
||
left_padded_token_type_ids = left_padded_sequence["token_type_ids"]
|
||
right_padded_token_type_ids = right_padded_sequence["token_type_ids"]
|
||
|
||
assert token_type_ids + [0] * padding_size == right_padded_token_type_ids
|
||
assert [0] * padding_size + token_type_ids == left_padded_token_type_ids
|
||
|
||
if "attention_mask" in tokenizer.model_input_names:
|
||
attention_mask = encoded_sequence["attention_mask"]
|
||
right_padded_attention_mask = right_padded_sequence["attention_mask"]
|
||
left_padded_attention_mask = left_padded_sequence["attention_mask"]
|
||
|
||
self.assertTrue(attention_mask + [0] * padding_size == right_padded_attention_mask)
|
||
self.assertTrue([0] * padding_size + attention_mask == left_padded_attention_mask)
|
||
|
||
def test_internal_consistency(self):
|
||
tokenizers = self.get_tokenizers()
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
|
||
tokens = []
|
||
for word in nodes:
|
||
tokens.extend(tokenizer.tokenize(word))
|
||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||
ids_2 = tokenizer.encode(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
self.assertListEqual(ids, ids_2)
|
||
|
||
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
|
||
self.assertNotEqual(len(tokens_2), 0)
|
||
text_2 = tokenizer.decode(ids)
|
||
self.assertIsInstance(text_2, str)
|
||
|
||
def test_mask_output(self):
|
||
tokenizers = self.get_tokenizers(fast=False, do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
|
||
if (
|
||
tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer"
|
||
and "token_type_ids" in tokenizer.model_input_names
|
||
):
|
||
information = tokenizer.encode_plus(nodes, xpaths=xpaths, add_special_tokens=True)
|
||
sequences, mask = information["input_ids"], information["token_type_ids"]
|
||
self.assertEqual(len(sequences), len(mask))
|
||
|
||
def test_number_of_added_tokens(self):
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
# test 1: single sequence
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
|
||
sequences = tokenizer.encode(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
attached_sequences = tokenizer.encode(nodes, xpaths=xpaths, add_special_tokens=True)
|
||
|
||
# Method is implemented (e.g. not GPT-2)
|
||
if len(attached_sequences) != 2:
|
||
self.assertEqual(
|
||
tokenizer.num_special_tokens_to_add(pair=False), len(attached_sequences) - len(sequences)
|
||
)
|
||
|
||
# test 2: two sequences
|
||
question, nodes, xpaths = self.get_question_nodes_and_xpaths()
|
||
|
||
sequences = tokenizer.encode(question, nodes, xpaths=xpaths, add_special_tokens=False)
|
||
attached_sequences = tokenizer.encode(question, nodes, xpaths=xpaths, add_special_tokens=True)
|
||
|
||
# Method is implemented (e.g. not GPT-2)
|
||
if len(attached_sequences) != 2:
|
||
self.assertEqual(
|
||
tokenizer.num_special_tokens_to_add(pair=True), len(attached_sequences) - len(sequences)
|
||
)
|
||
|
||
def test_padding(self, max_length=50):
|
||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||
tokenizer_r = self.get_tokenizer(pretrained_name, **kwargs)
|
||
tokenizer_p = self.get_tokenizer(pretrained_name, **kwargs)
|
||
|
||
self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id)
|
||
pad_token_id = tokenizer_p.pad_token_id
|
||
|
||
# Encode - Simple input
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
input_r = tokenizer_r.encode(nodes, xpaths=xpaths, max_length=max_length, padding="max_length")
|
||
input_p = tokenizer_p.encode(nodes, xpaths=xpaths, max_length=max_length, padding="max_length")
|
||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||
|
||
input_r = tokenizer_r.encode(nodes, xpaths=xpaths, padding="longest")
|
||
input_p = tokenizer_p.encode(nodes, xpaths=xpaths, padding=True)
|
||
self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
|
||
|
||
# Encode - Pair input
|
||
question, nodes, xpaths = self.get_question_nodes_and_xpaths()
|
||
input_r = tokenizer_r.encode(
|
||
question, nodes, xpaths=xpaths, max_length=max_length, padding="max_length"
|
||
)
|
||
input_p = tokenizer_p.encode(
|
||
question, nodes, xpaths=xpaths, max_length=max_length, padding="max_length"
|
||
)
|
||
self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||
input_r = tokenizer_r.encode(question, nodes, xpaths=xpaths, padding=True)
|
||
input_p = tokenizer_p.encode(question, nodes, xpaths=xpaths, padding="longest")
|
||
self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id)
|
||
|
||
# Encode_plus - Simple input
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
input_r = tokenizer_r.encode_plus(nodes, xpaths=xpaths, max_length=max_length, padding="max_length")
|
||
input_p = tokenizer_p.encode_plus(nodes, xpaths=xpaths, max_length=max_length, padding="max_length")
|
||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||
|
||
input_r = tokenizer_r.encode_plus(nodes, xpaths=xpaths, padding="longest")
|
||
input_p = tokenizer_p.encode_plus(nodes, xpaths=xpaths, padding=True)
|
||
self.assert_padded_input_match(
|
||
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
|
||
)
|
||
|
||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||
|
||
# Encode_plus - Pair input
|
||
question, nodes, xpaths = self.get_question_nodes_and_xpaths()
|
||
input_r = tokenizer_r.encode_plus(
|
||
question, nodes, xpaths=xpaths, max_length=max_length, padding="max_length"
|
||
)
|
||
input_p = tokenizer_p.encode_plus(
|
||
question, nodes, xpaths=xpaths, max_length=max_length, padding="max_length"
|
||
)
|
||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||
input_r = tokenizer_r.encode_plus(question, nodes, xpaths=xpaths, padding="longest")
|
||
input_p = tokenizer_p.encode_plus(question, nodes, xpaths=xpaths, padding=True)
|
||
self.assert_padded_input_match(
|
||
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
|
||
)
|
||
self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"])
|
||
|
||
# Batch_encode_plus - Simple input
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
|
||
input_r = tokenizer_r.batch_encode_plus(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
max_length=max_length,
|
||
padding="max_length",
|
||
)
|
||
input_p = tokenizer_p.batch_encode_plus(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
max_length=max_length,
|
||
padding="max_length",
|
||
)
|
||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||
|
||
input_r = tokenizer_r.batch_encode_plus(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
max_length=max_length,
|
||
padding="longest",
|
||
)
|
||
input_p = tokenizer_p.batch_encode_plus(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
max_length=max_length,
|
||
padding=True,
|
||
)
|
||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||
|
||
input_r = tokenizer_r.batch_encode_plus(nodes, xpaths=xpaths, padding="longest")
|
||
input_p = tokenizer_p.batch_encode_plus(nodes, xpaths=xpaths, padding=True)
|
||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||
|
||
# Batch_encode_plus - Pair input
|
||
questions, nodes, xpaths = self.get_question_nodes_and_xpaths_batch()
|
||
|
||
input_r = tokenizer_r.batch_encode_plus(
|
||
list(zip(questions, nodes)),
|
||
is_pair=True,
|
||
xpaths=xpaths,
|
||
max_length=max_length,
|
||
truncation=True,
|
||
padding="max_length",
|
||
)
|
||
input_p = tokenizer_p.batch_encode_plus(
|
||
list(zip(questions, nodes)),
|
||
is_pair=True,
|
||
xpaths=xpaths,
|
||
max_length=max_length,
|
||
truncation=True,
|
||
padding="max_length",
|
||
)
|
||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||
|
||
input_r = tokenizer_r.batch_encode_plus(
|
||
list(zip(questions, nodes)),
|
||
is_pair=True,
|
||
xpaths=xpaths,
|
||
padding=True,
|
||
)
|
||
input_p = tokenizer_p.batch_encode_plus(
|
||
list(zip(questions, nodes)),
|
||
is_pair=True,
|
||
xpaths=xpaths,
|
||
padding="longest",
|
||
)
|
||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||
|
||
# Using pad on single examples after tokenization
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
input_r = tokenizer_r.encode_plus(nodes, xpaths=xpaths)
|
||
input_r = tokenizer_r.pad(input_r)
|
||
|
||
input_p = tokenizer_r.encode_plus(nodes, xpaths=xpaths)
|
||
input_p = tokenizer_r.pad(input_p)
|
||
|
||
self.assert_padded_input_match(
|
||
input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id
|
||
)
|
||
|
||
# Using pad on single examples after tokenization
|
||
input_r = tokenizer_r.encode_plus(nodes, xpaths=xpaths)
|
||
input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
|
||
|
||
input_p = tokenizer_r.encode_plus(nodes, xpaths=xpaths)
|
||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||
|
||
self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id)
|
||
|
||
# Using pad after tokenization
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
input_r = tokenizer_r.batch_encode_plus(nodes, xpaths=xpaths)
|
||
input_r = tokenizer_r.pad(input_r)
|
||
|
||
input_p = tokenizer_r.batch_encode_plus(nodes, xpaths=xpaths)
|
||
input_p = tokenizer_r.pad(input_p)
|
||
|
||
self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id)
|
||
|
||
# Using pad after tokenization
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
input_r = tokenizer_r.batch_encode_plus(nodes, xpaths=xpaths)
|
||
input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length")
|
||
|
||
input_p = tokenizer_r.batch_encode_plus(nodes, xpaths=xpaths)
|
||
input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length")
|
||
|
||
self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id)
|
||
|
||
def test_call(self):
|
||
# Tests that all call wrap to encode_plus and batch_encode_plus
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
# Test not batched
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
encoded_sequences_1 = tokenizer.encode_plus(nodes, xpaths=xpaths)
|
||
encoded_sequences_2 = tokenizer(nodes, xpaths=xpaths)
|
||
self.assertEqual(encoded_sequences_1, encoded_sequences_2)
|
||
|
||
# Test not batched pairs
|
||
question, nodes, xpaths = self.get_question_nodes_and_xpaths()
|
||
encoded_sequences_1 = tokenizer.encode_plus(nodes, xpaths=xpaths)
|
||
encoded_sequences_2 = tokenizer(nodes, xpaths=xpaths)
|
||
self.assertEqual(encoded_sequences_1, encoded_sequences_2)
|
||
|
||
# Test batched
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
encoded_sequences_1 = tokenizer.batch_encode_plus(nodes, is_pair=False, xpaths=xpaths)
|
||
encoded_sequences_2 = tokenizer(nodes, xpaths=xpaths)
|
||
self.assertEqual(encoded_sequences_1, encoded_sequences_2)
|
||
|
||
def test_batch_encode_plus_batch_sequence_length(self):
|
||
# Tests that all encoded values have the correct size
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
|
||
encoded_sequences = [
|
||
tokenizer.encode_plus(nodes_example, xpaths=xpaths_example)
|
||
for nodes_example, xpaths_example in zip(nodes, xpaths)
|
||
]
|
||
encoded_sequences_batch = tokenizer.batch_encode_plus(
|
||
nodes, is_pair=False, xpaths=xpaths, padding=False
|
||
)
|
||
self.assertListEqual(
|
||
encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
|
||
)
|
||
|
||
maximum_length = len(
|
||
max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len)
|
||
)
|
||
|
||
# check correct behaviour if no pad_token_id exists and add it eventually
|
||
self._check_no_pad_token_padding(tokenizer, nodes)
|
||
|
||
encoded_sequences_padded = [
|
||
tokenizer.encode_plus(
|
||
nodes_example, xpaths=xpaths_example, max_length=maximum_length, padding="max_length"
|
||
)
|
||
for nodes_example, xpaths_example in zip(nodes, xpaths)
|
||
]
|
||
|
||
encoded_sequences_batch_padded = tokenizer.batch_encode_plus(
|
||
nodes, is_pair=False, xpaths=xpaths, padding=True
|
||
)
|
||
self.assertListEqual(
|
||
encoded_sequences_padded,
|
||
self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch_padded),
|
||
)
|
||
|
||
# check 'longest' is unsensitive to a max length
|
||
encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus(
|
||
nodes, is_pair=False, xpaths=xpaths, padding=True
|
||
)
|
||
encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus(
|
||
nodes, is_pair=False, xpaths=xpaths, max_length=maximum_length + 10, padding="longest"
|
||
)
|
||
for key in encoded_sequences_batch_padded_1:
|
||
self.assertListEqual(
|
||
encoded_sequences_batch_padded_1[key],
|
||
encoded_sequences_batch_padded_2[key],
|
||
)
|
||
|
||
# check 'no_padding' is unsensitive to a max length
|
||
encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus(
|
||
nodes, is_pair=False, xpaths=xpaths, padding=False
|
||
)
|
||
encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus(
|
||
nodes, is_pair=False, xpaths=xpaths, max_length=maximum_length + 10, padding=False
|
||
)
|
||
for key in encoded_sequences_batch_padded_1:
|
||
self.assertListEqual(
|
||
encoded_sequences_batch_padded_1[key],
|
||
encoded_sequences_batch_padded_2[key],
|
||
)
|
||
|
||
@unittest.skip(reason="batch_encode_plus does not handle overflowing tokens.")
|
||
def test_batch_encode_plus_overflowing_tokens(self):
|
||
pass
|
||
|
||
def test_batch_encode_plus_padding(self):
|
||
# Test that padded sequences are equivalent between batch_encode_plus and encode_plus
|
||
|
||
# Right padding tests
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
|
||
max_length = 100
|
||
|
||
# check correct behaviour if no pad_token_id exists and add it eventually
|
||
self._check_no_pad_token_padding(tokenizer, nodes)
|
||
|
||
encoded_sequences = [
|
||
tokenizer.encode_plus(
|
||
nodes_example, xpaths=xpaths_example, max_length=max_length, padding="max_length"
|
||
)
|
||
for nodes_example, xpaths_example in zip(nodes, xpaths)
|
||
]
|
||
encoded_sequences_batch = tokenizer.batch_encode_plus(
|
||
nodes, is_pair=False, xpaths=xpaths, max_length=max_length, padding="max_length"
|
||
)
|
||
self.assertListEqual(
|
||
encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
|
||
)
|
||
|
||
# Left padding tests
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
tokenizer.padding_side = "left"
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
|
||
max_length = 100
|
||
|
||
# check correct behaviour if no pad_token_id exists and add it eventually
|
||
self._check_no_pad_token_padding(tokenizer, nodes)
|
||
|
||
encoded_sequences = [
|
||
tokenizer.encode_plus(
|
||
nodes_example, xpaths=xpaths_example, max_length=max_length, padding="max_length"
|
||
)
|
||
for nodes_example, xpaths_example in zip(nodes, xpaths)
|
||
]
|
||
encoded_sequences_batch = tokenizer.batch_encode_plus(
|
||
nodes, is_pair=False, xpaths=xpaths, max_length=max_length, padding="max_length"
|
||
)
|
||
self.assertListEqual(
|
||
encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch)
|
||
)
|
||
|
||
def test_padding_to_multiple_of(self):
|
||
tokenizers = self.get_tokenizers()
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
if tokenizer.pad_token is None:
|
||
self.skipTest(reason="No padding token.")
|
||
else:
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
|
||
# empty_tokens = tokenizer([""], [[]], padding=True, pad_to_multiple_of=8)
|
||
normal_tokens = tokenizer(nodes, xpaths=xpaths, padding=True, pad_to_multiple_of=8)
|
||
# for key, value in empty_tokens.items():
|
||
# self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
|
||
for key, value in normal_tokens.items():
|
||
self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
|
||
|
||
normal_tokens = tokenizer(nodes, xpaths=xpaths, pad_to_multiple_of=8)
|
||
for key, value in normal_tokens.items():
|
||
self.assertNotEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
|
||
|
||
# Should also work with truncation
|
||
normal_tokens = tokenizer(
|
||
nodes, xpaths=xpaths, padding=True, truncation=True, pad_to_multiple_of=8
|
||
)
|
||
for key, value in normal_tokens.items():
|
||
self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8")
|
||
|
||
# truncation to something which is not a multiple of pad_to_multiple_of raises an error
|
||
self.assertRaises(
|
||
ValueError,
|
||
tokenizer.__call__,
|
||
nodes,
|
||
xpaths=xpaths,
|
||
padding=True,
|
||
truncation=True,
|
||
max_length=12,
|
||
pad_to_multiple_of=8,
|
||
)
|
||
|
||
def test_special_tokens_mask_input_pairs(self):
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
encoded_sequence = tokenizer.encode(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
encoded_sequence_dict = tokenizer.encode_plus(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
add_special_tokens=True,
|
||
return_special_tokens_mask=True,
|
||
# add_prefix_space=False,
|
||
)
|
||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||
|
||
filtered_sequence = [
|
||
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
|
||
]
|
||
filtered_sequence = [x for x in filtered_sequence if x is not None]
|
||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||
|
||
def test_special_tokens_mask(self):
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
# Testing single inputs
|
||
encoded_sequence = tokenizer.encode(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
encoded_sequence_dict = tokenizer.encode_plus(
|
||
nodes, xpaths=xpaths, add_special_tokens=True, return_special_tokens_mask=True
|
||
)
|
||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||
|
||
filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]]
|
||
self.assertEqual(encoded_sequence, filtered_sequence)
|
||
|
||
def test_save_and_load_tokenizer(self):
|
||
# safety check on max_len default value so we are sure the test works
|
||
tokenizers = self.get_tokenizers()
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
self.assertNotEqual(tokenizer.model_max_length, 42)
|
||
|
||
# Now let's start the test
|
||
tokenizers = self.get_tokenizers()
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
# Isolate this from the other tests because we save additional tokens/etc
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
tmpdirname = tempfile.mkdtemp()
|
||
|
||
before_tokens = tokenizer.encode(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
before_vocab = tokenizer.get_vocab()
|
||
tokenizer.save_pretrained(tmpdirname)
|
||
|
||
after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname)
|
||
after_tokens = after_tokenizer.encode(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
after_vocab = after_tokenizer.get_vocab()
|
||
self.assertListEqual(before_tokens, after_tokens)
|
||
self.assertDictEqual(before_vocab, after_vocab)
|
||
|
||
shutil.rmtree(tmpdirname)
|
||
|
||
def test_right_and_left_padding(self):
|
||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
sequence = "Sequence"
|
||
padding_size = 10
|
||
|
||
# check correct behaviour if no pad_token_id exists and add it eventually
|
||
self._check_no_pad_token_padding(tokenizer, sequence)
|
||
|
||
padding_idx = tokenizer.pad_token_id
|
||
|
||
# RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||
tokenizer.padding_side = "right"
|
||
encoded_sequence = tokenizer.encode(nodes, xpaths=xpaths)
|
||
sequence_length = len(encoded_sequence)
|
||
padded_sequence = tokenizer.encode(
|
||
nodes, xpaths=xpaths, max_length=sequence_length + padding_size, padding="max_length"
|
||
)
|
||
padded_sequence_length = len(padded_sequence)
|
||
assert sequence_length + padding_size == padded_sequence_length
|
||
assert encoded_sequence + [padding_idx] * padding_size == padded_sequence
|
||
|
||
# LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True
|
||
tokenizer.padding_side = "left"
|
||
encoded_sequence = tokenizer.encode(nodes, xpaths=xpaths)
|
||
sequence_length = len(encoded_sequence)
|
||
padded_sequence = tokenizer.encode(
|
||
nodes, xpaths=xpaths, max_length=sequence_length + padding_size, padding="max_length"
|
||
)
|
||
padded_sequence_length = len(padded_sequence)
|
||
assert sequence_length + padding_size == padded_sequence_length
|
||
assert [padding_idx] * padding_size + encoded_sequence == padded_sequence
|
||
|
||
# RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_padding'
|
||
encoded_sequence = tokenizer.encode(nodes, xpaths=xpaths)
|
||
sequence_length = len(encoded_sequence)
|
||
|
||
tokenizer.padding_side = "right"
|
||
padded_sequence_right = tokenizer.encode(nodes, xpaths=xpaths, padding=True)
|
||
padded_sequence_right_length = len(padded_sequence_right)
|
||
assert sequence_length == padded_sequence_right_length
|
||
assert encoded_sequence == padded_sequence_right
|
||
|
||
tokenizer.padding_side = "left"
|
||
padded_sequence_left = tokenizer.encode(nodes, xpaths=xpaths, padding="longest")
|
||
padded_sequence_left_length = len(padded_sequence_left)
|
||
assert sequence_length == padded_sequence_left_length
|
||
assert encoded_sequence == padded_sequence_left
|
||
|
||
tokenizer.padding_side = "right"
|
||
padded_sequence_right = tokenizer.encode(nodes, xpaths=xpaths)
|
||
padded_sequence_right_length = len(padded_sequence_right)
|
||
assert sequence_length == padded_sequence_right_length
|
||
assert encoded_sequence == padded_sequence_right
|
||
|
||
tokenizer.padding_side = "left"
|
||
padded_sequence_left = tokenizer.encode(nodes, xpaths=xpaths, padding=False)
|
||
padded_sequence_left_length = len(padded_sequence_left)
|
||
assert sequence_length == padded_sequence_left_length
|
||
assert encoded_sequence == padded_sequence_left
|
||
|
||
def test_token_type_ids(self):
|
||
tokenizers = self.get_tokenizers()
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
# test 1: single sequence
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
|
||
output = tokenizer(nodes, xpaths=xpaths, return_token_type_ids=True)
|
||
|
||
# Assert that the token type IDs have the same length as the input IDs
|
||
self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"]))
|
||
|
||
# Assert that the token type IDs have the same length as the attention mask
|
||
self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"]))
|
||
|
||
self.assertIn(0, output["token_type_ids"])
|
||
self.assertNotIn(1, output["token_type_ids"])
|
||
|
||
# test 2: two sequences (question + nodes)
|
||
question, nodes, xpaths = self.get_question_nodes_and_xpaths()
|
||
|
||
output = tokenizer(question, nodes, xpaths, return_token_type_ids=True)
|
||
|
||
# Assert that the token type IDs have the same length as the input IDs
|
||
self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"]))
|
||
|
||
# Assert that the token type IDs have the same length as the attention mask
|
||
self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"]))
|
||
|
||
self.assertIn(0, output["token_type_ids"])
|
||
|
||
def test_offsets_mapping(self):
|
||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||
tokenizer_r = self.get_tokenizer(pretrained_name, **kwargs)
|
||
|
||
text = ["a", "wonderful", "test"]
|
||
xpaths = ["html/body" for _ in range(len(text))]
|
||
|
||
# No pair
|
||
tokens_with_offsets = tokenizer_r.encode_plus(
|
||
text,
|
||
xpaths=xpaths,
|
||
return_special_tokens_mask=True,
|
||
return_offsets_mapping=True,
|
||
add_special_tokens=True,
|
||
)
|
||
added_tokens = tokenizer_r.num_special_tokens_to_add(False)
|
||
offsets = tokens_with_offsets["offset_mapping"]
|
||
|
||
# Assert there is the same number of tokens and offsets
|
||
self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
|
||
|
||
# Assert there is online added_tokens special_tokens
|
||
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
|
||
|
||
# Pairs
|
||
text = "what's his name"
|
||
pair = ["a", "wonderful", "test"]
|
||
xpaths = ["html/body" for _ in range(len(pair))]
|
||
tokens_with_offsets = tokenizer_r.encode_plus(
|
||
text,
|
||
pair,
|
||
xpaths=xpaths,
|
||
return_special_tokens_mask=True,
|
||
return_offsets_mapping=True,
|
||
add_special_tokens=True,
|
||
)
|
||
added_tokens = tokenizer_r.num_special_tokens_to_add(True)
|
||
offsets = tokens_with_offsets["offset_mapping"]
|
||
|
||
# Assert there is the same number of tokens and offsets
|
||
self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"]))
|
||
|
||
# Assert there is online added_tokens special_tokens
|
||
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
|
||
|
||
def test_embedded_special_tokens(self):
|
||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||
tokenizer_p = self.get_tokenizer(pretrained_name, **kwargs)
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
tokens_r = tokenizer_r(nodes, xpaths=xpaths, add_special_tokens=True)
|
||
tokens_p = tokenizer_p(nodes, xpaths=xpaths, add_special_tokens=True)
|
||
|
||
for key in tokens_p:
|
||
self.assertEqual(tokens_r[key], tokens_p[key])
|
||
|
||
if "token_type_ids" in tokens_r:
|
||
self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"]))
|
||
|
||
tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"])
|
||
tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"])
|
||
self.assertSequenceEqual(tokens_r, tokens_p)
|
||
|
||
def test_compare_add_special_tokens(self):
|
||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||
|
||
simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False)
|
||
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
# tokenize()
|
||
no_special_tokens = tokenizer_r.tokenize(" ".join(nodes), add_special_tokens=False)
|
||
with_special_tokens = tokenizer_r.tokenize(" ".join(nodes), add_special_tokens=True)
|
||
self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
|
||
|
||
# encode()
|
||
no_special_tokens = tokenizer_r.encode(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
with_special_tokens = tokenizer_r.encode(nodes, xpaths=xpaths, add_special_tokens=True)
|
||
self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add)
|
||
|
||
# encode_plus()
|
||
no_special_tokens = tokenizer_r.encode_plus(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
with_special_tokens = tokenizer_r.encode_plus(nodes, xpaths=xpaths, add_special_tokens=True)
|
||
for key in no_special_tokens:
|
||
self.assertEqual(
|
||
len(no_special_tokens[key]),
|
||
len(with_special_tokens[key]) - simple_num_special_tokens_to_add,
|
||
)
|
||
|
||
# # batch_encode_plus
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
|
||
no_special_tokens = tokenizer_r.batch_encode_plus(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
with_special_tokens = tokenizer_r.batch_encode_plus(nodes, xpaths=xpaths, add_special_tokens=True)
|
||
for key in no_special_tokens:
|
||
for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
|
||
self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
|
||
|
||
@slow
|
||
def test_markuplm_truncation_integration_test(self):
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
|
||
tokenizer = MarkupLMTokenizer.from_pretrained("microsoft/markuplm-base", model_max_length=512)
|
||
|
||
for i in range(12, 512):
|
||
new_encoded_inputs = tokenizer.encode(nodes, xpaths=xpaths, max_length=i, truncation=True)
|
||
|
||
# Ensure that the input IDs are less than the max length defined.
|
||
self.assertLessEqual(len(new_encoded_inputs), i)
|
||
|
||
tokenizer.model_max_length = 20
|
||
new_encoded_inputs = tokenizer.encode(nodes, xpaths=xpaths, truncation=True)
|
||
dropped_encoded_inputs = tokenizer.encode(nodes, xpaths=xpaths, truncation=True)
|
||
|
||
# Ensure that the input IDs are still truncated when no max_length is specified
|
||
self.assertListEqual(new_encoded_inputs, dropped_encoded_inputs)
|
||
self.assertLessEqual(len(new_encoded_inputs), 20)
|
||
|
||
def test_sequence_ids(self):
|
||
tokenizers = self.get_tokenizers()
|
||
for tokenizer in tokenizers:
|
||
if not tokenizer.is_fast:
|
||
continue
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
seq_0 = "Test this method."
|
||
seq_1 = ["With", "these", "inputs."]
|
||
xpaths = ["html/body" for _ in range(len(seq_1))]
|
||
|
||
# We want to have sequence 0 and sequence 1 are tagged
|
||
# respectively with 0 and 1 token_ids
|
||
# (regardless of whether the model use token type ids)
|
||
# We use this assumption in the QA pipeline among other place
|
||
output = tokenizer(seq_0.split(), xpaths=xpaths)
|
||
self.assertIn(0, output.sequence_ids())
|
||
|
||
output = tokenizer(seq_0, seq_1, xpaths=xpaths)
|
||
self.assertIn(0, output.sequence_ids())
|
||
self.assertIn(1, output.sequence_ids())
|
||
|
||
if tokenizer.num_special_tokens_to_add(pair=True):
|
||
self.assertIn(None, output.sequence_ids())
|
||
|
||
def test_special_tokens_initialization(self):
|
||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||
added_tokens = [AddedToken("<special>", lstrip=True)]
|
||
|
||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||
)
|
||
nodes = "Hey this is a <special> token".split()
|
||
xpaths = ["html/body" for _ in range(len(nodes))]
|
||
r_output = tokenizer_r.encode(nodes, xpaths=xpaths)
|
||
|
||
special_token_id = tokenizer_r.encode(["<special>"], xpaths=["html/body"], add_special_tokens=False)[0]
|
||
|
||
self.assertTrue(special_token_id in r_output)
|
||
|
||
def test_training_new_tokenizer(self):
|
||
# This feature only exists for fast tokenizers
|
||
if not self.test_rust_tokenizer:
|
||
self.skipTest(reason="test_rust_tokenizer is set to False")
|
||
|
||
tokenizer = self.get_tokenizer()
|
||
new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100)
|
||
|
||
# Test we can use the new tokenizer with something not seen during training
|
||
text = [["this", "is", "the"], ["how", "are", "you"]]
|
||
xpaths = [["html/body"] * 3, ["html/body"] * 3]
|
||
inputs = new_tokenizer(text, xpaths=xpaths)
|
||
self.assertEqual(len(inputs["input_ids"]), 2)
|
||
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||
expected_result = ( # original expected result "this is the" seems contradicts to FacebookAI/roberta-based tokenizer
|
||
"thisisthe"
|
||
)
|
||
|
||
if tokenizer.backend_tokenizer.normalizer is not None:
|
||
expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
|
||
self.assertEqual(expected_result, decoded_input)
|
||
|
||
# We check that the parameters of the tokenizer remained the same
|
||
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
||
self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False))
|
||
self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True))
|
||
|
||
# Check we have the correct max_length for both pair and non-pair inputs.
|
||
self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence)
|
||
self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair)
|
||
|
||
# Assert the set of special tokens match as we didn't ask to change them
|
||
self.assertSequenceEqual(
|
||
tokenizer.all_special_tokens,
|
||
new_tokenizer.all_special_tokens,
|
||
)
|
||
|
||
self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map)
|
||
|
||
def test_training_new_tokenizer_with_special_tokens_change(self):
|
||
# This feature only exists for fast tokenizers
|
||
if not self.test_rust_tokenizer:
|
||
self.skipTest(reason="test_rust_tokenizer is set to False")
|
||
|
||
tokenizer = self.get_tokenizer()
|
||
# Test with a special tokens map
|
||
class_signature = inspect.signature(tokenizer.__class__)
|
||
if "cls_token" in class_signature.parameters:
|
||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||
SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: "<cls>"}
|
||
)
|
||
cls_id = new_tokenizer.get_vocab()["<cls>"]
|
||
self.assertEqual(new_tokenizer.cls_token, "<cls>")
|
||
self.assertEqual(new_tokenizer.cls_token_id, cls_id)
|
||
|
||
# Create a new mapping from the special tokens defined in the original tokenizer
|
||
special_tokens_list = PreTrainedTokenizerBase.SPECIAL_TOKENS_ATTRIBUTES.copy()
|
||
special_tokens_map = {}
|
||
for token in special_tokens_list:
|
||
# Get the private one to avoid unnecessary warnings.
|
||
if getattr(tokenizer, token) is not None:
|
||
special_token = getattr(tokenizer, token)
|
||
special_tokens_map[special_token] = f"{special_token}a"
|
||
|
||
# Train new tokenizer
|
||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||
SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map
|
||
)
|
||
|
||
# Check the changes
|
||
for token in special_tokens_list:
|
||
# Get the private one to avoid unnecessary warnings.
|
||
if getattr(tokenizer, token) is None:
|
||
continue
|
||
special_token = getattr(tokenizer, token)
|
||
if special_token in special_tokens_map:
|
||
new_special_token = getattr(new_tokenizer, token)
|
||
self.assertEqual(special_tokens_map[special_token], new_special_token)
|
||
|
||
new_id = new_tokenizer.get_vocab()[new_special_token]
|
||
self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id)
|
||
|
||
# Check if the AddedToken / string format has been kept
|
||
for special_token in tokenizer.all_special_tokens:
|
||
if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map:
|
||
# The special token must appear identically in the list of the new tokenizer.
|
||
self.assertTrue(
|
||
special_token in new_tokenizer.all_special_tokens,
|
||
f"'{special_token}' should be in {new_tokenizer.all_special_tokens}",
|
||
)
|
||
elif isinstance(special_token, AddedToken):
|
||
# The special token must appear in the list of the new tokenizer as an object of type AddedToken with
|
||
# the same parameters as the old AddedToken except the content that the user has requested to change.
|
||
special_token_str = special_token.content
|
||
new_special_token_str = special_tokens_map[special_token_str]
|
||
|
||
find = False
|
||
for candidate in new_tokenizer.all_special_tokens:
|
||
if (
|
||
isinstance(candidate, AddedToken)
|
||
and candidate.content == new_special_token_str
|
||
and candidate.lstrip == special_token.lstrip
|
||
and candidate.rstrip == special_token.rstrip
|
||
and candidate.normalized == special_token.normalized
|
||
and candidate.single_word == special_token.single_word
|
||
):
|
||
find = True
|
||
break
|
||
self.assertTrue(
|
||
find,
|
||
f"'{new_special_token_str}' doesn't appear in the list "
|
||
f"'{new_tokenizer.all_special_tokens}' as an AddedToken with the same parameters as "
|
||
f"'{special_token}' in the list {tokenizer.all_special_tokens}",
|
||
)
|
||
elif special_token not in special_tokens_map:
|
||
# The special token must appear identically in the list of the new tokenizer.
|
||
self.assertTrue(
|
||
special_token in new_tokenizer.all_special_tokens,
|
||
f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}",
|
||
)
|
||
|
||
else:
|
||
# The special token must appear in the list of the new tokenizer as an object of type string.
|
||
self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens)
|
||
|
||
# Test we can use the new tokenizer with something not seen during training
|
||
nodes = [["this", "is"], ["hello", "🤗"]]
|
||
xpaths = [["html/body"] * 2, ["html/body"] * 2]
|
||
inputs = new_tokenizer(nodes, xpaths=xpaths)
|
||
self.assertEqual(len(inputs["input_ids"]), 2)
|
||
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||
expected_result = "thisis" # same as line 1399
|
||
|
||
if tokenizer.backend_tokenizer.normalizer is not None:
|
||
expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result)
|
||
self.assertEqual(expected_result, decoded_input)
|
||
|
||
def test_batch_encode_dynamic_overflowing(self):
|
||
"""
|
||
When calling batch_encode with multiple sequences, it can return different number of
|
||
overflowing encoding for each sequence:
|
||
[
|
||
Sequence 1: [Encoding 1, Encoding 2],
|
||
Sequence 2: [Encoding 1],
|
||
Sequence 3: [Encoding 1, Encoding 2, ... Encoding N]
|
||
]
|
||
This needs to be padded so that it can represented as a tensor
|
||
"""
|
||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||
tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||
|
||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"):
|
||
returned_tensor = "pt"
|
||
|
||
# Single example
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
tokens = tokenizer.encode_plus(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
max_length=1,
|
||
padding=True,
|
||
truncation=True,
|
||
return_tensors=returned_tensor,
|
||
return_overflowing_tokens=True,
|
||
)
|
||
|
||
for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
|
||
if "xpath" not in key:
|
||
self.assertEqual(len(tokens[key].shape), 2)
|
||
else:
|
||
self.assertEqual(len(tokens[key].shape), 3)
|
||
|
||
# Batch of examples
|
||
# For these 2 examples, 3 training examples will be created
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
tokens = tokenizer.batch_encode_plus(
|
||
nodes,
|
||
xpaths=xpaths,
|
||
max_length=6,
|
||
padding=True,
|
||
truncation="only_first",
|
||
return_tensors=returned_tensor,
|
||
return_overflowing_tokens=True,
|
||
)
|
||
|
||
for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()):
|
||
if "xpath" not in key:
|
||
self.assertEqual(len(tokens[key].shape), 2)
|
||
self.assertEqual(tokens[key].shape[-1], 6)
|
||
else:
|
||
self.assertEqual(len(tokens[key].shape), 3)
|
||
self.assertEqual(tokens[key].shape[-2], 6)
|
||
|
||
@unittest.skip(reason="TO DO: overwrite this very extensive test.")
|
||
def test_alignment_methods(self):
|
||
pass
|
||
|
||
def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20, min_length=5):
|
||
toks = [(i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) for i in range(len(tokenizer))]
|
||
toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks))
|
||
toks = list(
|
||
filter(
|
||
lambda t: [t[0]]
|
||
== tokenizer.encode(t[1].split(" "), xpaths=len(t[1]) * ["html/body"], add_special_tokens=False),
|
||
toks,
|
||
)
|
||
)
|
||
if max_length is not None and len(toks) > max_length:
|
||
toks = toks[:max_length]
|
||
if min_length is not None and len(toks) < min_length and len(toks) > 0:
|
||
while len(toks) < min_length:
|
||
toks = toks + toks
|
||
# toks_str = [t[1] for t in toks]
|
||
toks_ids = [t[0] for t in toks]
|
||
|
||
# Ensure consistency
|
||
output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False)
|
||
# an extra blank will cause inconsistency: ["a","b",] & "a b"
|
||
"""
|
||
if " " not in output_txt and len(toks_ids) > 1:
|
||
output_txt = (
|
||
tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False)
|
||
+ " "
|
||
+ tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False)
|
||
)
|
||
"""
|
||
if with_prefix_space:
|
||
output_txt = " " + output_txt
|
||
nodes = output_txt.split(" ")
|
||
xpaths = ["html/body" for i in range(len(nodes))]
|
||
output_ids = tokenizer.encode(nodes, xpaths=xpaths, add_special_tokens=False)
|
||
return nodes, xpaths, output_ids
|
||
|
||
@unittest.skip(reason="This test is failing for fast")
|
||
def test_maximum_encoding_length_pair_input(self):
|
||
# slow part fixed, fast part not
|
||
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
# Build a sequence from our model's vocabulary
|
||
stride = 2
|
||
seq_0, xpaths_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
|
||
question_0 = " ".join(map(str, seq_0))
|
||
if len(ids) <= 2 + stride:
|
||
seq_0 = (seq_0 + " ") * (2 + stride)
|
||
ids = None
|
||
|
||
seq0_tokens = tokenizer(seq_0, xpaths=xpaths_0, add_special_tokens=False)
|
||
self.assertGreater(len(seq0_tokens["input_ids"]), 2 + stride)
|
||
question_1 = "This is another sentence to be encoded."
|
||
seq_1 = ["hello", "world"]
|
||
xpaths_1 = ["html/body" for i in range(len(seq_1))]
|
||
seq1_tokens = tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)
|
||
if abs(len(seq0_tokens["input_ids"]) - len(seq1_tokens["input_ids"])) <= 2:
|
||
seq1_tokens_input_ids = seq1_tokens["input_ids"] + seq1_tokens["input_ids"]
|
||
seq_1 = tokenizer.decode(seq1_tokens_input_ids, clean_up_tokenization_spaces=False)
|
||
seq_1 = seq_1.split(" ")
|
||
xpaths_1 = ["html/body" for i in range(len(seq_1))]
|
||
seq1_tokens = tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)
|
||
|
||
self.assertGreater(len(seq1_tokens["input_ids"]), 2 + stride)
|
||
|
||
smallest = (
|
||
seq1_tokens["input_ids"]
|
||
if len(seq0_tokens["input_ids"]) > len(seq1_tokens["input_ids"])
|
||
else seq0_tokens["input_ids"]
|
||
)
|
||
|
||
# We are not using the special tokens - a bit too hard to test all the tokenizers with this
|
||
# TODO try this again later
|
||
sequence = tokenizer(question_0, seq_1, xpaths=xpaths_1, add_special_tokens=False)
|
||
|
||
# Test with max model input length
|
||
model_max_length = tokenizer.model_max_length
|
||
self.assertEqual(model_max_length, 100)
|
||
seq_2 = seq_0 * model_max_length
|
||
question_2 = " ".join(map(str, seq_2))
|
||
xpaths_2 = xpaths_0 * model_max_length
|
||
# assertgreater -> assertgreaterequal
|
||
self.assertGreaterEqual(len(seq_2), model_max_length)
|
||
|
||
sequence1 = tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)
|
||
total_length1 = len(sequence1["input_ids"])
|
||
sequence2 = tokenizer(question_2, seq_1, xpaths=xpaths_1, add_special_tokens=False)
|
||
total_length2 = len(sequence2["input_ids"])
|
||
self.assertLess(total_length1, model_max_length, "Issue with the testing sequence, please update it.")
|
||
self.assertGreater(
|
||
total_length2, model_max_length, "Issue with the testing sequence, please update it."
|
||
)
|
||
|
||
# Simple
|
||
padding_strategies = (
|
||
[False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
|
||
)
|
||
for padding_state in padding_strategies:
|
||
with self.subTest(f"{tokenizer.__class__.__name__} Padding: {padding_state}"):
|
||
for truncation_state in [True, "longest_first", "only_first"]:
|
||
with self.subTest(f"{tokenizer.__class__.__name__} Truncation: {truncation_state}"):
|
||
output = tokenizer(
|
||
question_2,
|
||
seq_1,
|
||
xpaths=xpaths_1,
|
||
padding=padding_state,
|
||
truncation=truncation_state,
|
||
)
|
||
self.assertEqual(len(output["input_ids"]), model_max_length)
|
||
self.assertEqual(len(output["xpath_tags_seq"]), model_max_length)
|
||
self.assertEqual(len(output["xpath_subs_seq"]), model_max_length)
|
||
|
||
output = tokenizer(
|
||
[question_2],
|
||
[seq_1],
|
||
xpaths=[xpaths_1],
|
||
padding=padding_state,
|
||
truncation=truncation_state,
|
||
)
|
||
self.assertEqual(len(output["input_ids"][0]), model_max_length)
|
||
self.assertEqual(len(output["xpath_tags_seq"][0]), model_max_length)
|
||
self.assertEqual(len(output["xpath_subs_seq"][0]), model_max_length)
|
||
|
||
# Simple
|
||
output = tokenizer(
|
||
question_1, seq_2, xpaths=xpaths_2, padding=padding_state, truncation="only_second"
|
||
)
|
||
self.assertEqual(len(output["input_ids"]), model_max_length)
|
||
self.assertEqual(len(output["xpath_tags_seq"]), model_max_length)
|
||
self.assertEqual(len(output["xpath_subs_seq"]), model_max_length)
|
||
|
||
output = tokenizer(
|
||
[question_1], [seq_2], xpaths=[xpaths_2], padding=padding_state, truncation="only_second"
|
||
)
|
||
self.assertEqual(len(output["input_ids"][0]), model_max_length)
|
||
self.assertEqual(len(output["xpath_tags_seq"][0]), model_max_length)
|
||
self.assertEqual(len(output["xpath_subs_seq"][0]), model_max_length)
|
||
|
||
# Simple with no truncation
|
||
# Reset warnings
|
||
tokenizer.deprecation_warnings = {}
|
||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||
output = tokenizer(
|
||
question_1, seq_2, xpaths=xpaths_2, padding=padding_state, truncation=False
|
||
)
|
||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||
self.assertNotEqual(len(output["xpath_tags_seq"]), model_max_length)
|
||
self.assertNotEqual(len(output["xpath_subs_seq"]), model_max_length)
|
||
self.assertEqual(len(cm.records), 1)
|
||
self.assertTrue(
|
||
cm.records[0].message.startswith(
|
||
"Token indices sequence length is longer than the specified maximum sequence length"
|
||
" for this model"
|
||
)
|
||
)
|
||
|
||
tokenizer.deprecation_warnings = {}
|
||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||
output = tokenizer(
|
||
[question_1], [seq_2], xpaths=[xpaths_2], padding=padding_state, truncation=False
|
||
)
|
||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||
self.assertNotEqual(len(output["xpath_tags_seq"][0]), model_max_length)
|
||
self.assertNotEqual(len(output["xpath_subs_seq"][0]), model_max_length)
|
||
self.assertEqual(len(cm.records), 1)
|
||
self.assertTrue(
|
||
cm.records[0].message.startswith(
|
||
"Token indices sequence length is longer than the specified maximum sequence length"
|
||
" for this model"
|
||
)
|
||
)
|
||
# Check the order of Sequence of input ids, overflowing tokens and xpath_tags_seq sequence with truncation
|
||
truncated_first_sequence = (
|
||
tokenizer(seq_0, xpaths=xpaths_0, add_special_tokens=False)["input_ids"][:-2]
|
||
+ tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)["input_ids"]
|
||
)
|
||
truncated_second_sequence = (
|
||
tokenizer(seq_0, xpaths=xpaths_0, add_special_tokens=False)["input_ids"]
|
||
+ tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)["input_ids"][:-2]
|
||
)
|
||
truncated_longest_sequence = (
|
||
truncated_first_sequence if len(seq0_tokens) > len(seq1_tokens) else truncated_second_sequence
|
||
)
|
||
|
||
overflow_first_sequence = (
|
||
tokenizer(seq_0, xpaths=xpaths_0, add_special_tokens=False)["input_ids"][-(2 + stride) :]
|
||
+ tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)["input_ids"]
|
||
)
|
||
overflow_second_sequence = (
|
||
tokenizer(seq_0, xpaths=xpaths_0, add_special_tokens=False)["input_ids"]
|
||
+ tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)["input_ids"][-(2 + stride) :]
|
||
)
|
||
overflow_longest_sequence = (
|
||
overflow_first_sequence if len(seq0_tokens) > len(seq1_tokens) else overflow_second_sequence
|
||
)
|
||
|
||
xpath_tags_seq_first = [[5] * 50] * (
|
||
len(tokenizer(seq_0, xpaths=xpaths_0, add_special_tokens=False)["input_ids"]) - 2
|
||
)
|
||
xpath_tags_seq_first_sequence = (
|
||
xpath_tags_seq_first
|
||
+ tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)["xpath_tags_seq"]
|
||
)
|
||
overflowing_token_xpath_tags_seq_first_sequence_slow = [[5] * 50] * (2 + stride)
|
||
overflowing_token_xpath_tags_seq_first_sequence_fast = [[5] * 50] * (2 + stride) + tokenizer(
|
||
seq_1, xpaths=xpaths_1, add_special_tokens=False
|
||
)["xpath_tags_seq"]
|
||
|
||
xpath_tags_seq_second = [[5] * 50] * len(
|
||
tokenizer(seq_0, xpaths=xpaths_0, add_special_tokens=False)["input_ids"]
|
||
)
|
||
xpath_tags_seq_second_sequence = (
|
||
xpath_tags_seq_second
|
||
+ tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)["xpath_tags_seq"][:-2]
|
||
)
|
||
overflowing_token_xpath_tags_seq_second_sequence_slow = tokenizer(
|
||
seq_1, xpaths=xpaths_1, add_special_tokens=False
|
||
)["xpath_tags_seq"][-(2 + stride) :]
|
||
overflowing_token_xpath_tags_seq_second_sequence_fast = [[5] * 50] * len(
|
||
tokenizer(seq_0, xpaths=xpaths_0, add_special_tokens=False)["input_ids"]
|
||
) + tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)["xpath_tags_seq"][-(2 + stride) :]
|
||
|
||
xpath_tags_seq_longest_sequence = (
|
||
xpath_tags_seq_first_sequence
|
||
if len(seq0_tokens) > len(seq1_tokens)
|
||
else xpath_tags_seq_second_sequence
|
||
)
|
||
overflowing_token_xpath_tags_seq_longest_sequence_fast = (
|
||
overflowing_token_xpath_tags_seq_first_sequence_fast
|
||
if len(seq0_tokens) > len(seq1_tokens)
|
||
else overflowing_token_xpath_tags_seq_second_sequence_fast
|
||
)
|
||
|
||
# Overflowing tokens are handled quite differently in slow and fast tokenizers
|
||
if isinstance(tokenizer, MarkupLMTokenizerFast):
|
||
information = tokenizer(
|
||
question_0,
|
||
seq_1,
|
||
xpaths=xpaths_1,
|
||
max_length=len(sequence["input_ids"]) - 2,
|
||
add_special_tokens=False,
|
||
stride=stride,
|
||
truncation="longest_first",
|
||
return_overflowing_tokens=True,
|
||
# add_prefix_space=False,
|
||
)
|
||
truncated_sequence = information["input_ids"][0]
|
||
overflowing_tokens = information["input_ids"][1]
|
||
xpath_tags_seq = information["xpath_tags_seq"][0]
|
||
overflowing_xpath_tags_seq = information["xpath_tags_seq"][1]
|
||
self.assertEqual(len(information["input_ids"]), 2)
|
||
|
||
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
|
||
self.assertEqual(truncated_sequence, truncated_longest_sequence)
|
||
|
||
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
|
||
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
|
||
self.assertEqual(xpath_tags_seq, xpath_tags_seq_longest_sequence)
|
||
|
||
self.assertEqual(len(overflowing_xpath_tags_seq), 2 + stride + len(smallest))
|
||
self.assertEqual(
|
||
overflowing_xpath_tags_seq, overflowing_token_xpath_tags_seq_longest_sequence_fast
|
||
)
|
||
else:
|
||
# No overflowing tokens when using 'longest' in python tokenizers
|
||
with self.assertRaises(ValueError) as context:
|
||
information = tokenizer(
|
||
question_0,
|
||
seq_1,
|
||
xpaths=xpaths_1,
|
||
max_length=len(sequence["input_ids"]) - 2,
|
||
add_special_tokens=False,
|
||
stride=stride,
|
||
truncation="longest_first",
|
||
return_overflowing_tokens=True,
|
||
# add_prefix_space=False,
|
||
)
|
||
|
||
self.assertTrue(
|
||
context.exception.args[0].startswith(
|
||
"Not possible to return overflowing tokens for pair of sequences with the "
|
||
"`longest_first`. Please select another truncation strategy than `longest_first`, "
|
||
"for instance `only_second` or `only_first`."
|
||
)
|
||
)
|
||
|
||
# Overflowing tokens are handled quite differently in slow and fast tokenizers
|
||
if isinstance(tokenizer, MarkupLMTokenizerFast):
|
||
information = tokenizer(
|
||
question_0,
|
||
seq_1,
|
||
xpaths=xpaths_1,
|
||
max_length=len(sequence["input_ids"]) - 2,
|
||
add_special_tokens=False,
|
||
stride=stride,
|
||
truncation=True,
|
||
return_overflowing_tokens=True,
|
||
)
|
||
truncated_sequence = information["input_ids"][0]
|
||
overflowing_tokens = information["input_ids"][1]
|
||
xpath_tags_seq = information["xpath_tags_seq"][0]
|
||
overflowing_xpath_tags_seq = information["xpath_tags_seq"][1]
|
||
self.assertEqual(len(information["input_ids"]), 2)
|
||
|
||
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
|
||
self.assertEqual(truncated_sequence, truncated_longest_sequence)
|
||
|
||
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
|
||
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
|
||
self.assertEqual(xpath_tags_seq, xpath_tags_seq_longest_sequence)
|
||
self.assertEqual(
|
||
overflowing_xpath_tags_seq, overflowing_token_xpath_tags_seq_longest_sequence_fast
|
||
)
|
||
else:
|
||
# No overflowing tokens when using 'longest' in python tokenizers
|
||
with self.assertRaises(ValueError) as context:
|
||
information = tokenizer(
|
||
question_0,
|
||
seq_1,
|
||
xpaths=xpaths_1,
|
||
max_length=len(sequence["input_ids"]) - 2,
|
||
add_special_tokens=False,
|
||
stride=stride,
|
||
truncation=True,
|
||
return_overflowing_tokens=True,
|
||
)
|
||
|
||
self.assertTrue(
|
||
context.exception.args[0].startswith(
|
||
"Not possible to return overflowing tokens for pair of sequences with the "
|
||
"`longest_first`. Please select another truncation strategy than `longest_first`, "
|
||
"for instance `only_second` or `only_first`."
|
||
)
|
||
)
|
||
|
||
information_first_truncated = tokenizer(
|
||
question_0,
|
||
seq_1,
|
||
xpaths=xpaths_1,
|
||
max_length=len(sequence["input_ids"]) - 2,
|
||
add_special_tokens=False,
|
||
stride=stride,
|
||
truncation="only_first",
|
||
return_overflowing_tokens=True,
|
||
)
|
||
# Overflowing tokens are handled quite differently in slow and fast tokenizers
|
||
if isinstance(tokenizer, MarkupLMTokenizerFast):
|
||
truncated_sequence = information_first_truncated["input_ids"][0]
|
||
overflowing_tokens = information_first_truncated["input_ids"][1]
|
||
xpath_tags_seq = information_first_truncated["xpath_tags_seq"][0]
|
||
overflowing_xpath_tags_seq = information_first_truncated["xpath_tags_seq"][1]
|
||
self.assertEqual(len(information_first_truncated["input_ids"]), 2)
|
||
|
||
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
|
||
self.assertEqual(truncated_sequence, truncated_first_sequence)
|
||
|
||
self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq1_tokens["input_ids"]))
|
||
self.assertEqual(overflowing_tokens, overflow_first_sequence)
|
||
self.assertEqual(xpath_tags_seq, xpath_tags_seq_first_sequence)
|
||
# ISSUE HAPPENS HERE ↓
|
||
self.assertEqual(overflowing_xpath_tags_seq, overflowing_token_xpath_tags_seq_first_sequence_fast)
|
||
else:
|
||
truncated_sequence = information_first_truncated["input_ids"]
|
||
overflowing_tokens = information_first_truncated["overflowing_tokens"]
|
||
overflowing_xpath_tags_seq = information_first_truncated["overflowing_xpath_tags_seq"]
|
||
xpath_tags_seq = information_first_truncated["xpath_tags_seq"]
|
||
|
||
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
|
||
self.assertEqual(truncated_sequence, truncated_first_sequence)
|
||
|
||
self.assertEqual(len(overflowing_tokens), 2 + stride)
|
||
self.assertEqual(overflowing_tokens, seq0_tokens["input_ids"][-(2 + stride) :])
|
||
self.assertEqual(xpath_tags_seq, xpath_tags_seq_first_sequence)
|
||
self.assertEqual(overflowing_xpath_tags_seq, overflowing_token_xpath_tags_seq_first_sequence_slow)
|
||
|
||
information_second_truncated = tokenizer(
|
||
question_0,
|
||
seq_1,
|
||
xpaths=xpaths_1,
|
||
max_length=len(sequence["input_ids"]) - 2,
|
||
add_special_tokens=False,
|
||
stride=stride,
|
||
truncation="only_second",
|
||
return_overflowing_tokens=True,
|
||
# add_prefix_space=False,
|
||
)
|
||
# Overflowing tokens are handled quite differently in slow and fast tokenizers
|
||
if isinstance(tokenizer, MarkupLMTokenizerFast):
|
||
truncated_sequence = information_second_truncated["input_ids"][0]
|
||
overflowing_tokens = information_second_truncated["input_ids"][1]
|
||
xpath_tags_seq = information_second_truncated["xpath_tags_seq"][0]
|
||
overflowing_xpath_tags_seq = information_second_truncated["xpath_tags_seq"][1]
|
||
|
||
self.assertEqual(len(information_second_truncated["input_ids"]), 2)
|
||
|
||
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
|
||
self.assertEqual(truncated_sequence, truncated_second_sequence)
|
||
|
||
self.assertEqual(len(overflowing_tokens), 2 + stride + len(seq0_tokens["input_ids"]))
|
||
self.assertEqual(overflowing_tokens, overflow_second_sequence)
|
||
self.assertEqual(xpath_tags_seq, xpath_tags_seq_second_sequence)
|
||
self.assertEqual(overflowing_xpath_tags_seq, overflowing_token_xpath_tags_seq_second_sequence_fast)
|
||
else:
|
||
truncated_sequence = information_second_truncated["input_ids"]
|
||
overflowing_tokens = information_second_truncated["overflowing_tokens"]
|
||
xpath_tags_seq = information_second_truncated["xpath_tags_seq"]
|
||
overflowing_xpath_tags_seq = information_second_truncated["overflowing_xpath_tags_seq"]
|
||
|
||
self.assertEqual(len(truncated_sequence), len(sequence["input_ids"]) - 2)
|
||
self.assertEqual(truncated_sequence, truncated_second_sequence)
|
||
|
||
self.assertEqual(len(overflowing_tokens), 2 + stride)
|
||
self.assertEqual(overflowing_tokens, seq1_tokens["input_ids"][-(2 + stride) :])
|
||
self.assertEqual(xpath_tags_seq, xpath_tags_seq_second_sequence)
|
||
self.assertEqual(overflowing_xpath_tags_seq, overflowing_token_xpath_tags_seq_second_sequence_slow)
|
||
|
||
def test_maximum_encoding_length_single_input(self):
|
||
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
seq_0, xpaths_0, ids = self.get_clean_sequence(tokenizer, max_length=20)
|
||
|
||
sequence = tokenizer(seq_0, xpaths=xpaths_0, add_special_tokens=False)
|
||
total_length = len(sequence["input_ids"])
|
||
|
||
self.assertGreater(total_length, 4, "Issue with the testing sequence, please update it it's too short")
|
||
|
||
# Test with max model input length
|
||
model_max_length = tokenizer.model_max_length
|
||
self.assertEqual(model_max_length, 100)
|
||
seq_1 = seq_0 * model_max_length
|
||
xpaths_1 = xpaths_0 * model_max_length
|
||
sequence1 = tokenizer(seq_1, xpaths=xpaths_1, add_special_tokens=False)
|
||
total_length1 = len(sequence1["input_ids"])
|
||
self.assertGreater(
|
||
total_length1, model_max_length, "Issue with the testing sequence, please update it it's too short"
|
||
)
|
||
|
||
# Simple
|
||
padding_strategies = (
|
||
[False, True, "longest"] if tokenizer.pad_token and tokenizer.pad_token_id >= 0 else [False]
|
||
)
|
||
for padding_state in padding_strategies:
|
||
with self.subTest(f"Padding: {padding_state}"):
|
||
for truncation_state in [True, "longest_first", "only_first"]:
|
||
with self.subTest(f"Truncation: {truncation_state}"):
|
||
output = tokenizer(
|
||
seq_1,
|
||
xpaths=xpaths_1,
|
||
padding=padding_state,
|
||
truncation=truncation_state,
|
||
)
|
||
self.assertEqual(len(output["input_ids"]), model_max_length)
|
||
self.assertEqual(len(output["xpath_tags_seq"]), model_max_length)
|
||
self.assertEqual(len(output["xpath_subs_seq"]), model_max_length)
|
||
|
||
output = tokenizer(
|
||
[seq_1],
|
||
xpaths=[xpaths_1],
|
||
padding=padding_state,
|
||
truncation=truncation_state,
|
||
)
|
||
self.assertEqual(len(output["input_ids"][0]), model_max_length)
|
||
self.assertEqual(len(output["xpath_tags_seq"][0]), model_max_length)
|
||
self.assertEqual(len(output["xpath_subs_seq"][0]), model_max_length)
|
||
|
||
# Simple with no truncation
|
||
# Reset warnings
|
||
tokenizer.deprecation_warnings = {}
|
||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||
output = tokenizer(seq_1, xpaths=xpaths_1, padding=padding_state, truncation=False)
|
||
self.assertNotEqual(len(output["input_ids"]), model_max_length)
|
||
self.assertNotEqual(len(output["xpath_tags_seq"]), model_max_length)
|
||
self.assertNotEqual(len(output["xpath_subs_seq"]), model_max_length)
|
||
self.assertEqual(len(cm.records), 1)
|
||
self.assertTrue(
|
||
cm.records[0].message.startswith(
|
||
"Token indices sequence length is longer than the specified maximum sequence length"
|
||
" for this model"
|
||
)
|
||
)
|
||
|
||
tokenizer.deprecation_warnings = {}
|
||
with self.assertLogs("transformers", level="WARNING") as cm:
|
||
output = tokenizer([seq_1], xpaths=[xpaths_1], padding=padding_state, truncation=False)
|
||
self.assertNotEqual(len(output["input_ids"][0]), model_max_length)
|
||
self.assertNotEqual(len(output["xpath_tags_seq"][0]), model_max_length)
|
||
self.assertNotEqual(len(output["xpath_subs_seq"][0]), model_max_length)
|
||
self.assertEqual(len(cm.records), 1)
|
||
self.assertTrue(
|
||
cm.records[0].message.startswith(
|
||
"Token indices sequence length is longer than the specified maximum sequence length"
|
||
" for this model"
|
||
)
|
||
)
|
||
# Check the order of Sequence of input ids, overflowing tokens, xpath_tags_seq and xpath_subs_seq sequence with truncation
|
||
stride = 2
|
||
information = tokenizer(
|
||
seq_0,
|
||
xpaths=xpaths_0,
|
||
max_length=total_length - 2,
|
||
add_special_tokens=False,
|
||
stride=stride,
|
||
truncation=True,
|
||
return_overflowing_tokens=True,
|
||
)
|
||
|
||
# Overflowing tokens are handled quite differently in slow and fast tokenizers
|
||
if isinstance(tokenizer, MarkupLMTokenizerFast):
|
||
truncated_sequence = information["input_ids"][0]
|
||
overflowing_tokens = information["input_ids"][1]
|
||
xpath_tags_seq = information["xpath_tags_seq"][0]
|
||
overflowing_xpath_tags_seq = information["xpath_tags_seq"][1]
|
||
self.assertEqual(len(information["input_ids"]), 2)
|
||
|
||
self.assertEqual(len(truncated_sequence), total_length - 2)
|
||
self.assertEqual(truncated_sequence, sequence["input_ids"][:-2])
|
||
|
||
self.assertEqual(len(overflowing_tokens), 2 + stride)
|
||
self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :])
|
||
|
||
self.assertEqual(xpath_tags_seq, sequence["xpath_tags_seq"][:-2])
|
||
self.assertEqual(overflowing_xpath_tags_seq, sequence["xpath_tags_seq"][-(2 + stride) :])
|
||
else:
|
||
truncated_sequence = information["input_ids"]
|
||
overflowing_tokens = information["overflowing_tokens"]
|
||
xpath_tags_seq = information["xpath_tags_seq"]
|
||
overflowing_xpath_tags_seq = information["overflowing_xpath_tags_seq"]
|
||
self.assertEqual(len(truncated_sequence), total_length - 2)
|
||
self.assertEqual(truncated_sequence, sequence["input_ids"][:-2])
|
||
|
||
self.assertEqual(len(overflowing_tokens), 2 + stride)
|
||
self.assertEqual(overflowing_tokens, sequence["input_ids"][-(2 + stride) :])
|
||
self.assertEqual(xpath_tags_seq, sequence["xpath_tags_seq"][:-2])
|
||
self.assertEqual(overflowing_xpath_tags_seq, sequence["xpath_tags_seq"][-(2 + stride) :])
|
||
|
||
@unittest.skip(reason="MarkupLM tokenizer requires xpaths besides sequences.")
|
||
def test_pretokenized_inputs(self):
|
||
pass
|
||
|
||
@unittest.skip(reason="MarkupLM tokenizer always expects pretokenized inputs.")
|
||
def test_compare_pretokenized_inputs(self):
|
||
pass
|
||
|
||
@unittest.skip(reason="MarkupLM fast tokenizer does not support prepare_for_model")
|
||
def test_compare_prepare_for_model(self):
|
||
pass
|
||
|
||
@slow
|
||
def test_only_label_first_subword(self):
|
||
nodes = ["hello", "niels"]
|
||
xpaths = ["/html/body/div/li[1]/div/span" for _ in range(len(nodes))]
|
||
node_labels = [0, 1]
|
||
|
||
# test slow tokenizer
|
||
tokenizer_p = MarkupLMTokenizer.from_pretrained("microsoft/markuplm-base")
|
||
encoding = tokenizer_p(nodes, xpaths=xpaths, node_labels=node_labels)
|
||
self.assertListEqual(encoding.labels, [-100, 0, 1, -100, -100])
|
||
|
||
tokenizer_p = MarkupLMTokenizer.from_pretrained("microsoft/markuplm-base", only_label_first_subword=False)
|
||
encoding = tokenizer_p(nodes, xpaths=xpaths, node_labels=node_labels)
|
||
self.assertListEqual(encoding.labels, [-100, 0, 1, 1, -100])
|
||
|
||
# test fast tokenizer
|
||
tokenizer_r = MarkupLMTokenizerFast.from_pretrained("microsoft/markuplm-base")
|
||
encoding = tokenizer_r(nodes, xpaths=xpaths, node_labels=node_labels)
|
||
self.assertListEqual(encoding.labels, [-100, 0, 1, -100, -100])
|
||
|
||
tokenizer_r = MarkupLMTokenizerFast.from_pretrained("microsoft/markuplm-base", only_label_first_subword=False)
|
||
encoding = tokenizer_r(nodes, xpaths=xpaths, node_labels=node_labels)
|
||
self.assertListEqual(encoding.labels, [-100, 0, 1, 1, -100])
|
||
|
||
def test_markuplm_integration_test(self):
|
||
tokenizer_p = MarkupLMTokenizer.from_pretrained("microsoft/markuplm-base")
|
||
tokenizer_r = MarkupLMTokenizerFast.from_pretrained("microsoft/markuplm-base")
|
||
|
||
# There are 3 cases:
|
||
# CASE 1: document image classification (training + inference), document image token classification (inference),
|
||
# in which case only nodes and normalized bounding xpaths are provided to the tokenizer
|
||
# CASE 2: document image token classification (training),
|
||
# in which case one also provides word labels to the tokenizer
|
||
# CASE 3: document image visual question answering (inference),
|
||
# in which case one also provides a question to the tokenizer
|
||
|
||
# We need to test all 3 cases both on batched and non-batched inputs.
|
||
|
||
# CASE 1: not batched
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
|
||
expected_results = {'input_ids': [0, 42891, 8331, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'xpath_tags_seq': [[216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216]], 'xpath_subs_seq': [[1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001]], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # fmt: skip
|
||
|
||
encoding_p = tokenizer_p(nodes, xpaths=xpaths, padding="max_length", max_length=20)
|
||
encoding_r = tokenizer_r(nodes, xpaths=xpaths, padding="max_length", max_length=20)
|
||
self.assertDictEqual(dict(encoding_p), expected_results)
|
||
self.assertDictEqual(dict(encoding_r), expected_results)
|
||
|
||
# CASE 1: batched
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
|
||
expected_results = {'input_ids': [[0, 42891, 232, 12364, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 42891, 127, 766, 16, 22401, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'xpath_tags_seq': [[[216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216]], [[216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216]]], 'xpath_subs_seq': [[[1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001]], [[1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001]]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # fmt: skip
|
||
|
||
encoding_p = tokenizer_p(nodes, xpaths=xpaths, padding="max_length", max_length=20)
|
||
encoding_r = tokenizer_r(nodes, xpaths=xpaths, padding="max_length", max_length=20)
|
||
self.assertDictEqual(dict(encoding_p), expected_results)
|
||
self.assertDictEqual(dict(encoding_r), expected_results)
|
||
|
||
# CASE 2: not batched
|
||
nodes, xpaths = self.get_nodes_and_xpaths()
|
||
node_labels = [1, 2, 3]
|
||
|
||
expected_results = {'input_ids': [0, 42891, 8331, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'xpath_tags_seq': [[216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216]], 'xpath_subs_seq': [[1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001]], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [-100, 1, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], 'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # fmt: skip
|
||
|
||
encoding_p = tokenizer_p(nodes, xpaths=xpaths, node_labels=node_labels, padding="max_length", max_length=20)
|
||
encoding_r = tokenizer_r(nodes, xpaths=xpaths, node_labels=node_labels, padding="max_length", max_length=20)
|
||
self.assertDictEqual(dict(encoding_p), expected_results)
|
||
self.assertDictEqual(dict(encoding_r), expected_results)
|
||
|
||
# CASE 2: batched
|
||
nodes, xpaths = self.get_nodes_and_xpaths_batch()
|
||
node_labels = [[1, 2, 3], [2, 46, 17, 22, 3]]
|
||
|
||
expected_results = {'input_ids': [[0, 42891, 232, 12364, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 42891, 127, 766, 16, 22401, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'xpath_tags_seq': [[[216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216]], [[216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216]]], 'xpath_subs_seq': [[[1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001]], [[1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001]]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'labels': [[-100, 1, -100, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [-100, 2, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]], 'attention_mask': [[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # fmt: skip
|
||
|
||
encoding_p = tokenizer_p(nodes, xpaths=xpaths, node_labels=node_labels, padding="max_length", max_length=20)
|
||
encoding_r = tokenizer_r(nodes, xpaths=xpaths, node_labels=node_labels, padding="max_length", max_length=20)
|
||
self.assertDictEqual(dict(encoding_p), expected_results)
|
||
self.assertDictEqual(dict(encoding_r), expected_results)
|
||
|
||
# CASE 3: not batched
|
||
question, nodes, xpaths = self.get_question_nodes_and_xpaths()
|
||
|
||
expected_results = {'input_ids': [0, 12196, 18, 39, 766, 116, 2, 42891, 232, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'xpath_tags_seq': [[216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216]], 'xpath_subs_seq': [[1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001]], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # fmt: skip
|
||
|
||
encoding_p = tokenizer_p(question, nodes, xpaths, padding="max_length", max_length=20)
|
||
encoding_r = tokenizer_r(question, nodes, xpaths, padding="max_length", max_length=20)
|
||
self.assertDictEqual(dict(encoding_p), expected_results)
|
||
self.assertDictEqual(dict(encoding_r), expected_results)
|
||
|
||
# CASE 3: batched
|
||
questions, nodes, xpaths = self.get_question_nodes_and_xpaths_batch()
|
||
|
||
expected_results = {'input_ids': [[0, 12196, 18, 39, 766, 116, 2, 42891, 232, 12364, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 9178, 16, 37, 373, 116, 2, 42891, 127, 766, 16, 22401, 2, 1, 1, 1, 1, 1, 1, 1]], 'xpath_tags_seq': [[[216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216]], [[216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [109, 25, 50, 120, 50, 178, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216], [216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216, 216]]], 'xpath_subs_seq': [[[1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 1, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001]], [[1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [0, 0, 0, 2, 0, 0, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001], [1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001, 1001]]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]} # fmt: skip
|
||
|
||
encoding_p = tokenizer_p(questions, nodes, xpaths, padding="max_length", max_length=20)
|
||
encoding_r = tokenizer_r(questions, nodes, xpaths, padding="max_length", max_length=20)
|
||
self.assertDictEqual(dict(encoding_p), expected_results)
|
||
self.assertDictEqual(dict(encoding_r), expected_results)
|
||
|
||
@unittest.skip(reason="Doesn't support returning Numpy arrays")
|
||
def test_np_encode_plus_sent_to_model(self):
|
||
pass
|
||
|
||
@unittest.skip(reason="Chat is not supported")
|
||
def test_chat_template(self):
|
||
pass
|
||
|
||
@unittest.skip(reason="The model tested fails `Hub -> Fast == Hub -> Slow`, nothing much we can do")
|
||
def test_added_tokens_serialization(self):
|
||
pass
|
||
|
||
@unittest.skip("Chat is not supported")
|
||
def test_chat_template_return_assistant_tokens_mask(self):
|
||
pass
|
||
|
||
@unittest.skip("Chat is not supported")
|
||
def test_chat_template_return_assistant_tokens_mask_truncated(self):
|
||
pass
|
||
|
||
def test_empty_input_string(self):
|
||
tokenizer_return_type = []
|
||
output_tensor_type = []
|
||
|
||
if is_torch_available():
|
||
import numpy as np
|
||
import torch
|
||
|
||
tokenizer_return_type.append("pt")
|
||
output_tensor_type.append(torch.int64)
|
||
tokenizer_return_type.append("np")
|
||
output_tensor_type.append(np.int64)
|
||
|
||
if is_mlx_available():
|
||
import mlx.core as mx
|
||
|
||
tokenizer_return_type.append("mlx")
|
||
output_tensor_type.append(mx.int32)
|
||
|
||
if len(tokenizer_return_type) == 0:
|
||
self.skipTest(reason="No expected framework from PT, or MLX found")
|
||
|
||
tokenizers = self.get_tokenizers()
|
||
for tokenizer in tokenizers:
|
||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||
nodes, xpaths = self.get_empty_nodes_and_xpaths()
|
||
for return_type, target_type in zip(tokenizer_return_type, output_tensor_type):
|
||
output = tokenizer(nodes, xpaths=xpaths, return_tensors=return_type)
|
||
self.assertEqual(output.input_ids.dtype, target_type)
|
||
|
||
question, nodes, xpaths = self.get_empty_question_nodes_and_xpaths()
|
||
for return_type, target_type in zip(tokenizer_return_type, output_tensor_type):
|
||
output = tokenizer(nodes, xpaths=xpaths, return_tensors=return_type)
|
||
self.assertEqual(output.input_ids.dtype, target_type)
|
||
|
||
nodes, xpaths = self.get_empty_nodes_and_xpaths_batch()
|
||
for return_type, target_type in zip(tokenizer_return_type, output_tensor_type):
|
||
output = tokenizer(nodes, xpaths=xpaths, padding=True, return_tensors=return_type)
|
||
self.assertEqual(output.input_ids.dtype, target_type)
|
||
|
||
question, nodes, xpaths = self.get_empty_question_nodes_and_xpaths_batch()
|
||
for return_type, target_type in zip(tokenizer_return_type, output_tensor_type):
|
||
output = tokenizer(nodes, xpaths=xpaths, padding=True, return_tensors=return_type)
|
||
self.assertEqual(output.input_ids.dtype, target_type)
|