Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetune TinySAM on custom dataset #22

Open
Riley-livingston opened this issue Mar 26, 2024 · 1 comment
Open

Finetune TinySAM on custom dataset #22

Riley-livingston opened this issue Mar 26, 2024 · 1 comment

Comments

@Riley-livingston
Copy link

Riley-livingston commented Mar 26, 2024

Hello, Im trying to fine-tune the mask decoder of tiny sam on a custom dataset while freezing the weights of the image_encoder and prompt_encoder. Im having an issue in my training loop where the sam.forward() requires a "multimask_output" argument but the MaskDecoder.forward() doesn't accept a "multitask_output" argument.

Im not an ML Engineer so I don't know much about the underlying code. If anyone with more knowledge than me has some insight into how I can resolve this issue I would appreciate it, thanks!

here is how im freezing the image encoder and prompt encoder to maintain the original weights:

for name, param in sam_model.named_parameters():
  if name.startswith("image_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

I am also providing bounding box Prompts as the input. Here is my custom class for Dataset creation:

class SAMDataset(Dataset):
    """
    Dataset class for SAM model, serving images with associated bounding boxes and masks,
   
    """
    def __init__(self, dataset, bbox_mapping, sam_model, device='cuda'):
        self.dataset = dataset
        self.bbox_mapping = bbox_mapping
        self.sam_model = sam_model
        self.device = device
        self.target_size = (1024, 1024)  # Adjusted to the expected input size of the model

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # Assuming dataset[idx] returns a dict with 'image' and 'label' keys
        pil_image = self.dataset[idx]['image']
        pil_mask = self.dataset[idx]['label']

        image_tensor = to_tensor(np.array(pil_image)).to(self.device)
        mask_tensor = to_tensor(np.array(pil_mask)).to(self.device)

        # Resize image and mask to target size
        image_tensor = resize(image_tensor, self.target_size)
        mask_tensor = resize(mask_tensor, self.target_size)

        # Fetch bounding boxes directly without padding
        bboxes = self.bbox_mapping.get(idx + 1, [])  # Adjust index if necessary
        bboxes_tensor = torch.tensor(bboxes, dtype=torch.float, device=self.device)

        return {
            'image': image_tensor,
            'bboxes': bboxes_tensor,
            'mask': mask_tensor
        }
        
### Create a DataLoader instance for the training dataset
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset,shuffle=True, drop_last=False)

image torch.Size([1, 3, 1024, 1024])
bboxes torch.Size([1, 1, 4])
mask torch.Size([1, 1, 1024, 1024])
`
### Training Loop

num_epochs = 1
device = "cuda"
sam_model.to(device)
sam_model.train()

for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
        # Preparing the batched_input according to the Tiny sam_model's expected input format
        batched_input = [{
        'image': batch['image'].squeeze(0).to(device),
        'bboxes': batch['bboxes'].squeeze(0).to(device)
    }]
        # forward pass
        outputs_list = sam_model(batched_input, multimask_output = True)

        # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation
        # Here, you'd need to adapt the code to match the structure of your outputs
        predicted_masks = torch.stack([output['pred_mask'] for output in outputs_list]).squeeze(0)
        ground_truth_masks = batch["mask"].float().squeeze(1).to(device)

        loss = seg_loss(predicted_masks, ground_truth_masks)

        # backward pass (compute gradients of parameters)
        optimizer.zero_grad()
        loss.backward()

        # optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')


error when I DONT provide multitask_output:

TypeError                                 Traceback (most recent call last)
<ipython-input-108-f41ebba752d9> in <cell line: 12>()
     21 
     22         # forward pass
---> 23         outputs_list = sam_model(batched_input)
     24 
     25         # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation

2 frames
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

TypeError: Sam.forward() missing 1 required positional argument: 'multimask_output'

error when I do provide the multitask_output argument:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-42-9d874c2eda3d>](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in <cell line: 12>()
     19     }]
     20         # forward pass
---> 21         outputs_list = sam_model(batched_input, multimask_output = True)
     22 
     23         # Assuming the outputs_list is a list of dictionaries, and you need to aggregate the masks for loss computation

5 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://bvvo9qsh5t-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab_20240325-094203_RC00_618870756#) in _call_impl(self, *args, **kwargs)
   1518                 or _global_backward_pre_hooks or _global_backward_hooks
   1519                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520             return forward_call(*args, **kwargs)
   1521 
   1522         try:

TypeError: MaskDecoder.forward() got an unexpected keyword argument 'multimask_output'
 
@xinghaochen
Copy link
Owner

Hi, simply removing multimask_output from all codes should work well (1492efb). You can pull the newest codes and try again.

You can refer issue #9 for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants