Skip to content

Commit

Permalink
Update imports after QAT was moved out of prototype
Browse files Browse the repository at this point in the history
Summary: pytorch/ao#1091 moved QAT out of
prototype in torchao. This is a BC-breaking change so torchtune
also needs to update its QAT imports. Additionally, after
pytorch/ao#987 we decided that QAT in
torchao will use module swaps to insert fake quantizes, so there
is no need to have a separate module swap quantizer, so this
commit removes the `*ModuleSwapQuantizer` option.

Test Plan:
pytest -m integration_test tests/recipes/test_qat_distributed.py should work
  • Loading branch information
andrewor14 committed Oct 25, 2024
1 parent e030626 commit 1d05e1a
Showing 1 changed file with 46 additions and 28 deletions.
74 changes: 46 additions & 28 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional
from warnings import warn

from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API

Expand All @@ -18,22 +19,29 @@
int8_dynamic_activation_int4_weight,
quantize_,
)
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.prototype.qat._module_swap_api import (
disable_4w_fake_quant_module_swap,
disable_8da4w_fake_quant_module_swap,
enable_4w_fake_quant_module_swap,
enable_8da4w_fake_quant_module_swap,
Int4WeightOnlyQATQuantizerModuleSwap,
Int8DynActInt4WeightQATQuantizerModuleSwap,
)

try:
# torchao 0.7+
from torchao.quantization.qat import (
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
)
except ImportError:
# torchao 0.6 and before
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)


__all__ = [
Expand All @@ -52,9 +60,9 @@
_quantizer_mode_to_enable_fake_quant = {}


# ========================================================
# int8 dynamic activations + int4 weight tensor subclass |
# ========================================================
# ========================================
# int8 dynamic activations + int4 weight |
# ========================================


class Int8DynActInt4WeightQuantizer:
Expand Down Expand Up @@ -106,15 +114,15 @@ def quantize(self, model):
_quantizer_mode_to_enable_fake_quant["4w-qat"] = enable_4w_fake_quant


# =============
# module swap |
# =============
# ====================== #
# Backward compatibility #
# ====================== #

# Note: QAT tensor subclass implementation in torchao only works
# with FSDP2 today. For other distribution strategies like DDP and
# FSDP1, users will need to fall back to the old module swap flow.

# int4 weight-only
Int4WeightOnlyQATQuantizerModuleSwap = Int4WeightOnlyQATQuantizer
disable_4w_fake_quant_module_swap = disable_4w_fake_quant
enable_4w_fake_quant_module_swap = enable_4w_fake_quant
_quantizer_to_mode[Int4WeightOnlyQATQuantizerModuleSwap] = "4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"4w-qat-module-swap"
Expand All @@ -124,6 +132,9 @@ def quantize(self, model):
] = enable_4w_fake_quant_module_swap

# int8 dynamic activations + int4 weight
Int8DynActInt4WeightQATQuantizerModuleSwap = Int8DynActInt4WeightQATQuantizer
disable_8da4w_fake_quant_module_swap = disable_8da4w_fake_quant
enable_8da4w_fake_quant_module_swap = enable_8da4w_fake_quant
_quantizer_to_mode[Int8DynActInt4WeightQATQuantizerModuleSwap] = "8da4w-qat-module-swap"
_quantizer_mode_to_disable_fake_quant[
"8da4w-qat-module-swap"
Expand All @@ -141,16 +152,23 @@ def get_quantizer_mode(quantizer: Optional[Callable]) -> Optional[str]:
Currently supported:
- :class:`~torchao.quantization.quant_api.Int8DynActInt4WeightQuantizer`: "8da4w" (requires ``torch>=2.3.0``)
- :class:`~torchao.quantization.prototype.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat" (requires ``torch>=2.4.0``)
- :class:`~torchtune.training.quantization.Int8DynActInt4WeightQuantizer`: "8da4w"
- :class:`~torchtune.training.quantization.Int4WeightOnlyQuantizer`: "4w"
- :class:`~torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer`: "8da4w-qat"
- :class:`~torchao.quantization.qat.Int4WeightOnlyQATQuantizer`: "4w-qat"
Args:
quantizer (Optional[Callable]): A callable object that implements the `quantize` method.
Returns:
Optional[str]: The quantization mode.
"""
return _quantizer_to_mode.get(type(quantizer), None)
mode = _quantizer_to_mode.get(type(quantizer), None)
if mode is not None and "module-swap" in mode:
warn(
"*QuantizerModuleSwap is deprecated. Please use the version without 'ModuleSwap' instead"
)
return mode


def _get_disable_fake_quant(quantizer_mode: str) -> Callable:
Expand Down

0 comments on commit 1d05e1a

Please sign in to comment.