Skip to content

Commit

Permalink
update app.
Browse files Browse the repository at this point in the history
  • Loading branch information
dxli94 committed Sep 19, 2022
1 parent e425dab commit 9566d88
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 14 deletions.
12 changes: 6 additions & 6 deletions app/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ def app():
with torch.no_grad():
image_features = feature_extractor.extract_features(
sample, mode="image"
).image_features[:, 0]
).image_embeds_proj[:, 0]
text_features = feature_extractor.extract_features(
sample, mode="text"
).text_features[:, 0]
).text_embeds_proj[:, 0]
sims = (image_features @ text_features.t())[
0
] / feature_extractor.temp
Expand Down Expand Up @@ -173,10 +173,10 @@ def app():
# with torch.no_grad():
# image_features = feature_extractor.extract_features(
# sample, mode="image"
# ).image_features[:, 0]
# ).image_embeds_proj[:, 0]
# text_features = feature_extractor.extract_features(
# sample, mode="text"
# ).text_features[:, 0]
# ).text_embeds_proj[:, 0]

# st.write(image_features.shape)
# st.write(text_features.shape)
Expand Down Expand Up @@ -206,8 +206,8 @@ def app():
with torch.no_grad():
clip_features = model.extract_features(sample)

image_features = clip_features.image_features
text_features = clip_features.text_features
image_features = clip_features.image_embeds_proj
text_features = clip_features.text_embeds_proj

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
Expand Down
2 changes: 1 addition & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
app = MultiPage()

app.add_page("Image Description Generation", caption.app)
# app.add_page("Multimodal Search", ms.app)
app.add_page("Multimodal Search", ms.app)
app.add_page("Visual Question Answering", vqa.app)
app.add_page("Image Text Matching", itm.app)
app.add_page("Text Localization", tl.app)
Expand Down
19 changes: 12 additions & 7 deletions app/multimodal_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,17 @@
allow_output_mutation=True,
)
def load_feat():
path2feat = torch.load(
os.path.join(
os.path.dirname(__file__),
"resources/path2feat_coco_train2014.pth",
)
)
from lavis.common.utils import download_url

dirname = os.path.join(os.path.dirname(__file__), "assets")
filename = "path2feat_coco_train2014.pth"
filepath = os.path.join(dirname, filename)
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/path2feat_coco_train2014.pth"

if not os.path.exists(filepath):
download_url(url=url, root=dirname, filename="path2feat_coco_train2014.pth")

path2feat = torch.load(filepath)
paths = sorted(path2feat.keys())

all_img_feats = torch.stack([path2feat[k] for k in paths], dim=0).to(device)
Expand Down Expand Up @@ -98,7 +103,7 @@ def app():
with torch.no_grad():
text_feature = feature_extractor.extract_features(
sample, mode="text"
).text_features[0, 0]
).text_embeds_proj[0, 0]

path2feat, paths, all_img_feats = load_feat()
all_img_feats.to(device)
Expand Down

0 comments on commit 9566d88

Please sign in to comment.