Skip to content

Commit

Permalink
Feat (examples/a2q+): new super resolution models (Xilinx#811)
Browse files Browse the repository at this point in the history
* Feat (a2q+): adding to super_res example

* Updating links to pre-trained checkpoints
  • Loading branch information
i-colbert authored Jan 26, 2024
1 parent 56056ba commit f05ef84
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 19 deletions.
2 changes: 2 additions & 0 deletions src/brevitas_examples/super_resolution/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ Note that this is a difference from many academic works that train only on the Y
| [quant_espcn_x2_w8a8_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_base-f761e4a1.pth) | x2 | int8 | (u)int8 | 30.96 |
| [quant_espcn_x2_w8a8_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth) | x2 | int8 | (u)int8 | 30.79 |
| [quant_espcn_x2_w8a8_a2q_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth) | x2 | int8 | (u)int8 | 30.56 |
| [quant_espcn_x2_w8a8_a2q_plus_16b](https://github.com/Xilinx/brevitas/releases/download/super_res_r2/quant_espcn_x2_w8a8_a2q_plus_16b-0ddf46f1.pth) | x2 | int8 | (u)int8 | 31.24 |
||
| [quant_espcn_x2_w4a4_base](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_base-80658e6d.pth) | x2 | int4 | (u)int4 | 30.30 |
| [quant_espcn_x2_w4a4_a2q_32b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth) | x2 | int4 | (u)int4 | 30.27 |
| [quant_espcn_x2_w4a4_a2q_13b](https://github.com/Xilinx/brevitas/releases/download/super_res_r1/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth) | x2 | int4 | (u)int4 | 30.24 |
| [quant_espcn_x2_w4a4_a2q_plus_13b](https://github.com/Xilinx/brevitas/releases/download/super_res_r2/quant_espcn_x2_w4a4_a2q_plus_13b-6e6d55f0.pth) | x2 | int4 | (u)int4 | 30.95 |


## Train
Expand Down
46 changes: 37 additions & 9 deletions src/brevitas_examples/super_resolution/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch import hub
import torch.nn as nn

from .common import CommonIntAccumulatorAwareZeroCenterWeightQuant
from .espcn import *

model_impl = {
Expand Down Expand Up @@ -43,18 +44,45 @@
upscale_factor=2,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=13)}
acc_bit_width=13),
'quant_espcn_x2_w4a4_a2q_plus_13b':
partial(
quant_espcn,
upscale_factor=2,
weight_bit_width=4,
act_bit_width=4,
acc_bit_width=13,
weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant),
'quant_espcn_x2_w8a8_a2q_plus_16b':
partial(
quant_espcn,
upscale_factor=2,
weight_bit_width=8,
act_bit_width=8,
acc_bit_width=16,
weight_quant=CommonIntAccumulatorAwareZeroCenterWeightQuant)}

root_url = 'https://github.com/Xilinx/brevitas/releases/download/super_res_r1'
root_url = 'https://github.com/Xilinx/brevitas/releases/download/'

model_url = {
'float_espcn_x2': f'{root_url}/float_espcn_x2-2f85a454.pth',
'quant_espcn_x2_w4a4_a2q_13b': f'{root_url}/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth',
'quant_espcn_x2_w4a4_a2q_32b': f'{root_url}/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth',
'quant_espcn_x2_w4a4_base': f'{root_url}/quant_espcn_x2_w4a4_base-80658e6d.pth',
'quant_espcn_x2_w8a8_a2q_16b': f'{root_url}/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth',
'quant_espcn_x2_w8a8_a2q_32b': f'{root_url}/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth',
'quant_espcn_x2_w8a8_base': f'{root_url}/quant_espcn_x2_w8a8_base-f761e4a1.pth'}
'float_espcn_x2':
f'{root_url}/super_res_r1/float_espcn_x2-2f85a454.pth',
'quant_espcn_x2_w4a4_a2q_13b':
f'{root_url}/super_res_r1/quant_espcn_x2_w4a4_a2q_13b-9fff234e.pth',
'quant_espcn_x2_w4a4_a2q_32b':
f'{root_url}/super_res_r1/quant_espcn_x2_w4a4_a2q_32b-8702a412.pth',
'quant_espcn_x2_w4a4_base':
f'{root_url}/super_res_r1/quant_espcn_x2_w4a4_base-80658e6d.pth',
'quant_espcn_x2_w8a8_a2q_16b':
f'{root_url}/super_res_r1/quant_espcn_x2_w8a8_a2q_16b-f9e1da66.pth',
'quant_espcn_x2_w8a8_a2q_32b':
f'{root_url}/super_res_r1/quant_espcn_x2_w8a8_a2q_32b-85470d9b.pth',
'quant_espcn_x2_w8a8_base':
f'{root_url}/super_res_r1/quant_espcn_x2_w8a8_base-f761e4a1.pth',
'quant_espcn_x2_w4a4_a2q_plus_13b':
f'{root_url}/super_res_r2/quant_espcn_x2_w4a4_a2q_plus_13b-6e6d55f0.pth',
'quant_espcn_x2_w8a8_a2q_plus_16b':
f'{root_url}/super_res_r2/quant_espcn_x2_w8a8_a2q_plus_16b-0ddf46f1.pth'}


