-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Dilated Sliding Window mask_mod #12
base: main
Are you sure you want to change the base?
Conversation
|
||
def dilated_sliding_window(b, h, q_idx, kv_idx): | ||
diff = q_idx - kv_idx | ||
in_window = (diff >= 0) & (diff < window_size * dilation) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm from the paper its not clear to me that its always causal
what about torch.abs(diff) < window_size ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One other nit I think that its clearer if we keep the window_size and dilation separate
e.g. to recreate the paper (if we didnt have the and_causal mask)
we would set window_size = 8 and dilation = 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about torch.abs(diff) < window_size ?
I thought it would be good to make this implementation consistent with attn_gym/masks/sliding_window.py
.
However, seems reasonable to follow the non-causal way the paper described. I will update the generate_dilated_sliding_window()
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One other nit I think that its clearer if we keep the window_size and dilation separate
e.g. to recreate the paper (if we didnt have the and_causal mask)
we would set window_size = 8 and dilation = 2
Maybe, I missed something. Can you please explain what does it mean by "if we keep the window_size and dilation separate"?
Did you mean setting window_size = 8 and dilation = 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh I just meant that the dilation factor doesnt have any impact on the absolute size of the window.
window_size * dilation
-> window_size
So the "potential" size of the window is 16 elements (8 forward, 8 backward ) but a dilation factor knocks out half and we end up up with 4 on both sides. We dont extend the window so as to capture more elements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for opening up the PR! Left two comments/ questions let me know what you think
Hi @drisspg, thanks for the review! |
Summary
Visualization