diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index 5acd4d736..1bc520963 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -308,17 +308,44 @@ def get_quant_weights(self, i, i1, permutation_list): return q -def random_projection(float_input: torch.Tensor, quantized_input: torch.Tensor, target_dim: int): +def random_projection( + float_input: torch.Tensor, + quantized_input: torch.Tensor, + target_dim: int, + batch_size: int = 2048): # use random projection to reduce dimensionality n = quantized_input.size(1) dev = float_input.device - # create gaussian random matrix - R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(target_dim, n), device=dev) - quantized_input = torch.transpose(quantized_input, 1, 2) @ R.T - float_input = torch.transpose(float_input, 1, 2) @ R.T - del R + # use batching if target_dim is greater than 8000 to avoid memory issues + if target_dim > 8000: + accumulated_batches = 0 + first_batch = True + quantized_input = torch.transpose(quantized_input, 1, 2) + float_input = torch.transpose(float_input, 1, 2) + while accumulated_batches < target_dim: + # cur_target_dim makes sure to fully use batch_size unless we're too close to target_dim + cur_target_dim = min(batch_size, target_dim - accumulated_batches) + accumulated_batches += cur_target_dim + R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(cur_target_dim, n), device=dev) + if first_batch: + quantized_input_proj = (quantized_input @ R.T).cpu() + float_input_proj = (float_input @ R.T).cpu() + first_batch = False + else: + # concatenate projected input along last dimension + quantized_input_proj = torch.cat( + [quantized_input_proj, (quantized_input @ R.T).cpu()], dim=-1) + float_input_proj = torch.cat([float_input_proj, (float_input @ R.T).cpu()], dim=-1) + else: + # create gaussian random matrix + R = torch.normal(mean=0.0, std=1. / math.sqrt(n), size=(target_dim, n), device=dev) + quantized_input_proj = torch.transpose(quantized_input, 1, 2) @ R.T + float_input_proj = torch.transpose(float_input, 1, 2) @ R.T + del R # reshape back - quantized_input = torch.transpose(quantized_input, 1, 2) - float_input = torch.transpose(float_input, 1, 2) + quantized_input_proj = torch.transpose(quantized_input_proj, 1, 2) + float_input_proj = torch.transpose(float_input_proj, 1, 2) + del quantized_input + del float_input - return float_input, quantized_input + return float_input_proj, quantized_input_proj