Skip to content

Commit

Permalink
add random projection to gptq
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jun 5, 2024
1 parent 9e4388c commit 4ca9f18
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 19 deletions.
24 changes: 13 additions & 11 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,20 @@ def __init__(
use_random_sampling,
target_dim) -> None:

super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)
super().__init__(
layer,
name,
act_order,
len_parallel_layers,
create_weight_orig,
use_random_proj,
use_random_sampling,
target_dim)

self.float_input = None
self.quantized_input = None
self.index_computed = False
self.p = p
self.save_dir = None
self.use_random_proj = use_random_proj
self.use_random_sampling = use_random_sampling
self.target_dim = target_dim

def collect_float_input(self, module, args, output):
# this is the hook function to collect the float inputs of this layer
Expand Down Expand Up @@ -464,7 +468,10 @@ def __init__(
create_weight_orig,
use_quant_activations,
act_order,
return_forward_output)
return_forward_output,
use_random_proj,
use_random_sampling,
target_dim)

self.p = p

Expand All @@ -476,11 +483,6 @@ def __init__(
# speeding up by collecting float input first so we don't need to do it later
self.collect_float_first = collect_float_first

# random proj/sample and target_dim
self.use_random_proj = use_random_proj
self.use_random_sampling = use_random_sampling
self.target_dim = target_dim

def __enter__(self):
# initialize gpxq layers
self.setup_gpxq_layers()
Expand Down
72 changes: 66 additions & 6 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,25 @@ class GPTQ(GPxQ):
"""

def __init__(
self, layer, name, act_order, len_parallel_layers, create_weight_orig,
num_blocks) -> None:
super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)
self,
layer,
name,
act_order,
len_parallel_layers,
create_weight_orig,
num_blocks,
use_random_proj,
use_random_sampling,
target_dim) -> None:
super().__init__(
layer,
name,
act_order,
len_parallel_layers,
create_weight_orig,
use_random_proj,
use_random_sampling,
target_dim)

# Define how many columns to update in each mini-block
self.blocksize = math.ceil(self.columns / num_blocks)
Expand Down Expand Up @@ -106,6 +122,41 @@ def update_batch(self, module, input, current_layer):
inp_processed.append(inp)
inp_processed = torch.stack(inp_processed)

dev = inp_processed.device
n = inp_processed.shape[-1]
if self.use_random_proj:
# use batching if target_dim is greater than 8000 to avoid memory issues
if self.target_dim > 8000:
batch_size_proj = 4096
accumulated_batches = 0
first_batch = True
while accumulated_batches < self.target_dim:
# cur_target_dim makes sure to fully use batch_size unless we're too close to target_dim
cur_target_dim = min(batch_size_proj, self.target_dim - accumulated_batches)
accumulated_batches += cur_target_dim
R = torch.normal(
mean=0.0, std=1. / math.sqrt(n), size=(n, cur_target_dim), device=dev)
if first_batch:
inp_processed_proj = inp_processed @ R
first_batch = False
else:
# concatenate projected input along last dimension
inp_processed_proj = torch.cat([inp_processed_proj, (inp_processed @ R)],
dim=-1)
# finally setting inp_processed to projected one, del proj afterwards
inp_processed = inp_processed_proj
del inp_processed_proj
else:
R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(n, self.target_dim))
# projecting the input data
inp_processed = inp_processed @ R.to(inp_processed.device)
del R
elif self.use_random_sampling:
# choose random indices and take TARGET_DIM many
ind = torch.randint(n, (self.target_dim,))
inp_processed = inp_processed.index_select(-1, ind.to(dev))
del ind

# Hessian computation
self.H *= self.nsamples / (self.nsamples + batch_size)
self.nsamples += batch_size
Expand Down Expand Up @@ -414,7 +465,10 @@ def __init__(
act_order: bool = False,
accumulator_bit_width: Optional[int] = None,
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: False,
a2q_gptq_class: Optional[A2GPTQ] = A2GPTQ) -> None:
a2q_gptq_class: Optional[A2GPTQ] = A2GPTQ,
use_random_proj: bool = False,
use_random_sampling: bool = False,
target_dim: int = 4096) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -424,7 +478,10 @@ def __init__(
create_weight_orig,
use_quant_activations,
act_order,
return_forward_output)
return_forward_output,
use_random_proj,
use_random_sampling,
target_dim)

# How many subblock to use during GPTQ for each layer
self.num_blocks = num_blocks
Expand Down Expand Up @@ -473,4 +530,7 @@ def initialize_module_optimizer(
act_order=act_order,
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
num_blocks=self.num_blocks)
num_blocks=self.num_blocks,
use_random_proj=self.use_random_proj,
use_random_sampling=self.use_random_sampling,
target_dim=self.target_dim)
24 changes: 22 additions & 2 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ def __init__(
create_weight_orig: bool = True,
use_quant_activations: bool = True,
act_order: bool = False,
return_forward_output: bool = False) -> None:
return_forward_output: bool = False,
use_random_proj: bool = False,
use_random_sampling: bool = False,
target_dim: int = 4096) -> None:

if not inplace:
model = deepcopy(model)
Expand Down Expand Up @@ -105,6 +108,11 @@ def __init__(
else:
self.model.forward = self.catch_stopfwd

# random proj/sample and target_dim
self.use_random_proj = use_random_proj
self.use_random_sampling = use_random_sampling
self.target_dim = target_dim

def _is_module_supported(self, module):
if isinstance(module, SUPPORTED_CONV_OP):
return True
Expand Down Expand Up @@ -201,7 +209,15 @@ def catch_stopfwd(self, *args, **kwargs):
class GPxQ(ABC):

def __init__(
self, layer, name, act_order, len_parallel_layers=1, create_weight_orig=True) -> None:
self,
layer,
name,
act_order,
len_parallel_layers=1,
create_weight_orig=True,
use_random_proj=False,
use_random_sampling=False,
target_dim=4096) -> None:
self.layer = layer
self.name = name
self.act_order = act_order
Expand Down Expand Up @@ -231,6 +247,10 @@ def __init__(
# Some layers require knowledge from quant inputs to compute quant weights
self.quant_input = None

self.use_random_proj = use_random_proj
self.use_random_sampling = use_random_sampling
self.target_dim = target_dim

def process_input(self, inp):
# Input is a tuple, so we take first element
inp = inp[0]
Expand Down

0 comments on commit 4ca9f18

Please sign in to comment.