Skip to content

Commit

Permalink
Feat (GPFQ): offload float input to disc
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed Jun 18, 2024
1 parent 2d4e359 commit 170bdac
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional

from accelerate.utils.offload import offload_state_dict
from accelerate.utils.offload import OffloadedWeightsLoader
import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -35,6 +38,7 @@ def __init__(
len_parallel_layers,
create_weight_orig,
p,
collect_float_first,
use_random_proj,
use_random_sampling,
target_dim) -> None:
Expand All @@ -54,6 +58,8 @@ def __init__(
self.index_computed = False
self.p = p

self.collect_float_first = collect_float_first

def collect_float_input(self, module, args, output):
# this is the hook function to collect the float inputs of this layer
inp = self.process_input(args)
Expand Down Expand Up @@ -118,6 +124,14 @@ def collect_float_input(self, module, args, output):
else:
self.float_input = torch.cat([self.float_input, inp_processed], dim=1)

def offload_float_input(self, tmp_dir):
# create tmp directory for this layer
self.save_dir = tmp_dir + '/' + self.name
# method expects dict
offload_state_dict(self.save_dir, state_dict={'float_input': self.float_input.detach()})
# then delete float_input to save memory
del self.float_input

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
return input
Expand Down Expand Up @@ -205,6 +219,12 @@ def single_layer_update(self):
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype

# load float input from disc if needed
if self.collect_float_first:
# load float_input from disc
self.float_input = OffloadedWeightsLoader(save_folder=self.save_dir)['float_input']

if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(
self.layer,
Expand Down Expand Up @@ -280,6 +300,7 @@ def __init__(
create_weight_orig,
accumulator_bit_width,
p,
collect_float_first,
use_random_proj,
use_random_sampling,
target_dim) -> None:
Expand All @@ -291,6 +312,7 @@ def __init__(
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=p,
collect_float_first=collect_float_first,
use_random_proj=use_random_proj,
use_random_sampling=use_random_sampling,
target_dim=target_dim)
Expand Down Expand Up @@ -328,6 +350,12 @@ def single_layer_update(self):
weight = self.layer.weight.data
dev = weight.device
dtype = weight.dtype

# load float input from disc if needed
if self.collect_float_first:
# load float_input from disc
self.float_input = OffloadedWeightsLoader(save_folder=self.save_dir)['float_input']

if isinstance(self.layer, SUPPORTED_CONV_OP):
if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)):
weight = weight.transpose(1, 0) # This performs a view
Expand Down Expand Up @@ -505,13 +533,22 @@ def __enter__(self):
return self.setup_gpxq_hooks()

def __exit__(self, type, value, traceback):
if self.collect_float_first:
self.tmp_dir.cleanup()
self.exit()

def finalize_float_collection(self):
# remove the hooks we attached during the float collection
for name, hook in self.float_collection_hooks.items():
hook.remove()

# create temp dir
self.tmp_dir = TemporaryDirectory()

# save all float activations to disc and delete them in the layers
for name, layer in self.gpxq_layers.items():
layer.offload_float_input(tmp_dir=self.tmp_dir.name)

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
Expand Down Expand Up @@ -571,6 +608,7 @@ def initialize_module_optimizer(
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p,
collect_float_first=self.collect_float_first,
accumulator_bit_width=self.accumulator_bit_width,
use_random_proj=self.use_random_proj,
use_random_sampling=self.use_random_sampling,
Expand All @@ -582,6 +620,7 @@ def initialize_module_optimizer(
len_parallel_layers=len_parallel_layers,
create_weight_orig=create_weight_orig,
p=self.p,
collect_float_first=self.collect_float_first,
use_random_proj=self.use_random_proj,
use_random_sampling=self.use_random_sampling,
target_dim=self.target_dim)

0 comments on commit 170bdac

Please sign in to comment.