Skip to content

Commit

Permalink
deepcopy in add, size check in BlockDiag matmul, doc edits
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianAment committed Sep 7, 2022
1 parent 2e66c24 commit b5b65d7
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion linear_operator/operators/_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2557,7 +2557,7 @@ def __add__(self, other: Union[torch.Tensor, LinearOperator, float]) -> LinearOp
from .zero_linear_operator import ZeroLinearOperator

if isinstance(other, ZeroLinearOperator):
return self
return deepcopy(self)
elif isinstance(other, DiagLinearOperator):
return AddedDiagLinearOperator(self, other)
elif isinstance(other, RootLinearOperator):
Expand Down
7 changes: 5 additions & 2 deletions linear_operator/operators/block_diag_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,11 @@ def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True)
def matmul(self, other):
from .diag_linear_operator import DiagLinearOperator

# this is trivial if we multiply two BlockDiagLinearOperator
if isinstance(other, BlockDiagLinearOperator):
# this is trivial if we multiply two BlockDiagLinearOperator with matching block sizes
if (
isinstance(other, BlockDiagLinearOperator)
and self.base_linear_op.shape[-1] == other.base_linear_op.shape[0]
):
return BlockDiagLinearOperator(self.base_linear_op @ other.base_linear_op)
# special case if we have a DiagLinearOperator
if isinstance(other, DiagLinearOperator):
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/operators/cat_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def cat(inputs, dim=0, output_device=None):

class CatLinearOperator(LinearOperator):
r"""
A `LinearOperator` that represents the concatenation of other lazy tensors.
A `LinearOperator` that represents the concatenation of other linear operators.
Each LinearOperator must have the same shape except in the concatenating
dimension.
Expand Down
10 changes: 5 additions & 5 deletions linear_operator/operators/constant_mul_linear_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ConstantMulLinearOperator(LinearOperator):
.. note::
To element-wise multiply two lazy tensors, see :class:`linear_operator.lazy.MulLinearOperator`
To element-wise multiply two lazy tensors, see :class:`linear_operator.operators.MulLinearOperator`
Args:
base_linear_op (LinearOperator) or (b x n x m)): The base_lazy tensor
Expand All @@ -38,18 +38,18 @@ class ConstantMulLinearOperator(LinearOperator):
Example::
>>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([1, 2, 3])
>>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([1, 2, 3])
>>> constant = torch.tensor(1.2)
>>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op.to_dense()
>>> # Returns:
>>> # [[ 1.2, 2.4, 3.6 ]
>>> # [ 2.4, 1.2, 2.4 ]
>>> # [ 3.6, 2.4, 1.2 ]]
>>>
>>> base_base_linear_op = linear_operator.lazy.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]])
>>> base_base_linear_op = linear_operator.operators.ToeplitzLinearOperator([[1, 2, 3], [2, 3, 4]])
>>> constant = torch.tensor([1.2, 0.5])
>>> new_base_linear_op = linear_operator.lazy.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op = linear_operator.operators.ConstantMulLinearOperator(base_base_linear_op, constant)
>>> new_base_linear_op.to_dense()
>>> # Returns:
>>> # [[[ 1.2, 2.4, 3.6 ]
Expand Down
4 changes: 2 additions & 2 deletions linear_operator/test/linear_operator_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,13 +883,13 @@ def _test_triangular_linear_op_inv_quad_logdet(self):
linear_op = self.create_linear_op()
rootdecomp = linear_operator.root_decomposition(linear_op)

if isinstance(rootdecomp, linear_operator.lazy.CholLinearOperator):
if isinstance(rootdecomp, linear_operator.operators.CholLinearOperator):
chol = linear_operator.root_decomposition(linear_op).root.clone()
linear_operator.utils.memoize.clear_cache_hook(linear_op)
linear_operator.utils.memoize.add_to_cache(
linear_op,
"root_decomposition",
linear_operator.lazy.RootLinearOperator(chol),
linear_operator.operators.RootLinearOperator(chol),
)

_wrapped_cholesky = MagicMock(wraps=torch.linalg.cholesky_ex)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/utils/contour_integral_quad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def contour_integral_quad(
Performs :math:`\mathbf K^{1/2} \mathbf b` or :math:`\mathbf K^{-1/2} \mathbf b`
using contour integral quadrature.
:param linear_operator.lazy.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K`
:param linear_operator.operators.LinearOperator linear_op: LinearOperator representing :math:`\mathbf K`
:param torch.Tensor rhs: Right hand side tensor :math:`\mathbf b`
:param bool inverse: (default False) whether to compute :math:`\mathbf K^{1/2} \mathbf b` (if False)
or `\mathbf K^{-1/2} \mathbf b` (if True)
Expand Down
2 changes: 1 addition & 1 deletion linear_operator/utils/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def apply_permutation(
Broadcasting rules apply.
:param matrix: :math:`\mathbf K`
:type matrix: ~linear_operator.lazy.LinearOperator or ~torch.Tensor (... x n x n)
:type matrix: ~linear_operator.operators.LinearOperator or ~torch.Tensor (... x n x n)
:param left_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{left}`
:type left_permutation: ~torch.Tensor, optional (... x <= n)
:param right_permutation: vector representing :math:`\boldsymbol{\Pi}_\text{right}`
Expand Down

0 comments on commit b5b65d7

Please sign in to comment.