first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
This commit is contained in:
153
examples/scheduler/run_greedy.py
Normal file
153
examples/scheduler/run_greedy.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Pre-train or fine-tune a causal language model using GreedyLR or Cosine scheduler.
|
||||
|
||||
Example usage:
|
||||
|
||||
# Pre-train with GreedyLR (default):
|
||||
python run_greedy.py
|
||||
|
||||
# Pre-train with cosine scheduler for comparison:
|
||||
python run_greedy.py --lr_scheduler_type cosine
|
||||
|
||||
# Use a different model:
|
||||
python run_greedy.py --model_name_or_path Qwen/Qwen3-0.6B
|
||||
|
||||
# Fine-tune from a pretrained checkpoint:
|
||||
python run_greedy.py --model_name_or_path meta-llama/Llama-3.2-1B --finetune
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForLanguageModeling,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Pre-train/fine-tune a causal LM with GreedyLR or Cosine scheduler")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
default="meta-llama/Llama-3.2-1B",
|
||||
help="Model identifier from huggingface.co/models or path to a local checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default="wikitext",
|
||||
help="The name of the dataset to use (via datasets library)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default="wikitext-2-raw-v1",
|
||||
help="The configuration name of the dataset",
|
||||
)
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="greedy", choices=["greedy", "cosine"])
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-4)
|
||||
parser.add_argument("--max_steps", type=int, default=2000)
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=2)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--output_dir", type=str, default="./output")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument(
|
||||
"--finetune",
|
||||
action="store_true",
|
||||
help="Fine-tune from pretrained weights instead of training from scratch",
|
||||
)
|
||||
parser.add_argument("--block_size", type=int, default=512, help="Context length for tokenization")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
set_seed(args.seed)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
if args.finetune:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
|
||||
param_count = sum(p.numel() for p in model.parameters())
|
||||
logger.info(f"Model: {args.model_name_or_path} ({param_count / 1e6:.1f}M parameters)")
|
||||
|
||||
raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples["text"], truncation=True, max_length=args.block_size)
|
||||
|
||||
tokenized_datasets = raw_datasets.map(
|
||||
tokenize_function, batched=True, remove_columns=raw_datasets["train"].column_names
|
||||
)
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
scheduler_kwargs = {}
|
||||
if args.lr_scheduler_type == "greedy":
|
||||
scheduler_kwargs = {"patience": 10, "factor": 0.99, "min_lr": 1e-5}
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
per_device_train_batch_size=args.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
learning_rate=args.learning_rate,
|
||||
lr_scheduler_type=args.lr_scheduler_type,
|
||||
lr_scheduler_kwargs=scheduler_kwargs,
|
||||
max_steps=args.max_steps,
|
||||
warmup_steps=0 if args.lr_scheduler_type == "greedy" else 100,
|
||||
eval_strategy="steps",
|
||||
eval_steps=200,
|
||||
save_steps=500,
|
||||
logging_steps=10,
|
||||
bf16=True,
|
||||
report_to="tensorboard",
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets.get("validation", tokenized_datasets.get("test")),
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
logger.info(f"Starting training with {args.lr_scheduler_type} scheduler")
|
||||
trainer.train()
|
||||
trainer.save_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user