diff --git a/linear_operator/operators/block_diag_linear_operator.py b/linear_operator/operators/block_diag_linear_operator.py index 2d6d6319..2df9b3f5 100644 --- a/linear_operator/operators/block_diag_linear_operator.py +++ b/linear_operator/operators/block_diag_linear_operator.py @@ -46,6 +46,15 @@ class BlockDiagLinearOperator(BlockLinearOperator, metaclass=_MetaBlockDiagLinea The dimension that specifies the blocks. """ + def __init__(self, base_linear_op, block_dim=-3): + super().__init__(base_linear_op, block_dim) + # block diagonal is restricted to have square diagonal blocks + if self.base_linear_op.shape[-1] != self.base_linear_op.shape[-2]: + raise RuntimeError( + "base_linear_op must be a batch of square matrices, but non-batch dimensions are " + f"{base_linear_op.shape[-2:]}" + ) + @property def num_blocks(self): return self.base_linear_op.size(-3) diff --git a/linear_operator/operators/block_linear_operator.py b/linear_operator/operators/block_linear_operator.py index 8edacee0..e96efff1 100644 --- a/linear_operator/operators/block_linear_operator.py +++ b/linear_operator/operators/block_linear_operator.py @@ -12,7 +12,7 @@ class BlockLinearOperator(LinearOperator): """ An abstract LinearOperator class for block tensors. - Super classes will determine how the different blocks are layed out + Subclasses will determine how the different blocks are layed out (e.g. block diagonal, sum over blocks, etc.) BlockLinearOperators represent the groups of blocks as a batched Tensor. @@ -39,7 +39,7 @@ def __init__(self, base_linear_op, block_dim=-3): block_dim = block_dim if block_dim < 0 else (block_dim - base_linear_op.dim()) # Everything is MUCH easier to write if the last batch dimension is the block dimension - # I.e. blopck_dim = -3 + # I.e. block_dim = -3 # We'll permute the dimensions if this is not the case if block_dim != -3: positive_block_dim = base_linear_op.dim() + block_dim @@ -48,7 +48,6 @@ def __init__(self, base_linear_op, block_dim=-3): *range(positive_block_dim + 1, base_linear_op.dim() - 2), positive_block_dim, ) - super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op)) self.base_linear_op = base_linear_op diff --git a/linear_operator/test/linear_operator_test_case.py b/linear_operator/test/linear_operator_test_case.py index 15a89dec..1646f594 100644 --- a/linear_operator/test/linear_operator_test_case.py +++ b/linear_operator/test/linear_operator_test_case.py @@ -877,12 +877,12 @@ def test_diagonalization(self, symeig=False): def test_diagonalization_symeig(self): return self.test_diagonalization(symeig=True) + # NOTE: this is currently not executed, and fails if the underscore is removed def _test_triangular_linear_op_inv_quad_logdet(self): # now we need to test that a second cholesky isn't being called in the inv_quad_logdet with linear_operator.settings.max_cholesky_size(math.inf): linear_op = self.create_linear_op() rootdecomp = linear_operator.root_decomposition(linear_op) - if isinstance(rootdecomp, linear_operator.operators.CholLinearOperator): chol = linear_operator.root_decomposition(linear_op).root.clone() linear_operator.utils.memoize.clear_cache_hook(linear_op) diff --git a/test/operators/test_low_rank_root_added_diag_linear_operator.py b/test/operators/test_low_rank_root_added_diag_linear_operator.py index 446ffd4d..ceec58d5 100644 --- a/test/operators/test_low_rank_root_added_diag_linear_operator.py +++ b/test/operators/test_low_rank_root_added_diag_linear_operator.py @@ -67,10 +67,12 @@ def _test_solve(self, rhs, lhs=None, cholesky=False): self.assertFalse(linear_cg_mock.called) - def _test_inv_quad_logdet(self, reduce_inv_quad=True, cholesky=False): + # NOTE: this is currently not executed + def _test_inv_quad_logdet(self, reduce_inv_quad=True, cholesky=False, linear_op=None): if not self.__class__.skip_slq_tests: # Forward - linear_op = self.create_linear_op() + if linear_op is None: + linear_op = self.create_linear_op() evaluated = self.evaluate_linear_op(linear_op) flattened_evaluated = evaluated.view(-1, *linear_op.matrix_shape)