Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Parameters missing from graph when KeOpsLinearOpeartor is used #55

Open
m-julian opened this issue Mar 30, 2023 · 5 comments
Open
Labels
bug Something isn't working

Comments

@m-julian
Copy link

🐛 Bug

I have implemented a KeOps periodic kernel in cornellius-gp/gpytorch#2296 , however the raw_lengthscale parameter does not have gradients computed (see cornellius-gp/gpytorch#2296 (comment) ). I have managed to track the issue to the matrix multiplication implemented in LinearOperator.matmul, see Expected Behavior. This matrix multiplication is called when doing LinearOperator.sum (as in the example below).

To reproduce

** Code snippet to reproduce **

import torch
from gpytorch.kernels.keops import PeriodicKernel as KeOpsPeriodicKernel #implementation from pull request 2296
import gpytorch

torch.manual_seed(7)

M, N, D = 1000, 2000, 3
x = torch.randn(M, D).double()
y = torch.randn(N, D).double()
k = KeOpsPeriodicKernel(ard_num_dims=3).double()
k.lengthscale = torch.tensor(1.0).double()
k.period_length = torch.tensor(1.0).double()

# context manager used so that type(covar) is KeOpsLinearOpeartor, not LazyEvaluatedKernelTensor
with gpytorch.settings.lazily_evaluate_kernels(False):
    covar = k(x, y)
    print(type(covar))
     # Calls `LinearOperator.sum``, which subsequently calls `LinearOperator.matmul`
     # `LinearOperator.matmul` uses a custom torch.Function for matrix multiplication
    res2 = covar.sum(dim=1) # res2 is a torch.Tensor here
    res2 = res2.sum()
    print(res2)
    g_x = torch.autograd.grad(res2, [k.raw_lengthscale, k.raw_period_length])
    print(g_x)

** Stack trace/error message **

<class 'linear_operator.operators.keops_linear_operator.KeOpsLinearOperator'>
tensor(202237.5145, dtype=torch.float64, grad_fn=<SumBackward0>)
Traceback (most recent call last):
  File "/home/julian/Desktop/test/keops_periodic_low_level/issue_keops_linear_operator.py", line 23, in <module>
    g_x = torch.autograd.grad(res2, [k.raw_lengthscale, k.raw_period_length])
  File "/home/julian/.venv/ichor/lib/python3.10/site-packages/torch/autograd/__init__.py", line 300, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

where k.raw_lengthscale is causing the issue.

Expected Behavior

Compute the gradients for both k.raw_lengthscale and k.raw_period_length. The exact place where the issue occurs is here

return (self @ ones).squeeze(-1)

which calls the LinearOperator.matmul that loses track of the gradients. I am summing across the columns in the example, but the same issue occurs if summing across the rows. Adding the following check

# Case: summing across columns
if dim == (self.dim() - 1):
    ones = torch.ones(self.size(-1), 1, dtype=self.dtype, device=self.device)
    from .keops_linear_operator import KeOpsLinearOperator
    if isinstance(self, KeOpsLinearOperator):
        return self.covar_mat.sum(dim=1)
    return (self @ ones).squeeze(-1)

gives gradients for both raw_lengthscale and raw_period_length as the custom Matmul is never called.

<class 'linear_operator.operators.keops_linear_operator.KeOpsLinearOperator'>
tensor(202237.5145, dtype=torch.float64, grad_fn=<SumBackward0>)
(tensor([[70682.4975, 70796.7652, 70631.9364]], dtype=torch.float64), tensor([[ 70.2535,  19.3231, -47.2902]], dtype=torch.float64))

This is probably not the best solution, perhaps the Matmul forward/backward methods can be changed, so the gradients are computed correctly?

As a check, the same numbers are returned if the normal periodic kernel is used

<class 'linear_operator.operators.dense_linear_operator.DenseLinearOperator'>
tensor(202237.5145, dtype=torch.float64, grad_fn=<SumBackward0>)
(tensor([[70682.4975, 70796.7652, 70631.9364]], dtype=torch.float64), tensor([[ 70.2535,  19.3231, -47.2902]], dtype=torch.float64))

System information

Please complete the following information:
linear_operator version: 0.3.0
torch version: 1.13.1+cu117

Additional context

Add any other context about the problem here.

@m-julian m-julian added the bug Something isn't working label Mar 30, 2023
@gpleiss
Copy link
Member

gpleiss commented Mar 30, 2023

Hi @m-julian I've found the bug (on our end) and I'm hoping to put up a PR to fix later today or tomorrow.

@m-julian
Copy link
Author

Thanks, looking forward to it!

@m-julian
Copy link
Author

Hi @gpleiss, when are you planning to put up the PR? I am happy to test out the fix and check gradients are computed correctly.

@gpleiss
Copy link
Member

gpleiss commented Apr 19, 2023

Sorry, I got a bit backlogged on this PR. I'll try to have something up on Friday or this weekend.

@gpleiss
Copy link
Member

gpleiss commented May 13, 2023

This bug will be fixed in the next LinearOperator release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants