Skip to content

Commit

Permalink
added shape check in constructor of block diag op, added notes about …
Browse files Browse the repository at this point in the history
…skipped tests
  • Loading branch information
SebastianAment committed Sep 8, 2022
1 parent d3e45c4 commit c4fd5fa
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
9 changes: 9 additions & 0 deletions linear_operator/operators/block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ class BlockDiagLinearOperator(BlockLinearOperator, metaclass=_MetaBlockDiagLinea
The dimension that specifies the blocks.
"""

def __init__(self, base_linear_op, block_dim=-3):
super().__init__(base_linear_op, block_dim)
# block diagonal is restricted to have square diagonal blocks
if self.base_linear_op.shape[-1] != self.base_linear_op.shape[-2]:
raise RuntimeError(
"base_linear_op must be a batch of square matrices, but non-batch dimensions are "
f"{base_linear_op.shape[-2:]}"
)

@property
def num_blocks(self):
return self.base_linear_op.size(-3)
Expand Down
5 changes: 2 additions & 3 deletions linear_operator/operators/block_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class BlockLinearOperator(LinearOperator):
"""
An abstract LinearOperator class for block tensors.
Super classes will determine how the different blocks are layed out
Subclasses will determine how the different blocks are layed out
(e.g. block diagonal, sum over blocks, etc.)
BlockLinearOperators represent the groups of blocks as a batched Tensor.
Expand All @@ -39,7 +39,7 @@ def __init__(self, base_linear_op, block_dim=-3):
block_dim = block_dim if block_dim < 0 else (block_dim - base_linear_op.dim())

# Everything is MUCH easier to write if the last batch dimension is the block dimension
# I.e. blopck_dim = -3
# I.e. block_dim = -3
# We'll permute the dimensions if this is not the case
if block_dim != -3:
positive_block_dim = base_linear_op.dim() + block_dim
Expand All @@ -48,7 +48,6 @@ def __init__(self, base_linear_op, block_dim=-3):
*range(positive_block_dim + 1, base_linear_op.dim() - 2),
positive_block_dim,
)

super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op))
self.base_linear_op = base_linear_op

Expand Down
2 changes: 1 addition & 1 deletion linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,12 +877,12 @@ def test_diagonalization(self, symeig=False):
def test_diagonalization_symeig(self):
return self.test_diagonalization(symeig=True)

# NOTE: this is currently not executed, and fails if the underscore is removed
def _test_triangular_linear_op_inv_quad_logdet(self):
# now we need to test that a second cholesky isn't being called in the inv_quad_logdet
with linear_operator.settings.max_cholesky_size(math.inf):
linear_op = self.create_linear_op()
rootdecomp = linear_operator.root_decomposition(linear_op)

if isinstance(rootdecomp, linear_operator.operators.CholLinearOperator):
chol = linear_operator.root_decomposition(linear_op).root.clone()
linear_operator.utils.memoize.clear_cache_hook(linear_op)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def _test_solve(self, rhs, lhs=None, cholesky=False):

self.assertFalse(linear_cg_mock.called)

def _test_inv_quad_logdet(self, reduce_inv_quad=True, cholesky=False):
# NOTE: this is currently not executed
def _test_inv_quad_logdet(self, reduce_inv_quad=True, cholesky=False, linear_op=None):
if not self.__class__.skip_slq_tests:
# Forward
linear_op = self.create_linear_op()
if linear_op is None:
linear_op = self.create_linear_op()
evaluated = self.evaluate_linear_op(linear_op)
flattened_evaluated = evaluated.view(-1, *linear_op.matrix_shape)

Expand Down

0 comments on commit c4fd5fa

Please sign in to comment.