def get_model_by_name(name: str, pretrained: bool = False) -> Union[FloatESPCN, QuantESPCN]:
Expand Down
10 changes: 8 additions & 2 deletions src/brevitas_examples/super_resolution/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import brevitas.nn as qnn
from brevitas.nn.quant_layer import WeightQuantType
from brevitas.quant import Int8AccumulatorAwareWeightQuant
from brevitas.quant import Int8AccumulatorAwareZeroCenterWeightQuant
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Uint8ActPerTensorFloat
Expand All @@ -26,9 +27,14 @@ class CommonIntWeightPerChannelQuant(Int8WeightPerTensorFloat):


class CommonIntAccumulatorAwareWeightQuant(Int8AccumulatorAwareWeightQuant):
"""A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance"""
restrict_scaling_impl = FloatRestrictValue # backwards compatibility
pre_scaling_min_val = 1e-10
scaling_min_val = 1e-10
bit_width = None


class CommonIntAccumulatorAwareZeroCenterWeightQuant(Int8AccumulatorAwareZeroCenterWeightQuant):
"""A2Q+: Improving Accumulator-Aware Weight Quantization"""
bit_width = None


class CommonIntActQuant(Int8ActPerTensorFloat):
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas_examples/super_resolution/models/espcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,15 @@ def float_espcn(upscale_factor: int, num_channels: int = 3) -> FloatESPCN:


def quant_espcn(
upcsale_factor: int,
upscale_factor: int,
num_channels: int = 3,
weight_bit_width: int = 8,
act_bit_width: int = 8,
acc_bit_width: int = 32,
weight_quant: WeightQuantType = CommonIntWeightPerChannelQuant) -> QuantESPCN:
""" """
return QuantESPCN(
upscale_factor=upcsale_factor,
upscale_factor=upscale_factor,
num_channels=num_channels,
act_bit_width=act_bit_width,
acc_bit_width=acc_bit_width,
Expand Down
55 changes: 49 additions & 6 deletions src/brevitas_examples/super_resolution/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,52 @@
from torch import Tensor
import torch.nn as nn

from brevitas.core.scaling import AccumulatorAwareParameterPreScaling
from brevitas.core.scaling import AccumulatorAwareZeroCenterParameterPreScaling
import brevitas.nn as qnn
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL

EPS = 1e-10


def _get_a2q_module(module: nn.Module):
for submod in module.modules():
if isinstance(submod, AccumulatorAwareParameterPreScaling):
return submod
return None


def _calc_a2q_acc_bit_width(
weight_max_l1_norm: Tensor, input_bit_width: Tensor, input_is_signed: bool):
"""Using the closed-form bounds on accumulator bit-width as derived in
`A2Q: Accumulator-Aware Quantization with Guaranteed Overflow Avoidance`.
This function returns the minimum accumulator bit-width that can be used
without risk of overflow."""
assert weight_max_l1_norm.numel() == 1
input_is_signed = float(input_is_signed)
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS)
alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = torch.ceil(min_bit_width)
return min_bit_width


def _calc_a2q_plus_acc_bit_width(
weight_max_l1_norm: Tensor, input_bit_width: Tensor, input_is_signed: bool):
"""Using the closed-form bounds on accumulator bit-width as derived in `A2Q+:
Improving Accumulator-Aware Weight Quantization`. This function returns the
minimum accumulator bit-width that can be used without risk of overflow,
assuming that the floating-point weights are zero-centered."""
input_is_signed = float(input_is_signed)
assert weight_max_l1_norm.numel() == 1
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS)
input_range = pow(2., input_bit_width) - 1. # 2^N - 1.
min_bit_width = torch.log2(weight_max_l1_norm * input_range + 2.)
min_bit_width = torch.ceil(min_bit_width)
return min_bit_width


def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor:
assert isinstance(module, qnn.QuantConv2d), "Error: function only support QuantConv2d."

Expand All @@ -24,12 +64,15 @@ def _calc_min_acc_bit_width(module: QuantWBIOL) -> Tensor:
quant_weight_per_channel_l1_norm = quant_weight.norm(p=1, dim=(1, 2, 3))

# using the closed-form bounds on accumulator bit-width
weight_max_l1_norm = quant_weight_per_channel_l1_norm.max()
weight_max_l1_norm = torch.clamp_min(weight_max_l1_norm, EPS)
alpha = torch.log2(weight_max_l1_norm) + input_bit_width - input_is_signed
phi = lambda x: torch.log2(1. + pow(2., -x))
min_bit_width = alpha + phi(alpha) + 1.
min_bit_width = torch.ceil(min_bit_width)
min_bit_width = _calc_a2q_acc_bit_width(
quant_weight_per_channel_l1_norm.max(),
input_bit_width=input_bit_width,
input_is_signed=input_is_signed)
if isinstance(_get_a2q_module(module), AccumulatorAwareZeroCenterParameterPreScaling):
min_bit_width = _calc_a2q_plus_acc_bit_width(
quant_weight_per_channel_l1_norm.max(),
input_bit_width=input_bit_width,
input_is_signed=input_is_signed)
return min_bit_width


Expand Down

0 comments on commit f05ef84

Please sign in to comment.