Skip to content

Commit

Permalink
Add batched random projection
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed May 31, 2024
1 parent db861eb commit 9e4388c
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,17 +308,44 @@ def get_quant_weights(self, i, i1, permutation_list):
return q


def random_projection(float_input: torch.Tensor, quantized_input: torch.Tensor, target_dim: int):
def random_projection(
float_input: torch.Tensor,
quantized_input: torch.Tensor,
target_dim: int,
batch_size: int = 2048):
# use random projection to reduce dimensionality
n = quantized_input.size(1)
dev = float_input.device
# create gaussian random matrix
R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(target_dim, n), device=dev)
quantized_input = torch.transpose(quantized_input, 1, 2) @ R.T
float_input = torch.transpose(float_input, 1, 2) @ R.T
del R
# use batching if target_dim is greater than 8000 to avoid memory issues
if target_dim > 8000:
accumulated_batches = 0
first_batch = True
quantized_input = torch.transpose(quantized_input, 1, 2)
float_input = torch.transpose(float_input, 1, 2)
while accumulated_batches < target_dim:
# cur_target_dim makes sure to fully use batch_size unless we're too close to target_dim
cur_target_dim = min(batch_size, target_dim - accumulated_batches)
accumulated_batches += cur_target_dim
R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(cur_target_dim, n), device=dev)
if first_batch:
quantized_input_proj = (quantized_input @ R.T).cpu()
float_input_proj = (float_input @ R.T).cpu()
first_batch = False
else:
# concatenate projected input along last dimension
quantized_input_proj = torch.cat(
[quantized_input_proj, (quantized_input @ R.T).cpu()], dim=-1)
float_input_proj = torch.cat([float_input_proj, (float_input @ R.T).cpu()], dim=-1)
else:
# create gaussian random matrix
R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(target_dim, n), device=dev)
quantized_input_proj = torch.transpose(quantized_input, 1, 2) @ R.T
float_input_proj = torch.transpose(float_input, 1, 2) @ R.T
del R
# reshape back
quantized_input = torch.transpose(quantized_input, 1, 2)
float_input = torch.transpose(float_input, 1, 2)
quantized_input_proj = torch.transpose(quantized_input_proj, 1, 2)
float_input_proj = torch.transpose(float_input_proj, 1, 2)
del quantized_input
del float_input

return float_input, quantized_input
return float_input_proj, quantized_input_proj

0 comments on commit 9e4388c

Please sign in to comment.