From d8f24c1ad92da3ad98e664da7e2783835f96c892 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 22 Oct 2024 13:28:42 -0700 Subject: [PATCH] Update imports after QAT was moved out of prototype Summary: https://github.com/pytorch/ao/pull/1091 moved QAT out of prototype in torchao. This is a BC-breaking change so torchtune also needs to update its QAT imports. Additionally, after https://github.com/pytorch/ao/issues/987 we decided that QAT in torchao will use module swaps to insert fake quantizes, so there is no need to have a separate module swap quantizer, so this commit removes the `*ModuleSwapQuantizer` option. Test Plan: pytest -m integration_test tests/recipes/test_qat_distributed.py should work --- torchtune/training/quantization.py | 70 ++++++++++++++++++------------ torchtune/utils/_import_guard.py | 2 +- 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/torchtune/training/quantization.py b/torchtune/training/quantization.py index 465e987981..db93529dca 100644 --- a/torchtune/training/quantization.py +++ b/torchtune/training/quantization.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from typing import Callable, Optional +from warnings import warn from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API @@ -18,22 +19,29 @@ int8_dynamic_activation_int4_weight, quantize_, ) -from torchao.quantization.prototype.qat import ( - disable_4w_fake_quant, - disable_8da4w_fake_quant, - enable_4w_fake_quant, - enable_8da4w_fake_quant, - Int4WeightOnlyQATQuantizer, - Int8DynActInt4WeightQATQuantizer, -) -from torchao.quantization.prototype.qat._module_swap_api import ( - disable_4w_fake_quant_module_swap, - disable_8da4w_fake_quant_module_swap, - enable_4w_fake_quant_module_swap, - enable_8da4w_fake_quant_module_swap, - Int4WeightOnlyQATQuantizerModuleSwap, - Int8DynActInt4WeightQATQuantizerModuleSwap, -) + +try: + # torchao 0.7+ + from torchao.quantization.qat import ( + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, + ) + from torchao.quantization.qat.linear import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + ) +except ImportError: + # torchao 0.6 and before + from torchao.quantization.prototype.qat import ( + disable_4w_fake_quant, + disable_8da4w_fake_quant, + enable_4w_fake_quant, + enable_8da4w_fake_quant, + Int4WeightOnlyQATQuantizer, + Int8DynActInt4WeightQATQuantizer, + ) __all__ = [ @@ -52,9 +60,9 @@ _quantizer_mode_to_enable_fake_quant = {} -# ======================================================== -# int8 dynamic activations + int4 weight tensor subclass | -# ======================================================== +# ======================================== +# int8 dynamic activations + int4 weight | +# ======================================== class Int8DynActInt4WeightQuantizer: @@ -106,15 +114,15 @@ def quantize(self, model): _quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant -# ============= -# module swap | -# ============= +# ====================== # +# Backward compatibility # +# ====================== # -# Note: QAT tensor subclass implementation in torchao only works -# with FSDP2 today. For other distribution strategies like DDP and -# FSDP1, users will need to fall back to the old module swap flow. # int4 weight-only +Int4WeightOnlyQATQuantizerModuleSwap = Int4WeightOnlyQATQuantizer +disable_4w_fake_quant_module_swap = disable_4w_fake_quant +enable_4w_fake_quant_module_swap = enable_4w_fake_quant _quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap" _quantizer_mode_to_disable_fake_quant[ "4w-qat-module-swap" @@ -124,6 +132,9 @@ def quantize(self, model): ] = enable_4w_fake_quant_module_swap # int8 dynamic activations + int4 weight +Int8DynActInt4WeightQATQuantizerModuleSwap = Int8DynActInt4WeightQATQuantizer +disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant +enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant _quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap" _quantizer_mode_to_disable_fake_quant[ "8da4w-qat-module-swap" @@ -142,7 +153,7 @@ def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: Currently supported: - :class:`~torchao.quantization.quant_api.Int8DynActInt4WeightQuantizer`: "8da4w" (requires ``torch>=2.3.0``) - - :class:`~torchao.quantization.prototype.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``) + - :class:`~torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``) Args: quantizer (Optional[Callable]): A callable object that implements the `quantize` method. @@ -150,7 +161,12 @@ def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]: Returns: Optional[str]: The quantization mode. """ - return _quantizer_to_mode.get(type(quantizer), None) + mode = _quantizer_to_mode.get(type(quantizer), None) + if mode is not None and "module-swap" in mode: + warn( + "*QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead" + ) + return mode def _get_disable_fake_quant(quantizer_mode: str) -> Callable: diff --git a/torchtune/utils/_import_guard.py b/torchtune/utils/_import_guard.py index 93e7941fbc..06904d9cb6 100644 --- a/torchtune/utils/_import_guard.py +++ b/torchtune/utils/_import_guard.py @@ -20,7 +20,7 @@ _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API = _is_fbcode() or ( not _is_fbcode() and ( - ("dev" not in torchao_version and torchao_version >= "0.6.0") + ("dev" not in torchao_version and torchao_version >= "0.7.0") or ( "dev" in torchao_version and _nightly_version_ge(torchao_version, "2024-10-10")