Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do you have plan to add SAM as a visual encoder? #10

Open
StarCycle opened this issue Mar 27, 2024 · 5 comments
Open

Do you have plan to add SAM as a visual encoder? #10

StarCycle opened this issue Mar 27, 2024 · 5 comments

Comments

@StarCycle
Copy link

SAM can be used with Siglip/CLIP

For example, Vary uses SAM+CLIP, and Deepseek-VL uses Siglip+SAM.

Would you like to try them with this codebase?

@siddk
Copy link
Collaborator

siddk commented Apr 1, 2024

We absolutely can; just to confirm, is this the model you'd want us to try adding: https://huggingface.co/timm/samvit_base_patch16.sa1b?

CC @ashwin-balakrishna96 to add to our internal run list!

@StarCycle
Copy link
Author

Hello @siddk @ashwin-balakrishna96

Yes! Please try the SAM-base!
Here are some experience from my colleague:

  • If you want to concatenate the SAM output with the Siglip output, you may need to add 2 convolution layers after the SAM-base, to change the output size from [64, 64, 256] to [256, 1024]. You can check this figure or the Vary paper.
image
  • It's possible to use the SAM-base as the only visual encoder. But you need pretraining to align SAM-base to LLM embedding space using a small language model (e.g., OPT-125M). You may need multiple epochs in this phase
image

Best,
StarCycle

@StarCycle
Copy link
Author

StarCycle commented Apr 2, 2024

I guess the training pipeline is the same for Dinov2.

  • We cannot just use Dinov2 as the visual encoder since it is not aligned with the LLM embedding space.
  • We cannot align Dinov2 with contrastive learning like CLIP, which may let Dinov2 to forget many things...
  • But we can combine Dinov2 and a small language model, and let them predict the next token. Such generative pretraining may be very helpful. Please note that in your another work Voltron-robotics, you took exactly the same approach! Why not try it again ヾ(^▽^*))

Please let me know if you find anything interesting!

Best,
StarCycle

@StarCycle
Copy link
Author

You can just start with this:

from torch import nn
from urllib.request import urlopen
from PIL import Image
import timm

class DownSampledSAMVit(nn.Module):
  def __init__(self, name, downsample_channels=(512,1024)):
    super().__init__()
    self.SAMViT = timm.create_model(
      name,
      pretrained=True,
      num_classes=0,  # remove classifier nn.Linear
    )
    data_config = timm.data.resolve_model_data_config(self.SAMViT)
    self.transforms = timm.data.create_transform(**data_config, is_training=False)

    in_channels = self.SAMViT.neck[-1].weight.shape[0]
    downsamples = []
    for i in range(len(downsample_channels)):
      out_channels = downsample_channels[i]
      downsamples.append(
        nn.Conv2d(
          in_channels,
          out_channels,
          kernel_size=3,
          stride=2,
          padding=1,
          bias=False,
        )
      )
      in_channels = out_channels
    self.downsamples = nn.Sequential(*downsamples)
  
  def forward(self, rgb):
    out = self.SAMViT.forward_features(rgb)
    out = self.downsamples(out)
    return out

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = DownSampledSAMVit('samvit_base_patch16.sa1b').cuda().eval()
transforms = model.transforms
output = model(transforms(img).unsqueeze(0).cuda())

@ashwin-balakrishna96
Copy link

@StarCycle thanks a bunch for the suggestion. We can try integrating the SAM baseline in a week or so, but if you have cycles and would be interested in opening up a PR to integrate it in the meanwhile (especially because it seems like you've already been thinking about how the code should look), we would also be very happy to review it and integrate it into Prismatic :)

jmercat pushed a commit to sagadre/prismatic-vlms that referenced this issue Jul 16, 2024
Adds 2.7B Phi-2 LLM with non-LLaMa tokenizer handling, recommended
"Input: / Output:" formatting.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants