first commit
Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled

This commit is contained in:
陈赣
2026-06-05 16:53:03 +08:00
commit 06f1fd69a6
6047 changed files with 1895387 additions and 0 deletions

View File

View File

@@ -0,0 +1,3 @@
distributed_type: MULTI_GPU
num_machines: 1
num_processes: 2

View File

@@ -0,0 +1,4 @@
distributed_type: DEEPSPEED
deepspeed_config:
deepspeed_config_file: tests/trainer/distributed/scripts/ds_config_zero2.json
num_processes: 2

View File

@@ -0,0 +1,9 @@
distributed_type: DEEPSPEED
deepspeed_config:
deepspeed_config_file: tests/trainer/distributed/scripts/ds_config_zero2.json
num_processes: 2
parallelism_config:
parallelism_config_sp_size: 2
parallelism_config_sp_backend: deepspeed
parallelism_config_sp_seq_length_is_variable: true
parallelism_config_sp_attn_implementation: sdpa

View File

@@ -0,0 +1,4 @@
distributed_type: DEEPSPEED
deepspeed_config:
deepspeed_config_file: tests/trainer/distributed/scripts/ds_config_zero3.json
num_processes: 2

View File

@@ -0,0 +1,4 @@
distributed_type: FSDP
fsdp_config:
fsdp_version: 1
num_processes: 2

View File

@@ -0,0 +1,4 @@
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
num_processes: 2

View File

@@ -0,0 +1,10 @@
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
num_processes: 2
parallelism_config:
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 1
parallelism_config_tp_size: 1
parallelism_config_cp_size: 2
parallelism_config_cp_comm_strategy: alltoall

View File

@@ -0,0 +1,88 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Worker script for dispatch_batches=False with a finite iterable dataset.
Verifies that training completes successfully when ``dispatch_batches``
is disabled.
Run via torchrun or accelerate launch.
"""
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset
from transformers import HfArgumentParser, Trainer, TrainingArguments
class RegressionModel(nn.Module):
def __init__(self, a=0, b=0):
super().__init__()
self.a = nn.Parameter(torch.tensor(a).float())
self.b = nn.Parameter(torch.tensor(b).float())
self.config = None
def forward(self, input_x, labels=None, **kwargs):
y = input_x * self.a + self.b
if labels is None:
return (y,)
loss = nn.functional.mse_loss(y, labels)
return (loss, y)
class RegressionDataset:
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
np.random.seed(seed)
self.label_names = ["labels"] if label_names is None else label_names
self.length = length
self.x = np.random.normal(size=(length,)).astype(np.float32)
self.ys = [a * self.x + b + np.random.normal(scale=0.1, size=(length,)) for _ in self.label_names]
self.ys = [y.astype(np.float32) for y in self.ys]
def __len__(self):
return self.length
def __getitem__(self, i):
result = {name: y[i] for name, y in zip(self.label_names, self.ys)}
result["input_x"] = self.x[i]
return result
class FiniteIterableDataset(IterableDataset):
def __init__(self, a=2, b=3, length=64, seed=42, label_names=None):
self.dataset = RegressionDataset(a=a, b=b, length=length, seed=seed, label_names=label_names)
self.current_sample = 0
def __iter__(self):
while self.current_sample < len(self.dataset):
yield self.dataset[self.current_sample]
self.current_sample += 1
if __name__ == "__main__":
parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
training_args.per_device_train_batch_size = 1
training_args.max_steps = 1
training_args.accelerator_config.dispatch_batches = False
train_dataset = FiniteIterableDataset(label_names=["labels", "extra"], length=1)
model = RegressionModel()
trainer = Trainer(model, training_args, train_dataset=train_dataset)
trainer.train()

View File

@@ -0,0 +1,32 @@
{
"fp16": {
"enabled": "auto"
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto"
}

View File

@@ -0,0 +1,35 @@
{
"fp16": {
"enabled": "auto"
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto"
}

View File

@@ -0,0 +1,113 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Worker script for eval/predict ordering tests.
Verifies that distributed eval/predict returns all samples in the correct order.
Run via torchrun or accelerate launch.
"""
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import EvalPrediction, HfArgumentParser, Trainer, TrainingArguments
from transformers.utils import logging
logger = logging.get_logger(__name__)
class DummyDataset(Dataset):
def __init__(self, length: int = 101):
self.length = length
def __len__(self):
return self.length
def __getitem__(self, i) -> int:
return i
class DummyDataCollator:
def __call__(self, features):
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
# Add some (unused) params otherwise DDP will complain.
self.fc = nn.Linear(120, 80)
def forward(self, input_ids, labels=None):
if labels is not None:
return torch.tensor(0.0, device=input_ids.device), input_ids
else:
return input_ids
if __name__ == "__main__":
parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
for dataset_length in [49, 7]:
dataset = DummyDataset(dataset_length)
def compute_metrics(p: EvalPrediction) -> dict:
sequential = list(range(len(dataset)))
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
if not success and training_args.local_process_index == 0:
logger.warning(
"Predictions and/or labels do not match expected results:\n - predictions: "
f"{p.predictions.tolist()}\n - labels: {p.label_ids.tolist()}\n - expected: {sequential}"
)
return {"success": success}
trainer = Trainer(
model=DummyModel(),
args=training_args,
data_collator=DummyDataCollator(),
eval_dataset=dataset,
compute_metrics=compute_metrics,
)
metrics = trainer.evaluate()
logger.info(metrics)
if metrics["eval_success"] is not True:
logger.error(metrics)
exit(1)
p = trainer.predict(dataset)
logger.info(p.metrics)
if p.metrics["test_success"] is not True:
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = 2
metrics = trainer.evaluate()
logger.info(metrics)
if metrics["eval_success"] is not True:
logger.error(metrics)
exit(1)
p = trainer.predict(dataset)
logger.info(p.metrics)
if p.metrics["test_success"] is not True:
logger.error(p.metrics)
exit(1)
trainer.args.eval_accumulation_steps = None

View File

