first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
418
tests/test_training_mixin.py
Normal file
418
tests/test_training_mixin.py
Normal file
@@ -0,0 +1,418 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Training overfit tester mixin for model tests."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.testing_utils import Colors, build_cpu_memory_monitor, init_test_logger, is_training_test
|
||||
|
||||
|
||||
logger = logging.getLogger("transformers.training_test")
|
||||
|
||||
|
||||
class TrainingTesterMixin(ABC):
|
||||
"""
|
||||
Mixin for training overfit tests. Add to model test classes alongside ModelTesterMixin.
|
||||
|
||||
The model_tester (e.g., CausalLMModelTester) already provides:
|
||||
- get_config() -> tiny model config
|
||||
- prepare_config_and_inputs_for_common() -> config + input dict
|
||||
- causal_lm_class, base_model_class, etc.
|
||||
|
||||
This mixin adds training-specific tests using that infrastructure.
|
||||
"""
|
||||
|
||||
# ============================================================
|
||||
# Training hyperparameters
|
||||
# ============================================================
|
||||
training_overfit_steps: int = 300
|
||||
training_overfit_batch_size: int = 2
|
||||
training_overfit_learning_rate: float = 1e-3
|
||||
training_overfit_seq_length: int = 64
|
||||
training_overfit_log_freq: int = 10
|
||||
|
||||
# Loss reduction and grad norm reduction thresholds for passing the test (i.e 95% reduction)
|
||||
training_loss_reduction_threshold: float = 0.9
|
||||
training_grad_norm_reduction_threshold: float = 0.9
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def model_tester(self):
|
||||
"""The model tester instance (e.g., CausalLMModelTester)."""
|
||||
...
|
||||
|
||||
# ============================================================
|
||||
# Modality detection
|
||||
# ============================================================
|
||||
def _get_model_modality(self) -> str:
|
||||
"""Detect the modality of the model based on its input signature."""
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if "input_ids" in inputs_dict:
|
||||
return "text"
|
||||
elif "pixel_values" in inputs_dict:
|
||||
return "image"
|
||||
elif "input_features" in inputs_dict or "input_values" in inputs_dict:
|
||||
return "audio"
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {inputs_dict}")
|
||||
|
||||
# ============================================================
|
||||
# Training data creation for each modality
|
||||
# ============================================================
|
||||
def _create_text_training_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_length: int,
|
||||
vocab_size: int,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Create a simple text batch without needing a tokenizer."""
|
||||
# Create a deterministic sequence (not random, so model can learn it)
|
||||
pattern = list(range(1, min(20, vocab_size))) # tokens 1-19
|
||||
num_repeats = (seq_length // len(pattern)) + 1
|
||||
tokens = (pattern * num_repeats)[:seq_length]
|
||||
input_ids = torch.tensor([tokens] * batch_size, dtype=torch.long)
|
||||
return {"input_ids": input_ids, "labels": input_ids.clone()}
|
||||
|
||||
def _create_image_training_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels: int,
|
||||
height: int,
|
||||
width: int,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Create fixed batch for image models using a deterministic pattern."""
|
||||
pass
|
||||
|
||||
def _create_audio_training_batch(
|
||||
self,
|
||||
batch_size: int,
|
||||
audio_length: int,
|
||||
feature_size: int | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Create fixed batch for audio models using a deterministic waveform."""
|
||||
pass
|
||||
|
||||
def _decode_text_tokens(self, tokens: list[int], max_display: int = 40) -> str:
|
||||
"""Decode tokens to readable string (maps token IDs to letters: 1->a, 2->b, etc.)."""
|
||||
decoded = "".join(chr(ord("a") + (t - 1) % 26) for t in tokens)
|
||||
if len(decoded) > max_display:
|
||||
return f"'{decoded[:max_display]}...'"
|
||||
return f"'{decoded}'"
|
||||
|
||||
def _get_trainable_model_class(self):
|
||||
"""Get the model class to use for training (prefers *ForCausalLM, *ForSequenceClassification, etc.)."""
|
||||
# Prefer model classes with a head (for computing loss)
|
||||
if hasattr(self.model_tester, "causal_lm_class") and self.model_tester.causal_lm_class is not None:
|
||||
return self.model_tester.causal_lm_class
|
||||
if (
|
||||
hasattr(self.model_tester, "sequence_classification_class")
|
||||
and self.model_tester.sequence_classification_class is not None
|
||||
):
|
||||
return self.model_tester.sequence_classification_class
|
||||
# Fall back to first model class
|
||||
return self.all_model_classes[0]
|
||||
|
||||
@is_training_test
|
||||
def test_training_overfit(self):
|
||||
"""Test that a tiny model can overfit on a fixed batch."""
|
||||
# Initialize logging and memory monitoring
|
||||
init_test_logger()
|
||||
memory_monitor = build_cpu_memory_monitor(logger)
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"Starting test: {self._testMethodName}")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Skip if model doesn't support training
|
||||
if not getattr(self.model_tester, "is_training", True):
|
||||
logger.info(f"{Colors.YELLOW}Skipping: Model tester not configured for training tests{Colors.RESET}")
|
||||
self.skipTest("Model tester not configured for training tests")
|
||||
|
||||
# Configuration
|
||||
logger.info(f"{Colors.BOLD}Job Configuration:{Colors.RESET}")
|
||||
logger.info(f" {Colors.CYAN}total_steps:{Colors.RESET} {self.training_overfit_steps}")
|
||||
logger.info(f" {Colors.CYAN}batch_size:{Colors.RESET} {self.training_overfit_batch_size}")
|
||||
logger.info(f" {Colors.CYAN}learning_rate:{Colors.RESET} {self.training_overfit_learning_rate}")
|
||||
logger.info(f" {Colors.CYAN}seq_length:{Colors.RESET} {self.training_overfit_seq_length}")
|
||||
logger.info(f" {Colors.CYAN}log_freq:{Colors.RESET} {self.training_overfit_log_freq}")
|
||||
logger.info(f" {Colors.CYAN}device:{Colors.RESET} cpu")
|
||||
|
||||
set_seed(42)
|
||||
|
||||
logger.info("-" * 70)
|
||||
logger.info(f"{Colors.BOLD}Building model{Colors.RESET}")
|
||||
load_start = time.perf_counter()
|
||||
|
||||
# Get tiny config from existing infrastructure
|
||||
config = self.model_tester.get_config()
|
||||
|
||||
model_class = self._get_trainable_model_class()
|
||||
model = model_class(config)
|
||||
model.train()
|
||||
|
||||
load_time = time.perf_counter() - load_start
|
||||
logger.info(f"Model loaded in {Colors.GREEN}{load_time:.3f}s{Colors.RESET}")
|
||||
|
||||
# Log model architecture
|
||||
# TODO(3outeille): make sure if there is other parameters to log
|
||||
logger.info(f"{Colors.BOLD}Model Architecture:{Colors.RESET}")
|
||||
logger.info(f" {Colors.CYAN}model_class:{Colors.RESET} {model_class.__name__}")
|
||||
if hasattr(config, "hidden_size"):
|
||||
logger.info(f" {Colors.CYAN}hidden_size:{Colors.RESET} {config.hidden_size}")
|
||||
if hasattr(config, "num_hidden_layers"):
|
||||
logger.info(f" {Colors.CYAN}num_hidden_layers:{Colors.RESET} {config.num_hidden_layers}")
|
||||
if hasattr(config, "num_attention_heads"):
|
||||
logger.info(f" {Colors.CYAN}num_attention_heads:{Colors.RESET} {config.num_attention_heads}")
|
||||
if hasattr(config, "num_key_value_heads"):
|
||||
logger.info(f" {Colors.CYAN}num_key_value_heads:{Colors.RESET} {config.num_key_value_heads}")
|
||||
if hasattr(config, "intermediate_size"):
|
||||
logger.info(f" {Colors.CYAN}intermediate_size:{Colors.RESET} {config.intermediate_size}")
|
||||
if hasattr(config, "vocab_size"):
|
||||
logger.info(f" {Colors.CYAN}vocab_size:{Colors.RESET} {config.vocab_size}")
|
||||
if hasattr(config, "num_experts"):
|
||||
logger.info(f" {Colors.CYAN}num_experts:{Colors.RESET} {config.num_experts}")
|
||||
if hasattr(config, "num_experts_per_tok"):
|
||||
logger.info(f" {Colors.CYAN}num_experts_per_tok:{Colors.RESET} {config.num_experts_per_tok}")
|
||||
|
||||
# Count parameters
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
logger.info(
|
||||
f"{Colors.CYAN}Model size:{Colors.RESET} {Colors.BRIGHT_GREEN}{total_params:,}{Colors.RESET} total parameters"
|
||||
)
|
||||
logger.info(
|
||||
f"{Colors.CYAN}Trainable parameters:{Colors.RESET} {Colors.BRIGHT_GREEN}{trainable_params:,}{Colors.RESET}"
|
||||
)
|
||||
|
||||
# Memory after model load
|
||||
mem_stats = memory_monitor.get_stats()
|
||||
logger.info(
|
||||
f"{Colors.MAGENTA}Memory after model load:{Colors.RESET} {mem_stats.rss_gib:.2f} GiB ({mem_stats.rss_pct:.1f}%)"
|
||||
)
|
||||
|
||||
logger.info("-" * 70)
|
||||
logger.info(f"{Colors.BOLD}Creating fixed batch{Colors.RESET}")
|
||||
|
||||
modality = self._get_model_modality()
|
||||
logger.info(f"{Colors.CYAN}Detected modality:{Colors.RESET} {modality}")
|
||||
_, sample_inputs = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
if modality == "text":
|
||||
# For text models, we need a tokenizer - use a simple one or create fake tokens
|
||||
batch = self._create_text_training_batch(
|
||||
batch_size=self.training_overfit_batch_size,
|
||||
seq_length=self.training_overfit_seq_length,
|
||||
vocab_size=config.vocab_size,
|
||||
)
|
||||
logger.info(f"{Colors.CYAN}Training pattern:{Colors.RESET} Repeating token sequence (1-19)")
|
||||
else:
|
||||
raise ValueError(f"Modality {modality} not supported yet for training overfit")
|
||||
|
||||
tokens_per_batch = self.training_overfit_batch_size * self.training_overfit_seq_length
|
||||
logger.info(f" {Colors.CYAN}batch_size:{Colors.RESET} {self.training_overfit_batch_size}")
|
||||
logger.info(f" {Colors.CYAN}seq_length:{Colors.RESET} {self.training_overfit_seq_length}")
|
||||
logger.info(f" {Colors.CYAN}tokens_per_batch:{Colors.RESET} {tokens_per_batch:,}")
|
||||
logger.info(f"{Colors.DIM}Using same fixed batch every step (deterministic overfitting){Colors.RESET}")
|
||||
|
||||
logger.info("-" * 70)
|
||||
logger.info(f"{Colors.BOLD}Building optimizer{Colors.RESET}")
|
||||
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(), lr=self.training_overfit_learning_rate, weight_decay=0.0, betas=(0.9, 0.999)
|
||||
)
|
||||
logger.info(f"{Colors.CYAN}Optimizer:{Colors.RESET} Adam")
|
||||
logger.info(f" {Colors.CYAN}learning_rate:{Colors.RESET} {self.training_overfit_learning_rate}")
|
||||
logger.info(f" {Colors.CYAN}weight_decay:{Colors.RESET} 0.0")
|
||||
logger.info(f" {Colors.CYAN}betas:{Colors.RESET} (0.9, 0.999)")
|
||||
|
||||
# Training Loop
|
||||
logger.info("-" * 70)
|
||||
logger.info("Training starts at step 1")
|
||||
|
||||
initial_loss = None
|
||||
final_loss = None
|
||||
initial_grad_norm = None
|
||||
final_grad_norm = None
|
||||
training_start = time.perf_counter()
|
||||
memory_monitor.reset_peak_stats()
|
||||
|
||||
for step in range(1, self.training_overfit_steps + 1):
|
||||
step_start = time.perf_counter()
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
if initial_loss is None:
|
||||
initial_loss = loss.item()
|
||||
final_loss = loss.item()
|
||||
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
|
||||
if initial_grad_norm is None:
|
||||
initial_grad_norm = grad_norm.item()
|
||||
final_grad_norm = grad_norm.item()
|
||||
|
||||
optimizer.step()
|
||||
|
||||
step_time = time.perf_counter() - step_start
|
||||
|
||||
# Log at frequency
|
||||
if step == 1 or step % self.training_overfit_log_freq == 0 or step == self.training_overfit_steps:
|
||||
tokens_per_sec = tokens_per_batch / step_time
|
||||
mem_stats = memory_monitor.get_stats()
|
||||
logger.info(
|
||||
f"{Colors.CYAN}step:{Colors.RESET} {step} "
|
||||
f"{Colors.GREEN}loss:{Colors.RESET} {loss.item():7.4f} "
|
||||
f"{Colors.YELLOW}grad_norm:{Colors.RESET} {grad_norm.item():6.4f} "
|
||||
f"{Colors.MAGENTA}memory:{Colors.RESET} {mem_stats.rss_gib:.2f}GiB({mem_stats.rss_pct:.1f}%) "
|
||||
f"{Colors.BLUE}tok/s:{Colors.RESET} {tokens_per_sec:,.0f} "
|
||||
f"{Colors.DIM}step_time:{Colors.RESET} {step_time:.3f}s"
|
||||
)
|
||||
|
||||
training_time = time.perf_counter() - training_start
|
||||
|
||||
# Training Summary
|
||||
total_tokens = self.training_overfit_steps * tokens_per_batch
|
||||
logger.info("-" * 70)
|
||||
logger.info(f"{Colors.BOLD}Training completed{Colors.RESET}")
|
||||
logger.info(f"Total training time: {training_time:.2f}s")
|
||||
logger.info(f"Total steps: {self.training_overfit_steps}")
|
||||
logger.info(f"Total tokens seen: {total_tokens:,}")
|
||||
logger.info(f"Average tokens/sec: {total_tokens / training_time:,.0f}")
|
||||
|
||||
# Memory summary
|
||||
mem_stats = memory_monitor.get_stats()
|
||||
logger.info(f"{Colors.BOLD}Memory usage:{Colors.RESET}")
|
||||
logger.info(
|
||||
f" {Colors.CYAN}current_rss:{Colors.RESET} {mem_stats.rss_gib:.2f} GiB ({mem_stats.rss_pct:.1f}%)"
|
||||
)
|
||||
logger.info(
|
||||
f" {Colors.CYAN}peak_rss:{Colors.RESET} {mem_stats.peak_rss_gib:.2f} GiB ({mem_stats.peak_rss_pct:.1f}%)"
|
||||
)
|
||||
logger.info(
|
||||
f" {Colors.CYAN}available:{Colors.RESET} {mem_stats.available_gib:.2f} GiB / {mem_stats.total_gib:.2f} GiB"
|
||||
)
|
||||
|
||||
# Loss analysis
|
||||
loss_reduction = (initial_loss - final_loss) / initial_loss * 100
|
||||
logger.info(f"{Colors.BOLD}Loss metrics:{Colors.RESET}")
|
||||
logger.info(f" {Colors.CYAN}initial_loss:{Colors.RESET} {initial_loss:.4f}")
|
||||
logger.info(f" {Colors.CYAN}final_loss:{Colors.RESET} {final_loss:.4f}")
|
||||
logger.info(f" {Colors.CYAN}loss_reduction:{Colors.RESET} {loss_reduction:.1f}%")
|
||||
|
||||
# Grad norm analysis
|
||||
grad_norm_reduction = (initial_grad_norm - final_grad_norm) / initial_grad_norm * 100
|
||||
logger.info(f"{Colors.BOLD}Grad norm metrics:{Colors.RESET}")
|
||||
logger.info(f" {Colors.CYAN}initial_grad_norm:{Colors.RESET} {initial_grad_norm:.4f}")
|
||||
logger.info(f" {Colors.CYAN}final_grad_norm:{Colors.RESET} {final_grad_norm:.4f}")
|
||||
logger.info(f" {Colors.CYAN}grad_norm_reduction:{Colors.RESET} {grad_norm_reduction:.1f}%")
|
||||
|
||||
# Generation Test (only for text/causal LM models)
|
||||
# TODO(3outeille): handle audio and generate
|
||||
generation_matches = None
|
||||
if modality == "text" and hasattr(model, "generate"):
|
||||
logger.info("-" * 70)
|
||||
logger.info(f"{Colors.BOLD}Testing generation{Colors.RESET}")
|
||||
|
||||
model.eval()
|
||||
|
||||
# Get the expected token sequence (same pattern used in training)
|
||||
expected_tokens = batch["input_ids"][0].tolist()
|
||||
|
||||
# Use first token as prompt
|
||||
prompt_ids = torch.tensor([[expected_tokens[0]]], dtype=torch.long)
|
||||
num_tokens_to_generate = len(expected_tokens) - 1
|
||||
|
||||
logger.info(f"Prompt: {self._decode_text_tokens([expected_tokens[0]])}")
|
||||
|
||||
model_type = getattr(config, "model_type", "")
|
||||
use_cache = model_type == "recurrent_gemma"
|
||||
if use_cache:
|
||||
logger.info("Only RecurrentGemmaModel is using use_cache=True. Other models run with use_cache=False")
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
prompt_ids,
|
||||
max_new_tokens=num_tokens_to_generate,
|
||||
do_sample=False,
|
||||
pad_token_id=config.pad_token_id if hasattr(config, "pad_token_id") else 0,
|
||||
eos_token_id=0,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
generated_tokens = generated_ids[0].tolist()
|
||||
|
||||
# Compare generated tokens with expected tokens
|
||||
generation_matches = generated_tokens == expected_tokens
|
||||
|
||||
# TODO(3outeille): handle audio and image generation
|
||||
if generation_matches:
|
||||
logger.info(f"Expected: {Colors.GREEN}{self._decode_text_tokens(expected_tokens)}{Colors.RESET}")
|
||||
logger.info(f"Generated: {Colors.GREEN}{self._decode_text_tokens(generated_tokens)}{Colors.RESET}")
|
||||
logger.info(f"{Colors.GREEN}✓ Generation matches training sequence!{Colors.RESET}")
|
||||
else:
|
||||
logger.info(f"Expected: {Colors.GREEN}{self._decode_text_tokens(expected_tokens)}{Colors.RESET}")
|
||||
logger.info(f"Generated: {Colors.RED}{self._decode_text_tokens(generated_tokens)}{Colors.RESET}")
|
||||
# Count matching tokens
|
||||
matches = sum(1 for g, e in zip(generated_tokens, expected_tokens) if g == e)
|
||||
logger.info(
|
||||
f"{Colors.YELLOW}✗ Generation mismatch: {matches}/{len(expected_tokens)} tokens match{Colors.RESET}"
|
||||
)
|
||||
|
||||
# Assertions
|
||||
logger.info("-" * 70)
|
||||
logger.info(f"{Colors.BOLD}Running assertions{Colors.RESET}")
|
||||
|
||||
# Assert loss decreased significantly
|
||||
loss_reduction_ratio = (initial_loss - final_loss) / initial_loss
|
||||
self.assertGreater(
|
||||
loss_reduction_ratio,
|
||||
self.training_loss_reduction_threshold,
|
||||
f"Expected loss to decrease by at least {self.training_loss_reduction_threshold * 100:.0f}%, "
|
||||
f"got {loss_reduction:.1f}%",
|
||||
)
|
||||
logger.info(
|
||||
f"{Colors.GREEN}✓ Loss decreased by more than {self.training_loss_reduction_threshold * 100:.0f}%{Colors.RESET}"
|
||||
)
|
||||
|
||||
# Assert grad_norm decreased significantly
|
||||
grad_norm_reduction_ratio = (initial_grad_norm - final_grad_norm) / initial_grad_norm
|
||||
self.assertGreater(
|
||||
grad_norm_reduction_ratio,
|
||||
self.training_grad_norm_reduction_threshold,
|
||||
f"Expected grad_norm to decrease by at least {self.training_grad_norm_reduction_threshold * 100:.0f}%, "
|
||||
f"got {grad_norm_reduction:.1f}%",
|
||||
)
|
||||
logger.info(
|
||||
f"{Colors.GREEN}✓ Grad norm decreased by more than {self.training_grad_norm_reduction_threshold * 100:.0f}%{Colors.RESET}"
|
||||
)
|
||||
|
||||
# Assert generation matches (if applicable)
|
||||
if generation_matches is not None:
|
||||
self.assertTrue(generation_matches, "Expected model to generate the training sequence after overfitting")
|
||||
logger.info(f"{Colors.GREEN}✓ Generated sequence matches training sequence{Colors.RESET}")
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"Finished test: {self._testMethodName}")
|
||||
logger.info("=" * 70)
|
||||
Reference in New Issue
Block a user