Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edits to generic add, BlockDiagLinearOperator's matmul, and documentation #10

Merged
merged 2 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@ A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch

[![Run Test Suite](https://github.com/cornellius-gp/linear_operator/actions/workflows/run_test_suite.yml/badge.svg)](https://github.com/cornellius-gp/linear_operator/actions/workflows/run_test_suite.yml)
[![Documentation Status](https://readthedocs.org/projects/linear-operator/badge/?version=latest)](https://linear-operator.readthedocs.io/en/latest/?badge=latest)

## Development
To run unit tests:
```
python -m unittest discover
```
2 changes: 1 addition & 1 deletion linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,7 +2557,7 @@ def __add__(self, other: Union[torch.Tensor, LinearOperator, float]) -> LinearOp
from .zero_linear_operator import ZeroLinearOperator

if isinstance(other, ZeroLinearOperator):
return self
return deepcopy(self)
elif isinstance(other, DiagLinearOperator):
return AddedDiagLinearOperator(self, other)
elif isinstance(other, RootLinearOperator):
Expand Down
13 changes: 11 additions & 2 deletions linear_operator/operators/block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ class BlockDiagLinearOperator(BlockLinearOperator, metaclass=_MetaBlockDiagLinea
The dimension that specifies the blocks.
"""

def __init__(self, base_linear_op, block_dim=-3):
super().__init__(base_linear_op, block_dim)
# block diagonal is restricted to have square diagonal blocks
if self.base_linear_op.shape[-1] != self.base_linear_op.shape[-2]:
raise RuntimeError(
"base_linear_op must be a batch of square matrices, but non-batch dimensions are "
f"{base_linear_op.shape[-2:]}"
)

@property
def num_blocks(self):
return self.base_linear_op.size(-3)
Expand Down Expand Up @@ -139,8 +148,8 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
def matmul(self, other):
from .diag_linear_operator import DiagLinearOperator

# this is trivial if we multiply two BlockDiagLinearOperator
if isinstance(other, BlockDiagLinearOperator):
# this is trivial if we multiply two BlockDiagLinearOperator with matching block sizes
if isinstance(other, BlockDiagLinearOperator) and self.base_linear_op.shape == other.base_linear_op.shape:
return BlockDiagLinearOperator(self.base_linear_op @ other.base_linear_op)
# special case if we have a DiagLinearOperator
if isinstance(other, DiagLinearOperator):
Expand Down
5 changes: 2 additions & 3 deletions linear_operator/operators/block_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class BlockLinearOperator(LinearOperator):
"""
An abstract LinearOperator class for block tensors.
Super classes will determine how the different blocks are layed out
Subclasses will determine how the different blocks are layed out
(e.g. block diagonal, sum over blocks, etc.)

BlockLinearOperators represent the groups of blocks as a batched Tensor.
Expand All @@ -39,7 +39,7 @@ def __init__(self, base_linear_op, block_dim=-3):
block_dim = block_dim if block_dim < 0 else (block_dim - base_linear_op.dim())

# Everything is MUCH easier to write if the last batch dimension is the block dimension
# I.e. blopck_dim = -3
# I.e. block_dim = -3
# We'll permute the dimensions if this is not the case
if block_dim != -3:
positive_block_dim = base_linear_op.dim() + block_dim
Expand All @@ -48,7 +48,6 @@ def __init__(self, base_linear_op, block_dim=-3):
*range(positive_block_dim + 1, base_linear_op.dim() - 2),
positive_block_dim,
)

super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op))
self.base_linear_op = base_linear_op

Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/cat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def cat(inputs, dim=0, output_device=None):

class CatLinearOperator(LinearOperator):
r"""
A `LinearOperator` that represents the concatenation of other lazy tensors.
A `LinearOperator` that represents the concatenation of other linear operators.
Each LinearOperator must have the same shape except in the concatenating
dimension.

Expand Down
10 changes: 5 additions & 5 deletions linear_operator/operators/constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ConstantMulLinearOperator(LinearOperator):

.. note::

To element-wise multiply two lazy tensors, see :class:`linear_operator.lazy.MulLinearOperator`
To element-wise multiply two lazy tensors, see :class:`linear_operator.operators.MulLinearOperator`

Args:
base_linear_op (LinearOperator) or (b x n x m)): The base_lazy tensor
Expand All @@ -38,18 +38,18 @@ class ConstantMulLinearOperator(LinearOperator):

Example::

>>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([1, 2, 3])
>>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([1, 2, 3])
>>> constant = torch.tensor(1.2)
>>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op.to_dense()
>>> # Returns:
>>> # [[ 1.2, 2.4, 3.6 ]
>>> # [ 2.4, 1.2, 2.4 ]
>>> # [ 3.6, 2.4, 1.2 ]]
>>>
>>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]])
>>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]])
>>> constant = torch.tensor([1.2, 0.5])
>>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op.to_dense()
>>> # Returns:
>>> # [[[ 1.2, 2.4, 3.6 ]
Expand Down
6 changes: 3 additions & 3 deletions linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,19 +877,19 @@ def test_diagonalization(self, symeig=False):
def test_diagonalization_symeig(self):
return self.test_diagonalization(symeig=True)

# NOTE: this is currently not executed, and fails if the underscore is removed
def _test_triangular_linear_op_inv_quad_logdet(self):
# now we need to test that a second cholesky isn't being called in the inv_quad_logdet
with linear_operator.settings.max_cholesky_size(math.inf):
linear_op = self.create_linear_op()
rootdecomp = linear_operator.root_decomposition(linear_op)

if isinstance(rootdecomp, linear_operator.lazy.CholLinearOperator):
Balandat marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(rootdecomp, linear_operator.operators.CholLinearOperator):
chol = linear_operator.root_decomposition(linear_op).root.clone()
linear_operator.utils.memoize.clear_cache_hook(linear_op)
linear_operator.utils.memoize.add_to_cache(
linear_op,
"root_decomposition",
linear_operator.lazy.RootLinearOperator(chol),
linear_operator.operators.RootLinearOperator(chol),
)

_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/utils/contour_integral_quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def contour_integral_quad(
Performs :math:`\mathbf K^{1/2} \mathbf b` or :math:`\mathbf K^{-1/2} \mathbf b`
using contour integral quadrature.

:param linear_operator.lazy.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K`
:param linear_operator.operators.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K`
:param torch.Tensor rhs: Right hand side tensor :math:`\mathbf b`
:param bool inverse: (default False) whether to compute :math:`\mathbf K^{1/2} \mathbf b` (if False)
or `\mathbf K^{-1/2} \mathbf b` (if True)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/utils/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def apply_permutation(
Broadcasting rules apply.

:param matrix: :math:`\mathbf K`
:type matrix: ~linear_operator.lazy.LinearOperator or ~torch.Tensor (... x n x n)
:type matrix: ~linear_operator.operators.LinearOperator or ~torch.Tensor (... x n x n)
:param left_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{left}`
:type left_permutation: ~torch.Tensor, optional (... x <= n)
:param right_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{right}`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,12 @@ def _test_solve(self, rhs, lhs=None, cholesky=False):

self.assertFalse(linear_cg_mock.called)

def _test_inv_quad_logdet(self, reduce_inv_quad=True, cholesky=False):
# NOTE: this is currently not executed
def _test_inv_quad_logdet(self, reduce_inv_quad=True, cholesky=False, linear_op=None):
if not self.__class__.skip_slq_tests:
# Forward
linear_op = self.create_linear_op()
if linear_op is None:
linear_op = self.create_linear_op()
evaluated = self.evaluate_linear_op(linear_op)
flattened_evaluated = evaluated.view(-1, *linear_op.matrix_shape)

Expand Down