You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have built a simple multi-head attention model and would like to obtain the GradCAM map of the layer_norm layer. Up to the linear layer of qkv, the cam can output normally, but after out_linear, cam suddenly becomes 0. out_linear.weight.grad can be output normally, with very small values but some of them are positive. Below is the forward code of my model:
` def forward(self,input):
query, key, value = self.qkv_project(input).chunk(3, dim=-1)
batch_size = query.size(0)
query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
key= key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
value = value .view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
heads = []
for i in range(self.num_heads):
q = self.q_linear_reference[i](query[:, i, :, :])
k = self.k_linear_reference[i](key[:, i, :, :])
v = self.v_linear_reference[i](value [:, i, :, :])
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn_weights = torch.softmax(scores, dim=-1)
head = torch.matmul(scores, v)
heads.append(head)
heads = torch.cat(heads, dim=-1)
attn_output = self.out_linear(heads)
attn_output = self.layer_norm(attn_output)
return attn_output`
Thx in advance
The text was updated successfully, but these errors were encountered:
I have built a simple multi-head attention model and would like to obtain the GradCAM map of the layer_norm layer. Up to the linear layer of qkv, the cam can output normally, but after out_linear, cam suddenly becomes 0. out_linear.weight.grad can be output normally, with very small values but some of them are positive. Below is the forward code of my model:
` def forward(self,input):
Thx in advance
The text was updated successfully, but these errors were encountered: