Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch MPS backend support #42

Open
emdupre opened this issue Feb 14, 2023 · 11 comments
Open

torch MPS backend support #42

emdupre opened this issue Feb 14, 2023 · 11 comments

Comments

@emdupre
Copy link

emdupre commented Feb 14, 2023

Thanks for your great work making VM accessible !

I was looking into starting with himalaya, but it seems that you do not currently support pytorch's MPS backend for working on the M1 GPU. Is this correct ?

As the MPS backend has been officially released for almost a year, it would be great to take advantage of it to accelerate himalaya models ! Is this something that you would be interested in ?

Thanks again,
Elizabeth

@mvdoc
Copy link
Collaborator

mvdoc commented Feb 15, 2023

oh cool, thanks for letting us now! I didn't know that pytorch supports the M1. Since I have an M1 based machine, I will try to implement this in himalaya :)

@mvdoc
Copy link
Collaborator

mvdoc commented Feb 15, 2023

Unfortunately, it looks like the MPS support in pytorch is far from being completed. Many linear algebra operators are not implemented yet (pytorch/pytorch#77764). For example, using torch.linalg.eigh returns the following error:

NotImplementedError: The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device.
If you want this op to be added in priority during the prototype phase of this feature, please comment on pytorch/pytorch#77764.
As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op.
WARNING: this will be slower than running natively on MPS.

I don't think we can support the MPS backend in himalaya until all the linalg operations are fully implemented in pytorch.

@emdupre
Copy link
Author

emdupre commented Feb 15, 2023

Thanks for getting to this so quickly, @mvdoc !!

Would it make since to try the PYTORCH_ENABLE_MPS_FALLBACK=1 flag as suggested in the linked thread ? I completely understand if you'd rather wait until the full M1 linalg support is available, but it might also be nice to take advantage of what is currently available.

@mvdoc
Copy link
Collaborator

mvdoc commented Feb 15, 2023 via email

@TomDLT
Copy link
Collaborator

TomDLT commented Feb 15, 2023

Himalaya solvers are using GPUs to speed up two kinds of expensive computations:

  • matrix inversions, through torch.linalg.eigh or torch.linalg.svd
  • matrix multiplications and other operations, through torch.matmul, torch.mean, etc.

I think both improvements are important. So even though MPS does not support torch.linalg.eigh, it could still be useful with PYTORCH_ENABLE_MPS_FALLBACK=1 to speed up matrix multiplications and other operations. In fact, some solvers are not using matrix inversions at all (e.g. KernelRidge(solver="conjugate_gradient"),WeightedKernelRidge(), or MultipleKernelRidgeCV(solver="hyper_gradient")). For these solvers, an MPS backend would likely be beneficial.

Also, all solvers using torch.linalg.eigh can also work with torch.linalg.svd. Do you know if MPS supports torch.linalg.svd?

@emdupre
Copy link
Author

emdupre commented Feb 15, 2023

Also, all solvers using torch.linalg.eigh can also work with torch.linalg.svd. Do you know if MPS supports torch.linalg.svd?

Locally I'm not able to confirm MPS support using either the stable or nightly build, getting

UserWarning: The operator 'aten::linalg_svd' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.

Looking at this older (solved) bug report, however, it seems that they were able to run torch.linalg.svd with the MPS backend.... I'll try to track down the discrepancy.

EDIT : It looks like this is set here and indeed triggered as a warning in the thread I sent. So : no, torch.linalg.svd is not currently supported on the MPS backend !

@mvdoc
Copy link
Collaborator

mvdoc commented Feb 15, 2023

I will experiment a bit more and see what speedup we get even if we use PYTORCH_ENABLE_MPS_FALLBACK=1 :)

@mvdoc
Copy link
Collaborator

mvdoc commented Feb 16, 2023

Well, I'm happy I was wrong (by a lot). I ran the voxelwise tutorial that fits the banded ridge model (it's actually one banded ridge model + one ridge model). We get a ~3x speed up by using the MPS backend. (The eigh diagonalizer kept crashing with torch_mps, so I had to switch to the svd diagonalizer; I guess if we are comparing the svd solver for the CPU as well, the speed up may be slightly larger.)

There are still things to check (some tests fail with the torch_mps fail due to numerical precision), but I can see the torch_mps backend implemented soon.

Backend numpy, eigh diag

python 06_plot_banded_ridge_model.py  15898.39s user 1106.48s system 243% cpu 1:56:27.62 total

Backend torch_mps, svd diag

python 06_plot_banded_ridge_model.py  1195.92s user 120.63s system 60% cpu 36:22.02 total

@mvdoc
Copy link
Collaborator

mvdoc commented Feb 16, 2023

Another test: we don't get a noticeable speedup when running a simple ridge model.

Backend torch, svd diag

python 05_plot_motion_energy_model.py  236.70s user 50.63s system 177% cpu 2:41.66 total

Backend torch_mps, svd diag

python 05_plot_motion_energy_model.py  89.28s user 29.34s system 79% cpu 2:28.93 total

@TomDLT
Copy link
Collaborator

TomDLT commented Feb 16, 2023

We get a ~3x speed up by using the MPS backend.

Nice! Thanks for working on this.

@emdupre
Copy link
Author

emdupre commented Feb 16, 2023

Yes, thank you @mvdoc for working on this and @TomDLT for your insight !

If there's anything I can provide to help, here, please let me know. Happy to help develop or review !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants