diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 465e98798..21ff5f890 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Callable, Optional +from warnings import warn from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API @@ -18,22 +19,29 @@ int8_dynamic_activation_int4_weight, quantize_, ) -from torchao.quantization.prototype.qat import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, - Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATQuantizer, -) -from torchao.quantization.prototype.qat._module_swap_api import ( - disable_4w_fake_quant_module_swap, - disable_8da4w_fake_quant_module_swap, - enable_4w_fake_quant_module_swap, - enable_8da4w_fake_quant_module_swap, - Int4WeightOnlyQATQuantizerModuleSwap, - Int8DynActInt4WeightQATQuantizerModuleSwap, -) + +try: + # torchao 0.7+ + from torchao.quantization.qat import ( + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, + ) + from torchao.quantization.qat.linear import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + ) +except ImportError: + # torchao 0.6 and before + from torchao.quantization.prototype.qat import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, + ) __all__ = [ @@ -52,9 +60,9 @@ _quantizer_mode_to_enable_fake_quant = {} -# ======================================================== -# int8 dynamic activations + int4 weight tensor subclass | -# ======================================================== +# ======================================== +# int8 dynamic activations + int4 weight | +# ======================================== class Int8DynActInt4WeightQuantizer: @@ -106,15 +114,15 @@ def quantize(self, model): _quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant -# ============= -# module swap | -# ============= +# ====================== # +# Backward compatibility # +# ====================== # -# Note: QAT tensor subclass implementation in torchao only works -# with FSDP2 today. For other distribution strategies like DDP and -# FSDP1, users will need to fall back to the old module swap flow. # int4 weight-only +Int4WeightOnlyQATQuantizerModuleSwap = Int4WeightOnlyQATQuantizer +disable_4w_fake_quant_module_swap = disable_4w_fake_quant +enable_4w_fake_quant_module_swap = enable_4w_fake_quant _quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap" _quantizer_mode_to_disable_fake_quant[ "4w-qat-module-swap" @@ -124,6 +132,9 @@ def quantize(self, model): ] = enable_4w_fake_quant_module_swap # int8 dynamic activations + int4 weight +Int8DynActInt4WeightQATQuantizerModuleSwap = Int8DynActInt4WeightQATQuantizer +disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant +enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant _quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap" _quantizer_mode_to_disable_fake_quant[ "8da4w-qat-module-swap" @@ -141,16 +152,22 @@ def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: Currently supported: - - :class:`~torchao.quantization.quant_api.Int8DynActInt4WeightQuantizer`: "8da4w" (requires ``torch>=2.3.0``) - - :class:`~torchao.quantization.prototype.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``) - + - :class:`~torchtune.training.quantization.Int8DynActInt4WeightQuantizer`: "8da4w" + - :class:`~torchtune.training.quantization.Int4WeightOnlyQuantizer`: "4w" + - :class:`~torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" + - :class:`~torchao.quantization.qat.Int4WeightOnlyQATQuantizer`: "4w-qat" Args: quantizer (Optional[Callable]): A callable object that implements the `quantize` method. Returns: Optional[str]: The quantization mode. """ - return _quantizer_to_mode.get(type(quantizer), None) + mode = _quantizer_to_mode.get(type(quantizer), None) + if mode is not None and "module-swap" in mode: + warn( + "*QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead" + ) + return mode def _get_disable_fake_quant(quantizer_mode: str) -> Callable: