Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Early fusion multimodal models #1904

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

RdoubleA
Copy link
Contributor

@RdoubleA RdoubleA commented Oct 25, 2024

TODO: fix tests, update DeepFusion to support multiple encoders, add docstrings

Context

This is a focused RFC based on @pbontrager 's excellent original RFC on multimodal fusion models #1283. Since the RFC, we have already landed Deep Fusion model components. This PR discusses and implements the EarlyFusionModel component, along with testing and some lint updates.

Early fusion is simply a decoder with 1 or more extra encoders that merges their outputs with the token embeddings tables. The challenge lies in how we merge the embeddings and pass it into the decoder.

Design

There is one design consideration I am seeking feedback on, and that is the EarlyFusionModel's usage of self.decoder.tok_embeddings. It accesses the decoder's token embedding table outside of the decoder forward because we need to merge the image encoder and any other modality encoder's output embeddings with the text embeddings (in this case just concatenate in sequence dimension):

embeds = self.tok_embeddings(tokens)
bsz, seq_len, embed_dim = embeds.shape
for encoder, inp in (encoder_input or {}).items():
    encoder_embeds = self.encoders[encoder](**inp)
    encoder_mask = (tokens == self.encoder_tokens[encoder]).expand(bsz, seq_len, embed_dim)
    embeds[encoder_mask] = encoder_embeds
    
output = self.decoder(embeds, mask, input_pos)
return output

Now, instead of token ids, we are passing in the merged embeddings directly into the decoder. But since we already used the text-only tok_embeddings from the decoder, we need to skip it when passing in the merged embeddings for the final decoder output. There are two ways we can do this.

State dict surgery

In the current code changes and suggested by the original RFC, we can manually set self.decoder.tok_embeddings = nn.Identity() so that it becomes a no-op when you forward pass with merged embeddings.

  • This will require additional state dict hooks to make sure checkpoint saving and loading is still maintained despite the module change
  • If a user wants to use the decoder outside of the EarlyFusionModule in the same script, they will need to restore the original tok_embeddings module from nn.Identity

Additional input_embeds kwarg

We could add a new keyword argument in TransformerDecoder forward for input embeddings. If this is passed in, we automatically skip the token embeddings:

h = self.tok_embeddings(tokens) if input_embeds is None else input_embeds

This way we don't need any state dict hooks or decoder modifications. However, we are polluting the decoder model forward with more arguments.

Copy link

pytorch-bot bot commented Oct 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1904

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 4 Cancelled Jobs

As of commit e37a3e1 with merge base 74139c9 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 25, 2024
@joecummings joecummings added the rfc Request for comments label Oct 25, 2024
Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for putting this up Rafi! I left some comments on the implementation, but I'll leave the state dict discussion to others as we've already chatted on this.

def __init__(
self,
decoder: TransformerDecoder,
encoders: nn.ModuleDict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it would be nice if we allowed all of the encoder params to be a list/dict or single value input to make single encoder builders look much cleaner. Then we can package them as an iterable in the init.

encoders: nn.ModuleDict,
encoder_tokens: Dict[str, int],
decoder_trainable: bool,
encoders_trainable: Dict[str, bool],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An error should be thrown if the different input dicts don't have the same keys

if decoder_trainable:
trainable_params |= {
f"decoder.{n}" for n, p in self.decoder.named_parameters()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing the logic and the parameter to make fusion modules trainable/untrainable

been expanded to the number of tokens encoded for the given media. For example, if an image is tiled/patched
and tokenized to 100 tokens, we assume the text sequence already has 100 "image" tokens as placeholders.
"""
embeds = self.tok_embeddings(tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't do this because the encoder tokens won't be in the tok_embeddings table. You need to first filter those out as in here https://www.internalfb.com/intern/paste/P1666298928/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's talk about this offline, because from the reference code I was using the encoder tokens are part of the embedding table

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. rfc Request for comments
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants