Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Ezra-Yu committed Aug 17, 2023
1 parent 0d3d026 commit 857cf7a
Showing 1 changed file with 75 additions and 65 deletions.
140 changes: 75 additions & 65 deletions mmpretrain/models/backbones/vpt.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import patch

import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
from mmpretrain.models.backbones import VisionTransformer
from mmpretrain.models.backbones import ViTEVA02
from mmpretrain.models.utils import build_norm_layer
import torch
import torch.nn as nn

from mmpretrain.models.utils import resize_pos_embed
from mmpretrain.models.backbones import VisionTransformer, ViTEVA02
from mmpretrain.models.utils import build_norm_layer, resize_pos_embed
from mmpretrain.registry import MODELS


def init_prompt(prompt_init, prompt):
Expand All @@ -25,11 +21,13 @@ def init_prompt(prompt_init, prompt):
else:
nn.init.normal_(prompt, std=0.02)


@MODELS.register_module()
class PromptedViT(VisionTransformer):
'''Vision Transformer with Prompt.
A PyTorch implement of : `Visual Prompt Tuning<https://arxiv.org/abs/2203.12119>`_
"""Vision Transformer with Prompt.
A PyTorch implement of : `Visual Prompt Tuning
<https://arxiv.org/abs/2203.12119>`_
Args:
prompt_length (int): the length of prompt parameters. Defaults to 1.
Expand All @@ -48,41 +46,47 @@ class tokens with shape (B, L, C).
& prompt tensor with shape (B, C).
- ``"avg_prompt"``: The global averaged prompt tensor with
shape (B, C).
- ``"avg_prompt_clstoken"``: The global averaged cls_tocken
- ``"avg_prompt_clstoken"``: The global averaged cls_tocken
& prompt tensor with shape (B, C).
Defaults to ``"avg_all"``.
*args(list, optional): Other args for VisionTransformer.
**kwargs(dict, optional): Other args for VisionTransformer.
'''
"""

num_extra_tokens = 1 # class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'}
OUT_TYPES = {
'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt',
'avg_prompt_clstoken'
}

def __init__(self,
prompt_length: int = 1,
deep_prompt: bool = True,
out_type: str ='avg_all',
out_type: str = 'avg_all',
prompt_init: str = 'normal',
norm_cfg: dict =dict(type='LN'),
norm_cfg: dict = dict(type='LN'),
*args,
**kwargs):
super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs)
super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs)

self.prompt_layers = len(self.layers) if deep_prompt else 1
prompt = torch.empty(
self.prompt_layers, prompt_length, self.embed_dims)
prompt = torch.empty(self.prompt_layers, prompt_length,
self.embed_dims)
init_prompt(prompt_init, prompt)
self.prompt_initialized = False if prompt_init == 'token' else True
self.prompt = nn.Parameter(prompt, requires_grad=True)

self.prompt_length = prompt_length
self.deep_prompt = deep_prompt
self.num_extra_tokens = self.num_extra_tokens + prompt_length
self.num_extra_tokens = self.num_extra_tokens + prompt_length

if self.out_type in {'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'}:
if self.out_type in {
'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'
}:
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
# freeze stages

# freeze stages
self.frozen_stages = len(self.layers)
self._freeze_stages()

Expand All @@ -107,18 +111,18 @@ def forward(self, x):

# reshape to [layers, batch, tokens, embed_dims]
prompt = self.prompt.unsqueeze(1).expand(-1, x.shape[0], -1, -1)
x = torch.cat(
[x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]],
dim=1)
x = torch.cat([x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], dim=1)

outs = []
for i, layer in enumerate(self.layers):
x = layer(x)

if self.deep_prompt and i != len(self.layers) - 1:
x = torch.cat(
[x[:, :1, :], prompt[i, :, :, :], x[:, self.prompt_length + 1:, :]],
dim=1)
x = torch.cat([
x[:, :1, :], prompt[i, :, :, :],
x[:, self.prompt_length + 1:, :]
],
dim=1)

# final_norm should be False here
if i == len(self.layers) - 1 and self.final_norm:
Expand All @@ -141,13 +145,13 @@ def _format_output(self, x, hw):
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return self.ln2(x[:, self.prompt_length+1:].mean(dim=1))
return self.ln2(x[:, self.prompt_length + 1:].mean(dim=1))
if self.out_type == 'avg_all':
return self.ln2(x.mean(dim=1))
return self.ln2(x.mean(dim=1))
if self.out_type == 'avg_prompt':
return self.ln2(x[:, 1:self.prompt_length+1].mean(dim=1))
return self.ln2(x[:, 1:self.prompt_length + 1].mean(dim=1))
if self.out_type == 'avg_prompt_clstoken':
return self.ln2(x[:, :self.prompt_length+1].mean(dim=1))
return self.ln2(x[:, :self.prompt_length + 1].mean(dim=1))


def new_AttentionWithRoPE_forward_fn(self, x, patch_resolution):
Expand All @@ -163,11 +167,13 @@ def new_AttentionWithRoPE_forward_fn(self, x, patch_resolution):
if extra_token_num > 0:
q_t = q[:, :, extra_token_num:, :]
ro_q_t = self.rope(q_t, patch_resolution)
q = torch.cat((q[:, :, :extra_token_num, :], ro_q_t), -2).type_as(v)
q = torch.cat((q[:, :, :extra_token_num, :], ro_q_t),
-2).type_as(v)

k_t = k[:, :, extra_token_num:, :]
ro_k_t = self.rope(k_t, patch_resolution)
k = torch.cat((k[:, :, :extra_token_num , :], ro_k_t), -2).type_as(v)
k = torch.cat((k[:, :, :extra_token_num, :], ro_k_t),
-2).type_as(v)
else:
q = self.rope(q, patch_resolution)
k = self.rope(k, patch_resolution)
Expand All @@ -188,9 +194,10 @@ def new_AttentionWithRoPE_forward_fn(self, x, patch_resolution):

@MODELS.register_module()
class PromptedViTEVA02(ViTEVA02):
'''EVA02 Vision Transformer with Prompt.
A PyTorch implement of : `Visual Prompt Tuning<https://arxiv.org/abs/2203.12119>`_
"""EVA02 Vision Transformer with Prompt.
A PyTorch implement of : `Visual Prompt Tuning
<https://arxiv.org/abs/2203.12119>`_
Args:
prompt_length (int): the length of prompt parameters. Defaults to 1.
Expand All @@ -209,32 +216,36 @@ class tokens with shape (B, L, C).
& prompt tensor with shape (B, C).
- ``"avg_prompt"``: The global averaged prompt tensor with
shape (B, C).
- ``"avg_prompt_clstoken"``: The global averaged cls_tocken
- ``"avg_prompt_clstoken"``: The global averaged cls_tocken
& prompt tensor with shape (B, C).
Defaults to ``"avg_all"``.
*args(list, optional): Other args for ViTEVA02.
**kwargs(dict, optional): Other args for ViTEVA02.
'''
"""

num_extra_tokens = 1 # class token
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'}
OUT_TYPES = {
'raw', 'cls_token', 'featmap', 'avg_featmap', 'avg_all', 'avg_prompt',
'avg_prompt_clstoken'
}

# 'avg_all' : avg of 'prompt' & 'cls_token' & 'featmap'
# 'avg_prompt' avg of 'prompt'
# 'avg_prompt_clstoken' avg of 'cls_token' and 'prompt'
def __init__(self,
prompt_length = 1,
deep_prompt = True,
prompt_length=1,
deep_prompt=True,
out_type='avg_all',
prompt_init: str = 'normal',
norm_cfg=dict(type='LN'),
*args,
**kwargs):
super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs)
super().__init__(*args, out_type=out_type, norm_cfg=norm_cfg, **kwargs)

self.prompt_layers = len(self.layers) if deep_prompt else 1
prompt = torch.empty(
self.prompt_layers, prompt_length, self.embed_dims)
prompt = torch.empty(self.prompt_layers, prompt_length,
self.embed_dims)
if prompt_init == 'uniform':
nn.init.uniform_(prompt, -0.08, 0.08)
elif prompt_init == 'zero':
Expand All @@ -250,16 +261,17 @@ def __init__(self,
self.prompt_length = prompt_length
self.deep_prompt = deep_prompt

if self.out_type in {'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'}:
if self.out_type in {
'avg_featmap', 'avg_all', 'avg_prompt', 'avg_prompt_clstoken'
}:
self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)
# freeze stages

# freeze stages
self.frozen_stages = len(self.layers)
self._freeze_stages()

@patch(
'mmpretrain.models.backbones.vit_eva02.AttentionWithRoPE.forward',
new_AttentionWithRoPE_forward_fn)

@patch('mmpretrain.models.backbones.vit_eva02.AttentionWithRoPE.forward',
new_AttentionWithRoPE_forward_fn)
def forward(self, x):
B = x.shape[0]
x, patch_resolution = self.patch_embed(x)
Expand All @@ -281,18 +293,18 @@ def forward(self, x):

# reshape to [layers, batch, tokens, embed_dims]
prompt = self.prompt.unsqueeze(1).expand(-1, x.shape[0], -1, -1)
x = torch.cat(
[x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]],
dim=1)
x = torch.cat([x[:, :1, :], prompt[0, :, :, :], x[:, 1:, :]], dim=1)

outs = []
for i, layer in enumerate(self.layers):
x = layer(x, patch_resolution)

if self.deep_prompt and i != len(self.layers) - 1:
x = torch.cat(
[x[:, :1, :], prompt[i, :, :, :], x[:, self.prompt_length + 1:, :]],
dim=1)
x = torch.cat([
x[:, :1, :], prompt[i, :, :, :],
x[:, self.prompt_length + 1:, :]
],
dim=1)

if i == len(self.layers) - 1 and self.final_norm:
x = self.ln1(x)
Expand All @@ -302,7 +314,6 @@ def forward(self, x):

return tuple(outs)


def _format_output(self, x, hw):
if self.out_type == 'raw':
return x
Expand All @@ -315,11 +326,10 @@ def _format_output(self, x, hw):
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
if self.out_type == 'avg_featmap':
return self.ln2(x[:, self.prompt_length:].mean(dim=1))
return self.ln2(x[:, self.prompt_length:].mean(dim=1))
if self.out_type == 'avg_all':
return self.ln2(x.mean(dim=1))
return self.ln2(x.mean(dim=1))
if self.out_type == 'avg_prompt':
return self.ln2(x[:, 1:self.prompt_length+1].mean(dim=1))
return self.ln2(x[:, 1:self.prompt_length + 1].mean(dim=1))
if self.out_type == 'avg_prompt_clstoken':
return self.ln2(x[:, :self.prompt_length+1].mean(dim=1))

return self.ln2(x[:, :self.prompt_length + 1].mean(dim=1))

0 comments on commit 857cf7a

Please sign in to comment.