first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
0
tests/optimization/__init__.py
Normal file
0
tests/optimization/__init__.py
Normal file
417
tests/optimization/test_greedy_lr.py
Normal file
417
tests/optimization/test_greedy_lr.py
Normal file
@@ -0,0 +1,417 @@
|
||||
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers.optimization import GreedyLR, StreamingAverage, get_greedy_schedule
|
||||
|
||||
|
||||
@require_torch
|
||||
class GreedyLRTest(unittest.TestCase):
|
||||
def _get_scheduler(self, **kwargs):
|
||||
model = nn.Linear(10, 10)
|
||||
defaults = {"lr": 0.1}
|
||||
defaults.update(kwargs.pop("optim_kwargs", {}))
|
||||
optimizer = torch.optim.SGD(model.parameters(), **defaults)
|
||||
scheduler_kwargs = {
|
||||
"mode": "min",
|
||||
"factor": 0.9,
|
||||
"patience": 3,
|
||||
"min_lr": 1e-6,
|
||||
"max_lr": 1.0,
|
||||
"cooldown": 0,
|
||||
"warmup": 0,
|
||||
"verbose": False,
|
||||
}
|
||||
scheduler_kwargs.update(kwargs)
|
||||
scheduler = GreedyLR(optimizer, **scheduler_kwargs)
|
||||
return optimizer, scheduler
|
||||
|
||||
def test_initialization_valid_params(self):
|
||||
optimizer, scheduler = self._get_scheduler()
|
||||
self.assertEqual(scheduler.mode, "min")
|
||||
self.assertAlmostEqual(scheduler.factor, 0.9)
|
||||
self.assertEqual(scheduler.patience, 3)
|
||||
self.assertEqual(len(scheduler.min_lrs), len(optimizer.param_groups))
|
||||
self.assertEqual(len(scheduler.max_lrs), len(optimizer.param_groups))
|
||||
self.assertAlmostEqual(scheduler._last_lr[0], 0.1)
|
||||
|
||||
def test_initialization_max_mode(self):
|
||||
optimizer, scheduler = self._get_scheduler(mode="max")
|
||||
self.assertEqual(scheduler.mode, "max")
|
||||
self.assertEqual(scheduler.best, float("-inf"))
|
||||
|
||||
def test_initialization_invalid_factor(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self._get_scheduler(factor=1.0)
|
||||
with self.assertRaises(ValueError):
|
||||
self._get_scheduler(factor=1.5)
|
||||
|
||||
def test_initialization_invalid_mode(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self._get_scheduler(mode="unknown")
|
||||
|
||||
def test_initialization_invalid_threshold_mode(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self._get_scheduler(threshold_mode="unknown")
|
||||
|
||||
def test_initialization_not_optimizer(self):
|
||||
with self.assertRaises(TypeError):
|
||||
GreedyLR("not_an_optimizer")
|
||||
|
||||
def test_lr_decrease_on_plateau(self):
|
||||
optimizer, scheduler = self._get_scheduler(patience=3)
|
||||
initial_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# Establish a best metric
|
||||
scheduler.step(5.0)
|
||||
# Provide worse metrics for patience + 1 steps to trigger decrease
|
||||
for _ in range(4):
|
||||
scheduler.step(10.0)
|
||||
|
||||
self.assertLess(optimizer.param_groups[0]["lr"], initial_lr)
|
||||
expected_lr = initial_lr * 0.9
|
||||
self.assertAlmostEqual(optimizer.param_groups[0]["lr"], expected_lr, places=7)
|
||||
|
||||
def test_lr_increase_on_improvement(self):
|
||||
optimizer, scheduler = self._get_scheduler(patience=3)
|
||||
initial_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# Provide continuously improving metrics for patience + 1 steps
|
||||
metric = 10.0
|
||||
for _ in range(4):
|
||||
metric *= 0.8
|
||||
scheduler.step(metric)
|
||||
|
||||
self.assertGreater(optimizer.param_groups[0]["lr"], initial_lr)
|
||||
expected_lr = initial_lr / 0.9
|
||||
self.assertAlmostEqual(optimizer.param_groups[0]["lr"], expected_lr, places=7)
|
||||
|
||||
def test_lr_never_below_min_lr(self):
|
||||
optimizer, scheduler = self._get_scheduler(patience=1, min_lr=0.01, factor=0.5)
|
||||
|
||||
# Establish a best metric, then plateau repeatedly
|
||||
scheduler.step(1.0)
|
||||
for _ in range(50):
|
||||
scheduler.step(10.0)
|
||||
|
||||
self.assertGreaterEqual(optimizer.param_groups[0]["lr"], 0.01 - 1e-10)
|
||||
|
||||
def test_lr_never_above_max_lr(self):
|
||||
optimizer, scheduler = self._get_scheduler(patience=1, max_lr=0.2, factor=0.5, optim_kwargs={"lr": 0.1})
|
||||
|
||||
# Provide continuously improving metrics
|
||||
metric = 10.0
|
||||
for _ in range(50):
|
||||
metric *= 0.8
|
||||
scheduler.step(metric)
|
||||
|
||||
self.assertLessEqual(optimizer.param_groups[0]["lr"], 0.2 + 1e-10)
|
||||
|
||||
def test_cooldown_prevents_further_reduction(self):
|
||||
optimizer, scheduler = self._get_scheduler(patience=2, cooldown=3)
|
||||
|
||||
# Trigger a reduction
|
||||
scheduler.step(5.0)
|
||||
for _ in range(3):
|
||||
scheduler.step(10.0)
|
||||
|
||||
lr_after_reduction = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# During cooldown, more bad metrics should NOT trigger another reduction
|
||||
for _ in range(3):
|
||||
scheduler.step(10.0)
|
||||
self.assertAlmostEqual(optimizer.param_groups[0]["lr"], lr_after_reduction, places=7)
|
||||
|
||||
def test_warmup_prevents_further_increase(self):
|
||||
optimizer, scheduler = self._get_scheduler(patience=2, warmup=3)
|
||||
|
||||
# Trigger an increase
|
||||
metric = 10.0
|
||||
for _ in range(3):
|
||||
metric *= 0.8
|
||||
scheduler.step(metric)
|
||||
|
||||
lr_after_increase = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# During warmup, more good metrics should NOT trigger another increase
|
||||
for _ in range(3):
|
||||
metric *= 0.8
|
||||
scheduler.step(metric)
|
||||
self.assertAlmostEqual(optimizer.param_groups[0]["lr"], lr_after_increase, places=7)
|
||||
|
||||
def test_smoothing_uses_streaming_average(self):
|
||||
optimizer, scheduler = self._get_scheduler(smooth=True, window_size=3, patience=10)
|
||||
|
||||
self.assertIsNotNone(scheduler._streaming_avg)
|
||||
self.assertEqual(scheduler._streaming_avg.window_size, 3)
|
||||
|
||||
scheduler.step(10.0)
|
||||
scheduler.step(8.0)
|
||||
scheduler.step(6.0)
|
||||
scheduler.step(4.0)
|
||||
|
||||
# Window should be capped at size 3
|
||||
self.assertEqual(len(scheduler._streaming_avg.values), 3)
|
||||
# After 4 values with window 3, values are [8.0, 6.0, 4.0], avg = 6.0
|
||||
avg = scheduler._streaming_avg.sum / len(scheduler._streaming_avg.values)
|
||||
self.assertAlmostEqual(avg, 6.0, places=5)
|
||||
|
||||
def test_no_smoothing_by_default(self):
|
||||
_, scheduler = self._get_scheduler()
|
||||
self.assertIsNone(scheduler._streaming_avg)
|
||||
|
||||
def test_state_dict_round_trip(self):
|
||||
optimizer1, scheduler1 = self._get_scheduler(smooth=True, window_size=5)
|
||||
|
||||
# Build up state
|
||||
metrics = [10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 6.0, 7.0]
|
||||
for m in metrics:
|
||||
scheduler1.step(m)
|
||||
|
||||
state = scheduler1.state_dict()
|
||||
|
||||
# Create a new scheduler and load state
|
||||
optimizer2, scheduler2 = self._get_scheduler(smooth=True, window_size=5)
|
||||
scheduler2.load_state_dict(state)
|
||||
|
||||
self.assertEqual(scheduler2.best, scheduler1.best)
|
||||
self.assertEqual(scheduler2.num_bad_epochs, scheduler1.num_bad_epochs)
|
||||
self.assertEqual(scheduler2.num_good_epochs, scheduler1.num_good_epochs)
|
||||
self.assertEqual(scheduler2.last_epoch, scheduler1.last_epoch)
|
||||
self.assertEqual(scheduler2.cooldown_counter, scheduler1.cooldown_counter)
|
||||
self.assertEqual(scheduler2.warmup_counter, scheduler1.warmup_counter)
|
||||
self.assertAlmostEqual(optimizer2.param_groups[0]["lr"], optimizer1.param_groups[0]["lr"], places=7)
|
||||
|
||||
# Both schedulers should behave identically going forward
|
||||
for m in [5.0, 4.0, 3.0]:
|
||||
scheduler1.step(m)
|
||||
scheduler2.step(m)
|
||||
self.assertAlmostEqual(optimizer1.param_groups[0]["lr"], optimizer2.param_groups[0]["lr"], places=7)
|
||||
|
||||
def test_state_dict_contains_all_keys(self):
|
||||
_, scheduler = self._get_scheduler(smooth=True)
|
||||
scheduler.step(10.0)
|
||||
state = scheduler.state_dict()
|
||||
|
||||
required_keys = [
|
||||
"factor",
|
||||
"min_lrs",
|
||||
"max_lrs",
|
||||
"patience",
|
||||
"verbose",
|
||||
"cooldown",
|
||||
"warmup",
|
||||
"cooldown_counter",
|
||||
"warmup_counter",
|
||||
"mode",
|
||||
"threshold",
|
||||
"threshold_mode",
|
||||
"best",
|
||||
"num_bad_epochs",
|
||||
"num_good_epochs",
|
||||
"eps",
|
||||
"last_epoch",
|
||||
"smooth",
|
||||
"window_size",
|
||||
"reset_start",
|
||||
"reset_start_original",
|
||||
"_last_lr",
|
||||
"_init_lrs",
|
||||
"_streaming_avg",
|
||||
]
|
||||
for key in required_keys:
|
||||
self.assertIn(key, state)
|
||||
|
||||
def test_load_state_dict_backward_compatibility(self):
|
||||
_, scheduler = self._get_scheduler()
|
||||
|
||||
partial_state = {
|
||||
"factor": 0.8,
|
||||
"patience": 5,
|
||||
"best": 5.0,
|
||||
"num_bad_epochs": 3,
|
||||
}
|
||||
scheduler.load_state_dict(partial_state)
|
||||
|
||||
self.assertAlmostEqual(scheduler.factor, 0.8)
|
||||
self.assertEqual(scheduler.patience, 5)
|
||||
self.assertAlmostEqual(scheduler.best, 5.0)
|
||||
self.assertEqual(scheduler.num_bad_epochs, 3)
|
||||
# Missing keys should retain defaults
|
||||
self.assertEqual(scheduler.cooldown_counter, 0)
|
||||
|
||||
def test_factory_function(self):
|
||||
model = nn.Linear(10, 10)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
scheduler = get_greedy_schedule(optimizer, patience=5, min_lr=1e-5, factor=0.95)
|
||||
|
||||
self.assertIsInstance(scheduler, GreedyLR)
|
||||
self.assertEqual(scheduler.patience, 5)
|
||||
self.assertAlmostEqual(scheduler.factor, 0.95)
|
||||
|
||||
def test_factory_function_with_kwargs(self):
|
||||
model = nn.Linear(10, 10)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
scheduler = get_greedy_schedule(optimizer, mode="max", smooth=True, window_size=10)
|
||||
|
||||
self.assertIsInstance(scheduler, GreedyLR)
|
||||
self.assertEqual(scheduler.mode, "max")
|
||||
self.assertTrue(scheduler.smooth)
|
||||
self.assertIsNotNone(scheduler._streaming_avg)
|
||||
|
||||
def test_get_scheduler_integration(self):
|
||||
from transformers import get_scheduler
|
||||
|
||||
model = nn.Linear(10, 10)
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
|
||||
scheduler = get_scheduler(
|
||||
"greedy",
|
||||
optimizer=optimizer,
|
||||
scheduler_specific_kwargs={"patience": 5, "factor": 0.95},
|
||||
)
|
||||
self.assertIsInstance(scheduler, GreedyLR)
|
||||
self.assertEqual(scheduler.patience, 5)
|
||||
|
||||
def test_get_last_lr(self):
|
||||
optimizer, scheduler = self._get_scheduler()
|
||||
scheduler.step(10.0)
|
||||
last_lr = scheduler.get_last_lr()
|
||||
self.assertIsInstance(last_lr, list)
|
||||
self.assertEqual(len(last_lr), len(optimizer.param_groups))
|
||||
|
||||
def test_reset_at_min_lr(self):
|
||||
optimizer, scheduler = self._get_scheduler(patience=1, min_lr=0.01, factor=0.5, reset_start=2)
|
||||
initial_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# Drive LR to min_lr
|
||||
scheduler.step(1.0)
|
||||
for _ in range(100):
|
||||
scheduler.step(10.0)
|
||||
|
||||
# After reset, LR should be back to initial
|
||||
self.assertAlmostEqual(optimizer.param_groups[0]["lr"], initial_lr, places=7)
|
||||
|
||||
def test_max_mode_lr_decrease(self):
|
||||
optimizer, scheduler = self._get_scheduler(mode="max", patience=2)
|
||||
initial_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# Establish best, then provide worse (lower) metrics
|
||||
scheduler.step(10.0)
|
||||
for _ in range(3):
|
||||
scheduler.step(1.0)
|
||||
|
||||
self.assertLess(optimizer.param_groups[0]["lr"], initial_lr)
|
||||
|
||||
def test_max_mode_lr_increase(self):
|
||||
optimizer, scheduler = self._get_scheduler(mode="max", patience=2)
|
||||
initial_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# Provide continuously improving (higher) metrics
|
||||
metric = 1.0
|
||||
for _ in range(3):
|
||||
metric *= 1.5
|
||||
scheduler.step(metric)
|
||||
|
||||
self.assertGreater(optimizer.param_groups[0]["lr"], initial_lr)
|
||||
|
||||
def test_relative_threshold_mode(self):
|
||||
optimizer, scheduler = self._get_scheduler(threshold_mode="rel", threshold=0.1, patience=2)
|
||||
|
||||
# Best is 10.0. With rel threshold 0.1, improvement needs current < 10.0 * 0.9 = 9.0
|
||||
scheduler.step(10.0)
|
||||
# 9.5 is not good enough (9.5 > 9.0)
|
||||
scheduler.step(9.5)
|
||||
self.assertEqual(scheduler.num_bad_epochs, 1)
|
||||
|
||||
def test_multiple_param_groups(self):
|
||||
model = nn.Linear(10, 10)
|
||||
optimizer = torch.optim.SGD([{"params": model.weight, "lr": 0.1}, {"params": model.bias, "lr": 0.01}])
|
||||
scheduler = GreedyLR(optimizer, patience=2, factor=0.9, min_lr=1e-6)
|
||||
|
||||
self.assertEqual(len(scheduler.min_lrs), 2)
|
||||
self.assertEqual(len(scheduler.max_lrs), 2)
|
||||
|
||||
# Trigger reduction
|
||||
scheduler.step(5.0)
|
||||
for _ in range(3):
|
||||
scheduler.step(10.0)
|
||||
|
||||
self.assertAlmostEqual(optimizer.param_groups[0]["lr"], 0.1 * 0.9, places=7)
|
||||
self.assertAlmostEqual(optimizer.param_groups[1]["lr"], 0.01 * 0.9, places=7)
|
||||
|
||||
|
||||
@require_torch
|
||||
class StreamingAverageTest(unittest.TestCase):
|
||||
def test_basic_average(self):
|
||||
avg = StreamingAverage(window_size=3)
|
||||
self.assertAlmostEqual(avg.streamavg(1.0), 1.0)
|
||||
self.assertAlmostEqual(avg.streamavg(2.0), 1.5)
|
||||
self.assertAlmostEqual(avg.streamavg(3.0), 2.0)
|
||||
# Window full, oldest drops
|
||||
self.assertAlmostEqual(avg.streamavg(4.0), 3.0)
|
||||
|
||||
def test_state_dict_round_trip(self):
|
||||
avg1 = StreamingAverage(window_size=3)
|
||||
avg1.streamavg(1.0)
|
||||
avg1.streamavg(2.0)
|
||||
avg1.streamavg(3.0)
|
||||
|
||||
state = avg1.state_dict()
|
||||
avg2 = StreamingAverage(window_size=3)
|
||||
avg2.load_state_dict(state)
|
||||
|
||||
self.assertEqual(avg2.values, avg1.values)
|
||||
self.assertAlmostEqual(avg2.sum, avg1.sum)
|
||||
self.assertEqual(avg2.window_size, avg1.window_size)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BackwardCompatibilityTest(unittest.TestCase):
|
||||
def test_default_lr_scheduler_type_unchanged(self):
|
||||
from transformers import TrainingArguments
|
||||
|
||||
args = TrainingArguments(output_dir="./test_output")
|
||||
self.assertEqual(args.lr_scheduler_type, "linear")
|
||||
|
||||
def test_existing_schedulers_still_work(self):
|
||||
from transformers import get_scheduler
|
||||
|
||||
model = nn.Linear(10, 10)
|
||||
for sched_type in ["linear", "cosine", "constant", "constant_with_warmup"]:
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
|
||||
scheduler = get_scheduler(
|
||||
name=sched_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=5,
|
||||
num_training_steps=100,
|
||||
)
|
||||
self.assertIsNotNone(scheduler)
|
||||
# Run a few steps to verify it works
|
||||
for _ in range(5):
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
self.assertGreaterEqual(optimizer.param_groups[0]["lr"], 0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
231
tests/optimization/test_optimization.py
Normal file
231
tests/optimization/test_optimization.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
Adafactor,
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
get_cosine_schedule_with_warmup,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
get_inverse_sqrt_schedule,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_polynomial_decay_schedule_with_warmup,
|
||||
get_scheduler,
|
||||
get_wsd_schedule,
|
||||
)
|
||||
|
||||
|
||||
def unwrap_schedule(scheduler, num_steps=10):
|
||||
lrs = []
|
||||
for _ in range(num_steps):
|
||||
lrs.append(scheduler.get_lr()[0])
|
||||
scheduler.step()
|
||||
return lrs
|
||||
|
||||
|
||||
def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
|
||||
lrs = []
|
||||
for step in range(num_steps):
|
||||
lrs.append(scheduler.get_lr()[0])
|
||||
scheduler.step()
|
||||
if step == num_steps // 2:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file_name = os.path.join(tmpdirname, "schedule.bin")
|
||||
torch.save(scheduler.state_dict(), file_name)
|
||||
|
||||
state_dict = torch.load(file_name, weights_only=False)
|
||||
scheduler.load_state_dict(state_dict)
|
||||
return lrs
|
||||
|
||||
|
||||
@require_torch
|
||||
class OptimizationTest(unittest.TestCase):
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
self.assertEqual(len(list1), len(list2))
|
||||
for a, b in zip(list1, list2):
|
||||
self.assertAlmostEqual(a, b, delta=tol)
|
||||
|
||||
def test_adam_w(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = torch.optim.AdamW(params=[w], lr=2e-1, weight_decay=0.0)
|
||||
for _ in range(100):
|
||||
loss = criterion(w, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
|
||||
w.grad.zero_()
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
def test_adafactor(self):
|
||||
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
|
||||
target = torch.tensor([0.4, 0.2, -0.5])
|
||||
criterion = nn.MSELoss()
|
||||
# No warmup, constant schedule, no gradient clipping
|
||||
optimizer = Adafactor(
|
||||
params=[w],
|
||||
lr=1e-2,
|
||||
eps=(1e-30, 1e-3),
|
||||
clip_threshold=1.0,
|
||||
decay_rate=-0.8,
|
||||
beta1=None,
|
||||
weight_decay=0.0,
|
||||
relative_step=False,
|
||||
scale_parameter=False,
|
||||
warmup_init=False,
|
||||
)
|
||||
for _ in range(1000):
|
||||
loss = criterion(w, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
|
||||
w.grad.zero_()
|
||||
self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ScheduleInitTest(unittest.TestCase):
|
||||
m = nn.Linear(50, 50) if is_torch_available() else None
|
||||
optimizer = torch.optim.AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
|
||||
num_steps = 10
|
||||
|
||||
def assertListAlmostEqual(self, list1, list2, tol, msg=None):
|
||||
self.assertEqual(len(list1), len(list2))
|
||||
for a, b in zip(list1, list2):
|
||||
self.assertAlmostEqual(a, b, delta=tol, msg=msg)
|
||||
|
||||
def test_schedulers(self):
|
||||
common_kwargs = {"num_warmup_steps": 2, "num_training_steps": 10}
|
||||
# schedulers doct format
|
||||
# function: (sched_args_dict, expected_learning_rates)
|
||||
scheds = {
|
||||
get_constant_schedule: ({}, [10.0] * self.num_steps),
|
||||
get_constant_schedule_with_warmup: (
|
||||
{"num_warmup_steps": 4},
|
||||
[0.0, 2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0],
|
||||
),
|
||||
get_linear_schedule_with_warmup: (
|
||||
{**common_kwargs},
|
||||
[0.0, 5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25],
|
||||
),
|
||||
get_cosine_schedule_with_warmup: (
|
||||
{**common_kwargs},
|
||||
[0.0, 5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38],
|
||||
),
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup: (
|
||||
{**common_kwargs, "num_cycles": 2},
|
||||
[0.0, 5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46],
|
||||
),
|
||||
get_polynomial_decay_schedule_with_warmup: (
|
||||
{**common_kwargs, "power": 2.0, "lr_end": 1e-7},
|
||||
[0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156],
|
||||
),
|
||||
get_inverse_sqrt_schedule: (
|
||||
{"num_warmup_steps": 2},
|
||||
[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
|
||||
),
|
||||
get_wsd_schedule: (
|
||||
{**common_kwargs, "num_decay_steps": 2, "min_lr_ratio": 0.0},
|
||||
[0.0, 5.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 5.0],
|
||||
),
|
||||
}
|
||||
|
||||
for scheduler_func, data in scheds.items():
|
||||
kwargs, expected_learning_rates = data
|
||||
|
||||
scheduler = scheduler_func(self.optimizer, **kwargs)
|
||||
self.assertEqual(len([scheduler.get_lr()[0]]), 1)
|
||||
lrs_1 = unwrap_schedule(scheduler, self.num_steps)
|
||||
self.assertListAlmostEqual(
|
||||
lrs_1,
|
||||
expected_learning_rates,
|
||||
tol=1e-2,
|
||||
msg=f"failed for {scheduler_func} in normal scheduler",
|
||||
)
|
||||
|
||||
scheduler = scheduler_func(self.optimizer, **kwargs)
|
||||
if scheduler_func.__name__ != "get_constant_schedule":
|
||||
LambdaScheduleWrapper.wrap_scheduler(scheduler) # wrap to test picklability of the schedule
|
||||
lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
|
||||
self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload")
|
||||
|
||||
def test_get_scheduler(self):
|
||||
test_params = [
|
||||
{
|
||||
"name": "warmup_stable_decay",
|
||||
"optimizer": self.optimizer,
|
||||
"num_warmup_steps": 2,
|
||||
"num_training_steps": 10,
|
||||
"scheduler_specific_kwargs": {
|
||||
"num_decay_steps": 2,
|
||||
"warmup_type": "linear",
|
||||
"decay_type": "linear",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "warmup_stable_decay",
|
||||
"optimizer": self.optimizer,
|
||||
"num_warmup_steps": 2,
|
||||
"num_training_steps": 10,
|
||||
"scheduler_specific_kwargs": {
|
||||
"num_decay_steps": 2,
|
||||
"warmup_type": "cosine",
|
||||
"decay_type": "cosine",
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "warmup_stable_decay",
|
||||
"optimizer": self.optimizer,
|
||||
"num_warmup_steps": 2,
|
||||
"num_training_steps": 10,
|
||||
"scheduler_specific_kwargs": {
|
||||
"num_decay_steps": 2,
|
||||
"warmup_type": "1-sqrt",
|
||||
"decay_type": "1-sqrt",
|
||||
},
|
||||
},
|
||||
{"name": "cosine", "optimizer": self.optimizer, "num_warmup_steps": 2, "num_training_steps": 10},
|
||||
]
|
||||
|
||||
for param in test_params:
|
||||
self.assertTrue(get_scheduler(**param), msg=f"failed for {param['name']} in get_scheduler")
|
||||
|
||||
|
||||
class LambdaScheduleWrapper:
|
||||
"""See https://github.com/huggingface/transformers/issues/21689"""
|
||||
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.fn(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def wrap_scheduler(cls, scheduler):
|
||||
scheduler.lr_lambdas = list(map(cls, scheduler.lr_lambdas))
|
||||
Reference in New Issue
Block a user