first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
0
tests/models/moshi/__init__.py
Normal file
0
tests/models/moshi/__init__.py
Normal file
976
tests/models/moshi/test_modeling_moshi.py
Normal file
976
tests/models/moshi/test_modeling_moshi.py
Normal file
@@ -0,0 +1,976 @@
|
||||
# Copyright 2024, The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Testing suite for the PyTorch Moshi model."""
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import cached_property
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from datasets import Audio, load_dataset
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
MoshiConfig,
|
||||
PreTrainedConfig,
|
||||
)
|
||||
from transformers.integrations.deepspeed import (
|
||||
is_deepspeed_available,
|
||||
is_deepspeed_zero3_enabled,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
is_torch_available,
|
||||
require_torch,
|
||||
require_torch_fp16,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION,
|
||||
ModelTesterMixin,
|
||||
floats_tensor,
|
||||
ids_tensor,
|
||||
)
|
||||
from ...test_pipeline_mixin import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
MoshiForCausalLM,
|
||||
MoshiForConditionalGeneration,
|
||||
MoshiModel,
|
||||
)
|
||||
|
||||
|
||||
def _config_zero_init(config):
|
||||
configs_no_init = copy.deepcopy(config)
|
||||
for key in configs_no_init.__dict__:
|
||||
if "_range" in key or "_std" in key or "initializer_factor" in key or "layer_scale" in key:
|
||||
setattr(configs_no_init, key, 1e-10)
|
||||
if isinstance(getattr(configs_no_init, key, None), PreTrainedConfig):
|
||||
no_init_subconfig = _config_zero_init(getattr(configs_no_init, key))
|
||||
setattr(configs_no_init, key, no_init_subconfig)
|
||||
return configs_no_init
|
||||
|
||||
|
||||
class MoshiDecoderTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=4, # need batch_size != num_hidden_layers
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=4,
|
||||
hidden_act="silu",
|
||||
rms_norm_eps=0.001,
|
||||
ffn_dim=32,
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=100,
|
||||
pad_token_id=25,
|
||||
num_codebooks=4,
|
||||
audio_encoder_type="mimi",
|
||||
attn_implementation="eager",
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.ffn_dim = ffn_dim
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.pad_token_id = pad_token_id
|
||||
self.num_codebooks = num_codebooks
|
||||
self.audio_encoder_type = audio_encoder_type
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
def prepare_config_and_inputs(self, batch_size=None):
|
||||
batch_size = self.batch_size if batch_size is None else batch_size
|
||||
input_ids = ids_tensor([batch_size, self.seq_length], self.vocab_size)
|
||||
config = self.get_config()
|
||||
|
||||
attention_mask = input_ids.ne(self.pad_token_id)
|
||||
|
||||
inputs_dict = {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def get_config(self):
|
||||
config = MoshiConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
d_ff=self.intermediate_size,
|
||||
num_codebooks=self.num_codebooks,
|
||||
rms_norm_eps=self.rms_norm_eps,
|
||||
tie_word_embeddings=False,
|
||||
pad_token_id=self.pad_token_id,
|
||||
ffn_dim=self.ffn_dim,
|
||||
audio_encoder_config={"model_type": self.audio_encoder_type},
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_common(self, batch_size=None):
|
||||
config, inputs_dict = self.prepare_config_and_inputs(batch_size)
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class MoshiDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (MoshiModel, MoshiForCausalLM) if is_torch_available() else ()
|
||||
|
||||
test_resize_embeddings = True
|
||||
pipeline_model_mapping = (
|
||||
{
|
||||
"feature-extraction": MoshiModel,
|
||||
"text-generation": MoshiForCausalLM,
|
||||
}
|
||||
if is_torch_available()
|
||||
else {}
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MoshiDecoderTester(self)
|
||||
self.config_tester = ConfigTester(
|
||||
self,
|
||||
config_class=MoshiConfig,
|
||||
hidden_size=16,
|
||||
audio_encoder_config={"model_type": self.model_tester.audio_encoder_type},
|
||||
)
|
||||
|
||||
@unittest.skip(reason="The MoshiModel does not have support dynamic compile yet")
|
||||
@pytest.mark.torch_compile_test
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=1):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(batch_size)
|
||||
input_ids = inputs_dict.pop("input_ids").to(torch_device)
|
||||
attention_mask = inputs_dict.pop("attention_mask").to(torch_device)
|
||||
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def _get_logits_processor_kwargs(self, do_sample=False, config=None):
|
||||
logits_processor_kwargs = {}
|
||||
return logits_processor_kwargs
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
if use_attention_mask or (not use_attention_mask and dtype == "fp32" and not output_attentions):
|
||||
self.skipTest("Test is failing, fix me :) ")
|
||||
parent_parameterized_test = getattr(ModelTesterMixin, self._testMethodName)
|
||||
parent_parameterized_test(self)
|
||||
|
||||
# Copied from tests.test_modeling_common.ModelTesterMixin.test_resize_tokens_embeddings
|
||||
def test_resize_tokens_embeddings(self):
|
||||
if not self.test_resize_embeddings:
|
||||
self.skipTest(reason="test_resize_embeddings is set to `False`")
|
||||
|
||||
(
|
||||
original_config,
|
||||
inputs_dict,
|
||||
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config = copy.deepcopy(original_config)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
with deepspeed.zero.Init():
|
||||
model = model_class(config)
|
||||
else:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
model_embed_pre_resize = model.get_input_embeddings()
|
||||
type_model_embed_pre_resize = type(model_embed_pre_resize)
|
||||
|
||||
if self.model_tester.is_training is False:
|
||||
model.eval()
|
||||
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
# Retrieve the embeddings and clone theme
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||
cloned_embeddings = model_embed.weight.clone()
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||
# Check to make sure the type of embeddings returned post resizing is same as type of input
|
||||
type_model_embed_post_resize = type(model_embed)
|
||||
self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize)
|
||||
# Check that added embeddings mean is close to the old embeddings mean
|
||||
if is_deepspeed_zero3_enabled():
|
||||
with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None):
|
||||
old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
|
||||
new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
|
||||
else:
|
||||
old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
|
||||
new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
|
||||
torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
# A distriputed launcher is needed for the forward pass when deepspeed is enabled
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size - 15)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||
|
||||
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||
# Input ids should be clamped to the maximum size of the vocabulary
|
||||
inputs_dict["input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
|
||||
# make sure that decoder_input_ids are resized as well
|
||||
if not is_deepspeed_zero3_enabled():
|
||||
# A distriputed launcher is needed for the forward pass when deepspeed is enabled
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||
models_equal = True
|
||||
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
|
||||
if p1.data.ne(p2.data).sum() > 0:
|
||||
models_equal = False
|
||||
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
del model
|
||||
if is_deepspeed_zero3_enabled():
|
||||
with deepspeed.zero.Init():
|
||||
model = model_class(config)
|
||||
else:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
model.resize_token_embeddings(model_vocab_size + 10, pad_to_multiple_of=1)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertTrue(new_model_vocab_size + 10, model_vocab_size)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=64)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
|
||||
self.assertTrue(model_embed.weight.shape[0], new_model_vocab_size)
|
||||
self.assertTrue(new_model_vocab_size, model.vocab_size)
|
||||
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 13, pad_to_multiple_of=64)
|
||||
self.assertTrue(model_embed.weight.shape[0] // 64, 0)
|
||||
|
||||
# Check that resizing a model to a multiple of pad_to_multiple leads to a model of exactly that size
|
||||
target_dimension = 128
|
||||
model_embed = model.resize_token_embeddings(target_dimension, pad_to_multiple_of=64)
|
||||
self.assertTrue(model_embed.weight.shape[0], target_dimension)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Asking to pad the embedding matrix to a multiple of `1.3`, which is not and integer. Please make sure to pass an integer",
|
||||
):
|
||||
model.resize_token_embeddings(model_vocab_size, pad_to_multiple_of=1.3)
|
||||
|
||||
# Test when `vocab_size` is smaller than `hidden_size`.
|
||||
del model
|
||||
config.vocab_size = 4
|
||||
config.pad_token_id = 4 # Ignore copy
|
||||
if is_deepspeed_zero3_enabled():
|
||||
with deepspeed.zero.Init():
|
||||
model = model_class(config)
|
||||
else:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
|
||||
model_vocab_size = config.get_text_config().vocab_size
|
||||
# Retrieve the embeddings and clone theme
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||
cloned_embeddings = model_embed.weight.clone()
|
||||
|
||||
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||
new_model_vocab_size = model.config.get_text_config().vocab_size
|
||||
self.assertEqual(new_model_vocab_size, model_vocab_size + 10)
|
||||
# Check that it actually resizes the embeddings matrix
|
||||
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||
# Check to make sure the type of embeddings returned post resizing is same as type of input
|
||||
type_model_embed_post_resize = type(model_embed)
|
||||
self.assertEqual(type_model_embed_pre_resize, type_model_embed_post_resize)
|
||||
# Check that added embeddings mean is close to the old embeddings mean
|
||||
if is_deepspeed_zero3_enabled():
|
||||
with deepspeed.zero.GatheredParameters(model_embed.weight, modifier_rank=None):
|
||||
old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
|
||||
new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
|
||||
else:
|
||||
old_embeddings_mean = torch.mean(model_embed.weight.data[:-10, :], axis=0)
|
||||
new_embeddings_mean = torch.mean(model_embed.weight.data[-10:, :], axis=0)
|
||||
torch.testing.assert_close(old_embeddings_mean, new_embeddings_mean, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple input modalities.")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@is_flaky(max_attempts=5, description="flaky on some models.")
|
||||
def test_save_load(self):
|
||||
super().test_save_load()
|
||||
|
||||
|
||||
class MoshiTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=4, # need batch_size != num_hidden_layers
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
vocab_size=99,
|
||||
hidden_size=32,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=8,
|
||||
intermediate_size=4,
|
||||
hidden_act="silu",
|
||||
rms_norm_eps=0.001,
|
||||
ffn_dim=32,
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=100,
|
||||
pad_token_id=25,
|
||||
bos_token_id=25,
|
||||
num_codebooks=4,
|
||||
audio_encoder_type="mimi",
|
||||
attn_implementation="eager",
|
||||
depth_hidden_size=16,
|
||||
depth_num_hidden_layers=2,
|
||||
depth_max_position_embeddings=5,
|
||||
depth_num_attention_heads=8,
|
||||
depth_ffn_dim=16,
|
||||
depth_sliding_window=4,
|
||||
mimi_intermediate_size=40,
|
||||
mimi_hidden_size=32,
|
||||
mimi_num_filters=8,
|
||||
mimi_num_residual_layers=1,
|
||||
mimi_upsampling_ratios=[8, 4],
|
||||
mimi_codebook_size=64,
|
||||
mimi_vector_quantization_hidden_dimension=64,
|
||||
mimi_codebook_dim=64,
|
||||
mimi_upsample_groups=32,
|
||||
mimi_num_hidden_layers=2,
|
||||
mimi_num_attention_heads=2,
|
||||
mimi_num_key_value_heads=2,
|
||||
mimi_sliding_window=3,
|
||||
sampling_rate=800,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
self.seq_length = seq_length
|
||||
self.is_training = is_training
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.ffn_dim = ffn_dim
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.pad_token_id = pad_token_id
|
||||
self.bos_token_id = bos_token_id
|
||||
self.num_codebooks = num_codebooks
|
||||
self.attn_implementation = attn_implementation
|
||||
self.depth_hidden_size = depth_hidden_size
|
||||
self.depth_num_hidden_layers = depth_num_hidden_layers
|
||||
self.depth_max_position_embeddings = depth_max_position_embeddings
|
||||
self.depth_num_attention_heads = depth_num_attention_heads
|
||||
self.depth_ffn_dim = depth_ffn_dim
|
||||
self.depth_sliding_window = depth_sliding_window
|
||||
|
||||
self.audio_encoder_type = audio_encoder_type
|
||||
self.mimi_intermediate_size = mimi_intermediate_size
|
||||
self.mimi_hidden_size = mimi_hidden_size
|
||||
self.mimi_num_filters = mimi_num_filters
|
||||
self.mimi_num_residual_layers = mimi_num_residual_layers
|
||||
self.mimi_upsampling_ratios = mimi_upsampling_ratios
|
||||
self.mimi_codebook_size = mimi_codebook_size
|
||||
self.mimi_vector_quantization_hidden_dimension = mimi_vector_quantization_hidden_dimension
|
||||
self.mimi_codebook_dim = mimi_codebook_dim
|
||||
self.mimi_upsample_groups = mimi_upsample_groups
|
||||
self.mimi_num_hidden_layers = mimi_num_hidden_layers
|
||||
self.mimi_num_attention_heads = mimi_num_attention_heads
|
||||
self.mimi_num_key_value_heads = mimi_num_key_value_heads
|
||||
self.mimi_sliding_window = mimi_sliding_window
|
||||
self.sampling_rate = sampling_rate
|
||||
|
||||
self.num_hidden_states_types = 2
|
||||
|
||||
def prepare_config_and_inputs(self, batch_size=None):
|
||||
batch_size = self.batch_size if batch_size is None else batch_size
|
||||
|
||||
input_ids = ids_tensor([batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
moshi_audio_codes = ids_tensor([batch_size, self.num_codebooks, self.seq_length], self.mimi_codebook_size)
|
||||
user_audio_codes = ids_tensor([batch_size, self.num_codebooks, self.seq_length], self.mimi_codebook_size)
|
||||
attention_mask = input_ids.ne(self.pad_token_id)
|
||||
|
||||
config = self.get_config()
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"moshi_audio_codes": moshi_audio_codes,
|
||||
"user_audio_codes": user_audio_codes,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def get_config(self):
|
||||
mimi_dict_config = {
|
||||
"model_type": self.audio_encoder_type,
|
||||
"audio_channels": 1,
|
||||
"hidden_size": self.mimi_hidden_size,
|
||||
"num_filters": self.mimi_num_filters,
|
||||
"num_residual_layers": self.mimi_num_residual_layers,
|
||||
"upsampling_ratios": self.mimi_upsampling_ratios,
|
||||
"codebook_size": self.mimi_codebook_size,
|
||||
"vector_quantization_hidden_dimension": self.mimi_vector_quantization_hidden_dimension,
|
||||
"upsample_groups": self.mimi_upsample_groups,
|
||||
"num_hidden_layers": self.mimi_num_hidden_layers,
|
||||
"num_attention_heads": self.mimi_num_attention_heads,
|
||||
"num_key_value_heads": self.mimi_num_key_value_heads,
|
||||
"sliding_window": self.mimi_sliding_window,
|
||||
"codebook_dim": self.mimi_codebook_dim,
|
||||
"use_cache": False,
|
||||
"sampling_rate": self.sampling_rate,
|
||||
}
|
||||
|
||||
depth_dict_config = {
|
||||
"hidden_size": self.depth_hidden_size,
|
||||
"num_hidden_layers": self.depth_num_hidden_layers,
|
||||
"max_position_embeddings": self.depth_max_position_embeddings,
|
||||
"num_attention_heads": self.depth_num_attention_heads,
|
||||
"ffn_dim": self.depth_ffn_dim,
|
||||
"sliding_window": self.depth_sliding_window,
|
||||
}
|
||||
|
||||
config = MoshiConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
d_ff=self.intermediate_size,
|
||||
num_codebooks=self.num_codebooks,
|
||||
rms_norm_eps=self.rms_norm_eps,
|
||||
tie_word_embeddings=False,
|
||||
pad_token_id=self.pad_token_id,
|
||||
bos_token_id=self.bos_token_id,
|
||||
ffn_dim=self.ffn_dim,
|
||||
audio_encoder_config=mimi_dict_config,
|
||||
depth_decoder_config=depth_dict_config,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
return config
|
||||
|
||||
def prepare_config_and_inputs_for_common(self, batch_size=None):
|
||||
config, inputs_dict = self.prepare_config_and_inputs(batch_size)
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (MoshiForConditionalGeneration,) if is_torch_available() else ()
|
||||
# training is not supported yet for Moshi
|
||||
test_resize_embeddings = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = MoshiTester(self)
|
||||
|
||||
# special case for labels
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
|
||||
|
||||
if return_labels:
|
||||
inputs_dict["text_labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length),
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def _get_input_ids_and_config(self, batch_size=2):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(batch_size)
|
||||
input_ids = inputs_dict.pop("input_ids").to(torch_device)
|
||||
attention_mask = inputs_dict.pop("attention_mask").to(torch_device)
|
||||
|
||||
# Make sure we only return `input_ids`.
|
||||
# Note that audio_codes will still be generated internally, so the ability to test audio codes is still there.
|
||||
# There are further tests to test that audio waveforms and codes are well generated.
|
||||
inputs_dict["return_audio_waveforms"] = False
|
||||
inputs_dict["return_audio_codes"] = False
|
||||
inputs_dict["concat_unconditional_inputs"] = False
|
||||
|
||||
return config, input_ids, attention_mask, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_generate(self, batch_size=2):
|
||||
config, filtered_inputs_dict = super().prepare_config_and_inputs_for_generate(batch_size=batch_size)
|
||||
|
||||
# Make sure we only return `input_ids`.
|
||||
# Note that audio_codes will still be generated internally, so the ability to test audio codes is still there.
|
||||
# There are further tests to test that audio waveforms and codes are well generated.
|
||||
filtered_inputs_dict["return_audio_waveforms"] = False
|
||||
filtered_inputs_dict["return_audio_codes"] = False
|
||||
filtered_inputs_dict["concat_unconditional_inputs"] = False
|
||||
|
||||
return config, filtered_inputs_dict
|
||||
|
||||
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||
# Overwrite because the generate method actually always uses `inputs_embeds` so `use_cache` is always `True`
|
||||
super()._check_generate_outputs(
|
||||
output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Continuing from past key values is not straightforward as we're dealing with 3 inputs")
|
||||
def test_generate_continue_from_past_key_values(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION)
|
||||
@unittest.skip(reason="Unimplemented. Relies on `test_eager_matches_sdpa_generate` to check correctness.")
|
||||
def test_eager_matches_sdpa_inference(
|
||||
self, name, dtype, padding_side, use_attention_mask, output_attentions, enable_kernels
|
||||
):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="The Moshi model does not have support dynamic compile yet")
|
||||
@pytest.mark.torch_compile_test
|
||||
def test_sdpa_can_compile_dynamic(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_left_padding_compatibility(self):
|
||||
# Overwrite -- Moshi needs to prepare the audio codes, and they must be padded accordingly
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
moshi_audio_codes = inputs_dict["moshi_audio_codes"]
|
||||
user_audio_codes = inputs_dict["user_audio_codes"]
|
||||
|
||||
pad_size = (input_ids.shape[0], 32)
|
||||
padding = (
|
||||
torch.ones((pad_size[0], self.model_tester.num_codebooks, 32), dtype=input_ids.dtype, device=torch_device)
|
||||
* config.audio_vocab_size
|
||||
)
|
||||
padded_moshi_audio_codes = torch.cat((padding, moshi_audio_codes), dim=2)
|
||||
padded_user_audio_codes = torch.cat((padding, user_audio_codes), dim=2)
|
||||
|
||||
# the audio codes are randomly generated in `prepare_config_and_inputs_for_generate`, and they must match
|
||||
# their padded version for the test to be valid -- we need to pass both
|
||||
unpadded_custom_inputs = {"moshi_audio_codes": moshi_audio_codes, "user_audio_codes": user_audio_codes}
|
||||
padded_custom_inputs = {
|
||||
"moshi_audio_codes": padded_moshi_audio_codes,
|
||||
"user_audio_codes": padded_user_audio_codes,
|
||||
}
|
||||
super().test_left_padding_compatibility(
|
||||
unpadded_custom_inputs=unpadded_custom_inputs, padded_custom_inputs=padded_custom_inputs
|
||||
)
|
||||
|
||||
@slow
|
||||
@is_flaky(max_attempts=5, description="flaky on some models.")
|
||||
def test_eager_matches_sdpa_generate(self):
|
||||
"""Overwritten -- mochi has custom inputs and custom output checks"""
|
||||
|
||||
max_new_tokens = 5
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if not model_class._supports_sdpa:
|
||||
self.skipTest(f"{model_class.__name__} does not support SDPA")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
dummy_input = inputs_dict[model_class.main_input_name]
|
||||
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
|
||||
dummy_input = dummy_input.to(torch.float16)
|
||||
|
||||
inputs_dict[model_class.main_input_name] = dummy_input
|
||||
|
||||
# make sure that all models have enough positions for generation
|
||||
if hasattr(config, "max_position_embeddings"):
|
||||
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
|
||||
|
||||
model = model_class(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
|
||||
model_sdpa = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
dtype=torch.float16,
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")
|
||||
|
||||
model_eager = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
dtype=torch.float16,
|
||||
attn_implementation="eager",
|
||||
).to(torch_device)
|
||||
|
||||
self.assertTrue(model_eager.config._attn_implementation == "eager")
|
||||
|
||||
for name, submodule in model_eager.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
raise ValueError("The eager model should not have SDPA attention layers")
|
||||
|
||||
has_sdpa = False
|
||||
for name, submodule in model_sdpa.named_modules():
|
||||
class_name = submodule.__class__.__name__
|
||||
if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name:
|
||||
has_sdpa = True
|
||||
break
|
||||
if not has_sdpa:
|
||||
raise ValueError("The SDPA model should have SDPA attention layers")
|
||||
|
||||
# Just test that a large cache works as expected
|
||||
res_eager = model_eager.generate(
|
||||
**inputs_dict,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
depth_decoder_do_sample=False,
|
||||
)
|
||||
|
||||
res_sdpa = model_sdpa.generate(
|
||||
**inputs_dict,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
depth_decoder_do_sample=False,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(res_eager.sequences, res_sdpa.sequences)
|
||||
torch.testing.assert_close(res_eager.audio_sequences, res_sdpa.audio_sequences)
|
||||
|
||||
@pytest.mark.generate
|
||||
def test_generate_without_input_ids(self):
|
||||
config, _, _, _ = self._get_input_ids_and_config()
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output_ids_generate = model.generate(
|
||||
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||
)
|
||||
print(output_ids_generate)
|
||||
self.assertIsNotNone(output_ids_generate)
|
||||
|
||||
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
|
||||
def test_training_gradient_checkpointing(self):
|
||||
super().test_training_gradient_checkpointing()
|
||||
|
||||
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
super().test_training_gradient_checkpointing_use_reentrant_false()
|
||||
|
||||
@pytest.mark.xfail(reason="This architecture seems to not compute gradients for some layer.")
|
||||
def test_training_gradient_checkpointing_use_reentrant_true(self):
|
||||
super().test_training_gradient_checkpointing_use_reentrant_true()
|
||||
|
||||
def test_generate_from_input_values(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, _, _ = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
input_values_length = int(
|
||||
self.model_tester.seq_length * config.sampling_rate / config.audio_encoder_config.frame_rate
|
||||
)
|
||||
|
||||
user_input_values = floats_tensor((input_ids.shape[0], 1, input_values_length))
|
||||
moshi_input_values = floats_tensor((input_ids.shape[0], 1, input_values_length))
|
||||
|
||||
user_audio_codes = model.audio_encoder.encode(user_input_values, num_quantizers=model.num_codebooks)[0]
|
||||
moshi_audio_codes = model.audio_encoder.encode(moshi_input_values, num_quantizers=model.num_codebooks)[0]
|
||||
|
||||
outputs_from_audio_codes = model.generate(
|
||||
input_ids, max_new_tokens=5, user_audio_codes=user_audio_codes, moshi_audio_codes=moshi_audio_codes
|
||||
)
|
||||
|
||||
outputs_from_audio_values = model.generate(
|
||||
input_ids, max_new_tokens=5, user_input_values=user_input_values, moshi_input_values=moshi_input_values
|
||||
)
|
||||
|
||||
self.assertTrue((outputs_from_audio_values.sequences == outputs_from_audio_codes.sequences).all())
|
||||
self.assertTrue(
|
||||
torch.allclose(outputs_from_audio_codes.audio_sequences, outputs_from_audio_values.audio_sequences)
|
||||
)
|
||||
|
||||
def test_generate_depth_decoder_kwargs(self):
|
||||
# test sampling and beam search
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, _, input_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
model.generate(input_ids, max_new_tokens=5, **input_dict, depth_decoder_do_sample=True)
|
||||
|
||||
model.generate(
|
||||
input_ids, max_new_tokens=5, **input_dict, depth_decoder_do_sample=True, depth_decoder_num_beams=5
|
||||
)
|
||||
|
||||
def test_generate_from_unconditional(self):
|
||||
# test sampling and beam search
|
||||
for model_class in self.all_generative_model_classes:
|
||||
config, input_ids, _, input_dict = self._get_input_ids_and_config()
|
||||
|
||||
model = model_class(config).to(torch_device).eval()
|
||||
|
||||
# check bs>1
|
||||
model.generate(
|
||||
**model.get_unconditional_inputs(num_samples=4), max_new_tokens=5, concat_unconditional_inputs=False
|
||||
)
|
||||
|
||||
# check same results from unconditional or no inputs
|
||||
outputs_from_unconditional = model.generate(
|
||||
**model.get_unconditional_inputs(num_samples=1), max_new_tokens=5, concat_unconditional_inputs=False
|
||||
)
|
||||
outputs_from_none = model.generate(max_new_tokens=5)
|
||||
|
||||
self.assertTrue((outputs_from_unconditional.sequences == outputs_from_none.sequences).all())
|
||||
self.assertTrue(
|
||||
torch.allclose(outputs_from_unconditional.audio_sequences, outputs_from_none.audio_sequences)
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Compile not yet supported because in Moshi models")
|
||||
def test_sdpa_can_dispatch_on_flash(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
|
||||
def test_cpu_offload(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
|
||||
def test_disk_offload_bin(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Some undefined behavior encountered with test versions of this model. Skip for now.")
|
||||
def test_disk_offload_safetensors(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities")
|
||||
def test_generate_continue_from_inputs_embeds(self):
|
||||
pass
|
||||
|
||||
@is_flaky(max_attempts=5, description="flaky on some models.")
|
||||
def test_save_load(self):
|
||||
super().test_save_load()
|
||||
|
||||
@pytest.mark.generate
|
||||
@unittest.skip(reason="Moshi requires setting `model.generated_audio_codes` in generate() before preparing inputs")
|
||||
def test_prepare_inputs_for_generation_kwargs_forwards(self):
|
||||
# If in the future `model.generated_audio_codes` is not required, this test can be re-enabled
|
||||
super().test_prepare_inputs_for_generation_kwargs_forwards(
|
||||
last_hidden_state=torch.randn(2, 3, 32), kwargs_depth_decoder={}
|
||||
)
|
||||
|
||||
@unittest.skip(reason="Moshi has no separate base model without a head.")
|
||||
def test_model_base_model_prefix(self):
|
||||
pass
|
||||
|
||||
|
||||
def place_dict_on_device(dict_to_place, device):
|
||||
for key in dict_to_place:
|
||||
if dict_to_place[key] is not None and isinstance(dict_to_place[key], torch.Tensor):
|
||||
dict_to_place[key] = dict_to_place[key].to(device)
|
||||
return dict_to_place
|
||||
|
||||
|
||||
@require_torch
|
||||
class MoshiIntegrationTests(unittest.TestCase):
|
||||
@cached_property
|
||||
def feature_extractor(self):
|
||||
return AutoFeatureExtractor.from_pretrained("kmhf/hf-moshiko")
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self):
|
||||
return AutoTokenizer.from_pretrained("kmhf/hf-moshiko")
|
||||
|
||||
def _load_datasample(self):
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
dataset = ds.cast_column("audio", Audio(sampling_rate=self.feature_extractor.sampling_rate))
|
||||
# automatic decoding with librispeech
|
||||
speech_sample = dataset.sort("id")[0]["audio"]["array"]
|
||||
return speech_sample
|
||||
|
||||
@slow
|
||||
def test_moshika_conditional_greedy(self):
|
||||
model = MoshiForConditionalGeneration.from_pretrained(
|
||||
"kmhf/hf-moshika", dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
inputs = self.feature_extractor(self._load_datasample(), return_tensors="pt").to(
|
||||
device=torch_device, dtype=torch.float16
|
||||
)
|
||||
|
||||
user_audio_codes = model.audio_encoder.encode(**inputs, num_quantizers=8).audio_codes
|
||||
|
||||
input_ids = self.tokenizer.encode("<pad><pad><pad><pad><unk> Hello,<pad><unk>", return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
moshi_audio_codes = [[[1049, 127, 1880, 972, 972, 1156, 1913, 415, 1933],
|
||||
[1700, 243, 91, 91, 91, 745, 1478, 638, 57],
|
||||
[1626, 457, 457, 457, 457, 1839, 200, 2011, 1142],
|
||||
[546, 290, 390, 390, 290, 1408, 1812, 1187, 1911],
|
||||
[306, 306, 1314, 1314, 1314, 759, 796, 854, 1466],
|
||||
[1443, 1443, 1030, 317, 347, 1178, 613, 1576, 2023],
|
||||
[1871, 428, 1433, 1433, 1978, 1405, 1755, 820, 610],
|
||||
[2008, 1744, 1511, 568, 1533, 550, 237, 1412, 1401]]]
|
||||
# fmt: on
|
||||
|
||||
moshi_audio_codes = torch.tensor(moshi_audio_codes, device=torch_device)
|
||||
user_audio_codes = user_audio_codes[:, :, : moshi_audio_codes.shape[-1]]
|
||||
|
||||
model_outputs = model.generate(
|
||||
user_audio_codes=user_audio_codes,
|
||||
moshi_audio_codes=moshi_audio_codes,
|
||||
input_ids=input_ids,
|
||||
do_sample=False,
|
||||
depth_decoder_do_sample=False,
|
||||
return_audio_codes=True,
|
||||
max_new_tokens=2,
|
||||
)
|
||||
|
||||
expected_text_token = 452
|
||||
expected_audio_tokens = [916, 1396, 1238, 579, 1105, 914, 1257, 810] # fmt: skip
|
||||
|
||||
self.assertTrue(expected_text_token == model_outputs.sequences[0, -2].item())
|
||||
self.assertTrue(expected_audio_tokens == model_outputs.audio_codes[0, :, -1].tolist())
|
||||
|
||||
@slow
|
||||
def test_moshiko_greedy_unconditional_fp16_eager(self):
|
||||
model = MoshiForConditionalGeneration.from_pretrained(
|
||||
"kmhf/hf-moshiko", dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 1443], [1871, 428], [2008, 1744]] # fmt: skip
|
||||
|
||||
model_outputs = model.generate(
|
||||
do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10
|
||||
)
|
||||
|
||||
# eager equivalence is not as strict as sdpa.
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())
|
||||
|
||||
@slow
|
||||
def test_moshiko_greedy_unconditional_fp32(self):
|
||||
model = MoshiForConditionalGeneration.from_pretrained(
|
||||
"kmhf/hf-moshiko", dtype=torch.float32, device_map="auto"
|
||||
)
|
||||
|
||||
expected_audio_codesum = 72065
|
||||
expected_text_tokens = [3, 3, 3, 0, 11725, 261, 3, 3, 3, 3] # fmt: skip
|
||||
some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 1443], [1871, 428], [2008, 1744]] # fmt: skip
|
||||
|
||||
model_outputs = model.generate(
|
||||
do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10
|
||||
)
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
audio_code_sums = model_outputs.audio_codes.sum().item()
|
||||
self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= (3e-3 * audio_code_sums))
|
||||
|
||||
self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())
|
||||
|
||||
@slow
|
||||
@require_torch_fp16
|
||||
def test_moshiko_greedy_unconditional_fp16(self):
|
||||
model = MoshiForConditionalGeneration.from_pretrained(
|
||||
"kmhf/hf-moshiko", dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
|
||||
expected_audio_codesum = 72065
|
||||
expected_text_tokens = [3, 3, 3, 0, 11725, 261, 3, 3, 3, 3] # fmt: skip
|
||||
some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 1443], [1871, 428], [2008, 1744]] # fmt: skip
|
||||
|
||||
model_outputs = model.generate(
|
||||
do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10
|
||||
)
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
audio_code_sums = model_outputs.audio_codes.sum().item()
|
||||
self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= (3e-3 * audio_code_sums))
|
||||
|
||||
self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())
|
||||
|
||||
@slow
|
||||
@require_torch_fp16
|
||||
def test_moshika_greedy_unconditional_fp16(self):
|
||||
model = MoshiForConditionalGeneration.from_pretrained(
|
||||
"kmhf/hf-moshika", dtype=torch.float16, device_map="auto"
|
||||
)
|
||||
|
||||
expected_audio_codesum = 72932
|
||||
expected_text_tokens = [3, 3, 3, 0, 667, 263, 3, 3, 0, 705] # fmt: skip
|
||||
some_expected_audio_tokens = [[1049, 127], [1700, 243], [1626, 457], [546, 290], [306, 306], [1443, 347], [1871, 428], [2008, 2008]] # fmt: skip
|
||||
|
||||
model_outputs = model.generate(
|
||||
do_sample=False, depth_decoder_do_sample=False, return_audio_codes=True, max_new_tokens=10
|
||||
)
|
||||
|
||||
# make sure audio encoded codes are correct
|
||||
audio_code_sums = model_outputs.audio_codes.sum().item()
|
||||
self.assertTrue(np.abs(audio_code_sums - expected_audio_codesum) <= 2048)
|
||||
|
||||
self.assertTrue(expected_text_tokens == model_outputs.sequences[0, 1:].tolist())
|
||||
self.assertTrue(some_expected_audio_tokens == model_outputs.audio_codes[0, :, :2].tolist())
|
||||
431
tests/models/moshi/test_tokenization_moshi.py
Normal file
431
tests/models/moshi/test_tokenization_moshi.py
Normal file
@@ -0,0 +1,431 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
from transformers import (
|
||||
SPIECE_UNDERLINE,
|
||||
AddedToken,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.convert_slow_tokenizer import MoshiConverter
|
||||
from transformers.testing_utils import (
|
||||
get_tests_dir,
|
||||
nested_simplify,
|
||||
require_sentencepiece,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
)
|
||||
|
||||
from ...test_tokenization_common import SMALL_TRAINING_CORPUS, TokenizerTesterMixin
|
||||
|
||||
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class MoshiTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
from_pretrained_id = ["kmhf/hf-moshiko"]
|
||||
tokenizer_class = PreTrainedTokenizerFast
|
||||
|
||||
test_slow_tokenizer = False
|
||||
test_rust_tokenizer = True
|
||||
from_pretrained_kwargs = {}
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
# We have a SentencePiece fixture for testing
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=MoshiConverter(vocab_file=SAMPLE_VOCAB).converted(),
|
||||
bos_token="<s>",
|
||||
unk_token="<unk>",
|
||||
eos_token="</s>",
|
||||
)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.save_pretrained(cls.tmpdirname)
|
||||
|
||||
def get_rust_tokenizer(cls, pretrained_name=None, **kwargs) -> PreTrainedTokenizerFast:
|
||||
pretrained_name = pretrained_name or cls.tmpdirname
|
||||
return cls.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
@unittest.skip(reason="No slow tokenizer")
|
||||
def test_added_tokens_serialization(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="PreTrainedTokenizerFast doesn't have tokenizer_file in its signature")
|
||||
def test_rust_tokenizer_signature(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(reason="No slow tokenizer")
|
||||
def test_encode_decode_with_spaces(self):
|
||||
pass
|
||||
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = PreTrainedTokenizerFast(
|
||||
tokenizer_object=MoshiConverter(vocab_file=SAMPLE_VOCAB).converted(),
|
||||
bos_token="<s>",
|
||||
unk_token="<unk>",
|
||||
eos_token="</s>",
|
||||
)
|
||||
|
||||
tokens = tokenizer.tokenize("This is a test")
|
||||
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(tokens),
|
||||
[285, 46, 10, 170, 382],
|
||||
)
|
||||
|
||||
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
|
||||
self.assertListEqual(
|
||||
tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"9",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"é",
|
||||
".",
|
||||
],
|
||||
)
|
||||
ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
self.assertListEqual(
|
||||
ids,
|
||||
[8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4],
|
||||
)
|
||||
|
||||
back_tokens = tokenizer.convert_ids_to_tokens(ids)
|
||||
self.assertListEqual(
|
||||
back_tokens,
|
||||
[
|
||||
SPIECE_UNDERLINE + "I",
|
||||
SPIECE_UNDERLINE + "was",
|
||||
SPIECE_UNDERLINE + "b",
|
||||
"or",
|
||||
"n",
|
||||
SPIECE_UNDERLINE + "in",
|
||||
SPIECE_UNDERLINE + "",
|
||||
"<unk>",
|
||||
"2",
|
||||
"0",
|
||||
"0",
|
||||
"0",
|
||||
",",
|
||||
SPIECE_UNDERLINE + "and",
|
||||
SPIECE_UNDERLINE + "this",
|
||||
SPIECE_UNDERLINE + "is",
|
||||
SPIECE_UNDERLINE + "f",
|
||||
"al",
|
||||
"s",
|
||||
"<unk>",
|
||||
".",
|
||||
],
|
||||
)
|
||||
|
||||
def test_special_tokens_initialization(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
added_tokens = [AddedToken("<special>", lstrip=True)]
|
||||
|
||||
tokenizer_r = self.tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||
)
|
||||
r_output = tokenizer_r.encode("Hey this is a <special> token")
|
||||
|
||||
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
|
||||
|
||||
self.assertTrue(special_token_id in r_output)
|
||||
|
||||
def test_training_new_tokenizer(self):
|
||||
# This feature only exists for fast tokenizers
|
||||
if not self.test_rust_tokenizer:
|
||||
self.skipTest(reason="test_rust_tokenizer is set to False")
|
||||
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100)
|
||||
|
||||
# Test we can use the new tokenizer with something not seen during training
|
||||
inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."])
|
||||
self.assertEqual(len(inputs["input_ids"]), 2)
|
||||
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||||
expected_result = "This is the first sentence"
|
||||
|
||||
self.assertEqual(expected_result, decoded_input)
|
||||
|
||||
# We check that the parameters of the tokenizer remained the same
|
||||
# Check we have the same number of added_tokens for both pair and non-pair inputs.
|
||||
self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False))
|
||||
self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True))
|
||||
|
||||
# Check we have the correct max_length for both pair and non-pair inputs.
|
||||
self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence)
|
||||
self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair)
|
||||
|
||||
# Assert the set of special tokens match as we didn't ask to change them
|
||||
self.assertSequenceEqual(
|
||||
tokenizer.all_special_tokens,
|
||||
new_tokenizer.all_special_tokens,
|
||||
)
|
||||
|
||||
self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map)
|
||||
|
||||
def test_training_new_tokenizer_with_special_tokens_change(self):
|
||||
# This feature only exists for fast tokenizers
|
||||
if not self.test_rust_tokenizer:
|
||||
self.skipTest(reason="test_rust_tokenizer is set to False")
|
||||
|
||||
tokenizer = self.get_rust_tokenizer()
|
||||
# Test with a special tokens map
|
||||
class_signature = inspect.signature(tokenizer.__class__)
|
||||
if "cls_token" in class_signature.parameters:
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: "<cls>"}
|
||||
)
|
||||
cls_id = new_tokenizer.get_vocab()["<cls>"]
|
||||
self.assertEqual(new_tokenizer.cls_token, "<cls>")
|
||||
self.assertEqual(new_tokenizer.cls_token_id, cls_id)
|
||||
|
||||
# Create a new mapping from the special tokens defined in the original tokenizer
|
||||
special_tokens_list = []
|
||||
special_tokens_map = {}
|
||||
for token in special_tokens_list:
|
||||
# Get the private one to avoid unnecessary warnings.
|
||||
if getattr(tokenizer, token) is not None:
|
||||
special_token = getattr(tokenizer, token)
|
||||
special_tokens_map[special_token] = f"{special_token}a"
|
||||
|
||||
# Train new tokenizer
|
||||
new_tokenizer = tokenizer.train_new_from_iterator(
|
||||
SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map
|
||||
)
|
||||
|
||||
# Check the changes
|
||||
for token in special_tokens_list:
|
||||
# Get the private one to avoid unnecessary warnings.
|
||||
if getattr(tokenizer, token) is None:
|
||||
continue
|
||||
special_token = getattr(tokenizer, token)
|
||||
if special_token in special_tokens_map:
|
||||
new_special_token = getattr(new_tokenizer, token)
|
||||
self.assertEqual(special_tokens_map[special_token], new_special_token)
|
||||
|
||||
new_id = new_tokenizer.get_vocab()[new_special_token]
|
||||
self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id)
|
||||
|
||||
# Check if the AddedToken / string format has been kept
|
||||
for special_token in tokenizer.all_special_tokens:
|
||||
if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map:
|
||||
# The special token must appear identically in the list of the new tokenizer.
|
||||
self.assertTrue(
|
||||
special_token in new_tokenizer.all_special_tokens,
|
||||
f"'{special_token}' should be in {new_tokenizer.all_special_tokens}",
|
||||
)
|
||||
elif isinstance(special_token, AddedToken):
|
||||
# The special token must appear in the list of the new tokenizer as an object of type AddedToken with
|
||||
# the same parameters as the old AddedToken except the content that the user has requested to change.
|
||||
special_token_str = special_token.content
|
||||
new_special_token_str = special_tokens_map[special_token_str]
|
||||
|
||||
find = False
|
||||
for candidate in new_tokenizer.all_special_tokens:
|
||||
if (
|
||||
isinstance(candidate, AddedToken)
|
||||
and candidate.content == new_special_token_str
|
||||
and candidate.lstrip == special_token.lstrip
|
||||
and candidate.rstrip == special_token.rstrip
|
||||
and candidate.normalized == special_token.normalized
|
||||
and candidate.single_word == special_token.single_word
|
||||
):
|
||||
find = True
|
||||
break
|
||||
special_token.content = new_special_token_str
|
||||
self.assertTrue(
|
||||
find,
|
||||
f"'{special_token.__repr__()}' should appear as an `AddedToken` in the all_special_tokens = "
|
||||
f"{[k for k in new_tokenizer.all_special_tokens if str(k) == new_special_token_str]} but it is missing"
|
||||
", this means that the new tokenizers did not keep the `rstrip`, `lstrip`, `normalized` etc attributes.",
|
||||
)
|
||||
elif special_token not in special_tokens_map:
|
||||
# The special token must appear identically in the list of the new tokenizer.
|
||||
self.assertTrue(
|
||||
special_token in new_tokenizer.all_special_tokens,
|
||||
f"'{special_token.__repr__()}' should be in {new_tokenizer.all_special_tokens}",
|
||||
)
|
||||
|
||||
else:
|
||||
# The special token must appear in the list of the new tokenizer as an object of type string.
|
||||
self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens)
|
||||
|
||||
# Test we can use the new tokenizer with something not seen during training
|
||||
inputs = new_tokenizer(["This is the first sentence", "This sentence is different 🤗."])
|
||||
self.assertEqual(len(inputs["input_ids"]), 2)
|
||||
decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
||||
expected_result = "This is the first sentence"
|
||||
|
||||
self.assertEqual(expected_result, decoded_input)
|
||||
|
||||
def test_alignment_methods(self):
|
||||
# TODO: @ArthurZucker - alignment is broken
|
||||
pass
|
||||
|
||||
def test_added_tokens_do_lower_case(self):
|
||||
# TODO: @ArthurZucker
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class MoshiIntegrationTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
checkpoint_name = "kmhf/hf-moshiko"
|
||||
cls.rust_tokenizer = AutoTokenizer.from_pretrained(checkpoint_name)
|
||||
return cls
|
||||
|
||||
@require_torch
|
||||
def integration_tests(self):
|
||||
inputs = self.tokenizer(
|
||||
["The following string should be properly encoded: Hello.", "But ird and ปี ird ด"],
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
long_attention_mask = [1] * 21
|
||||
|
||||
# fmt: off
|
||||
self.assertEqual(
|
||||
nested_simplify(inputs),
|
||||
{
|
||||
"input_ids": [
|
||||
[287, 547, 2359, 457, 297, 3708, 11488, 279, 11725, 263],
|
||||
[588, 478, 1442, 267, 260, 228, 188, 159, 228, 188, 185, 260, 260, 478, 1442, 260, 260, 260, 228, 188, 152],
|
||||
],
|
||||
"attention_mask": [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], long_attention_mask],
|
||||
},
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
def test_fast_special_tokens(self):
|
||||
fast_tokenizer = self.rust_tokenizer
|
||||
|
||||
fast_tokenizer.add_eos_token = False
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [318, 1145, 694]
|
||||
|
||||
fast_tokenizer.add_eos_token = True
|
||||
fast = fast_tokenizer.encode("A sample test", add_special_tokens=True)
|
||||
assert fast == [318, 1145, 694]
|
||||
|
||||
self.rust_tokenizer.add_eos_token = False
|
||||
|
||||
def test_simple_encode_decode(self):
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
|
||||
self.assertEqual(rust_tokenizer.encode("This is a test"), [353, 275, 272, 694])
|
||||
self.assertEqual(rust_tokenizer.decode([353, 275, 272, 694], skip_special_tokens=True), "This is a test")
|
||||
|
||||
# bytefallback showcase
|
||||
bytefallback_tokens = [260, 235, 152, 163, 234, 184, 191, 13340, 235, 160, 163, 236, 180, 159, 234, 156, 179] # fmt: skip
|
||||
self.assertEqual(rust_tokenizer.encode("生活的真谛是"), bytefallback_tokens)
|
||||
self.assertEqual(
|
||||
rust_tokenizer.decode(bytefallback_tokens, skip_special_tokens=True),
|
||||
"生活的真谛是",
|
||||
)
|
||||
|
||||
# Inner spaces showcase
|
||||
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [2769, 260, 11725])
|
||||
self.assertEqual(rust_tokenizer.decode([2769, 260, 11725], skip_special_tokens=True), "Hi Hello")
|
||||
|
||||
self.assertEqual(rust_tokenizer.encode("Hi Hello"), [2769, 260, 260, 11725])
|
||||
self.assertEqual(rust_tokenizer.decode([2769, 260, 260, 11725], skip_special_tokens=True), "Hi Hello")
|
||||
|
||||
# TODO: @ArthurZucker
|
||||
# self.assertEqual(rust_tokenizer.encode(""), [])
|
||||
|
||||
# self.assertEqual(rust_tokenizer.encode(" "), [260, 260])
|
||||
|
||||
# self.assertEqual(rust_tokenizer.encode(" "), [260, 260, 260])
|
||||
|
||||
# self.assertEqual(rust_tokenizer.encode(" Hello"), [260, 11725])
|
||||
|
||||
# self.assertEqual(rust_tokenizer.encode("<s>"), [607, 266, 578])
|
||||
|
||||
def test_no_differences_decode(self):
|
||||
rust_tokenizer = self.rust_tokenizer
|
||||
|
||||
self.assertEqual(rust_tokenizer.decode([869]), "levels")
|
||||
|
||||
self.assertEqual(rust_tokenizer.decode([30112, 869]), "unanswered levels")
|
||||
|
||||
|
||||
@require_sentencepiece
|
||||
@require_tokenizers
|
||||
class CommonSpmIntegrationTests(unittest.TestCase):
|
||||
"""
|
||||
A class that regroups important test to make sure that we properly handle the special tokens.
|
||||
"""
|
||||
|
||||
def test_edge_case_tabulation(self):
|
||||
fast_tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko")
|
||||
input_text = "Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61"
|
||||
EXPECTED_IDS = [11510, 934, 4451, 266, 578, 263, 260, 13, 13, 260, 14, 14, 5209, 260, 260, 1202, 260, 527, 1322, 244, 163, 156, 140, 260, 260, 244, 163, 168, 155, 430, 1047, 261, 260, 265, 270, 278, 281, 260, 265, 280, 260, 280, 261, 285, 265] # fmt: skip
|
||||
EXPECTED_TOKENS = ['▁Hey', '<', 'eo', 's', '>', '.', '▁', '<0x09>', '<0x09>', '▁', '<0x0A>', '<0x0A>', 'you', '▁', '▁', 'é', '▁', '▁@', '#', '<0xF0>', '<0x9F>', '<0x98>', '<0x88>', '▁', '▁', '<0xF0>', '<0x9F>', '<0xA4>', '<0x97>', '!', '▁▁▁▁▁▁▁', ',', '▁', '1', '2', '3', '4', '▁', '1', '5', '▁', '5', ',', '6', '1'] # fmt: skip
|
||||
|
||||
tokens = fast_tokenizer.tokenize(input_text)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(tokens, EXPECTED_TOKENS)
|
||||
|
||||
input_ids = fast_tokenizer.encode(input_text)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(input_ids, EXPECTED_IDS)
|
||||
|
||||
text = fast_tokenizer.decode(EXPECTED_IDS)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(text, "Hey<eos>. \t\t \n\nyou é @#😈 🤗! , 1234 15 5,61")
|
||||
|
||||
input_text = "\t\t\t\t \n\n61"
|
||||
EXPECTED_IDS = [260, 13, 13, 13, 13, 260, 14, 14, 285, 265]
|
||||
EXPECTED_TOKENS = ["▁", "<0x09>", "<0x09>", "<0x09>", "<0x09>", "▁", "<0x0A>", "<0x0A>", "6", "1"]
|
||||
|
||||
tokens = fast_tokenizer.tokenize(input_text)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(tokens, EXPECTED_TOKENS)
|
||||
|
||||
input_ids = fast_tokenizer.encode(input_text)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(input_ids, EXPECTED_IDS)
|
||||
|
||||
text = fast_tokenizer.decode(EXPECTED_IDS)
|
||||
with self.subTest("test fast edge case fast"):
|
||||
self.assertEqual(text, "\t\t\t\t \n\n61")
|
||||
Reference in New Issue
Block a user