diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 9f4d5cd6..b775a3ba 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -209,16 +209,12 @@ def pad_tensor_to_multiple_of(ndarray, pad_to_dims, val=0, distr_pad=False): def calculate_matvec_accumulator_range(matrix: np.ndarray, vec_dt: DataType): - """Calculate the minimum and maximum possible result (accumulator) values - for a dot product x * A, given matrix A of dims (MW, MH), and vector (1, MW) - with datatype vec_dt. Returns (acc_min, acc_max). - """ - max_weight = abs(matrix).sum(axis=0).max() - max_input = max(abs(vec_dt.min()), abs(vec_dt.max())) - max_value = max_input * max_weight - # If either the weight and input datatypes are signed, then the minimum - # value that their accumulated product can be is -max_value. Else, it's 0. - min_value = -max_value if (matrix.min() < 0) or vec_dt.signed() else 0 + """Calculate the minimum and maximum possible result (accumulator) values for a dot product x * A, + given matrix A of dims (MW, MH), and vector (1, MW) with datatype vec_dt. Returns (acc_min, acc_max).""" + max_vectors = np.where(matrix > 0, vec_dt.max(), vec_dt.min()) + min_vectors = np.where(matrix > 0, vec_dt.min(), vec_dt.max()) + max_value = (matrix * max_vectors).sum(axis=0).max() + min_value = (matrix * min_vectors).sum(axis=0).min() return (min_value, max_value)