Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve the compatibility with non-CUDA environments #124

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

YuHengjie
Copy link

Many thanks to the author for proposing this amazing KAN. I modified the KANLayer.py file for better compatibility in non-CUDA environments. BTW, it is my first pull request. I would appreciate it if you could accept this request. @KindXiaoming

Before:
line 126 in KANLayer.py
self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base)).requires_grad_(sb_trainable)
After:
if torch.cuda.is_available(): self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base).cuda()).requires_grad_(sb_trainable) else: self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base)).requires_grad_(sb_trainable)

@@ -123,7 +123,10 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base=
if isinstance(scale_base, float):
self.scale_base = torch.nn.Parameter(torch.ones(size, device=device) * scale_base).requires_grad_(sb_trainable) # make scale trainable
else:
self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base).to(device)).requires_grad_(sb_trainable)
if torch.cuda.is_available():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could potentially lead to errors for users who have CUDA device but intend to run on CPU.

@AlessandroFlati
Copy link
Contributor

I do agree with @Jim137. BTW, if you already set device in the calling function there's no need to do so. If you want it to be transparent, even if I don't agree with this approach, you should modify EVERY method that regards cuda vs cpu.

@carlguo866
Copy link

Was gonna make a pull request about this and then saw this - 100% agree!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants