Some checks failed
Self-hosted runner (nightly-past-ci-caller) / Get number (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.11 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.10 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.9 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.8 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.7 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.6 (push) Has been cancelled
Self-hosted runner (nightly-past-ci-caller) / TensorFlow 2.5 (push) Has been cancelled
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Has been cancelled
Build documentation / build (push) Has been cancelled
Build documentation / build_other_lang (push) Has been cancelled
CodeQL Security Analysis / CodeQL Analysis (push) Has been cancelled
New model PR merged notification / Notify new model (push) Has been cancelled
PR CI / pr-ci (push) Has been cancelled
Slow tests on important models (on Push - A10) / Get all modified files (push) Has been cancelled
Secret Leaks / trufflehog (push) Has been cancelled
Update Transformers metadata / build_and_package (push) Has been cancelled
Slow tests on important models (on Push - A10) / Model CI (push) Has been cancelled
Check Tiny Models / Check tiny models (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Model CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Pipeline CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Example CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / DeepSpeed CI (push) Has been cancelled
Self-hosted runner (Intel Gaudi3 scheduled CI caller) / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI - Flash Attn / Setup (push) Has been cancelled
Nvidia CI - Flash Attn / Model CI (push) Has been cancelled
Nvidia CI / Setup (push) Has been cancelled
Nvidia CI / Model CI (push) Has been cancelled
Nvidia CI / Torch pipeline CI (push) Has been cancelled
Nvidia CI / Example CI (push) Has been cancelled
Nvidia CI / Trainer/FSDP CI (push) Has been cancelled
Nvidia CI / DeepSpeed CI (push) Has been cancelled
Nvidia CI / Quantization CI (push) Has been cancelled
Nvidia CI / Kernels CI (push) Has been cancelled
Doctests / Setup (push) Has been cancelled
Doctests / Call doctest jobs (push) Has been cancelled
Doctests / Send results to webhook (push) Has been cancelled
Extras Smoke Test / Get supported Python versions (push) Has been cancelled
Extras Smoke Test / Test extras on Python ${{ matrix.python-version }} (push) Has been cancelled
Extras Smoke Test / Check Slack token availability (push) Has been cancelled
Extras Smoke Test / Notify failures to Slack (push) Has been cancelled
Self-hosted runner (AMD scheduled CI caller) / Trigger Scheduled AMD CI (push) Has been cancelled
Stale Bot / Close Stale Issues (push) Has been cancelled
643 lines
27 KiB
Python
643 lines
27 KiB
Python
# Copyright 2026 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import gc
|
|
import unittest
|
|
from contextlib import ExitStack, contextmanager
|
|
from unittest.mock import patch
|
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, MetalConfig, OPTForCausalLM
|
|
from transformers.quantizers.quantizer_metal import MetalHfQuantizer
|
|
from transformers.testing_utils import (
|
|
require_torch,
|
|
slow,
|
|
torch_device,
|
|
)
|
|
from transformers.utils import is_torch_available
|
|
|
|
|
|
if is_torch_available():
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
@contextmanager
|
|
def _patch_mps_available(available: bool = True):
|
|
"""Patch ``torch.backends.mps.is_available`` to simulate MPS presence/absence."""
|
|
with patch("torch.backends.mps.is_available", return_value=available):
|
|
yield
|
|
|
|
|
|
@contextmanager
|
|
def _patch_no_mps():
|
|
"""Convenience: simulate a machine with no MPS device."""
|
|
with _patch_mps_available(False):
|
|
yield
|
|
|
|
|
|
@contextmanager
|
|
def _patch_has_mps():
|
|
"""Convenience: simulate a machine with an MPS device."""
|
|
with ExitStack() as stack:
|
|
stack.enter_context(_patch_mps_available(True))
|
|
stack.enter_context(patch("transformers.quantizers.quantizer_metal.is_kernels_available", return_value=True))
|
|
yield
|
|
|
|
|
|
@require_torch
|
|
class MetalConfigTest(unittest.TestCase):
|
|
"""Unit tests for ``MetalConfig`` (no device / model needed)."""
|
|
|
|
def test_default_values(self):
|
|
config = MetalConfig()
|
|
self.assertEqual(config.bits, 4)
|
|
self.assertEqual(config.group_size, 64)
|
|
self.assertIsNone(config.modules_to_not_convert)
|
|
self.assertFalse(config.dequantize)
|
|
self.assertEqual(config.quant_method, "metal")
|
|
|
|
def test_custom_values(self):
|
|
config = MetalConfig(bits=8, group_size=32, modules_to_not_convert=["lm_head"], dequantize=True)
|
|
self.assertEqual(config.bits, 8)
|
|
self.assertEqual(config.group_size, 32)
|
|
self.assertEqual(config.modules_to_not_convert, ["lm_head"])
|
|
self.assertTrue(config.dequantize)
|
|
|
|
def test_invalid_bits_raises(self):
|
|
for bad_bits in (0, 1, 3, 5, 6, 7, 16):
|
|
with self.assertRaises(ValueError, msg=f"bits={bad_bits} should raise"):
|
|
MetalConfig(bits=bad_bits)
|
|
|
|
def test_valid_bits(self):
|
|
for bits in (2, 4, 8):
|
|
config = MetalConfig(bits=bits)
|
|
self.assertEqual(config.bits, bits)
|
|
|
|
def test_invalid_group_size_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
MetalConfig(group_size=0)
|
|
with self.assertRaises(ValueError):
|
|
MetalConfig(group_size=-1)
|
|
|
|
def test_to_dict(self):
|
|
config = MetalConfig(bits=4, group_size=64, modules_to_not_convert=["lm_head"])
|
|
d = config.to_dict()
|
|
self.assertEqual(d["quant_method"], "metal")
|
|
self.assertEqual(d["bits"], 4)
|
|
self.assertEqual(d["group_size"], 64)
|
|
self.assertEqual(d["modules_to_not_convert"], ["lm_head"])
|
|
|
|
def test_from_dict(self):
|
|
d = {"quant_method": "metal", "bits": 8, "group_size": 32, "modules_to_not_convert": None}
|
|
config = MetalConfig.from_dict(d)
|
|
self.assertEqual(config.bits, 8)
|
|
self.assertEqual(config.group_size, 32)
|
|
|
|
def test_to_dict_from_dict(self):
|
|
original = MetalConfig(bits=2, group_size=128, modules_to_not_convert=["lm_head"])
|
|
d = original.to_dict()
|
|
restored = MetalConfig.from_dict(d)
|
|
self.assertEqual(original.bits, restored.bits)
|
|
self.assertEqual(original.group_size, restored.group_size)
|
|
self.assertEqual(original.modules_to_not_convert, restored.modules_to_not_convert)
|
|
|
|
def test_get_loading_attributes(self):
|
|
config = MetalConfig(dequantize=True)
|
|
attrs = config.get_loading_attributes()
|
|
self.assertIn("dequantize", attrs)
|
|
self.assertTrue(attrs["dequantize"])
|
|
|
|
|
|
@require_torch
|
|
class MetalQuantizerEnvironmentTest(unittest.TestCase):
|
|
"""Validate ``MetalHfQuantizer.validate_environment`` under various conditions."""
|
|
|
|
def test_no_mps_prequantized_triggers_dequantize(self):
|
|
"""Pre-quantized model on non-MPS machine should auto-enable dequantize."""
|
|
with _patch_no_mps():
|
|
config = MetalConfig()
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = True
|
|
quantizer.validate_environment()
|
|
self.assertTrue(quantizer.quantization_config.dequantize)
|
|
|
|
def test_no_mps_not_prequantized_raises(self):
|
|
"""Quantize-on-the-fly on non-MPS machine should raise."""
|
|
with _patch_no_mps():
|
|
config = MetalConfig()
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = False
|
|
with self.assertRaises(RuntimeError):
|
|
quantizer.validate_environment()
|
|
|
|
def test_dequantize_flag_skips_mps_check(self):
|
|
"""When dequantize=True, no MPS check should occur."""
|
|
with _patch_no_mps():
|
|
config = MetalConfig(dequantize=True)
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = True
|
|
quantizer.validate_environment()
|
|
|
|
def test_missing_kernels_raises(self):
|
|
"""Missing ``kernels`` package should raise ImportError."""
|
|
with ExitStack() as stack:
|
|
stack.enter_context(_patch_mps_available(True))
|
|
stack.enter_context(
|
|
patch("transformers.quantizers.quantizer_metal.is_kernels_available", return_value=False)
|
|
)
|
|
config = MetalConfig()
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = False
|
|
with self.assertRaises(ImportError):
|
|
quantizer.validate_environment()
|
|
|
|
def test_cpu_in_device_map_not_prequantized_raises(self):
|
|
"""Quantize-on-the-fly with CPU in device_map should raise."""
|
|
with _patch_has_mps():
|
|
config = MetalConfig()
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = False
|
|
with self.assertRaises(ValueError):
|
|
quantizer.validate_environment(device_map={"model": "cpu"})
|
|
|
|
def test_disk_in_device_map_not_prequantized_raises(self):
|
|
"""Quantize-on-the-fly with disk in device_map should raise."""
|
|
with _patch_has_mps():
|
|
config = MetalConfig()
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = False
|
|
with self.assertRaises(ValueError):
|
|
quantizer.validate_environment(device_map={"model": "disk"})
|
|
|
|
def test_update_device_map_defaults_to_mps(self):
|
|
config = MetalConfig()
|
|
quantizer = MetalHfQuantizer(config)
|
|
result = quantizer.update_device_map(None)
|
|
self.assertEqual(result, {"": "mps"})
|
|
|
|
def test_is_serializable(self):
|
|
config = MetalConfig()
|
|
quantizer = MetalHfQuantizer(config)
|
|
self.assertTrue(quantizer.is_serializable())
|
|
|
|
def test_is_not_trainable(self):
|
|
config = MetalConfig()
|
|
quantizer = MetalHfQuantizer(config)
|
|
self.assertFalse(quantizer.is_trainable)
|
|
|
|
|
|
@require_torch
|
|
class AffineQuantizeDequantizeTest(unittest.TestCase):
|
|
"""Test the low-level ``_affine_quantize_tensor`` / ``_affine_dequantize_tensor`` functions."""
|
|
|
|
def _roundtrip(self, bits, group_size, N=64, K=256, dtype=torch.float32):
|
|
from transformers.integrations.metal_quantization import _affine_dequantize_tensor, _affine_quantize_tensor
|
|
|
|
weight = torch.randn(N, K, dtype=dtype)
|
|
w_packed, scales, biases = _affine_quantize_tensor(weight, group_size, bits)
|
|
|
|
self.assertEqual(w_packed.dtype, torch.uint32)
|
|
self.assertEqual(w_packed.shape, (N, K // (32 // bits)))
|
|
self.assertEqual(scales.shape, (N, K // group_size))
|
|
self.assertEqual(biases.shape, (N, K // group_size))
|
|
|
|
w_deq = _affine_dequantize_tensor(w_packed, scales, biases, group_size, bits)
|
|
self.assertEqual(w_deq.shape, (N, K))
|
|
|
|
return weight.float(), w_deq.float()
|
|
|
|
def test_roundtrip_4bit_gs64(self):
|
|
orig, deq = self._roundtrip(bits=4, group_size=64)
|
|
max_err = (orig - deq).abs().max().item()
|
|
self.assertLess(max_err, 0.30, "4-bit gs=64 round-trip error too large")
|
|
|
|
def test_roundtrip_4bit_gs128(self):
|
|
orig, deq = self._roundtrip(bits=4, group_size=128)
|
|
max_err = (orig - deq).abs().max().item()
|
|
self.assertLess(max_err, 0.5, "4-bit gs=128 round-trip error too large")
|
|
|
|
def test_roundtrip_8bit_gs64(self):
|
|
orig, deq = self._roundtrip(bits=8, group_size=64)
|
|
max_err = (orig - deq).abs().max().item()
|
|
self.assertLess(max_err, 0.02, "8-bit gs=64 round-trip error too large")
|
|
|
|
def test_roundtrip_2bit_gs64(self):
|
|
orig, deq = self._roundtrip(bits=2, group_size=64)
|
|
max_err = (orig - deq).abs().max().item()
|
|
self.assertLess(max_err, 1.50, "2-bit gs=64 round-trip error too large")
|
|
|
|
def test_quantize_shapes_2bit(self):
|
|
from transformers.integrations.metal_quantization import _affine_quantize_tensor
|
|
|
|
N, K = 32, 128
|
|
weight = torch.randn(N, K)
|
|
w_packed, scales, biases = _affine_quantize_tensor(weight, group_size=64, bits=2)
|
|
elems_per_int = 32 // 2
|
|
self.assertEqual(w_packed.shape, (N, K // elems_per_int))
|
|
self.assertEqual(scales.shape, (N, K // 64))
|
|
|
|
def test_quantize_preserves_device(self):
|
|
from transformers.integrations.metal_quantization import _affine_quantize_tensor
|
|
|
|
weight = torch.randn(32, 128, device="cpu")
|
|
w_packed, scales, biases = _affine_quantize_tensor(weight, group_size=64, bits=4)
|
|
self.assertEqual(w_packed.device.type, "cpu")
|
|
self.assertEqual(scales.device.type, "cpu")
|
|
self.assertEqual(biases.device.type, "cpu")
|
|
|
|
def test_dequantize_returns_correct_dtype(self):
|
|
"""Regression: dequantize should always return float32 (caller casts to target dtype)."""
|
|
from transformers.integrations.metal_quantization import _affine_dequantize_tensor, _affine_quantize_tensor
|
|
|
|
weight = torch.randn(32, 128, dtype=torch.bfloat16)
|
|
w_packed, scales, biases = _affine_quantize_tensor(weight, group_size=64, bits=4)
|
|
w_deq = _affine_dequantize_tensor(w_packed, scales, biases, group_size=64, bits=4)
|
|
self.assertEqual(w_deq.dtype, torch.float32)
|
|
|
|
|
|
@require_torch
|
|
class MetalLinearTest(unittest.TestCase):
|
|
"""Test the ``MetalLinear`` nn.Module directly (CPU, no kernel calls)."""
|
|
|
|
def test_prequantized_weight_shape(self):
|
|
"""Pre-quantized mode: weight should be uint32 with packed K dimension."""
|
|
from transformers.integrations.metal_quantization import MetalLinear
|
|
|
|
layer = MetalLinear(in_features=256, out_features=128, bits=4, group_size=64)
|
|
elems_per_int = 32 // 4
|
|
self.assertEqual(layer.weight.shape, (128, 256 // elems_per_int))
|
|
self.assertEqual(layer.weight.dtype, torch.uint32)
|
|
self.assertEqual(layer.scales.shape, (128, 256 // 64))
|
|
self.assertEqual(layer.qbiases.shape, (128, 256 // 64))
|
|
|
|
def test_quantize_on_the_fly_weight_shape(self):
|
|
"""Quantize-on-the-fly mode (dtype=None): weight should be full-shape float."""
|
|
from transformers.integrations.metal_quantization import MetalLinear
|
|
|
|
layer = MetalLinear(in_features=256, out_features=128, bits=4, group_size=64, dtype=None)
|
|
self.assertEqual(layer.weight.shape, (128, 256))
|
|
self.assertNotEqual(layer.weight.dtype, torch.uint32)
|
|
|
|
def test_no_bias_by_default(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear
|
|
|
|
layer = MetalLinear(in_features=128, out_features=64, bits=4, group_size=64)
|
|
self.assertIsNone(layer.bias)
|
|
|
|
def test_with_bias(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear
|
|
|
|
layer = MetalLinear(in_features=128, out_features=64, bias=True, bits=4, group_size=64)
|
|
self.assertIsNotNone(layer.bias)
|
|
self.assertEqual(layer.bias.shape, (64,))
|
|
|
|
def test_forward_fallback_when_not_uint32(self):
|
|
"""When weight is not uint32, forward should use standard nn.functional.linear (no kernel needed)."""
|
|
from transformers.integrations.metal_quantization import MetalLinear
|
|
|
|
layer = MetalLinear(in_features=128, out_features=64, bits=4, group_size=64, dtype=None)
|
|
layer.weight = nn.Parameter(torch.randn(64, 128))
|
|
x = torch.randn(2, 5, 128)
|
|
out = layer(x)
|
|
self.assertEqual(out.shape, (2, 5, 64))
|
|
|
|
def test_forward_fallback_with_bias(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear
|
|
|
|
layer = MetalLinear(in_features=128, out_features=64, bias=True, bits=4, group_size=64, dtype=None)
|
|
layer.weight = nn.Parameter(torch.randn(64, 128))
|
|
layer.bias = nn.Parameter(torch.randn(64))
|
|
x = torch.randn(1, 10, 128)
|
|
out = layer(x)
|
|
self.assertEqual(out.shape, (1, 10, 64))
|
|
|
|
def test_prequantized_shapes_8bit(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear
|
|
|
|
layer = MetalLinear(in_features=256, out_features=128, bits=8, group_size=64)
|
|
elems_per_int = 32 // 8 # 4
|
|
self.assertEqual(layer.weight.shape, (128, 256 // elems_per_int))
|
|
|
|
def test_prequantized_shapes_2bit(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear
|
|
|
|
layer = MetalLinear(in_features=256, out_features=128, bits=2, group_size=64)
|
|
elems_per_int = 32 // 2 # 16
|
|
self.assertEqual(layer.weight.shape, (128, 256 // elems_per_int))
|
|
|
|
|
|
@require_torch
|
|
class ReplaceWithMetalLinearTest(unittest.TestCase):
|
|
"""Test module replacement logic."""
|
|
|
|
def _make_small_model(self):
|
|
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM")
|
|
with torch.device("meta"):
|
|
model = OPTForCausalLM(config)
|
|
return model
|
|
|
|
def test_all_linears_replaced(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear, replace_with_metal_linear
|
|
|
|
model = self._make_small_model()
|
|
nb_linears = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
|
self.assertGreater(nb_linears, 0)
|
|
|
|
config = MetalConfig(bits=4, group_size=64)
|
|
replace_with_metal_linear(model, quantization_config=config, pre_quantized=True)
|
|
|
|
nb_metal = sum(1 for m in model.modules() if isinstance(m, MetalLinear))
|
|
self.assertEqual(nb_linears, nb_metal)
|
|
|
|
def test_modules_to_not_convert(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear, replace_with_metal_linear
|
|
|
|
model = self._make_small_model()
|
|
config = MetalConfig(bits=4, group_size=64)
|
|
replace_with_metal_linear(
|
|
model, modules_to_not_convert=["lm_head"], quantization_config=config, pre_quantized=True
|
|
)
|
|
self.assertNotIsInstance(model.lm_head, MetalLinear)
|
|
|
|
nb_metal = sum(1 for m in model.modules() if isinstance(m, MetalLinear))
|
|
nb_linears = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
|
self.assertEqual(nb_metal, nb_linears - 1)
|
|
|
|
def test_dequantize_skips_replacement(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear, replace_with_metal_linear
|
|
|
|
model = self._make_small_model()
|
|
config = MetalConfig(bits=4, group_size=64, dequantize=True)
|
|
replace_with_metal_linear(model, quantization_config=config, pre_quantized=True)
|
|
|
|
nb_metal = sum(1 for m in model.modules() if isinstance(m, MetalLinear))
|
|
self.assertEqual(nb_metal, 0, "No modules should be replaced when dequantize=True")
|
|
|
|
def test_prequantized_dtype_is_uint32(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear, replace_with_metal_linear
|
|
|
|
model = self._make_small_model()
|
|
config = MetalConfig(bits=4, group_size=64)
|
|
replace_with_metal_linear(model, quantization_config=config, pre_quantized=True)
|
|
|
|
for m in model.modules():
|
|
if isinstance(m, MetalLinear):
|
|
self.assertEqual(m.weight.dtype, torch.uint32)
|
|
break
|
|
|
|
def test_quantize_on_the_fly_dtype_is_not_uint32(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear, replace_with_metal_linear
|
|
|
|
model = self._make_small_model()
|
|
config = MetalConfig(bits=4, group_size=64)
|
|
replace_with_metal_linear(model, quantization_config=config, pre_quantized=False)
|
|
|
|
for m in model.modules():
|
|
if isinstance(m, MetalLinear):
|
|
self.assertNotEqual(m.weight.dtype, torch.uint32)
|
|
break
|
|
|
|
|
|
@require_torch
|
|
class MetalConversionOpsTest(unittest.TestCase):
|
|
"""Test the ``MetalQuantize`` and ``MetalDequantize`` weight conversion operations."""
|
|
|
|
def _make_quantizer(self, bits=4, group_size=64):
|
|
config = MetalConfig(bits=bits, group_size=group_size)
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = False
|
|
return quantizer
|
|
|
|
def test_metal_quantize_produces_correct_keys(self):
|
|
from transformers.integrations.metal_quantization import MetalQuantize
|
|
|
|
quantizer = self._make_quantizer()
|
|
op = MetalQuantize(quantizer)
|
|
weight = torch.randn(64, 256)
|
|
result = op.convert({"model.layer.weight": weight})
|
|
self.assertIn("model.layer.weight", result)
|
|
self.assertIn("model.layer.scales", result)
|
|
self.assertIn("model.layer.qbiases", result)
|
|
self.assertEqual(result["model.layer.weight"].dtype, torch.uint32)
|
|
|
|
def test_metal_quantize_preserves_original_dtype(self):
|
|
from transformers.integrations.metal_quantization import MetalQuantize
|
|
|
|
quantizer = self._make_quantizer()
|
|
op = MetalQuantize(quantizer)
|
|
for dtype in (torch.float32, torch.float16, torch.bfloat16):
|
|
weight = torch.randn(64, 256, dtype=dtype)
|
|
result = op.convert({"layer.weight": weight})
|
|
self.assertEqual(result["layer.scales"].dtype, dtype, f"scales dtype mismatch for input {dtype}")
|
|
self.assertEqual(result["layer.qbiases"].dtype, dtype, f"qbiases dtype mismatch for input {dtype}")
|
|
|
|
def test_metal_dequantize_returns_target_dtype(self):
|
|
"""MetalDequantize should return a tensor in the same dtype as the scales."""
|
|
from transformers.integrations.metal_quantization import MetalDequantize, MetalQuantize
|
|
|
|
quantizer = self._make_quantizer()
|
|
|
|
for dtype in (torch.float16, torch.bfloat16):
|
|
weight = torch.randn(64, 256, dtype=dtype)
|
|
q_op = MetalQuantize(quantizer)
|
|
q_result = q_op.convert({"layer.weight": weight})
|
|
|
|
dq_quantizer = self._make_quantizer()
|
|
dq_quantizer.pre_quantized = True
|
|
dq_quantizer.quantization_config.dequantize = True
|
|
dq_op = MetalDequantize(dq_quantizer)
|
|
|
|
dq_result = dq_op.convert(
|
|
{
|
|
"weight$": [q_result["layer.weight"]],
|
|
"scales": [q_result["layer.scales"]],
|
|
"qbiases": [q_result["layer.qbiases"]],
|
|
},
|
|
full_layer_name="layer.weight",
|
|
)
|
|
self.assertEqual(
|
|
dq_result["layer.weight"].dtype, dtype, f"dequantized dtype should match scales ({dtype})"
|
|
)
|
|
|
|
def test_quantize_then_dequantize_roundtrip(self):
|
|
from transformers.integrations.metal_quantization import MetalDequantize, MetalQuantize
|
|
|
|
quantizer = self._make_quantizer(bits=4, group_size=64)
|
|
q_op = MetalQuantize(quantizer)
|
|
weight = torch.randn(64, 256)
|
|
q_result = q_op.convert({"layer.weight": weight})
|
|
|
|
dq_quantizer = self._make_quantizer(bits=4, group_size=64)
|
|
dq_op = MetalDequantize(dq_quantizer)
|
|
dq_result = dq_op.convert(
|
|
{
|
|
"weight$": [q_result["layer.weight"]],
|
|
"scales": [q_result["layer.scales"]],
|
|
"qbiases": [q_result["layer.qbiases"]],
|
|
},
|
|
full_layer_name="layer.weight",
|
|
)
|
|
w_deq = dq_result["layer.weight"].float()
|
|
max_err = (weight - w_deq).abs().max().item()
|
|
self.assertLess(max_err, 0.5, "Quantize -> Dequantize round-trip error too large")
|
|
|
|
|
|
@require_torch
|
|
class MetalWeightConversionsTest(unittest.TestCase):
|
|
def test_get_weight_conversions_empty_when_not_dequantize(self):
|
|
config = MetalConfig()
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = True
|
|
self.assertEqual(quantizer.get_weight_conversions(), [])
|
|
|
|
def test_get_weight_conversions_has_entry_when_dequantize(self):
|
|
config = MetalConfig(dequantize=True)
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = True
|
|
conversions = quantizer.get_weight_conversions()
|
|
self.assertEqual(len(conversions), 1)
|
|
|
|
def test_get_weight_conversions_empty_when_not_prequantized(self):
|
|
config = MetalConfig(dequantize=True)
|
|
quantizer = MetalHfQuantizer(config)
|
|
quantizer.pre_quantized = False
|
|
self.assertEqual(quantizer.get_weight_conversions(), [])
|
|
|
|
|
|
@require_torch
|
|
class MetalModelConversionTest(unittest.TestCase):
|
|
"""Test that a model is correctly converted on the meta device."""
|
|
|
|
def setUp(self):
|
|
gc.collect()
|
|
|
|
def tearDown(self):
|
|
gc.collect()
|
|
|
|
def test_quantized_model_conversion(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear, replace_with_metal_linear
|
|
|
|
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
|
config = AutoConfig.from_pretrained(model_id)
|
|
quantization_config = MetalConfig(bits=4, group_size=64)
|
|
|
|
with torch.device("meta"):
|
|
model = OPTForCausalLM(config)
|
|
|
|
nb_linears = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
|
model = replace_with_metal_linear(model, quantization_config=quantization_config, pre_quantized=True)
|
|
nb_metal = sum(1 for m in model.modules() if isinstance(m, MetalLinear))
|
|
self.assertEqual(nb_linears, nb_metal)
|
|
|
|
def test_quantized_model_conversion_with_exclusion(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear, replace_with_metal_linear
|
|
|
|
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
|
config = AutoConfig.from_pretrained(model_id)
|
|
quantization_config = MetalConfig(bits=4, group_size=64)
|
|
|
|
with torch.device("meta"):
|
|
model = OPTForCausalLM(config)
|
|
|
|
nb_linears = sum(1 for m in model.modules() if isinstance(m, nn.Linear))
|
|
model = replace_with_metal_linear(
|
|
model, modules_to_not_convert=["out_proj"], quantization_config=quantization_config, pre_quantized=True
|
|
)
|
|
nb_metal = sum(1 for m in model.modules() if isinstance(m, MetalLinear))
|
|
nb_excluded = sum(1 for name, m in model.named_modules() if "out_proj" in name and isinstance(m, nn.Linear))
|
|
self.assertEqual(nb_metal + nb_excluded, nb_linears)
|
|
|
|
def test_param_needs_quantization(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear, replace_with_metal_linear
|
|
|
|
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
|
config = AutoConfig.from_pretrained(model_id)
|
|
quantization_config = MetalConfig(bits=4, group_size=64)
|
|
|
|
with torch.device("meta"):
|
|
model = OPTForCausalLM(config)
|
|
|
|
replace_with_metal_linear(model, quantization_config=quantization_config, pre_quantized=False)
|
|
|
|
quantizer = MetalHfQuantizer(quantization_config)
|
|
quantizer.pre_quantized = False
|
|
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, MetalLinear):
|
|
self.assertTrue(quantizer.param_needs_quantization(model, f"{name}.weight"))
|
|
self.assertFalse(quantizer.param_needs_quantization(model, f"{name}.scales"))
|
|
self.assertFalse(quantizer.param_needs_quantization(model, f"{name}.qbiases"))
|
|
|
|
def test_param_needs_quantization_prequantized_is_false(self):
|
|
from transformers.integrations.metal_quantization import MetalLinear, replace_with_metal_linear
|
|
|
|
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
|
config = AutoConfig.from_pretrained(model_id)
|
|
quantization_config = MetalConfig(bits=4, group_size=64)
|
|
|
|
with torch.device("meta"):
|
|
model = OPTForCausalLM(config)
|
|
|
|
replace_with_metal_linear(model, quantization_config=quantization_config, pre_quantized=True)
|
|
|
|
quantizer = MetalHfQuantizer(quantization_config)
|
|
quantizer.pre_quantized = True
|
|
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, MetalLinear):
|
|
self.assertFalse(
|
|
quantizer.param_needs_quantization(model, f"{name}.weight"),
|
|
"Pre-quantized weights should not be re-quantized",
|
|
)
|
|
|
|
|
|
@slow
|
|
@require_torch
|
|
class MetalSlowIntegrationTest(unittest.TestCase):
|
|
"""Slow tests that actually load a model with Metal quantization.
|
|
|
|
These run on CPU with ``dequantize=True`` so they don't require MPS.
|
|
"""
|
|
|
|
model_id = "medmekk/Llama-3.2-1B-Instruct-metal"
|
|
|
|
def setUp(self):
|
|
gc.collect()
|
|
|
|
def tearDown(self):
|
|
gc.collect()
|
|
|
|
def test_load_prequantized_dequantize_on_cpu(self):
|
|
"""Load a quantized checkpoint with dequantize=True on CPU and run a forward pass."""
|
|
with _patch_no_mps():
|
|
config = MetalConfig(dequantize=True)
|
|
model = AutoModelForCausalLM.from_pretrained(self.model_id, quantization_config=config, device_map="cpu")
|
|
self.assertIsNotNone(model)
|
|
for param in model.parameters():
|
|
self.assertNotEqual(param.dtype, torch.uint32, "All weights should be dequantized")
|
|
|
|
def test_quantized_model(self):
|
|
with _patch_no_mps():
|
|
config = MetalConfig(bits=4, group_size=64)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
self.model_id, quantization_config=config, device_map=torch_device
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
self.assertIsNotNone(model)
|
|
input = "Hello, how are you?"
|
|
EXPECTED_OUTPUT = "Hello, how are you? I'm doing well, thanks for asking. I"
|
|
input_ids = tokenizer.encode(input, return_tensors="pt").to(torch_device)
|
|
output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
|
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|