@@ -0,0 +1,125 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Worker script for FSDP generation tests.
Launched via ``torchrun`` from ``test_trainer_distributed_fsdp.py``.
"""
import argparse
import functools
from collections.abc import Callable
from typing import Any
import torch
import torch.distributed
from torch.distributed._composable.fsdp import fully_shard, register_fsdp_forward_method
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from transformers.testing_utils import backend_device_count, backend_torch_accelerator_module, torch_device
data = 4 * [
"Hello world!",
"The quick brown fox jumps over the lazy dog.",
]
def manage_process_group(func: Callable[..., Any]) -> Callable[..., Any]:
"""Manage the creation and destruction of the distributed process group for the wrapped function."""
def wrapped(*args: Any, **kwargs: Any) -> Any:
device_count = backend_device_count(torch_device)
torch.distributed.init_process_group(world_size=device_count)
try:
return func(*args, **kwargs)
finally:
torch.distributed.destroy_process_group()
return wrapped
@manage_process_group
def fsdp_generate():
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
fsdp_model = FullyShardedDataParallel(
model,
auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={GPT2Block}),
limit_all_gathers=True,
use_orig_params=True,
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
with FullyShardedDataParallel.summon_full_params(fsdp_model):
_ = fsdp_model.module.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=30,
)
@manage_process_group
def fsdp2_generate():
torch_accelerator_module = backend_torch_accelerator_module(torch_device)
torch_accelerator_module.set_device(device := torch.device(rank := torch.distributed.get_rank()))
model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(device)
mesh = init_device_mesh(device.type, (torch.distributed.get_world_size(),))
for submodule in model.modules():
if isinstance(submodule, GPT2Block):
fully_shard(submodule, mesh=mesh)
fully_shard(model, mesh=mesh)
register_fsdp_forward_method(model, "generate")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
batch = tokenizer(data[rank], return_tensors="pt", return_attention_mask=True).to(device)
_ = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=30,
)
if __name__ == "__main__":
class CLIArgs(argparse.Namespace):
fsdp: bool
fsdp2: bool
parser = argparse.ArgumentParser()
group = parser.add_mutually_exclusive_group()
group.add_argument("--fsdp", action="store_true")
group.add_argument("--fsdp2", action="store_true")
args = parser.parse_args(namespace=CLIArgs())
if args.fsdp:
fsdp_generate()
elif args.fsdp2:
fsdp2_generate()
else:
raise ValueError("Missing test selection")

View File

@@ -0,0 +1,114 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Worker script for loss averaging tests.
Verifies that ``average_tokens_across_devices`` produces correct loss
compared to a single-GPU baseline.
When ``--run_both_averaging_modes`` is passed, the script runs training
twice (with and without averaging) in a single process launch, saving
``<output_dir>_broken_losses.json`` and ``<output_dir>_fixed_losses.json``.
Run via torchrun or accelerate launch.
"""
import argparse
import json
import datasets
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
HfArgumentParser,
Trainer,
TrainerCallback,
TrainingArguments,
set_seed,
)
class StoreLossCallback(TrainerCallback):
"""Simple callback to store the loss."""
def __init__(self):
self.losses = []
def on_log(self, args, state, control, logs=None, **kwargs):
if "loss" in logs:
self.losses.append(logs["loss"])
def run_distributed_training(training_args, loss_file):
set_seed(42)
model_name = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:50]")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32)
loss_callback = StoreLossCallback()
training_args.logging_steps = 1
training_args.max_steps = 10
training_args.learning_rate = 3e-4
training_args.disable_tqdm = True
training_args.dataloader_drop_last = True
trainer = Trainer(
model,
training_args,
train_dataset=tokenized_dataset,
callbacks=[loss_callback],
data_collator=data_collator,
)
trainer.train()
with open(loss_file, "w") as f:
json.dump(loss_callback.losses, f)
if __name__ == "__main__":
# Parse our custom flag first, pass the rest to HfArgumentParser.
pre_parser = argparse.ArgumentParser(add_help=False)
pre_parser.add_argument("--run_both_averaging_modes", action="store_true")
custom_args, remaining = pre_parser.parse_known_args()
hf_parser = HfArgumentParser((TrainingArguments,))
(training_args,) = hf_parser.parse_args_into_dataclasses(remaining)
if custom_args.run_both_averaging_modes:
base_dir = training_args.output_dir
# Run without averaging ("broken")
training_args.average_tokens_across_devices = False
training_args.output_dir = base_dir + "/broken"
run_distributed_training(training_args, loss_file=base_dir + "/broken_losses.json")
# Run with averaging ("fixed")
training_args.average_tokens_across_devices = True
training_args.output_dir = base_dir + "/fixed"
run_distributed_training(training_args, loss_file=base_dir + "/fixed_losses.json")
else:
run_distributed_training(training_args, loss_file=training_args.output_dir + "_losses.json")

View File

@@ -0,0 +1,93 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dumps distributed environment info to a JSON file for verification.
This script creates a Trainer (which initializes the accelerator) and writes
each worker's env vars, TrainingArguments fields, and accelerator state to
``<output_dir>/env_rank<N>.json``.
Accepts all TrainingArguments flags (e.g. ``--deepspeed``, ``--fsdp``) so the
Trainer sets up the correct framework regardless of launcher.
Works with any launcher (torchrun, accelerate launch with DDP/FSDP/DeepSpeed).
"""
import json
import os
from transformers import AutoModelForCausalLM, HfArgumentParser, Trainer, TrainingArguments
def main():
parser = HfArgumentParser((TrainingArguments,))
(args,) = parser.parse_args_into_dataclasses()
args.disable_tqdm = True
model_name = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_name)
trainer = Trainer(model=model, args=args)
accelerator = trainer.accelerator
env_info = {
# Raw env vars set by torchrun / accelerate
"env_world_size": os.environ.get("WORLD_SIZE"),
"env_rank": os.environ.get("RANK"),
"env_local_rank": os.environ.get("LOCAL_RANK"),
"env_master_addr": os.environ.get("MASTER_ADDR"),
"env_master_port": os.environ.get("MASTER_PORT"),
# TrainingArguments-derived values
"args_local_rank": args.local_rank,
"args_world_size": args.world_size,
"args_process_index": args.process_index,
"args_local_process_index": args.local_process_index,
"args_parallel_mode": str(args.parallel_mode),
"args_n_gpu": args.n_gpu,
# Accelerator state
"accelerator_num_processes": accelerator.num_processes,
"accelerator_process_index": accelerator.process_index,
"accelerator_local_process_index": accelerator.local_process_index,
"accelerator_is_main_process": accelerator.is_main_process,
"accelerator_is_local_main_process": accelerator.is_local_main_process,
"accelerator_use_distributed": accelerator.use_distributed,
"accelerator_distributed_type": str(accelerator.distributed_type),
"accelerator_device": str(accelerator.device),
# Trainer-level flags (these gate framework-specific code paths)
"trainer_is_fsdp_enabled": trainer.is_fsdp_enabled,
"trainer_is_deepspeed_enabled": trainer.is_deepspeed_enabled,
}
# FSDP plugin info
fsdp_plugin = getattr(accelerator.state, "fsdp_plugin", None)
if fsdp_plugin is not None:
env_info["fsdp_version"] = getattr(fsdp_plugin, "fsdp_version", None)
env_info["fsdp_sharding_strategy"] = str(getattr(fsdp_plugin, "sharding_strategy", None))
env_info["fsdp_cpu_offload"] = str(getattr(fsdp_plugin, "cpu_offload", None))
env_info["fsdp_auto_wrap_policy"] = str(getattr(fsdp_plugin, "auto_wrap_policy", None))
# DeepSpeed plugin info
deepspeed_plugin = getattr(accelerator.state, "deepspeed_plugin", None)
if deepspeed_plugin is not None:
env_info["deepspeed_zero_stage"] = deepspeed_plugin.zero_stage
env_info["deepspeed_offload_optimizer_device"] = str(deepspeed_plugin.offload_optimizer_device)
env_info["deepspeed_offload_param_device"] = str(deepspeed_plugin.offload_param_device)
output_file = os.path.join(args.output_dir, f"env_rank{args.process_index}.json")
with open(output_file, "w") as f:
json.dump(env_info, f)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,136 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple causal LM script for distributed tests (FSDP, DeepSpeed).
Uses a tiny Qwen2 model with synthetic data so tests run fast
and don't require downloading real datasets.
Supports --do_train (default) and --do_eval via TrainingArguments.
32 training samples are created; with per_device_train_batch_size=4
and 2 GPUs this gives 4 steps per epoch.
"""
import json
import sys
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
HfArgumentParser,
Trainer,
TrainingArguments,
)
DTYPE_MAP = {"fp32": torch.float32, "bf16": torch.bfloat16, "fp16": torch.float16}
def _pop_custom_arg(name):
"""Pop a custom --name value arg from sys.argv before HfArgumentParser sees it."""
if name in sys.argv:
idx = sys.argv.index(name)
value = sys.argv[idx + 1]
sys.argv.pop(idx)
sys.argv.pop(idx)
return value
return None
def main():
# Parse custom args (not TrainingArguments fields)
model_name = _pop_custom_arg("--model_name") or "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
loss_output_file = _pop_custom_arg("--loss_output_file")
eval_output_file = _pop_custom_arg("--eval_output_file")
model_dtype = _pop_custom_arg("--model_dtype")
attn_impl = _pop_custom_arg("--attn_implementation")
pad_to_multiple_of = _pop_custom_arg("--pad_to_multiple_of")
parser = HfArgumentParser((TrainingArguments,))
(training_args,) = parser.parse_args_into_dataclasses()
# Default to training if neither --do_train nor --do_eval is set
if not training_args.do_train and not training_args.do_eval:
training_args.do_train = True
# Auto-enable eval when an eval output file is requested
if eval_output_file:
training_args.do_eval = True
torch_dtype = DTYPE_MAP[model_dtype] if model_dtype else None
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model_kwargs = {}
if torch_dtype:
model_kwargs["torch_dtype"] = torch_dtype
if attn_impl:
model_kwargs["attn_implementation"] = attn_impl
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
model.generation_config.pad_token_id = tokenizer.pad_token_id
# Synthetic dataset — 32 samples of tokenized text
# With per_device_train_batch_size=4 and 2 GPUs this gives 4 steps per epoch.
texts = [
"The quick brown fox jumps over the lazy dog. " * 5,
"A journey of a thousand miles begins with a single step. " * 5,
"To be or not to be, that is the question. " * 5,
"All that glitters is not gold, all that wanders is not lost. " * 5,
] * 8
train_dataset = None
eval_dataset = None
if training_args.do_train:
train_dataset = [tokenizer(text, max_length=128, truncation=True, padding="max_length") for text in texts]
if training_args.do_eval:
eval_dataset = [tokenizer(text, max_length=128, truncation=True, padding="max_length") for text in texts[:8]]
collator_kwargs = {}
if pad_to_multiple_of:
collator_kwargs["pad_to_multiple_of"] = int(pad_to_multiple_of)
training_args.disable_tqdm = True
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, **collator_kwargs),
)
if training_args.do_train:
trainer.train()
if training_args.do_eval:
eval_metrics = trainer.evaluate()
if eval_output_file and training_args.process_index == 0:
with open(eval_output_file, "w") as f:
json.dump(eval_metrics, f)
# Save per-step losses for equivalence testing
if training_args.do_train and loss_output_file and training_args.process_index == 0:
losses = [log["loss"] for log in trainer.state.log_history if "loss" in log]
with open(loss_output_file, "w") as f:
json.dump(losses, f)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,4 @@
{
"image_processor_type": "ViTImageProcessor",
"size": 30
}

View File

@@ -0,0 +1,87 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Worker script for dataloader worker seed divergence tests.
Verifies that dataloader workers get different random seeds across GPUs,
so that each rank sees different random augmentations.
Run via torchrun or accelerate launch.
"""
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import HfArgumentParser, Trainer, TrainingArguments, set_seed
from transformers.testing_utils import torch_device
def gather_from_all_gpus(tensor, world_size):
gather_list = [torch.zeros_like(tensor) for _ in range(world_size)]
dist.all_gather(gather_list, tensor)
return gather_list
class DummyDataset(Dataset):
def __init__(self):
self.length = 64
def __len__(self):
return self.length
def __getitem__(self, i) -> int:
x = random.random()
y = np.random.random()
z = torch.rand([]).item()
return {"x": torch.tensor([x, y, z])}
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(3, 1)
def forward(self, x):
local_tensor = torch.tensor(x, device=torch_device)
gathered = gather_from_all_gpus(local_tensor, dist.get_world_size())
assert not all(torch.allclose(t, gathered[0]) for t in gathered[1:])
y = self.fc(x)
return (y.mean(), y)
def run_distributed_training(training_args):
set_seed(42)
model = DummyModel()
dataset = DummyDataset()
training_args.max_steps = 3
# dataloader_num_workers must be > 0 to enable worker_init_fn
training_args.dataloader_num_workers = 2
trainer = Trainer(
model,
training_args,
train_dataset=dataset,
)
trainer.train()
if __name__ == "__main__":
parser = HfArgumentParser((TrainingArguments,))
training_args = parser.parse_args_into_dataclasses()[0]
run_distributed_training(training_args)

View File

@@ -0,0 +1,180 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Shared constants, helpers, and reusable test logic for distributed trainer tests.
This module provides:
- Path constants for test scripts and accelerate configs.
- ``TrainerDistributedCommon``, an abstract base class that contains reusable
test scenarios (training, mixed-precision, gradient accumulation, checkpoint
resume, evaluation). Framework-specific test files (DDP, FSDP, DeepSpeed)
subclass it and wire each scenario to parameterized test methods.
"""
import json
import os
from abc import ABC, abstractmethod
from transformers import is_torch_available
from transformers.testing_utils import execute_subprocess_async
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import get_last_checkpoint
if is_torch_available():
import torch
# ---------------------------------------------------------------------------
# Path constants
# ---------------------------------------------------------------------------
DISTRIBUTED_DIR = os.path.dirname(__file__)
CONFIGS_DIR = os.path.join(DISTRIBUTED_DIR, "accelerate_configs")
SCRIPTS_DIR = os.path.join(DISTRIBUTED_DIR, "scripts")
TRAIN_SCRIPT = os.path.join(SCRIPTS_DIR, "train.py")
class TrainerDistributedCommon(ABC):
"""Reusable test scenarios shared across DDP, FSDP, and DeepSpeed.
Subclasses must:
1. Implement ``get_accelerate_cmd`` to build the launch command.
2. Define the following test methods (parameterized as needed)::
test_training → self.check_training(dtype, ...)
test_training_mixed_precision → self.check_mixed_precision(dtype, ...)
test_training_with_gradient_accumulation → self.check_gradient_accumulation(...)
test_training_and_can_resume_normally → self.check_resume(...)
test_eval → self.check_eval(...)
These test methods can't be defined here as ``@abstractmethod`` because
``@parameterized.expand`` removes the original method name from the
subclass, which would cause ABC to raise ``TypeError`` at instantiation.
"""
@abstractmethod
def get_accelerate_cmd(self, script, config_file, launch_args=None, script_args=None, **kwargs):
"""Build the full ``accelerate launch`` command list.
Args:
script: Path to the Python script to run.
config_file: Path to the accelerate YAML config (always required).
launch_args: Extra flags inserted *before* the script
(e.g. ``--fsdp_sharding_strategy``, ``--offload_optimizer_device``).
script_args: Extra flags appended *after* the script
(e.g. ``--output_dir``, ``--bf16``).
**kwargs: Framework-specific overrides (e.g. ``num_processes``).
"""
...
# -------------------------------------------------------------------
# Helpers
# -------------------------------------------------------------------
def _get_default_script_args(self, output_dir, num_epochs=1, logging_steps=5, save_steps=None):
"""Build the baseline CLI arguments shared by all training runs."""
args = [
"--output_dir",
output_dir,
"--num_train_epochs",
str(num_epochs),
"--logging_steps",
str(logging_steps),
"--per_device_train_batch_size",
"4",
"--learning_rate",
"5e-5",
]
if save_steps is not None:
args += ["--save_steps", str(save_steps)]
else:
args += ["--save_strategy", "no"]
return args
def _train_and_get_log_history(self, cmd, output_dir):
"""Run a training command and return the log history from the last checkpoint."""
execute_subprocess_async(cmd, env=self.get_env())
checkpoint = get_last_checkpoint(output_dir)
state_file = os.path.join(checkpoint, "trainer_state.json")
return TrainerState.load_from_json(state_file).log_history
# -------------------------------------------------------------------
# Reusable test scenarios — called from subclass test methods
# -------------------------------------------------------------------
def check_training(self, dtype="bf16", **cmd_kwargs):
"""Verify that training completes with the model loaded in *dtype* (no mixed precision)."""
output_dir = self.get_auto_remove_tmp_dir()
args = self._get_default_script_args(output_dir) + ["--model_dtype", dtype]
execute_subprocess_async(
self.get_accelerate_cmd(TRAIN_SCRIPT, script_args=args, **cmd_kwargs),
env=self.get_env(),
)
def check_mixed_precision(self, dtype="bf16", **cmd_kwargs):
"""Verify mixed-precision training: model in fp32, compute in *dtype*."""
output_dir = self.get_auto_remove_tmp_dir()
args = self._get_default_script_args(output_dir) + ["--model_dtype", "fp32", f"--{dtype}"]
# fp16 requires a non-fused optimizer to avoid nan losses on small models
if dtype == "fp16":
args += ["--optim", "adamw_torch"]
execute_subprocess_async(
self.get_accelerate_cmd(TRAIN_SCRIPT, script_args=args, **cmd_kwargs),
env=self.get_env(),
)
def check_gradient_accumulation(self, **cmd_kwargs):
"""Verify that training with gradient accumulation completes without error."""
output_dir = self.get_auto_remove_tmp_dir()
args = self._get_default_script_args(output_dir) + ["--bf16", "--gradient_accumulation_steps", "2"]
execute_subprocess_async(
self.get_accelerate_cmd(TRAIN_SCRIPT, script_args=args, **cmd_kwargs),
env=self.get_env(),
)
def check_resume(self, **cmd_kwargs):
"""Verify that training can resume from a checkpoint with consistent learning rates."""
output_dir = self.get_auto_remove_tmp_dir()
args = self._get_default_script_args(output_dir, num_epochs=2, logging_steps=2, save_steps=2) + ["--bf16"]
original_logs = self._train_and_get_log_history(
self.get_accelerate_cmd(TRAIN_SCRIPT, script_args=args, **cmd_kwargs),
output_dir,
)
checkpoint = os.path.join(output_dir, "checkpoint-2")
self.assertTrue(os.path.isdir(checkpoint), f"Checkpoint dir not found: {checkpoint}")
resume_args = args + ["--resume_from_checkpoint", checkpoint]
resumed_logs = self._train_and_get_log_history(
self.get_accelerate_cmd(TRAIN_SCRIPT, script_args=resume_args, **cmd_kwargs),
output_dir,
)
for original, resumed in zip(original_logs, resumed_logs):
if "learning_rate" in original:
self.assertAlmostEqual(original["learning_rate"], resumed["learning_rate"], delta=1e-5)
def check_eval(self, **cmd_kwargs):
"""Verify that evaluation produces a finite eval loss."""
output_dir = self.get_auto_remove_tmp_dir()
eval_output = os.path.join(output_dir, "eval_metrics.json")
args = self._get_default_script_args(output_dir) + ["--do_eval", "--eval_output_file", eval_output]
execute_subprocess_async(
self.get_accelerate_cmd(TRAIN_SCRIPT, script_args=args, **cmd_kwargs),
env=self.get_env(),
)
with open(eval_output) as f:
eval_metrics = json.load(f)
self.assertIn("eval_loss", eval_metrics)
self.assertTrue(torch.isfinite(torch.tensor(eval_metrics["eval_loss"])))

View File

@@ -0,0 +1,297 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DDP-specific distributed trainer tests.
"""
import json
import os
import re
from parameterized import parameterized
from transformers.testing_utils import (
CaptureStderr,
TestCasePlus,
backend_device_count,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_accelerator,
slow,
torch_device,
)
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
from .test_trainer_distributed import CONFIGS_DIR, SCRIPTS_DIR, TrainerDistributedCommon
DDP_CONFIG_FILE = os.path.join(CONFIGS_DIR, "ddp.yaml")
dtypes = []
if is_torch_bf16_available_on_device(torch_device):
dtypes += ["bf16"]
if is_torch_fp16_available_on_device(torch_device):
dtypes += ["fp16"]
pure_dtype_params = ["fp32"] + dtypes
mixed_precision_params = list(dtypes)
def _parameterized_custom_name_func(func, param_num, param):
param_based_name = parameterized.to_safe_name("_".join(str(x) for x in param.args))
return f"{func.__name__}_{param_based_name}"
class DDPCommandsMixin:
"""Provides ``get_torchrun_cmd`` and ``get_accelerate_cmd`` for DDP."""
def get_torchrun_cmd(self, script, script_args=None, num_processes=None):
if num_processes is None:
num_processes = backend_device_count(torch_device)
port = get_torch_dist_unique_port()
cmd = [
"torchrun",
f"--nproc_per_node={num_processes}",
"--nnodes=1",
f"--master_port={port}",
script,
]
if script_args:
cmd.extend(script_args)
return cmd
def get_accelerate_cmd(
self, script, config_file, launch_args=None, script_args=None, num_processes=None, **kwargs
):
if num_processes is None:
num_processes = backend_device_count(torch_device)
port = get_torch_dist_unique_port()
cmd = [
"accelerate",
"launch",
"--config_file",
config_file,
"--num_processes",
str(num_processes),
"--main_process_port",
str(port),
]
if launch_args:
cmd.extend(launch_args)
cmd.append(script)
if script_args:
cmd.extend(script_args)
return cmd
@slow
@require_torch_multi_accelerator
class TestTrainerDistributedDDP(DDPCommandsMixin, TestCasePlus):
# -----------------------------------------------------------------------
# accelerate launch tests
# -----------------------------------------------------------------------
def test_eval_order(self):
output_dir = self.get_auto_remove_tmp_dir()
script = os.path.join(SCRIPTS_DIR, "eval_ddp.py")
cmd = self.get_accelerate_cmd(
script,
DDP_CONFIG_FILE,
script_args=["--output_dir", output_dir],
)
execute_subprocess_async(cmd, env=self.get_env())
def test_loss_averaging(self):
device_count = backend_device_count(torch_device)
min_bs = 2
output_dir = self.get_auto_remove_tmp_dir()
script = os.path.join(SCRIPTS_DIR, "loss_averaging.py")
# Launch 1: single-GPU baseline
cmd = self.get_accelerate_cmd(
script,
DDP_CONFIG_FILE,
script_args=[
"--output_dir",
f"{output_dir}/base",
"--per_device_train_batch_size",
str(min_bs * device_count),
"--average_tokens_across_devices",
"True",
],
num_processes=1,
)
execute_subprocess_async(cmd, env=self.get_env())
# Launch 2: multi-GPU with both averaging modes in one process
cmd = self.get_accelerate_cmd(
script,
DDP_CONFIG_FILE,
script_args=[
"--output_dir",
f"{output_dir}/multi",
"--per_device_train_batch_size",
str(min_bs),
"--run_both_averaging_modes",
],
num_processes=device_count,
)
execute_subprocess_async(cmd, env=self.get_env())
with open(f"{output_dir}/base_losses.json") as f:
base_loss = json.load(f)
with open(f"{output_dir}/multi/broken_losses.json") as f:
broken_loss = json.load(f)
with open(f"{output_dir}/multi/fixed_losses.json") as f:
fixed_loss = json.load(f)
broken_diff = [abs(base_loss[i] - broken_loss[i]) for i in range(len(base_loss))]
fixed_diff = [abs(base_loss[i] - fixed_loss[i]) for i in range(len(base_loss))]
sum_base = sum(base_loss)
sum_broken = sum(broken_loss)
relative_broken = abs(sum_base - sum_broken) / max(sum_base, sum_broken)
self.assertGreater(max(broken_diff), 0.5)
self.assertLess(max(fixed_diff), 0.005)
self.assertLess(relative_broken, 0.1)
def test_worker_seed(self):
output_dir = self.get_auto_remove_tmp_dir()
script = os.path.join(SCRIPTS_DIR, "worker_seed.py")
cmd = self.get_accelerate_cmd(
script,
DDP_CONFIG_FILE,
script_args=["--output_dir", output_dir],
)
execute_subprocess_async(cmd, env=self.get_env())
# -----------------------------------------------------------------------
# torchrun vs accelerate env parity
# -----------------------------------------------------------------------
def test_torchrun_accelerate_env_parity(self):
"""Verify torchrun and accelerate launch produce the same distributed environment for DDP."""
script = os.path.join(SCRIPTS_DIR, "torchrun_env_check.py")
num_processes = backend_device_count(torch_device)
torchrun_dir = self.get_auto_remove_tmp_dir()
cmd = self.get_torchrun_cmd(script, script_args=["--output_dir", torchrun_dir], num_processes=num_processes)
execute_subprocess_async(cmd, env=self.get_env())
accelerate_dir = self.get_auto_remove_tmp_dir()
cmd = self.get_accelerate_cmd(
script, DDP_CONFIG_FILE, script_args=["--output_dir", accelerate_dir], num_processes=num_processes
)
execute_subprocess_async(cmd, env=self.get_env())
for rank in range(num_processes):
with open(os.path.join(torchrun_dir, f"env_rank{rank}.json")) as f:
tr = json.load(f)
with open(os.path.join(accelerate_dir, f"env_rank{rank}.json")) as f:
ac = json.load(f)
for info in (tr, ac):
# Rank consistency: env vars, TrainingArguments, and accelerator all agree
self.assertEqual(info["env_world_size"], str(num_processes))
self.assertEqual(info["env_rank"], str(rank))
self.assertEqual(info["env_local_rank"], str(rank))
self.assertEqual(info["args_process_index"], rank)
self.assertEqual(info["args_local_process_index"], rank)
self.assertIn(info["args_local_rank"], (rank, -1)) # may be -1 before framework consumes it
self.assertEqual(info["accelerator_process_index"], rank)
self.assertEqual(info["accelerator_local_process_index"], rank)
self.assertIsNotNone(info["env_master_addr"])
self.assertIsNotNone(info["env_master_port"])
# World size and parallel mode
self.assertEqual(info["args_world_size"], num_processes)
self.assertEqual(info["args_n_gpu"], 1)
self.assertEqual(info["args_parallel_mode"], "ParallelMode.DISTRIBUTED")
self.assertEqual(info["accelerator_num_processes"], num_processes)
self.assertTrue(info["accelerator_use_distributed"])
self.assertEqual(info["accelerator_is_main_process"], rank == 0)
self.assertEqual(info["accelerator_is_local_main_process"], rank == 0)
# DDP: distributed type is MULTI_GPU
self.assertEqual(info["accelerator_distributed_type"], "DistributedType.MULTI_GPU")
# Each rank on its own device
self.assertIn(f"{torch_device}:{rank}", info["accelerator_device"])
# DDP should not activate FSDP or DeepSpeed
self.assertFalse(info["trainer_is_fsdp_enabled"])
self.assertFalse(info["trainer_is_deepspeed_enabled"])
self.assertNotIn("fsdp_version", info)
self.assertNotIn("deepspeed_zero_stage", info)
@parameterized.expand(
[
("base", "--log_level info", 1),
("low", "--log_level debug --log_level_replica debug", 2),
("high", "--log_level error --log_level_replica debug", 1),
("mixed", "--log_level error --log_level_replica error", 0),
]
)
def test_log_level_replica(self, _name, extra_args_str, expected_matches):
"""Test that log_level_replica controls logging on non-main processes."""
output_dir = self.get_auto_remove_tmp_dir()
script = os.path.join(SCRIPTS_DIR, "train.py")
script_args = [
"--output_dir",
output_dir,
"--num_train_epochs",
"1",
"--per_device_train_batch_size",
"4",
"--logging_strategy",
"no",
]
if extra_args_str:
script_args.extend(extra_args_str.split())
cmd = self.get_accelerate_cmd(script, DDP_CONFIG_FILE, script_args=script_args, num_processes=2)
log_info_string = "Running training"
with CaptureStderr() as cl:
execute_subprocess_async(cmd, env=self.get_env())
n_matches = len(re.findall(log_info_string, cl.err))
self.assertEqual(n_matches, expected_matches)
# ---------------------------------------------------------------------------
# DDP training integration tests (using train.py)
# ---------------------------------------------------------------------------
@slow
@require_torch_multi_accelerator
class TestTrainerDistributedDDPCommon(DDPCommandsMixin, TrainerDistributedCommon, TestCasePlus):
"""
Distributed DDP training tests using ``accelerate launch`` with the shared
train.py script. Mirrors the test structure used in FSDP and DeepSpeed.
"""
@parameterized.expand(pure_dtype_params, name_func=_parameterized_custom_name_func)
def test_training(self, dtype):
self.check_training(dtype, config_file=DDP_CONFIG_FILE)
@parameterized.expand(mixed_precision_params, name_func=_parameterized_custom_name_func)
def test_training_mixed_precision(self, dtype):
self.check_mixed_precision(dtype, config_file=DDP_CONFIG_FILE)
def test_training_with_gradient_accumulation(self):
self.check_gradient_accumulation(config_file=DDP_CONFIG_FILE)
def test_training_and_can_resume_normally(self):
self.check_resume(config_file=DDP_CONFIG_FILE)
def test_eval(self):
self.check_eval(config_file=DDP_CONFIG_FILE)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,668 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP-specific distributed trainer tests.
"""
import itertools
import json
import os
import unittest
from functools import partial
from pathlib import Path
from unittest.mock import patch
from parameterized import parameterized
from tests.trainer.trainer_test_utils import TrainerIntegrationCommon, get_regression_trainer # noqa
from transformers import HfArgumentParser, PreTrainedConfig, TrainingArguments, is_torch_available
from transformers.testing_utils import (
TestCasePlus,
backend_device_count,
execute_subprocess_async,
get_torch_dist_unique_port,
mockenv_context,
require_torch,
require_torch_accelerator,
require_torch_multi_accelerator,
slow,
torch_device,
)
from transformers.trainer_utils import set_seed
from transformers.utils import (
is_torch_bf16_available_on_device,
is_torch_fp16_available_on_device,
)
from .test_trainer_distributed import CONFIGS_DIR, SCRIPTS_DIR, TRAIN_SCRIPT, TrainerDistributedCommon
if is_torch_available():
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.trainer import FSDP_MODEL_NAME
# Base accelerate configs (version only — model-specific settings via launch args)
FSDP_CONFIG_FILE = os.path.join(CONFIGS_DIR, "fsdp.yaml")
FSDP2_CONFIG_FILE = os.path.join(CONFIGS_DIR, "fsdp2.yaml")
FSDP2_CP_CONFIG_FILE = os.path.join(CONFIGS_DIR, "fsdp2_cp.yaml")
FSDP_GENERATE_SCRIPT = os.path.join(SCRIPTS_DIR, "fsdp_generate.py")
FSDP_CONFIGS = {
"fsdp1": FSDP_CONFIG_FILE,
"fsdp2": FSDP2_CONFIG_FILE,
}
# Launch args shared by all training tests
TRAIN_LAUNCH_ARGS = [
"--fsdp_auto_wrap_policy",
"TRANSFORMER_BASED_WRAP",
]
dtypes = []
if is_torch_bf16_available_on_device(torch_device):
dtypes += ["bf16"]
if is_torch_fp16_available_on_device(torch_device):
dtypes += ["fp16"]
sharding_strategies = ["full_shard", "shard_grad_op"] # zero3 and zero2
fsdp_versions = ["fsdp1", "fsdp2"]
config_params = list(itertools.product(sharding_strategies, dtypes))
# Mixed precision: model loaded in fp32, training with --bf16/--fp16
mixed_precision_params = list(itertools.product(sharding_strategies, dtypes, fsdp_versions))
# Pure dtype: model loaded in target dtype, no mixed precision flags
pure_dtype_params = list(itertools.product(["fp32"] + dtypes, fsdp_versions))
resume_params = [
("FULL_STATE_DICT", "fsdp1"), # FULL_STATE_DICT only supported for fsdp1
("SHARDED_STATE_DICT", "fsdp1"),
("SHARDED_STATE_DICT", "fsdp2"),
]
set_seed(42)
if is_torch_available():
# hack to restore original logging level pre #21700
get_regression_trainer = partial(get_regression_trainer, log_level="info")
if is_torch_available():
class _BaseModel(PreTrainedModel):
base_model_prefix = "base"
config_class = PreTrainedConfig
def __init__(self, config):
super().__init__(config)
self.linear = nn.Linear(5, 5)
self.linear_2 = nn.Linear(5, 5)
self.post_init()
def forward(self, x):
return self.linear_2(self.linear(x))
@require_torch
class InitializeMissingKeysTest(unittest.TestCase):
"""Tests for FSDP non-rank-0 weight initialization: params should be moved from meta to CPU
and marked as initialized without being re-initialized."""
def _clear_init_flags(self, model):
for module in model.modules():
if hasattr(module, "_is_hf_initialized"):
delattr(module, "_is_hf_initialized")
for param in model.parameters():
if hasattr(param, "_is_hf_initialized"):
delattr(param, "_is_hf_initialized")
for buffer in model.buffers():
if hasattr(buffer, "_is_hf_initialized"):
delattr(buffer, "_is_hf_initialized")
def test_move_missing_keys_fsdp_non_rank0_moves_meta_to_cpu(self):
"""FSDP non-rank-0 path should move all params from meta to CPU."""
with torch.device("meta"):
model = _BaseModel(PreTrainedConfig())
for param in model.parameters():
self.assertEqual(param.device, torch.device("meta"))
with (
patch("transformers.modeling_utils.is_fsdp_enabled", return_value=True),
patch("transformers.modeling_utils.is_local_dist_rank_0", return_value=False),
):
model._move_missing_keys_from_meta_to_device(
missing_keys=set(), device_map=None, device_mesh=None, hf_quantizer=None
)
for name, param in model.named_parameters():
self.assertEqual(param.device, torch.device("cpu"), f"param {name} should be on CPU after FSDP move")
def test_fsdp_non_rank0_end_to_end_no_reinit(self):
"""End-to-end: move from meta + _initialize_missing_keys should mark all params initialized
without changing their values."""
with torch.device("meta"):
model = _BaseModel(PreTrainedConfig())
with (
patch("transformers.modeling_utils.is_fsdp_enabled", return_value=True),
patch("transformers.modeling_utils.is_local_dist_rank_0", return_value=False),
):
model._move_missing_keys_from_meta_to_device(
missing_keys=set(), device_map=None, device_mesh=None, hf_quantizer=None
)
pre_init_values = {name: param.clone() for name, param in model.named_parameters()}
self._clear_init_flags(model)
model._initialize_missing_keys(is_quantized=False)
for name, param in model.named_parameters():
self.assertTrue(getattr(param, "_is_hf_initialized", False), f"param {name} not marked initialized")
torch.testing.assert_close(param, pre_init_values[name], msg=f"param {name} was re-initialized")
self.assertTrue(getattr(model, "_is_hf_initialized", False))
def _parameterized_custom_name_func(func, param_num, param):
# customize the test name generator function as we want both params to appear in the sub-test
# name, as by default it shows only the first param
param_based_name = parameterized.to_safe_name("_".join(str(x) for x in param.args))
return f"{func.__name__}_{param_based_name}"
# ---------------------------------------------------------------------------
# Command mixins
# ---------------------------------------------------------------------------
class FSDPCommandsMixin:
"""Provides ``get_torchrun_cmd`` and ``get_accelerate_cmd`` for FSDP."""
def get_torchrun_cmd(self, script, script_args=None, num_processes=None):
if num_processes is None:
num_processes = backend_device_count(torch_device)
port = get_torch_dist_unique_port()
cmd = [
"torchrun",
f"--nproc_per_node={num_processes}",
"--nnodes=1",
f"--master_port={port}",
script,
]
if script_args:
cmd.extend(script_args)
return cmd
def get_accelerate_cmd(
self, script, config_file, launch_args=None, script_args=None, num_processes=None, **kwargs
):
if num_processes is None:
num_processes = backend_device_count(torch_device)
port = get_torch_dist_unique_port()
cmd = [
"accelerate",
"launch",
"--config_file",
config_file,
"--num_processes",
str(num_processes),
"--main_process_port",
str(port),
]
if launch_args:
cmd.extend(launch_args)
cmd.append(script)
if script_args:
cmd.extend(script_args)
return cmd
# ---------------------------------------------------------------------------
# Config parsing tests (no distributed training runs)
# ---------------------------------------------------------------------------
@require_torch_accelerator
class TestFSDPConfig(TestCasePlus):
def setUp(self):
super().setUp()
master_port = get_torch_dist_unique_port()
self.dist_env_1_gpu = {
"MASTER_ADDR": "localhost",
"MASTER_PORT": str(master_port),
"RANK": "0",
"LOCAL_RANK": "0",
"WORLD_SIZE": "1",
}
self.accelerate_fsdp_config = {
"fsdp_activation_checkpointing": False,
"fsdp_auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"fsdp_backward_prefetch": "BACKWARD_PRE",
"fsdp_cpu_ram_efficient_loading": True,
"fsdp_forward_prefetch": False,
"fsdp_offload_params": False,
"fsdp_reshard_after_forward": "FULL_SHARD",
"fsdp_state_dict_type": "FULL_STATE_DICT",
"fsdp_sync_module_states": True,
"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer",
"fsdp_use_orig_params": True,
"fsdp_version": 1,
}
self.fsdp_config = {
"backward_prefetch": "BACKWARD_PRE",
"forward_prefetch": "false",
"limit_all_gathers": "false",
"use_orig_params": "true",
"sync_module_states": "true",
"cpu_ram_efficient_loading": "true",
"activation_checkpointing": "false",
"min_num_params": 1,
}
@parameterized.expand(config_params, name_func=_parameterized_custom_name_func)
def test_accelerate_fsdp_config(self, sharding_strategy, dtype):
output_dir = self.get_auto_remove_tmp_dir()
# Snapshot before trainer construction — `_process_fsdp_args` strips the
# `fsdp_` prefix in place.
expected = dict(self.accelerate_fsdp_config)
kwargs = {
"output_dir": output_dir,
"train_len": 128,
"save_steps": 5,
"learning_rate": 0.1,
"fsdp": f"{sharding_strategy} offload auto_wrap",
"fsdp_config": self.accelerate_fsdp_config,
}
kwargs[dtype] = True
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs)
self.assertIs(trainer.args.fsdp, True)
self.assertTrue(trainer.args.fsdp_config.get("cpu_offload"))
for k, v in expected.items():
assert k.startswith("fsdp_")
# `transformer_layer_cls_to_wrap` is normalized from str → list during parsing.
if k == "fsdp_transformer_layer_cls_to_wrap" and isinstance(v, str):
v = [v]
self.assertEqual(trainer.args.fsdp_config[k[5:]], v)
def test_torchrun_fsdp_config(self):
"""Verify that --fsdp + --fsdp_config (torchrun-style) are parsed correctly."""
output_dir = self.get_auto_remove_tmp_dir()
fsdp_config = {"fsdp_transformer_layer_cls_to_wrap": "Qwen2DecoderLayer"}
kwargs = {
"output_dir": output_dir,
"train_len": 128,
"save_steps": 5,
"learning_rate": 0.1,
"fsdp": "full_shard auto_wrap",
"fsdp_config": fsdp_config,
"bf16": True,
}
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs)
self.assertIs(trainer.args.fsdp, True)
# fsdp_ prefix is stripped and value is normalized to a list during parsing
self.assertIn("Qwen2DecoderLayer", trainer.args.fsdp_config["transformer_layer_cls_to_wrap"])
def test_fsdp_cli_parsing(self):
"""`--fsdp` (bare) → True; legacy `--fsdp full_shard` still parses; absent → None."""
parser = HfArgumentParser(TrainingArguments)
base = ["--output_dir", "/tmp/x"]
args, _ = parser.parse_known_args([*base, "--fsdp"])
self.assertIs(args.fsdp, True)
args, _ = parser.parse_known_args([*base, "--fsdp", "full_shard"])
self.assertEqual(args.fsdp, "full_shard")
args, _ = parser.parse_known_args(base)
self.assertIsNone(args.fsdp)
# Bare `--fsdp` should resolve to a fully enabled FSDP setup through `_process_fsdp_args`.
with mockenv_context(**self.dist_env_1_gpu):
trainer_args = TrainingArguments(output_dir="/tmp/x", fsdp=True)
self.assertIs(trainer_args.fsdp, True)
self.assertIsNotNone(trainer_args.fsdp_plugin_args)
@parameterized.expand(config_params, name_func=_parameterized_custom_name_func)
def test_fsdp_config(self, sharding_strategy, dtype):
output_dir = self.get_auto_remove_tmp_dir()
kwargs = {
"output_dir": output_dir,
"train_len": 128,
"save_steps": 5,
"learning_rate": 0.1,
"fsdp": f"{sharding_strategy} offload auto_wrap",
"fsdp_config": self.fsdp_config,
}
kwargs[dtype] = True
with mockenv_context(**self.dist_env_1_gpu):
trainer = get_regression_trainer(**kwargs)
self.assertIs(trainer.args.fsdp, True)
self.assertTrue(trainer.args.fsdp_config.get("cpu_offload"))
for k, v in self.fsdp_config.items():
self.assertEqual(trainer.args.fsdp_config[k], v)
# ---------------------------------------------------------------------------
# FSDP distributed tests
# ---------------------------------------------------------------------------
@require_torch_multi_accelerator
class TestTrainerDistributedFSDP(FSDPCommandsMixin, TestCasePlus):
def _run_env_check(self, cmd, num_processes):
"""Run the env check script and return per-rank results."""
execute_subprocess_async(cmd, env=self.get_env())
# output_dir is always the last script_arg value
output_dir = cmd[cmd.index("--output_dir") + 1]
results = []
for rank in range(num_processes):
with open(os.path.join(output_dir, f"env_rank{rank}.json")) as f:
results.append(json.load(f))
return results
def test_torchrun_accelerate_fsdp1_env_parity(self):
"""Verify torchrun+--fsdp and accelerate launch produce the same FSDP1 env."""
script = os.path.join(SCRIPTS_DIR, "torchrun_env_check.py")
num_processes = backend_device_count(torch_device)
torchrun_dir = self.get_auto_remove_tmp_dir()
torchrun_results = self._run_env_check(
self.get_torchrun_cmd(
script,
script_args=[
"--output_dir",
torchrun_dir,
"--fsdp",
"full_shard",
"--fsdp_config",
'{"fsdp_version": 1}',
],
num_processes=num_processes,
),
num_processes,
)
accel_dir = self.get_auto_remove_tmp_dir()
accel_results = self._run_env_check(
self.get_accelerate_cmd(
script, FSDP_CONFIG_FILE, script_args=["--output_dir", accel_dir], num_processes=num_processes
),
num_processes,
)
self._check_parity(torchrun_results, accel_results, num_processes, expected_fsdp_version=1)
def test_torchrun_accelerate_fsdp2_env_parity(self):
"""Verify torchrun+--fsdp and accelerate launch produce the same FSDP2 env."""
script = os.path.join(SCRIPTS_DIR, "torchrun_env_check.py")
num_processes = backend_device_count(torch_device)
torchrun_dir = self.get_auto_remove_tmp_dir()
torchrun_results = self._run_env_check(
self.get_torchrun_cmd(
script,
script_args=[
"--output_dir",
torchrun_dir,
"--fsdp",
"full_shard",
"--fsdp_config",
'{"fsdp_version": 2}',
],
num_processes=num_processes,
),
num_processes,
)
accel_dir = self.get_auto_remove_tmp_dir()
accel_results = self._run_env_check(
self.get_accelerate_cmd(
script, FSDP2_CONFIG_FILE, script_args=["--output_dir", accel_dir], num_processes=num_processes
),
num_processes,
)
self._check_parity(torchrun_results, accel_results, num_processes, expected_fsdp_version=2)
def _check_parity(self, torchrun_results, accel_results, num_processes, expected_fsdp_version):
for rank in range(num_processes):
tr, ac = torchrun_results[rank], accel_results[rank]
# Both should agree on distributed env
self.assertEqual(tr["args_world_size"], ac["args_world_size"])
self.assertEqual(tr["args_process_index"], ac["args_process_index"])
self.assertEqual(tr["args_parallel_mode"], ac["args_parallel_mode"])
self.assertEqual(tr["accelerator_num_processes"], ac["accelerator_num_processes"])
self.assertEqual(tr["accelerator_use_distributed"], ac["accelerator_use_distributed"])
for info in (tr, ac):
# Rank consistency across all layers
self.assertEqual(info["env_world_size"], str(num_processes))
self.assertEqual(info["env_rank"], str(rank))
self.assertEqual(info["args_process_index"], rank)
self.assertEqual(info["args_local_process_index"], rank)
self.assertEqual(info["accelerator_process_index"], rank)
self.assertEqual(info["accelerator_local_process_index"], rank)
self.assertEqual(info["args_n_gpu"], 1)
self.assertEqual(info["accelerator_is_main_process"], rank == 0)
self.assertEqual(info["accelerator_is_local_main_process"], rank == 0)
self.assertIn(f"{torch_device}:{rank}", info["accelerator_device"])
# Both should have FSDP enabled with the correct version
self.assertEqual(info["accelerator_distributed_type"], "DistributedType.FSDP")
self.assertTrue(info["trainer_is_fsdp_enabled"])
self.assertFalse(info["trainer_is_deepspeed_enabled"])
self.assertEqual(info["fsdp_version"], expected_fsdp_version)
self.assertNotIn("deepspeed_zero_stage", info)
# ---------------------------------------------------------------------------
# All distributed FSDP training tests
# ---------------------------------------------------------------------------
@slow
@require_torch_multi_accelerator
class TestTrainerDistributedFSDPCommon(
FSDPCommandsMixin, TrainerDistributedCommon, TestCasePlus, TrainerIntegrationCommon
):
# -------------------------------------------------------------------
# FSDP training — accelerate (parameterized over fsdp version)
# -------------------------------------------------------------------
# Pure dtype training: model loaded in target dtype, no mixed precision
@parameterized.expand(pure_dtype_params, name_func=_parameterized_custom_name_func)
def test_training(self, dtype, fsdp_version):
self.check_training(dtype, config_file=FSDP_CONFIGS[fsdp_version])
# Mixed precision: model loaded in fp32, training with --bf16/--fp16
@parameterized.expand(mixed_precision_params, name_func=_parameterized_custom_name_func)
def test_training_mixed_precision(self, sharding_strategy, dtype, fsdp_version):
if fsdp_version == "fsdp2":
reshard = "true" if sharding_strategy == "full_shard" else "false"
else:
reshard = sharding_strategy.upper()
launch_args = list(TRAIN_LAUNCH_ARGS) + ["--fsdp_reshard_after_forward", reshard]
self.check_mixed_precision(dtype, config_file=FSDP_CONFIGS[fsdp_version], launch_args=launch_args)
@parameterized.expand(["true", "false"], name_func=_parameterized_custom_name_func)
def test_fsdp2_cpu_ram_efficient_loading(self, cpu_ram_efficient_loading):
launch_args = list(TRAIN_LAUNCH_ARGS) + [
"--fsdp_cpu_ram_efficient_loading",
cpu_ram_efficient_loading,
]
self.check_training("bf16", config_file=FSDP2_CONFIG_FILE, launch_args=launch_args)
@parameterized.expand(fsdp_versions, name_func=_parameterized_custom_name_func)
def test_training_with_gradient_accumulation(self, fsdp_version):
self.check_gradient_accumulation(config_file=FSDP_CONFIGS[fsdp_version])
@parameterized.expand(fsdp_versions, name_func=_parameterized_custom_name_func)
def test_basic_run_with_cpu_offload(self, fsdp_version):
output_dir = self.get_auto_remove_tmp_dir()
args = self._get_default_script_args(output_dir) + ["--bf16", "--max_steps", "10"]
launch_args = list(TRAIN_LAUNCH_ARGS) + ["--fsdp_offload_params", "true"]
execute_subprocess_async(
self.get_accelerate_cmd(
TRAIN_SCRIPT, script_args=args, config_file=FSDP_CONFIGS[fsdp_version], launch_args=launch_args
),
env=self.get_env(),
)
@parameterized.expand(resume_params, name_func=_parameterized_custom_name_func)
def test_training_and_can_resume_normally(self, state_dict_type, fsdp_version):
output_dir = self.get_auto_remove_tmp_dir()
args = self._get_default_script_args(output_dir, num_epochs=2, logging_steps=2, save_steps=2)
launch_args = list(TRAIN_LAUNCH_ARGS) + ["--fsdp_state_dict_type", state_dict_type]
cmd_kwargs = {"config_file": FSDP_CONFIGS[fsdp_version], "launch_args": launch_args}
logs = self._train_and_get_log_history(
self.get_accelerate_cmd(TRAIN_SCRIPT, script_args=args, **cmd_kwargs),
output_dir,
)
# resume from ckpt
checkpoint = os.path.join(output_dir, "checkpoint-2")
resume_args = args + ["--resume_from_checkpoint", checkpoint]
is_fsdp_ckpt = os.path.isdir(checkpoint) and (
# this checks the FSDP state dict when `SHARDED_STATE_DICT` is used
any(
FSDP_MODEL_NAME in folder_name
for folder_name in os.listdir(checkpoint)
if os.path.isdir(os.path.join(checkpoint, folder_name))
)
# this checks the FSDP state dict when `FULL_STATE_DICT` is used
or os.path.isfile(os.path.join(checkpoint, f"{FSDP_MODEL_NAME}.bin"))
)
self.assertTrue(is_fsdp_ckpt)
logs_resume = self._train_and_get_log_history(
self.get_accelerate_cmd(TRAIN_SCRIPT, script_args=resume_args, **cmd_kwargs),
output_dir,
)
for log, log1 in zip(logs, logs_resume):
if "learning_rate" in log:
self.assertAlmostEqual(log["learning_rate"], log1["learning_rate"], delta=1e-5)
# -------------------------------------------------------------------
# Context parallel tests
# -------------------------------------------------------------------
def test_cp_equivalence(self):
"""Test that CP produces the same losses as without CP."""
# CP doesn't work with Qwen2 (DTensor mixing error), so we use Llama here.
launch_args = list(TRAIN_LAUNCH_ARGS) + ["--fsdp_state_dict_type", "SHARDED_STATE_DICT"]
cp_script_args = [
"--model_name",
"hf-internal-testing/tiny-random-LlamaForCausalLM",
"--max_steps",
"10",
"--per_device_train_batch_size",
"1",
"--seed",
"42",
"--logging_steps",
"1",
"--save_strategy",
"no",
"--model_dtype",
"fp32",
"--attn_implementation",
"sdpa",
"--pad_to_multiple_of",
"4",
]
# Step 1: Run with CP enabled (cp_size=2)
cp_yes_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
cp_yes_losses_path = cp_yes_output_dir / "cp_yes_losses.json"
cmd = self.get_accelerate_cmd(
TRAIN_SCRIPT,
config_file=FSDP2_CP_CONFIG_FILE,
launch_args=launch_args,
script_args=["--output_dir", str(cp_yes_output_dir), "--loss_output_file", str(cp_yes_losses_path)]
+ cp_script_args,
)
execute_subprocess_async(cmd, env=self.get_env())
# Step 2: Run without CP (FSDP with num_processes=1, no parallelism_config)
cp_no_output_dir = Path(self.get_auto_remove_tmp_dir()).resolve()
cp_no_losses_path = cp_no_output_dir / "cp_no_losses.json"
cmd = self.get_accelerate_cmd(
TRAIN_SCRIPT,
config_file=FSDP2_CONFIG_FILE,
launch_args=launch_args,
script_args=[
"--output_dir",
str(cp_no_output_dir),
"--loss_output_file",
str(cp_no_losses_path),
]
+ cp_script_args,
num_processes=1,
)
execute_subprocess_async(cmd, env=self.get_env())
# Compare losses
with open(cp_yes_losses_path) as f:
cp_yes_losses = json.load(f)
with open(cp_no_losses_path) as f:
cp_no_losses = json.load(f)
assert len(cp_yes_losses) == len(cp_no_losses), (
f"Different number of losses: CP has {len(cp_yes_losses)}, no-CP has {len(cp_no_losses)}"
)
cp_yes_losses_tensor = torch.tensor(cp_yes_losses)
cp_no_losses_tensor = torch.tensor(cp_no_losses)
torch.testing.assert_close(
cp_yes_losses_tensor,
cp_no_losses_tensor,
rtol=2e-2,
atol=2e-2,
msg=f"CP losses {cp_yes_losses} do not match non-CP losses {cp_no_losses}",
)
# -------------------------------------------------------------------
# FSDP eval tests
# -------------------------------------------------------------------
def test_eval(self):
self.check_eval(config_file=FSDP_CONFIG_FILE)
# -------------------------------------------------------------------
# FSDP generation tests (moved from tests/generation/test_fsdp.py)
# -------------------------------------------------------------------
def test_fsdp_generate(self):
cmd = self.get_accelerate_cmd(
FSDP_GENERATE_SCRIPT,
config_file=FSDP_CONFIG_FILE,
script_args=["--fsdp"],
)
execute_subprocess_async(cmd, env=self.get_env())
def test_fsdp2_generate(self):
cmd = self.get_accelerate_cmd(
FSDP_GENERATE_SCRIPT,
config_file=FSDP2_CONFIG_FILE,
script_args=["--fsdp2"],
)
execute_subprocess_async(cmd, env=self.get_env())