Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cdpm model #261

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 200 additions & 0 deletions generative/networks/layers/RPE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
# from monai.networks.blocks import MLPBlock
# from monai.networks.layers import Act
Comment on lines +16 to +17
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove unused imports



class RPENet(nn.Module):
"""
Attention with slice relative position encoding by Wu et al. (https://arxiv.org/abs/2107.14222) and the official implementation
that can be found at https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py.
Args:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please leave a space between the preamble and the Args for all docstrings

channels : number of channels of the input.
num_heads: number of heads in the attention model.
time_embed_dim: number of channels of the time embedding.
"""
def __init__(
self,
channels: int,
num_heads: int,
time_embed_dim: int,
)-> None:
super().__init__()
self.embed_distances = nn.Linear(3, channels)
self.embed_diffusion_time = nn.Linear(time_embed_dim, channels)
self.silu = nn.SiLU()
self.out = nn.Linear(channels, channels)
self.out.weight.data *= 0.
self.out.bias.data *= 0.
self.channels = channels
self.num_heads = num_heads

def forward(self, temb: torch.Tensor, relative_distances: torch.Tensor) -> torch.Tensor:
distance_embs = torch.stack(
[torch.log(1+(relative_distances).clamp(min=0)),
torch.log(1+(-relative_distances).clamp(min=0)),
(relative_distances == 0).float()],
dim=-1
) # BxTxTx3
B, T, _ = relative_distances.shape
C = self.channels
emb = self.embed_diffusion_time(temb).view(B, T, 1, C) \
+ self.embed_distances(distance_embs) # B x T x T x C
return self.out(self.silu(emb)).view(*relative_distances.shape, self.num_heads, self.channels//self.num_heads)


class RPE(nn.Module):
"""
Attention with slice relative position encoding by Wu et al. (https://arxiv.org/abs/2107.14222) and the official implementation
that can be found at https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py.
Args:
channels : number of channels of the input.
num_heads: number of heads in the attention model.
time_embed_dim: number of channels of the time embedding.
"""
def __init__(
self,
channels: int,
num_heads: int,
time_embed_dim: int,
)-> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = channels // self.num_heads
self.rpe_net = RPENet(channels, num_heads, time_embed_dim)

def get_R(self, pairwise_distances, temb)-> torch.Tensor:
return self.rpe_net(temb, pairwise_distances)

def forward(self, x, pairwise_distances, temb, mode)-> torch.Tensor:
if mode == "qk":
return self.forward_qk(x, pairwise_distances, temb)
elif mode == "v":
return self.forward_v(x, pairwise_distances, temb)
else:
raise ValueError(f"Unexpected RPE attention mode: {mode}")

def forward_qk(self, qk, pairwise_distances, temb)-> torch.Tensor:
# qk is either of q or k and has shape BxDxHxTx(C/H)
# Output shape should be # BxDxHxTxT
R = self.get_R(pairwise_distances, temb)
return torch.einsum( # See Eq. 16 in https://arxiv.org/pdf/2107.14222.pdf
"bdhtf,btshf->bdhts", qk, R # BxDxHxTxT
)

def forward_v(self, attn, pairwise_distances, temb)-> torch.Tensor:
# attn has shape BxDxHxTxT
# Output shape should be # BxDxHxYx(C/H)
R = self.get_R(pairwise_distances, temb)
torch.einsum("bdhts,btshf->bdhtf", attn, R)
return torch.einsum( # See Eq. 16ish in https://arxiv.org/pdf/2107.14222.pdf
"bdhts,btshf->bdhtf", attn, R # BxDxHxTxT
)

def forward_safe_qk(self, x, pairwise_distances, temb)-> torch.Tensor:
R = self.get_R(pairwise_distances, temb)
B, T, _, H, F = R.shape
D = x.shape[1]
res = x.new_zeros(B, D, H, T, T) # attn shape
for b in range(B):
for d in range(D):
for h in range(H):
for i in range(T):
for j in range(T):
res[b, d, h, i, j] = x[b, d, h, i].dot(R[b, i, j, h])
Comment on lines +112 to +117
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for b in range(B):
for d in range(D):
for h in range(H):
for i in range(T):
for j in range(T):
res[b, d, h, i, j] = x[b, d, h, i].dot(R[b, i, j, h])
for b, d, h, i, j in np.ndindex(B, D, H, T, T):
res[b, d, h, i, j] = x[b, d, h, i].dot(R[b, i, j, h])

return res


class RPEAttention(nn.Module):
"""
Attention with slice relative position encoding by Wu et al. (https://arxiv.org/abs/2107.14222) and the official implementation
that can be found at https://github.com/microsoft/Cream/blob/6fb89a2f93d6d97d2c7df51d600fe8be37ff0db4/iRPE/DeiT-with-iRPE/rpe_vision_transformer.py.
Args:
channels : number of channels of the input.
num_heads: number of heads in the attention model.
time_embed_dim: number of channels of the time embedding.
use_rpe_q: Flag of using RPE attention mode q or not.
use_rpe_k: Flag of using RPE attention mode k or not.
use_rpe_v: Flag of using RPE attention mode v or not.
"""
def __init__(
self,
channels: int,
num_heads: int,
time_embed_dim: int = None,
use_rpe_q: bool = True,
use_rpe_k: bool = True,
use_rpe_v: bool = True,
)-> None:
super().__init__()
self.num_heads = num_heads
head_dim = channels // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(channels, channels * 3)
self.proj_out = nn.Linear(channels, channels)
self.norm = nn.GroupNorm(32, channels)# Separate channels into 32 groups

def make_rpe_func()-> torch.Tensor:
return RPE(
channels=channels, num_heads=num_heads,
time_embed_dim=time_embed_dim,
)
self.rpe_q = make_rpe_func() if use_rpe_q else None
self.rpe_k = make_rpe_func() if use_rpe_k else None
self.rpe_v = make_rpe_func() if use_rpe_v else None

def forward(self, x, temb, frame_indices, attn_mask=None)-> torch.Tensor:
B, D, C, T = x.shape
x = x.reshape(B*D, C, T)
x = self.norm(x)
x = x.view(B, D, C, T)
x = torch.einsum("BDCT -> BDTC", x) # just a permutation
qkv = self.qkv(x).reshape(B, D, T, 3, self.num_heads, C // self.num_heads)
qkv = torch.einsum("BDTtHF -> tBDHTF", qkv)
q, k, v = qkv[0], qkv[1], qkv[2]
# q, k, v shapes: BxDxHxTx(C/H)
q *= self.scale
attn = (q @ k.transpose(-2, -1)) # BxDxHxTxT
if self.rpe_q is not None or self.rpe_k is not None or self.rpe_v is not None:
pairwise_distances = (frame_indices.unsqueeze(-1) - frame_indices.unsqueeze(-2)) # BxTxT
# relative position on keys
if self.rpe_k is not None:
attn += self.rpe_k(q, pairwise_distances, temb=temb, mode="qk")
# relative position on queries
if self.rpe_q is not None:
attn += self.rpe_q(k * self.scale, pairwise_distances, temb=temb, mode="qk").transpose(-1, -2)

# softmax where all elements with mask==0 can attend to eachother and all with mask==1
# can attend to eachother (but elements with mask==0 can't attend to elements with mask==1)
def softmax(w, attn_mask)-> torch.Tensor:
if attn_mask is not None:
allowed_interactions = attn_mask.view(B, 1, T) * attn_mask.view(B, T, 1)
allowed_interactions += (1-attn_mask.view(B, 1, T)) * (1-attn_mask.view(B, T, 1))
inf_mask = (1-allowed_interactions)
inf_mask[inf_mask == 1] = torch.inf
w = w - inf_mask.view(B, 1, 1, T, T) # BxDxHxTxT
return torch.softmax(w.float(), dim=-1).type(w.dtype)

attn = softmax(attn, attn_mask)
out = attn @ v
# relative position on values
if self.rpe_v is not None:
out += self.rpe_v(attn, pairwise_distances, temb=temb, mode="v")
out = torch.einsum("BDHTF -> BDTHF", out).reshape(B, D, T, C)
out = self.proj_out(out)
x = x + out
x = torch.einsum("BDTC -> BDCT", x)
return x
1 change: 1 addition & 0 deletions generative/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
# limitations under the License.

from .vector_quantizer import EMAQuantizer, VectorQuantizer
from .vector_quantizer import EMAQuantizer, VectorQuantizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you mean to import from .RPE here, not duplicate the vector_quantizer import

Loading