first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
853
tests/trainer/test_trainer_optimizers.py
Normal file
853
tests/trainer/test_trainer_optimizers.py
Normal file
@@ -0,0 +1,853 @@
|
||||
# Copyright 2018 the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Trainer optimizer and LR scheduler tests: custom optimizers, LR scheduler kwargs, cosine-with-min-lr,
|
||||
reduce-on-plateau, Adafactor, bitsandbytes (RMSProp, AdEMAMix), LOMO, GrokAdamW, schedule-free,
|
||||
GaLore, Apollo, Stable AdamW, Liger kernel, optimizer choice resolution, factory pattern detection,
|
||||
and model parameter inspection.
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import (
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
require_apollo_torch,
|
||||
require_bitsandbytes,
|
||||
require_galore_torch,
|
||||
require_grokadamw,
|
||||
require_lomo,
|
||||
require_schedulefree,
|
||||
require_torch,
|
||||
require_torch_accelerator,
|
||||
require_torch_optimi,
|
||||
)
|
||||
from transformers.trainer_utils import check_target_module_exists
|
||||
|
||||
from .trainer_test_utils import (
|
||||
BasicTextGenerationModel,
|
||||
RegressionDataset,
|
||||
RegressionModel,
|
||||
RepeatDataset,
|
||||
TorchTracemalloc,
|
||||
TrainerIntegrationCommon,
|
||||
TstLayer,
|
||||
bytes2megabytes,
|
||||
get_regression_trainer,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
_ATTN_MLP_TARGET_MODULES = [r".*attn.*", r".*mlp.*"]
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerOptimizerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
args = TrainingArguments("..")
|
||||
self.n_epochs = args.num_train_epochs
|
||||
self.batch_size = args.train_batch_size
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _get_llama_and_dataset(self):
|
||||
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
|
||||
model = LlamaForCausalLM(config)
|
||||
train_dataset = RepeatDataset(torch.randint(0, 100, (128,)))
|
||||
return model, train_dataset
|
||||
|
||||
def _get_gpt2_and_dataset(self):
|
||||
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
|
||||
model = GPT2LMHeadModel(config)
|
||||
train_dataset = RepeatDataset(torch.randint(0, 100, (128,)))
|
||||
return model, train_dataset
|
||||
|
||||
def _train_with_llama(self, optim, optim_target_modules=None, **extra_kwargs):
|
||||
"""Smoke-test: tiny Llama + RepeatDataset with the given optimizer."""
|
||||
tiny_llama, train_dataset = self._get_llama_and_dataset()
|
||||
kwargs = {"learning_rate": 1e-9, "logging_steps": 5, "optim": optim}
|
||||
if optim_target_modules is not None:
|
||||
kwargs["optim_target_modules"] = optim_target_modules
|
||||
kwargs.update(extra_kwargs)
|
||||
args = TrainingArguments(self.get_auto_remove_tmp_dir(), **kwargs)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
return trainer
|
||||
|
||||
def _check_lr_display_without_scheduler(self, optim, optim_target_modules):
|
||||
"""Verify that LR is correctly reported without an LR scheduler."""
|
||||
tiny_llama, train_dataset = self._get_llama_and_dataset()
|
||||
learning_rate = 1e-9
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=learning_rate,
|
||||
logging_steps=5,
|
||||
optim=optim,
|
||||
optim_target_modules=optim_target_modules,
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
trainer.create_optimizer_and_scheduler(num_training_steps=10)
|
||||
self.assertEqual(trainer.get_learning_rates(), [learning_rate, learning_rate])
|
||||
|
||||
def _check_lr_display_with_scheduler(self, optim, optim_target_modules, num_train_epochs=2):
|
||||
"""Verify warmup + cosine LR schedule: increases then decreases."""
|
||||
tiny_llama, train_dataset = self._get_llama_and_dataset()
|
||||
learning_rate = 2e-4
|
||||
num_warmup_steps = 5
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
num_train_epochs=num_train_epochs,
|
||||
learning_rate=learning_rate,
|
||||
warmup_steps=num_warmup_steps,
|
||||
lr_scheduler_type="cosine",
|
||||
logging_steps=1,
|
||||
optim=optim,
|
||||
optim_target_modules=optim_target_modules,
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
logs = trainer.state.log_history[1:-1]
|
||||
|
||||
self.assertTrue(logs[num_warmup_steps - 1]["learning_rate"] == learning_rate)
|
||||
self.assertTrue(np.allclose(logs[-1]["learning_rate"], 0, atol=5e-6))
|
||||
|
||||
increasing_lrs = [
|
||||
logs[i]["learning_rate"] < logs[i + 1]["learning_rate"]
|
||||
for i in range(len(logs))
|
||||
if i < num_warmup_steps - 1
|
||||
]
|
||||
decreasing_lrs = [
|
||||
logs[i]["learning_rate"] > logs[i + 1]["learning_rate"]
|
||||
for i in range(len(logs) - 1)
|
||||
if i >= num_warmup_steps - 1
|
||||
]
|
||||
|
||||
self.assertTrue(all(increasing_lrs))
|
||||
self.assertTrue(all(decreasing_lrs))
|
||||
self.assertTrue(len(decreasing_lrs) > len(increasing_lrs))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# adafactor optmizer test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_adafactor_lr_none(self):
|
||||
# test the special case where lr=None, since Trainer can't not have lr_scheduler
|
||||
|
||||
from transformers.optimization import Adafactor, AdafactorSchedule
|
||||
|
||||
train_dataset = RegressionDataset()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(tmp_dir)
|
||||
model = RegressionModel()
|
||||
optimizer = Adafactor(
|
||||
model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None
|
||||
)
|
||||
lr_scheduler = AdafactorSchedule(optimizer)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
|
||||
trainer.train()
|
||||
|
||||
# Train a default model to compare against
|
||||
default_trainer = get_regression_trainer(learning_rate=0.1, output_dir=tmp_dir)
|
||||
default_trainer.train()
|
||||
|
||||
self.assertFalse(torch.allclose(trainer.model.a, default_trainer.model.a))
|
||||
self.assertFalse(torch.allclose(trainer.model.b, default_trainer.model.b))
|
||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BNB optimizer tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@parameterized.expand(["rmsprop_bnb", "ademamix", "ademamix_8bit", "rmsprop_bnb_8bit", "rmsprop_bnb_32bit"])
|
||||
@require_bitsandbytes
|
||||
def test_bnb_optim(self, optim):
|
||||
tiny_gpt2, train_dataset = self._get_gpt2_and_dataset()
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
logging_nan_inf_filter=False,
|
||||
optim=optim,
|
||||
)
|
||||
Trainer(tiny_gpt2, args, train_dataset=train_dataset).train()
|
||||
|
||||
@require_bitsandbytes
|
||||
def test_bnb_8bit_optimizer_skip_embedding(self):
|
||||
model = BasicTextGenerationModel(8, 4)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for name_optim in ["rmsprop_bnb_8bit", "adamw_8bit"]:
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
optim=name_optim,
|
||||
)
|
||||
trainer = Trainer(model=model, args=args)
|
||||
optimizer = trainer.create_optimizer()
|
||||
modules = optimizer.mng.module_weight_config_triple
|
||||
self.assertNotEqual(len(modules), 0)
|
||||
module, name, config = modules[0]
|
||||
self.assertIsInstance(module, torch.nn.Embedding)
|
||||
self.assertEqual(name, "weight")
|
||||
self.assertDictEqual(config, {"optim_bits": 32})
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LOMO tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@require_lomo
|
||||
@require_torch_accelerator
|
||||
def test_lomo(self):
|
||||
tiny_llama, train_dataset = self._get_llama_and_dataset()
|
||||
previous_params = {n: p.clone() for n, p in tiny_llama.named_parameters()}
|
||||
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(), learning_rate=1e-2, logging_steps=5, optim="lomo", max_steps=20
|
||||
)
|
||||
Trainer(tiny_llama, args, train_dataset=train_dataset).train()
|
||||
|
||||
for name, param in tiny_llama.named_parameters():
|
||||
self.assertFalse(torch.allclose(param, previous_params[name].to(param.device), rtol=1e-12, atol=1e-12))
|
||||
|
||||
@require_lomo
|
||||
@require_torch_accelerator
|
||||
def test_adalomo(self):
|
||||
self._train_with_llama("adalomo")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GrokAdamW test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@require_grokadamw
|
||||
@require_torch_accelerator
|
||||
def test_grokadamw(self):
|
||||
self._train_with_llama("grokadamw", learning_rate=2e-5, max_steps=20)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schedule-free tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@parameterized.expand([("schedule_free_adamw",), ("schedule_free_radam",)])
|
||||
@require_schedulefree
|
||||
@require_torch_accelerator
|
||||
def test_schedulefree(self, optim):
|
||||
self._train_with_llama(optim, lr_scheduler_type="constant")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GaLore tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_galore_matched_modules(self):
|
||||
regex_patterns = [r".*.attn.*", r".*.mlp.*"]
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, True]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(regex_patterns, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertTrue(is_regex)
|
||||
|
||||
exact_patterns = ["q_proj", "up_proj"]
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, True]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(exact_patterns, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertFalse(is_regex)
|
||||
|
||||
simple_regex = r".*.attn.*"
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, False]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertTrue(is_regex)
|
||||
|
||||
simple_regex = "model.transformer.h.0.attn.q_proj"
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, False]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(simple_regex, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertFalse(is_regex)
|
||||
|
||||
target_modules = ["attn", "mlp"]
|
||||
|
||||
module_names = [
|
||||
"model.transformer.h.0.ln_1",
|
||||
"model.transformer.h.0.attn.q_proj",
|
||||
"model.lm_head",
|
||||
"model.transformer.h.0.mlp.up_proj",
|
||||
]
|
||||
expected_values = [False, True, False, True]
|
||||
|
||||
for expected_value, module_name in zip(expected_values, module_names):
|
||||
is_module_matched, is_regex = check_target_module_exists(target_modules, module_name, return_is_regex=True)
|
||||
self.assertTrue(is_module_matched == expected_value)
|
||||
if is_module_matched:
|
||||
self.assertFalse(is_regex)
|
||||
|
||||
@parameterized.expand([("galore_adamw",), ("galore_adamw_layerwise",), ("galore_adamw_8bit",)])
|
||||
@require_galore_torch
|
||||
@require_torch_accelerator
|
||||
def test_galore(self, optim):
|
||||
self._train_with_llama(optim, optim_target_modules=_ATTN_MLP_TARGET_MODULES)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_accelerator
|
||||
def test_galore_extra_args(self):
|
||||
self._train_with_llama(
|
||||
"galore_adamw",
|
||||
optim_target_modules=_ATTN_MLP_TARGET_MODULES,
|
||||
optim_args="rank=64, update_proj_gap=100, scale=0.10",
|
||||
)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_accelerator
|
||||
def test_galore_layerwise_with_scheduler(self):
|
||||
self._train_with_llama(
|
||||
"galore_adamw_layerwise",
|
||||
optim_target_modules=_ATTN_MLP_TARGET_MODULES,
|
||||
lr_scheduler_type="cosine",
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
(_ATTN_MLP_TARGET_MODULES,),
|
||||
(["q_proj", "k_proj", "v_proj"],),
|
||||
("all-linear",),
|
||||
]
|
||||
)
|
||||
@require_galore_torch
|
||||
@require_torch_accelerator
|
||||
def test_galore_adafactor(self, optim_target_modules):
|
||||
upper_bound_pm = 700
|
||||
lower_bound_pm = 650
|
||||
tiny_llama, train_dataset = self._get_llama_and_dataset()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir, TorchTracemalloc() as tracemalloc:
|
||||
args = TrainingArguments(
|
||||
tmpdir,
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
optim="galore_adafactor",
|
||||
optim_target_modules=optim_target_modules,
|
||||
)
|
||||
Trainer(tiny_llama, args, train_dataset=train_dataset).train()
|
||||
|
||||
galore_peak_memory = tracemalloc.peaked + bytes2megabytes(tracemalloc.begin)
|
||||
self.assertTrue(galore_peak_memory < upper_bound_pm)
|
||||
self.assertTrue(lower_bound_pm < galore_peak_memory)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_accelerator
|
||||
def test_galore_lr_display_without_scheduler(self):
|
||||
self._check_lr_display_without_scheduler("galore_adamw", _ATTN_MLP_TARGET_MODULES)
|
||||
|
||||
@require_galore_torch
|
||||
@require_torch_accelerator
|
||||
def test_galore_lr_display_with_scheduler(self):
|
||||
self._check_lr_display_with_scheduler("galore_adamw", _ATTN_MLP_TARGET_MODULES)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Apollo tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@parameterized.expand([("apollo_adamw",), ("apollo_adamw_layerwise",)])
|
||||
@require_apollo_torch
|
||||
@require_torch_accelerator
|
||||
def test_apollo(self, optim):
|
||||
self._train_with_llama(optim, optim_target_modules=_ATTN_MLP_TARGET_MODULES)
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_accelerator
|
||||
def test_apollo_extra_args(self):
|
||||
self._train_with_llama(
|
||||
"apollo_adamw",
|
||||
optim_target_modules=_ATTN_MLP_TARGET_MODULES,
|
||||
optim_args="proj=random,scale_type=tensor,rank=1,update_proj_gap=100,scale=128.0",
|
||||
)
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_accelerator
|
||||
def test_apollo_layerwise_with_scheduler(self):
|
||||
self._train_with_llama(
|
||||
"apollo_adamw_layerwise",
|
||||
optim_target_modules=_ATTN_MLP_TARGET_MODULES,
|
||||
lr_scheduler_type="cosine",
|
||||
)
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_accelerator
|
||||
def test_apollo_lr_display_without_scheduler(self):
|
||||
self._check_lr_display_without_scheduler("apollo_adamw", _ATTN_MLP_TARGET_MODULES)
|
||||
|
||||
@require_apollo_torch
|
||||
@require_torch_accelerator
|
||||
def test_apollo_lr_display_with_scheduler(self):
|
||||
self._check_lr_display_with_scheduler("apollo_adamw", _ATTN_MLP_TARGET_MODULES, num_train_epochs=10)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Stable AdamW tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@require_torch_optimi
|
||||
@require_torch_accelerator
|
||||
def test_stable_adamw(self):
|
||||
self._train_with_llama("stable_adamw", optim_target_modules=_ATTN_MLP_TARGET_MODULES)
|
||||
|
||||
@require_torch_optimi
|
||||
@require_torch_accelerator
|
||||
def test_stable_adamw_extra_args(self):
|
||||
self._train_with_llama(
|
||||
"stable_adamw",
|
||||
optim_target_modules=_ATTN_MLP_TARGET_MODULES,
|
||||
optim_args="decouple_lr=True,max_lr=1e-3,kahan_sum=True",
|
||||
)
|
||||
|
||||
@require_torch_optimi
|
||||
@require_torch_accelerator
|
||||
def test_stable_adamw_trainer_adamw_args(self):
|
||||
tiny_llama, train_dataset = self._get_llama_and_dataset()
|
||||
args = TrainingArguments(
|
||||
self.get_auto_remove_tmp_dir(),
|
||||
learning_rate=1e-9,
|
||||
logging_steps=5,
|
||||
weight_decay=0.001,
|
||||
adam_beta1=0.89,
|
||||
adam_beta2=0.98,
|
||||
adam_epsilon=1e-8,
|
||||
optim="stable_adamw",
|
||||
optim_target_modules=_ATTN_MLP_TARGET_MODULES,
|
||||
)
|
||||
trainer = Trainer(tiny_llama, args, train_dataset=train_dataset)
|
||||
trainer.create_optimizer_and_scheduler(num_training_steps=10)
|
||||
|
||||
# check StableAdamW optimizer is created with the correct parameters
|
||||
self.assertEqual(trainer.optimizer.defaults["beta1"], args.adam_beta1)
|
||||
self.assertEqual(trainer.optimizer.defaults["beta2"], args.adam_beta2)
|
||||
self.assertEqual(trainer.optimizer.defaults["eps"], args.adam_epsilon)
|
||||
self.assertEqual(trainer.optimizer.defaults["weight_decay"], args.weight_decay)
|
||||
|
||||
@require_torch_optimi
|
||||
@require_torch_accelerator
|
||||
def test_stable_adamw_lr_display_without_scheduler(self):
|
||||
self._check_lr_display_without_scheduler("stable_adamw", _ATTN_MLP_TARGET_MODULES)
|
||||
|
||||
@require_torch_optimi
|
||||
@require_torch_accelerator
|
||||
def test_stable_adamw_lr_display_with_scheduler(self):
|
||||
self._check_lr_display_with_scheduler("stable_adamw", _ATTN_MLP_TARGET_MODULES, num_train_epochs=10)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Misc optimizer tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_optimizer_factory_pattern(self):
|
||||
"""Test that is_optimizer_factory correctly identifies factory classes vs optimizer classes."""
|
||||
from transformers.trainer_optimizer import is_optimizer_factory
|
||||
|
||||
# Create a mock optimizer class
|
||||
class MockComplexOptimizer(torch.optim.Optimizer):
|
||||
def __init__(self, params, lr=1e-3):
|
||||
defaults = {"lr": lr}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
pass
|
||||
|
||||
# Create a factory class (simulates Muon/Dion pattern)
|
||||
class MockOptimizerFactory:
|
||||
def __call__(self, opt_model, **optimizer_kwargs):
|
||||
all_params = list(opt_model.parameters())
|
||||
return MockComplexOptimizer(all_params, **optimizer_kwargs)
|
||||
|
||||
# Verify is_optimizer_factory correctly identifies factories vs optimizer classes
|
||||
self.assertFalse(is_optimizer_factory(MockComplexOptimizer)) # Optimizer class should return False
|
||||
self.assertTrue(is_optimizer_factory(MockOptimizerFactory)) # Factory class should return True
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Optimizer group and learning rate inspection tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_optimizer_group(self):
|
||||
model = nn.Sequential(nn.Linear(128, 64))
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = Trainer(model=model, args=TrainingArguments(output_dir=tmp_dir))
|
||||
# ValueError is raised if optimizer is None
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.get_optimizer_group()
|
||||
trainer.create_optimizer()
|
||||
# Get groups
|
||||
num_groups = len(trainer.get_optimizer_group())
|
||||
self.assertEqual(num_groups, 2)
|
||||
# Get group of parameter
|
||||
param = next(model.parameters())
|
||||
group = trainer.get_optimizer_group(param)
|
||||
self.assertIn(param, group["params"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom optimizer and LR scheduler tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TrainerOptimizerTest(TestCasePlus):
|
||||
def test_get_optimizer_group(self):
|
||||
model = nn.Sequential(nn.Linear(128, 64))
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = Trainer(model=model, args=TrainingArguments(output_dir=tmp_dir))
|
||||
# ValueError is raised if optimizer is None
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.get_optimizer_group()
|
||||
trainer.create_optimizer()
|
||||
# Get groups
|
||||
num_groups = len(trainer.get_optimizer_group())
|
||||
self.assertEqual(num_groups, 2)
|
||||
# Get group of parameter
|
||||
param = next(model.parameters())
|
||||
group = trainer.get_optimizer_group(param)
|
||||
self.assertIn(param, group["params"])
|
||||
|
||||
def test_optimizer_factory_pattern(self):
|
||||
"""Test that is_optimizer_factory correctly identifies factory classes vs optimizer classes."""
|
||||
from transformers.trainer_optimizer import is_optimizer_factory
|
||||
|
||||
# Create a mock optimizer class
|
||||
class MockComplexOptimizer(torch.optim.Optimizer):
|
||||
def __init__(self, params, lr=1e-3):
|
||||
defaults = {"lr": lr}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
pass
|
||||
|
||||
# Create a factory class (simulates Muon/Dion pattern)
|
||||
class MockOptimizerFactory:
|
||||
def __call__(self, opt_model, **optimizer_kwargs):
|
||||
all_params = list(opt_model.parameters())
|
||||
return MockComplexOptimizer(all_params, **optimizer_kwargs)
|
||||
|
||||
# Verify is_optimizer_factory correctly identifies factories vs optimizer classes
|
||||
self.assertFalse(is_optimizer_factory(MockComplexOptimizer)) # Optimizer class should return False
|
||||
self.assertTrue(is_optimizer_factory(MockOptimizerFactory)) # Factory class should return True
|
||||
|
||||
def test_custom_optimizer(self):
|
||||
train_dataset = RegressionDataset()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(tmp_dir)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
|
||||
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: 1.0)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset, optimizers=(optimizer, lr_scheduler))
|
||||
trainer.train()
|
||||
|
||||
# Train a default model to compare against
|
||||
default_trainer = get_regression_trainer(learning_rate=0.1, output_dir=tmp_dir)
|
||||
default_trainer.train()
|
||||
|
||||
self.assertFalse(torch.allclose(trainer.model.a, default_trainer.model.a))
|
||||
self.assertFalse(torch.allclose(trainer.model.b, default_trainer.model.b))
|
||||
self.assertEqual(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 1.0)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Weight decay parameter groups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_no_wd_param_group(self):
|
||||
model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)]))
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = Trainer(model=model, args=TrainingArguments(output_dir=tmp_dir))
|
||||
trainer.create_optimizer_and_scheduler(10)
|
||||
wd_names = ['0.linear1.weight', '0.linear2.weight', '1.0.linear1.weight', '1.0.linear2.weight', '1.1.linear1.weight', '1.1.linear2.weight'] # fmt: skip
|
||||
wd_params = [p for n, p in model.named_parameters() if n in wd_names]
|
||||
no_wd_params = [p for n, p in model.named_parameters() if n not in wd_names]
|
||||
self.assertListEqual(trainer.optimizer.param_groups[0]["params"], wd_params)
|
||||
self.assertListEqual(trainer.optimizer.param_groups[1]["params"], no_wd_params)
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerLRTest(TestCasePlus):
|
||||
def test_get_learning_rates(self):
|
||||
model = nn.Sequential(nn.Linear(128, 64))
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = Trainer(model=model, args=TrainingArguments(output_dir=tmp_dir))
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.get_learning_rates()
|
||||
trainer.create_optimizer()
|
||||
self.assertEqual(trainer.get_learning_rates(), [5e-05, 5e-05])
|
||||
|
||||
def test_lr_scheduler_kwargs(self):
|
||||
from transformers import get_polynomial_decay_schedule_with_warmup
|
||||
|
||||
# test scheduler kwargs passed via TrainingArguments
|
||||
train_dataset = RegressionDataset()
|
||||
model = RegressionModel()
|
||||
num_steps, num_warmup_steps = 10, 2
|
||||
extra_kwargs = {"power": 5.0, "lr_end": 1e-5} # Non-default arguments
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
tmp_dir,
|
||||
lr_scheduler_type="polynomial",
|
||||
lr_scheduler_kwargs=extra_kwargs,
|
||||
learning_rate=0.2,
|
||||
warmup_steps=num_warmup_steps,
|
||||
)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
|
||||
|
||||
# Checking that the scheduler was created
|
||||
self.assertIsNotNone(trainer.lr_scheduler)
|
||||
|
||||
# Checking that the correct args were passed
|
||||
sched1 = trainer.lr_scheduler
|
||||
sched2 = get_polynomial_decay_schedule_with_warmup(
|
||||
trainer.optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_steps, **extra_kwargs
|
||||
)
|
||||
self.assertEqual(sched1.lr_lambdas[0].args, sched2.lr_lambdas[0].args)
|
||||
self.assertEqual(sched1.lr_lambdas[0].keywords, sched2.lr_lambdas[0].keywords)
|
||||
|
||||
def test_cosine_with_min_lr_scheduler(self):
|
||||
train_dataset = RegressionDataset()
|
||||
model = RegressionModel()
|
||||
num_steps, num_warmup_steps = 10, 2
|
||||
extra_kwargs = {"min_lr": 1e-5} # Non-default arguments
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
tmp_dir,
|
||||
lr_scheduler_type="cosine_with_min_lr",
|
||||
lr_scheduler_kwargs=extra_kwargs,
|
||||
learning_rate=0.2,
|
||||
warmup_steps=num_warmup_steps,
|
||||
)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
|
||||
|
||||
# Checking that the scheduler was created
|
||||
self.assertIsNotNone(trainer.lr_scheduler)
|
||||
|
||||
# Check the last learning rate
|
||||
for _ in range(num_steps):
|
||||
trainer.lr_scheduler.step()
|
||||
self.assertEqual(trainer.lr_scheduler.get_last_lr()[0], 1e-5)
|
||||
|
||||
def test_cosine_with_min_lr_schedule_with_warmup_lr_rate(self):
|
||||
train_dataset = RegressionDataset()
|
||||
model = RegressionModel()
|
||||
num_steps, num_warmup_steps = 10, 2
|
||||
extra_kwargs = {"min_lr": 1e-5} # Non-default arguments
|
||||
args = TrainingArguments(
|
||||
"./regression",
|
||||
lr_scheduler_type="cosine_warmup_with_min_lr",
|
||||
lr_scheduler_kwargs=extra_kwargs,
|
||||
learning_rate=0.2,
|
||||
warmup_steps=num_warmup_steps,
|
||||
)
|
||||
trainer = Trainer(model, args, train_dataset=train_dataset)
|
||||
trainer.create_optimizer_and_scheduler(num_training_steps=num_steps)
|
||||
|
||||
# Checking that the scheduler was created
|
||||
self.assertIsNotNone(trainer.lr_scheduler)
|
||||
|
||||
# Check the last learning rate
|
||||
step_lrs = []
|
||||
for _ in range(num_steps):
|
||||
step_lrs.append(trainer.optimizer.param_groups[0]["lr"])
|
||||
trainer.lr_scheduler.step()
|
||||
self.assertEqual(step_lrs[0], 0.1)
|
||||
self.assertEqual(step_lrs[1], 0.2)
|
||||
self.assertEqual(step_lrs[-1], 1e-05)
|
||||
|
||||
def test_reduce_lr_on_plateau_args(self):
|
||||
# test passed arguments for a custom ReduceLROnPlateau scheduler
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
eval_dataset = RegressionDataset(length=64)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
tmp_dir,
|
||||
eval_strategy="epoch",
|
||||
metric_for_best_model="eval_loss",
|
||||
)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
|
||||
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.2, patience=5, cooldown=2)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
optimizers=(optimizer, lr_scheduler),
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
self.assertEqual(trainer.lr_scheduler.factor, 0.2)
|
||||
self.assertEqual(trainer.lr_scheduler.patience, 5)
|
||||
self.assertEqual(trainer.lr_scheduler.cooldown, 2)
|
||||
|
||||
def test_reduce_lr_on_plateau(self):
|
||||
# test the ReduceLROnPlateau scheduler
|
||||
|
||||
class TrainerWithLRLogs(Trainer):
|
||||
def log(self, logs):
|
||||
# the LR is computed after metrics and does not exist for the first epoch
|
||||
if hasattr(self.lr_scheduler, "_last_lr"):
|
||||
logs["learning_rate"] = self.lr_scheduler._last_lr[0]
|
||||
super().log(logs)
|
||||
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
eval_dataset = RegressionDataset(length=64)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
tmp_dir,
|
||||
lr_scheduler_type="reduce_lr_on_plateau",
|
||||
eval_strategy="epoch",
|
||||
metric_for_best_model="eval_loss",
|
||||
num_train_epochs=10,
|
||||
learning_rate=0.2,
|
||||
)
|
||||
model = RegressionModel()
|
||||
trainer = TrainerWithLRLogs(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||
trainer.train()
|
||||
|
||||
self.assertIsInstance(trainer.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
|
||||
patience = trainer.lr_scheduler.patience
|
||||
|
||||
logs = trainer.state.log_history[1:]
|
||||
best_loss = logs[0]["eval_loss"]
|
||||
bad_epochs = 0
|
||||
for i, log in enumerate(logs[:-1]): # Compare learning rate to next epoch's
|
||||
loss = log["eval_loss"]
|
||||
just_decreased = False
|
||||
if loss > best_loss:
|
||||
bad_epochs += 1
|
||||
if bad_epochs > patience:
|
||||
self.assertLess(logs[i + 1]["learning_rate"], log["learning_rate"])
|
||||
just_decreased = True
|
||||
bad_epochs = 0
|
||||
else:
|
||||
best_loss = loss
|
||||
bad_epochs = 0
|
||||
if not just_decreased:
|
||||
self.assertEqual(logs[i + 1]["learning_rate"], log["learning_rate"])
|
||||
|
||||
def test_greedy_lr_args(self):
|
||||
# test passed arguments for a custom GreedyLR scheduler
|
||||
from transformers.optimization import GreedyLR
|
||||
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
eval_dataset = RegressionDataset(length=64)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
tmp_dir,
|
||||
eval_strategy="epoch",
|
||||
metric_for_best_model="eval_loss",
|
||||
)
|
||||
model = RegressionModel()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
|
||||
lr_scheduler = GreedyLR(optimizer, factor=0.8, patience=5, cooldown=2)
|
||||
trainer = Trainer(
|
||||
model,
|
||||
args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
optimizers=(optimizer, lr_scheduler),
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
self.assertIsInstance(trainer.lr_scheduler, GreedyLR)
|
||||
self.assertEqual(trainer.lr_scheduler.factor, 0.8)
|
||||
self.assertEqual(trainer.lr_scheduler.patience, 5)
|
||||
self.assertEqual(trainer.lr_scheduler.cooldown, 2)
|
||||
|
||||
def test_greedy_lr(self):
|
||||
# test the GreedyLR scheduler
|
||||
from transformers.optimization import GreedyLR
|
||||
|
||||
class TrainerWithLRLogs(Trainer):
|
||||
def log(self, logs):
|
||||
if hasattr(self.lr_scheduler, "_last_lr"):
|
||||
logs["learning_rate"] = self.lr_scheduler._last_lr[0]
|
||||
super().log(logs)
|
||||
|
||||
train_dataset = RegressionDataset(length=64)
|
||||
eval_dataset = RegressionDataset(length=64)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = TrainingArguments(
|
||||
tmp_dir,
|
||||
lr_scheduler_type="greedy",
|
||||
lr_scheduler_kwargs={"patience": 1, "factor": 0.5},
|
||||
eval_strategy="epoch",
|
||||
metric_for_best_model="eval_loss",
|
||||
num_train_epochs=10,
|
||||
learning_rate=0.2,
|
||||
)
|
||||
model = RegressionModel()
|
||||
trainer = TrainerWithLRLogs(model, args, train_dataset=train_dataset, eval_dataset=eval_dataset)
|
||||
trainer.train()
|
||||
|
||||
self.assertIsInstance(trainer.lr_scheduler, GreedyLR)
|
||||
# Verify LR was adjusted at least once during training
|
||||
logs = trainer.state.log_history[1:]
|
||||
lr_values = [log["learning_rate"] for log in logs if "learning_rate" in log]
|
||||
self.assertTrue(len(set(lr_values)) > 1, "GreedyLR should have adjusted the LR at least once")
|
||||
Reference in New Issue
Block a user