From c7b81906880b4c840a690c5dc9c445d81ed5dd40 Mon Sep 17 00:00:00 2001 From: Tarun Karuturi Date: Tue, 24 Sep 2024 15:09:58 -0700 Subject: [PATCH 1/3] Vision model embeddings change --- torchtune/models/clip/_position_embeddings.py | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/torchtune/models/clip/_position_embeddings.py b/torchtune/models/clip/_position_embeddings.py index 580856cd1e..cea62199fe 100644 --- a/torchtune/models/clip/_position_embeddings.py +++ b/torchtune/models/clip/_position_embeddings.py @@ -7,6 +7,7 @@ from typing import Any, Tuple import torch +import torch.nn.functional as F from torch import nn # TODO (@Felipe): add load hooks + interpolation on positional encodings, @@ -35,8 +36,7 @@ def __init__(self, embed_dim: int, tile_size: int, patch_size: int) -> None: patch_grid_size = tile_size // patch_size scale = embed_dim**-0.5 self.positional_embedding = nn.Parameter( - scale - * torch.randn((patch_grid_size**2 + 1, embed_dim)) # +1 for CLS token + scale * torch.randn((patch_grid_size**2 + 1, embed_dim)) # +1 for CLS token ) def forward(self, x: torch.Tensor, *args: Tuple[Any]) -> torch.Tensor: @@ -83,8 +83,7 @@ def __init__( # different for every token, same for every tile self.local_token_positional_embedding = nn.Parameter( - scale - * torch.randn((patch_grid_size**2 + 1, embed_dim)) # +1 for CLS token + scale * torch.randn((patch_grid_size**2 + 1, embed_dim)) # +1 for CLS token ) # different for every token, different for every tile @@ -98,6 +97,7 @@ def __init__( ) ) + self.max_num_tiles = max_num_tiles self.gate = nn.Parameter(torch.zeros(1)) def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: @@ -121,20 +121,34 @@ def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: # When we batch images, all are padded to the same amount of tiles. # The aspect_ratio lets us know the non padded tiles for each image. # We only add positional encoding to those. + n_tiles_h = n_tiles_h.item() + n_tiles_w = n_tiles_w.item() + n_non_padded_tiles = int(n_tiles_h * n_tiles_w) # We get only the positional encoding for non padded tiles, # i.e. n_tiles_h, n_tiles_w. - pos_embed = self.global_token_positional_embedding[ - :n_tiles_h, :n_tiles_w, :, : - ] + torch._check_is_size(n_tiles_h) + torch._check_is_size(n_tiles_w) + torch._check(n_tiles_h <= self.max_num_tiles) + torch._check(n_tiles_w <= self.max_num_tiles) + padded_embedding = F.pad( + self.global_token_positional_embedding, (0, 0, 0, 0, 0, 1, 0, 1) + ) + + pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] # Add pos encoding to the non padded tiles. + pos_embed = pos_embed.clone() pos_embed = pos_embed.reshape( n_non_padded_tiles, self.n_tokens_per_tile, embed_dim ) pos_embed = pos_embed * self.gate.tanh() + x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) + torch._check(n_non_padded_tiles < self.max_num_tiles + 1) + torch._check(n_non_padded_tiles < x.size(1)) x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed + x = x[:, :n_tiles, :, :] return x @@ -176,19 +190,34 @@ def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: torch.Tensor: The input tensor with added positional embeddings. """ bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape + torch._check(n_tiles <= self.max_num_tiles) for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): # When we batch images, all are padded to the same amount of tiles. # The aspect_ratio lets us know the non padded tiles for each image. # We only add positional encoding to those. + n_tiles_h = n_tiles_h.item() + n_tiles_w = n_tiles_w.item() + n_non_padded_tiles = int(n_tiles_h * n_tiles_w) # We get only the positional encoding for non padded tiles, # i.e. n_tiles_h, n_tiles_w. - pos_embed = self.embedding[:n_tiles_h, :n_tiles_w, :, :] + torch._check_is_size(n_tiles_h) + torch._check_is_size(n_tiles_w) + torch._check(n_tiles_h <= self.max_num_tiles) + torch._check(n_tiles_w <= self.max_num_tiles) + padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1)) + pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] # Add pos encoding to the non padded tiles. + pos_embed = pos_embed.clone() pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim) + + x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) + torch._check_is_size(n_non_padded_tiles) + torch._check(n_non_padded_tiles < x.size(1)) x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh() + x = x[:, :n_tiles, :, :] return x From 256a7d084d7bd631ba0e9c2f4639e2acc3b6e7b0 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Fri, 4 Oct 2024 16:35:38 -0700 Subject: [PATCH 2/3] AOTI export script Summary: Download artifacts from https://www.internalfb.com/manifold/explorer/executorch/tree/models/llama/llama3_2_mm_v4 * dog.jpg * tune.pth Test Plan: Reviewers: Subscribers: Tasks: Tags: --- export_flamingo.py | 573 ++++++++++++++++++ .../models/flamingo/_component_builders.py | 4 + torchtune/modules/attention.py | 27 +- torchtune/modules/kv_cache.py | 34 +- torchtune/modules/model_fusion/_fusion.py | 10 +- 5 files changed, 622 insertions(+), 26 deletions(-) create mode 100644 export_flamingo.py diff --git a/export_flamingo.py b/export_flamingo.py new file mode 100644 index 0000000000..4f5c93651a --- /dev/null +++ b/export_flamingo.py @@ -0,0 +1,573 @@ +import numpy as np +import PIL +import torch +from executorch import exir +from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from torchtune.data import Message, padded_collate_tiled_images_and_mask +from torchtune.data._prompt_templates import _TemplateType +from torchtune.generation._generation import sample +from torchtune.models.clip._transform import CLIPImageTransform +from torchtune.models.flamingo import ( + flamingo_decoder, + flamingo_vision_encoder, + FlamingoTransform, +) +import os, sys +import time +from functools import lru_cache, wraps +from typing import Optional + +from torchtune.models.flamingo._component_builders import ( + flamingo_decoder, + flamingo_vision_encoder, +) +from torchtune.modules.model_fusion import DeepFusionModel +from torchvision.transforms.v2 import functional as F + +max_seq_len = 8192 +in_channels = 3 +tile_size = 560 +max_num_tiles = 4 +# how many tokens per image generated by the vision encoder +tokens_per_image = 6404 +# how many images to cache in the kv cache in cross attention +kv_cache_image_num = 1 +# maximum number of tokens generated by encoder and thus stored in the kv cache in cross attention +encoder_max_seq_len = tokens_per_image * kv_cache_image_num + + +@lru_cache(maxsize=1) +def get_vision_encoder(): + return flamingo_vision_encoder( + patch_size=14, + num_heads=16, + clip_embed_dim=1280, + clip_num_layers=32, + clip_hidden_states=[3, 7, 15, 23, 30], + decoder_embed_dim=4096, + num_layers_projection=8, + tile_size=tile_size, + max_num_tiles=4, + in_channels=3, + ) + + +@lru_cache(maxsize=1) +def get_text_decoder(): + return flamingo_decoder( + vocab_size=128_256, + num_layers=32, + fusion_interval=4, + num_special_tokens=8, + num_heads=32, + num_kv_heads=8, + embed_dim=4096, + max_seq_len=max_seq_len, + encoder_max_seq_len=encoder_max_seq_len, + rope_base=500000.0, + intermediate_dim=14336, + ) + + +@lru_cache(maxsize=1) +def get_flamingo(llama3_2_dir): + model = DeepFusionModel( + encoder=get_vision_encoder(), + decoder=get_text_decoder(), + encoder_trainable=False, + decoder_trainable=False, + fusion_trainable=False, + ) + print("Load checkpoint") + state_dict = torch.load(os.path.join(llama3_2_dir, "tune.pth")) + + model.load_state_dict(state_dict) + model.setup_caches( + batch_size=1, + dtype=torch.float32, + encoder_max_seq_len=encoder_max_seq_len, # Hardcoded in for now + decoder_max_seq_len=max_seq_len, + ) + return model + + +def flamingo_transform( + path: str, + max_seq_len: int = 8192, + special_tokens_path: Optional[str] = None, + prompt_template: Optional[_TemplateType] = None, +) -> FlamingoTransform: + """ + Data Transforms (including Tokenizer) for Llama3 Vision. + + Args: + path (str): path to the tokenizer + max_seq_len (int): maximum sequence length for tokenizing a single list of messages, + after which the input will be truncated. + special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face + model files that contains all registered special tokens, or a local json file + structured similarly. Default is None to use the canonical Llama3 special tokens. + prompt_template (Optional[_TemplateType]): optional specified prompt template. + If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface` + class. If a dictionary, it is assumed to be a custom prompt template mapping role to the + prepend/append tags. + + Returns: + FlamingoTransform: Instantiation of the Llama3 tokenizer + """ + special_tokens = ( + parse_hf_tokenizer_json(special_tokens_path) + if special_tokens_path is not None + else None + ) + template = ( + _get_prompt_template(prompt_template) if prompt_template is not None else None + ) + return FlamingoTransform( + path=path, + special_tokens=special_tokens, + tile_size=560, + patch_size=14, + max_num_tiles=4, + max_seq_len=max_seq_len, + image_mean=(0.48145466, 0.4578275, 0.40821073), + image_std=(0.26862954, 0.26130258, 0.27577711), + prompt_template=template, + ) + + +@lru_cache(maxsize=1) +def get_sample_preprocess_outputs(llama3_2_dir): + image_path = os.path.join(llama3_2_dir, "dog.jpg") + tokenizer_path = os.path.join(llama3_2_dir, "tokenizer.model") + transform = flamingo_transform(tokenizer_path) + images = [PIL.Image.open(image_path)] + messages = [ + Message( + role="user", + content=[ + {"type": "image", "content": images[0]}, + {"type": "text", "content": "What's in this image?"}, + ], + eot=True, + ), + Message(role="assistant", content=""), + ] + data = transform({"messages": messages}, inference=True) + seq_len = len(data["tokens"]) + total_response_length = max_seq_len + # mask and input_pos + causal_mask = torch.tril( + torch.ones( + size=(total_response_length, total_response_length), + dtype=torch.bool, + ) + ) + input_pos = torch.arange(total_response_length) + batch = padded_collate_tiled_images_and_mask( + [data], pad_direction="left", pad_max_images=kv_cache_image_num + ) + batch["mask"] = causal_mask[None, :seq_len, :] + batch["input_pos"] = input_pos[None, :seq_len] + batch["encoder_mask"] = batch["encoder_mask"][:, :seq_len] + return batch + + +@lru_cache(maxsize=1) +def get_text_decoder_inputs(llama3_2_dir, model): + data = get_sample_preprocess_outputs(llama3_2_dir).copy() + # breakpoint() + image = data["encoder_input"]["images"].to(dtype=torch.float32) + embeds = model.encoder(image, data["encoder_input"]["aspect_ratio"]) + data["encoder_input"] = embeds + tokens = data.pop("tokens") + return tokens, data + + +@lru_cache(maxsize=1) +def get_vision_encoder_dynamic_shapes(): + dim = torch.export.Dim("num_tiles", min=1, max=max_num_tiles) + image_dynamic_dim = { + 0: 1, + 1: 1, + 2: dim, + 3: 3, + 4: tile_size, + 5: tile_size, + } + return image_dynamic_dim + + +@lru_cache(maxsize=1) +def get_text_decoder_dynamic_shapes(): + dim = torch.export.Dim("token_dim", min=2, max=max_seq_len) + dim_enc = torch.export.Dim("enc_dim", min=1, max=encoder_max_seq_len) + + return { + "tokens": {0: 1, 1: dim}, + "mask": {0: 1, 1: dim, 2: None}, + "encoder_input": None, + "encoder_mask": {0: 1, 1: dim, 2: None}, + "input_pos": {0: 1, 1: dim}, + } + + +def timeit(func): + @wraps(func) + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + print( + f"Function {func.__name__} took {end_time - start_time} seconds to execute." + ) + return result + + return wrapper + + +@timeit +def run_vision_encoder_eager(vision_encoder, image, aspect_ratio): + return vision_encoder(image, aspect_ratio) + + +def benchmark_vision_encoder(vision_encoder, image, aspect_ratio): + # warm up run + vision_encoder(image, aspect_ratio) + # time it + total_time = 0 + for _ in range(30): + start_time = time.time() + res = vision_encoder(image, aspect_ratio) + total_time += time.time() - start_time + return total_time / 30, res + + +def benchmark_all_vision_encoder(llama3_2_dir): + preprocess_outputs = get_sample_preprocess_outputs(llama3_2_dir) + image = preprocess_outputs["image"] + # Eager + aspect_ratio = preprocess_outputs["aspect_ratio"] + print("image shape:", image.shape) + for dtype in [torch.bfloat16, torch.float32]: + image = image.to(dtype=dtype) + aspect_ratio = aspect_ratio.to(dtype=dtype) + vision_encoder = get_vision_encoder().to(dtype=dtype).eval() + print( + f"-----------------------------------Eager Mode {dtype} CPU-----------------------------------" + ) + avg, eager_res = benchmark_vision_encoder(vision_encoder, image, aspect_ratio) + print(f"Averaged time: {avg}") + + # # Torch.compile + # print(f"-----------------------------------Torch.compile {dtype} CPU-----------------------------------") + # with torch.no_grad(): + # compiled_vision_encoder = torch.compile(vision_encoder, mode="reduce-overhead") + # # warm up run + # compiled_vision_encoder(image, aspect_ratio) + # # time it + # avg, compiled_res = benchmark_vision_encoder(compiled_vision_encoder, image, aspect_ratio) + # print(f"Averaged time: {avg}") + # print(f"Close to eager? {torch.allclose(eager_res, compiled_res)}") + + # # torch.export + # print(f"-----------------------------------Torch.export {dtype} CPU-----------------------------------") + image_dynamic_dim = get_vision_encoder_dynamic_shapes() + # ep = torch.export.export( + # vision_encoder, + # (image, aspect_ratio), + # dynamic_shapes=(image_dynamic_dim, None), + # ) + # avg, exported_res = benchmark_vision_encoder(ep.module(), image, aspect_ratio) + # print(f"Averaged time: {avg}") + # print(f"Close to eager? {torch.allclose(eager_res, exported_res)}") + + # AOTInductor + print( + f"-----------------------------------AOTInductor {dtype} CPU-----------------------------------" + ) + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + so = torch._export.aot_compile( + vision_encoder, + args=(image, aspect_ratio), + options={"aot_inductor.output_path": "/tmp/vision_encoder.so"}, + dynamic_shapes=(image_dynamic_dim, None), + ) + aot_loaded = torch._export.aot_load(so, device="cpu") + avg, aoti_res = benchmark_vision_encoder(aot_loaded, image, aspect_ratio) + print(f"Averaged time: {avg}") + print(f"Close to eager? {torch.allclose(eager_res, aoti_res)}") + print( + f"-----------------------------------Eager Mode {dtype} GPU-----------------------------------" + ) + image_cuda = image.to(device="cuda") + aspect_ratio_cuda = aspect_ratio.to(device="cuda") + vision_encoder_cuda = vision_encoder.to(device="cuda") + avg, eager_res_cuda = benchmark_vision_encoder( + vision_encoder_cuda, image_cuda, aspect_ratio_cuda + ) + print(f"Averaged time: {avg}") + print(f"Close to eager? {torch.allclose(eager_res, eager_res_cuda.cpu())}") + # Torch.compile + # print("-----------------------------------Torch.compile fp32 GPU-----------------------------------") + # with torch.no_grad(): + # compiled_vision_encoder_cuda = torch.compile(vision_encoder_cuda, mode="reduce-overhead") + # # warm up run + # compiled_vision_encoder_cuda(image_cuda, aspect_ratio_cuda) + # # time it + # compiled_res = run_vision_encoder_eager(compiled_vision_encoder_cuda, image_cuda, aspect_ratio_cuda) + # print(f"Close to eager? {torch.allclose(eager_res, compiled_res)}") + + # print("-----------------------------------AOTInductor fp32 GPU-----------------------------------") + # with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + # so = torch._export.aot_compile( + # vision_encoder_cuda, + # args=(image_cuda, aspect_ratio_cuda), + # options={"aot_inductor.output_path": "/tmp/vision_transformer.so"}, + # dynamic_shapes=(image_dynamic_dim, None), + # ) + # aoti_res = run_vision_encoder_eager(torch._export.aot_load(so, device="cuda"), image_cuda, aspect_ratio_cuda) + + +def export_vision_encoder(llama3_2_dir): + preprocess_outputs = get_sample_preprocess_outputs(llama3_2_dir) + image = preprocess_outputs["encoder_input"]["image"] + aspect_ratio = preprocess_outputs["encoder_input"]["aspect_ratio"] + image_dynamic_dim = get_vision_encoder_dynamic_shapes() + vision_encoder = get_vision_encoder().eval() + with torch.no_grad(): + print("Start to export vision encoder") + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + ep = torch.export.export( + vision_encoder, + (image, aspect_ratio), + dynamic_shapes=(image_dynamic_dim, None), + ) + print("Done exporting vision encoder") + return ep + + +def aoti_export_vision_encoder(llama3_2_dir): + preprocess_outputs = get_sample_preprocess_outputs(llama3_2_dir) + image = preprocess_outputs["encoder_input"]["image"].to(dtype=torch.float32) + aspect_ratio = preprocess_outputs["encoder_input"]["aspect_ratio"] + image_dynamic_dim = get_vision_encoder_dynamic_shapes() + + model = get_flamingo(llama3_2_dir) + + print("Start to AOTI export vision encoder") + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + so = torch._export.aot_compile( + model.encoder, + args=(image, aspect_ratio), + options={ + "aot_inductor.output_path": os.path.join( + llama3_2_dir, "vision_encoder.so" + ) + }, + dynamic_shapes=(image_dynamic_dim, None), + ) + print("Done AOTI exporting vision encoder") + + +def aoti_export_text_decoder(llama3_2_dir): + """ + (Pdb) encoder_embed.shape + torch.Size([1, 6404, 4096]) + (Pdb) mask.shape + torch.Size([1, 17, 117]) + (Pdb) encoder_mask.shape + torch.Size([1, 17, 6404]) + (Pdb) input_pos.shape + torch.Size([1, 17]) + (Pdb) tokens.shape + torch.Size([1, 17]) + """ + model = get_flamingo(llama3_2_dir) + + with torch.no_grad(): + dynamic_shapes = get_text_decoder_dynamic_shapes() + tokens, data = get_text_decoder_inputs(llama3_2_dir, model) + + print("Start to generate aoti for text decoder") + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + so = torch._export.aot_compile( + model.decoder, + (tokens,), + data, + options={ + "aot_inductor.output_path": os.path.join( + llama3_2_dir, "text_decoder.so" + ) + }, + dynamic_shapes=dynamic_shapes, + ) + + +def export_text_decoder(llama3_2_dir): + model = get_flamingo(llama3_2_dir) + + class Decoder(torch.nn.Module): + def __init__(self, model) -> None: + super().__init__() + self.decoder = model.decoder + + def forward( + self, + tokens: torch.Tensor, + *, + mask: Optional[torch.Tensor] = None, + encoder_input: Optional[torch.Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ): + return self.decoder( + tokens, + mask=mask, + encoder_input=encoder_input, + encoder_mask=encoder_mask, + input_pos=input_pos, + ) + + m = Decoder(model) + with torch.no_grad(): + dynamic_shapes = get_text_decoder_dynamic_shapes() + tokens, data = get_text_decoder_inputs(llama3_2_dir, model) + + print("Start to export text decoder") + with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + ep = torch.export.export( + m, + (tokens,), + data, + dynamic_shapes=dynamic_shapes, + ) + print(ep) + return ep + + +def validate_text_decoder(llama3_2_dir): + # aoti_export_text_decoder(llama3_2_dir) + # compare result + aoti = torch._export.aot_load( + os.path.join(llama3_2_dir, "text_decoder.so"), device="cpu" + ) + model = get_flamingo(llama3_2_dir) + eager = model.decoder + # export = export_text_decoder(llama3_2_dir) + + # inputs + tokens, data = get_text_decoder_inputs(llama3_2_dir, model) + + # run eager + eager_res = eager(tokens, **data) + aoti_res = aoti(tokens, **data) + # export_res = export.module()(tokens, **data) + + # debug + # from torch._inductor.decomposition import select_decomp_table + # inductor_decomp_gm = export.run_decompositions(select_decomp_table()).module() + + # from torch._inductor.compile_fx import _recursive_post_grad_passes + # from torch._export.utils import _detect_fake_mode_from_gm + # from torch._inductor.virtualized import V + # fake_mode = _detect_fake_mode_from_gm(inductor_decomp_gm) + # with V.set_fake_mode(fake_mode): + # _recursive_post_grad_passes(inductor_decomp_gm, is_inference=True) + + # debug_res = inductor_decomp_gm(tokens, encoder_input=encoder_input, input_pos=input_pos) + + print( + f"AOTI close to eager? {torch.allclose(eager_res, aoti_res, atol=1e-3, rtol=1e-3)}" + ) + # print(f"Export close to eager? {torch.allclose(eager_res, export_res)}") + + print(f"Eager result: {eager_res}") + # print(f"Export result: {export_res}") + # print(f"Debug result: {debug_res}") + print(f"AOTInductor result: {aoti_res}") + + +def validate_vision_encoder(llama3_2_dir): + preprocess_outputs = get_sample_preprocess_outputs(llama3_2_dir) + image = preprocess_outputs["encoder_input"]["image"].to(dtype=torch.float32) + aspect_ratio = preprocess_outputs["encoder_input"]["aspect_ratio"] + # eager model + vision_encoder = get_vision_encoder().eval() + # aoti export + aoti_export_vision_encoder(llama3_2_dir) + aoti = torch._export.aot_load( + os.path.join(llama3_2_dir, "vision_encoder.so"), device="cpu" + ) + # export + export = export_vision_encoder() + # results + eager_res = vision_encoder(image, aspect_ratio) + aoti_res = aoti(image, aspect_ratio) + export_res = export.module()(image, aspect_ratio) + + print( + f"AOTI close to eager? {torch.allclose(eager_res, aoti_res, atol=1e-3, rtol=1e-3)}" + ) + print(f"Export close to eager? {torch.allclose(eager_res, export_res)}") + + print(f"Eager result: {eager_res}") + print(f"Export result: {export_res}") + print(f"AOTInductor result: {aoti_res}") + + +def test_aoti(llama3_2_dir): + aoti = torch._export.aot_load( + os.path.join(llama3_2_dir, "text_decoder.so"), device="cpu" + ) + model = get_flamingo(llama3_2_dir) + eager = model.decoder + # export = export_text_decoder(llama3_2_dir) + + # inputs + tokens, data = get_text_decoder_inputs(llama3_2_dir, model) + seq_len = tokens.shape[1] + + # first run with image + print("First run with image") + eager_res = eager(tokens, **data) + aoti_res = aoti(tokens, **data) + print( + f"AOTI close to eager? {torch.allclose(eager_res, aoti_res, atol=1e-3, rtol=1e-3)}" + ) + + # second run with no image + print("Second run with no image") + tok = sample(aoti_res[:, -1]) + # adjust input + data["encoder_mask"] = data["encoder_mask"][:, -1:] + data.pop("encoder_input") + seq_len += 1 + data["input_pos"] = input_pos[None, seq_len] + causal_mask = torch.tril( + torch.ones( + size=(max_seq_len, max_seq_len), + dtype=torch.bool, + ) + ) + data["mask"] = causal_mask[None, seq_len, None, :] + # run + logits = aoti(tok, encoder_mask=data["encoder_mask"], mask=data["mask"], input_pos=data["input_pos"]) + +if __name__ == "__main__": + llama3_2_dir = str(sys.argv[1]) + # validate_vision_encoder(llama3_2_dir) + aoti_export_text_decoder(llama3_2_dir) + # test_aoti(llama3_2_dir) + # validate_text_decoder(llama3_2_dir) + # model = get_flamingo(llama3_2_dir) + + # with torch.no_grad(): + + # data = get_sample_preprocess_outputs(llama3_2_dir).copy() + # image = data["encoder_input"]["images"].to(dtype=torch.float32) + # embeds = model.encoder(image, data["encoder_input"]["aspect_ratio"]) + # data["encoder_input"] = embeds + # tokens = data.pop("tokens") + # print("Start eager run") + # model.decoder(tokens, **data) diff --git a/torchtune/models/flamingo/_component_builders.py b/torchtune/models/flamingo/_component_builders.py index 870c028626..66e55a13b3 100644 --- a/torchtune/models/flamingo/_component_builders.py +++ b/torchtune/models/flamingo/_component_builders.py @@ -248,11 +248,15 @@ def flamingo_decoder( mlp_scale=TanhGate(), ) fusion_layer = FusionLayer(layer=decoder_layer, fusion_layer=xattn_layer) + # fusion_layer.state_dict_handle.remove() + # fusion_layer.load_state_dict_handle.remove() layers.append(fusion_layer) else: layers.append(decoder_layer) tok_embeddings = FusionEmbedding(vocab_size, num_special_tokens, embed_dim) + # tok_embeddings.state_dict_handle.remove() + # tok_embeddings.load_state_dict_handle.remove() output_proj = nn.Linear(embed_dim, vocab_size, bias=False) return TransformerDecoder( diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 2dfeaddc9a..5625dd7c9d 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -224,7 +224,7 @@ def forward( # y has shape [b, s_y, d] b, s_x, _ = x.shape s_y = y.shape[1] if y is not None else 0 - + # q has shape [b, s_x, num_heads * head_dim] q = self.q_proj(x) @@ -243,14 +243,20 @@ def forward( if self.q_norm is not None: q = self.q_norm(q) - if y is None: - if self.kv_cache is None: - raise ValueError( - "Must provide y input or use kv_cache to enable streaming decoding" - ) - k = self.kv_cache.k_cache - v = self.kv_cache.v_cache - else: + def true_fn(y): + # if self.kv_cache is None: + # raise ValueError( + # "Must provide y input or use kv_cache to enable streaming decoding" + # ) + if self.kv_cache is not None: + k = self.kv_cache.k_cache + v = self.kv_cache.v_cache + else: + k = torch.zeros(b, s_y, self.num_heads, self.head_dim).transpose(1, 2) + v = torch.zeros(b, s_y, self.num_heads, self.head_dim).transpose(1, 2) + return k, v + + def false_fn(y): # Update k and v shape, positional embeddings, and normalization # k has shape [b, s_y, num_kv_heads * head_dim] @@ -293,6 +299,9 @@ def forward( # Update key-value cache if self.kv_cache is not None: k, v = self.kv_cache.update(k, v) + return k, v + + k, v = torch.cond(torch.isnan(y).all(), true_fn, false_fn, (y, )) output = self._attention_call( q, diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index 84996518ad..f784d7eb8c 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -40,14 +40,20 @@ def __init__( self.register_buffer( "v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False ) - self.size = 0 + self.register_buffer( + "cache_pos", torch.arange(0, cache_shape[2]), persistent=False + ) self.batch_size = batch_size def reset(self) -> None: """Reset the cache to zero.""" self.k_cache.zero_() self.v_cache.zero_() - self.size = 0 + self.cache_pos -= self.size + + @property + def size(self) -> int: + return self.cache_pos[0].item() def update( self, k_val: torch.Tensor, v_val: torch.Tensor @@ -80,7 +86,7 @@ def update( Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. Raises: - ValueError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. + AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. ValueError: if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup. """ @@ -91,18 +97,20 @@ def update( f", but found new key tensors with batch size {k_val.shape[0]}!" ) - if (self.size + seq_len) > self.k_cache.shape[2]: - raise ValueError( - f"The current cache has been setup with a sequence length of {self.k_cache.shape[2]}" - f", but the cache has reached a sequence length of {(self.size + seq_len)}!" - ) - cache_pos = torch.arange(self.size, self.size + seq_len, device=k_val.device) - self.size += seq_len - + assert (self.cache_pos[0] + seq_len) <= self.k_cache.shape[2] k_out = self.k_cache v_out = self.v_cache - k_out.index_copy_(2, cache_pos, k_val) - v_out.index_copy_(2, cache_pos, v_val) + k_out[:, :, self.cache_pos[:seq_len]] = k_val + v_out[:, :, self.cache_pos[:seq_len]] = v_val + + # forward cache_pos seq_len positions along + # cache_pos starts at (0, 1, 2, 3, 4, 5, ...) + # an update of seq_len = 5 tokens brings it to + # (5, 6, 7, 8, 9, ...) + # this allows us to track the current position in the cache + # after the last update in a compile-friendly way without any dynamism + # e.g. relying on an int size tracker, or re-creating cache_pos every time + self.cache_pos.add_(seq_len) return k_out, v_out diff --git a/torchtune/modules/model_fusion/_fusion.py b/torchtune/modules/model_fusion/_fusion.py index ea1f01c383..100f44b1af 100644 --- a/torchtune/modules/model_fusion/_fusion.py +++ b/torchtune/modules/model_fusion/_fusion.py @@ -57,8 +57,8 @@ def __init__( self.fusion_first = fusion_first # Keep FusionLayer wrappings out of the state_dict - self._register_state_dict_hook(FusionLayer._state_dict_hook) - self._register_load_state_dict_pre_hook( + self.state_dict_handle = self._register_state_dict_hook(FusionLayer._state_dict_hook) + self.load_state_dict_handle = self._register_load_state_dict_pre_hook( FusionLayer._load_state_dict_hook, with_module=True ) # TODO: Switch to register_load_state_dict_pre_hook and @@ -200,8 +200,8 @@ def __init__(self, vocab_size: int, fusion_vocab_size: int, embed_dim: int) -> N # TODO: Support merging the embeddings after finetuning # Keep FusionLayer wrappings out of the state_dict - self._register_state_dict_hook(FusionEmbedding._state_dict_hook) - self._register_load_state_dict_pre_hook( + self.state_dict_handle = self._register_state_dict_hook(FusionEmbedding._state_dict_hook) + self.load_state_dict_handle = self._register_load_state_dict_pre_hook( FusionEmbedding._load_state_dict_hook, with_module=True ) # TODO: Switch to register_load_state_dict_pre_hook and @@ -444,6 +444,8 @@ def forward( encoder_embed = None if encoder_input is not None: encoder_embed = self.encoder(**encoder_input) + else: + encoder_embed = torch.tensor(torch.nan) output = self.decoder( tokens=tokens, From 361eab1f95093b4f4ece7bdc3c006e3816452083 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 16 Oct 2024 16:54:24 -0700 Subject: [PATCH 3/3] Update Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- export_flamingo.py | 40 ++++++++++++++++++----------- torchtune/modules/attention.py | 47 ++++++++++++++++++++-------------- torchtune/modules/kv_cache.py | 14 ++++++++++ 3 files changed, 67 insertions(+), 34 deletions(-) diff --git a/export_flamingo.py b/export_flamingo.py index 4f5c93651a..a38c0b4461 100644 --- a/export_flamingo.py +++ b/export_flamingo.py @@ -1,3 +1,8 @@ +import os, sys +import time +from functools import lru_cache, wraps +from typing import Optional + import numpy as np import PIL import torch @@ -12,10 +17,6 @@ flamingo_vision_encoder, FlamingoTransform, ) -import os, sys -import time -from functools import lru_cache, wraps -from typing import Optional from torchtune.models.flamingo._component_builders import ( flamingo_decoder, @@ -552,22 +553,31 @@ def test_aoti(llama3_2_dir): ) data["mask"] = causal_mask[None, seq_len, None, :] # run - logits = aoti(tok, encoder_mask=data["encoder_mask"], mask=data["mask"], input_pos=data["input_pos"]) + logits = aoti( + tok, + encoder_mask=data["encoder_mask"], + mask=data["mask"], + input_pos=data["input_pos"], + ) + if __name__ == "__main__": llama3_2_dir = str(sys.argv[1]) # validate_vision_encoder(llama3_2_dir) - aoti_export_text_decoder(llama3_2_dir) + # aoti_export_text_decoder(llama3_2_dir) # test_aoti(llama3_2_dir) # validate_text_decoder(llama3_2_dir) - # model = get_flamingo(llama3_2_dir) + model = get_flamingo(llama3_2_dir) - # with torch.no_grad(): + with torch.no_grad(): - # data = get_sample_preprocess_outputs(llama3_2_dir).copy() - # image = data["encoder_input"]["images"].to(dtype=torch.float32) - # embeds = model.encoder(image, data["encoder_input"]["aspect_ratio"]) - # data["encoder_input"] = embeds - # tokens = data.pop("tokens") - # print("Start eager run") - # model.decoder(tokens, **data) + data = get_sample_preprocess_outputs(llama3_2_dir).copy() + image = data["encoder_input"]["images"].to(dtype=torch.float32) + embeds = model.encoder(image, data["encoder_input"]["aspect_ratio"]) + data["encoder_input"] = embeds + tokens = data.pop("tokens") + print("Start eager run") + # model.decoder(tokens, **data) + torch._dynamo.config.capture_dynamic_output_shape_ops = True + torch._dynamo.config.capture_scalar_outputs = True + torch.compile(model.decoder, fullgraph=True)(tokens, **data) diff --git a/torchtune/modules/attention.py b/torchtune/modules/attention.py index 5625dd7c9d..e1b5b48413 100644 --- a/torchtune/modules/attention.py +++ b/torchtune/modules/attention.py @@ -224,7 +224,7 @@ def forward( # y has shape [b, s_y, d] b, s_x, _ = x.shape s_y = y.shape[1] if y is not None else 0 - + # q has shape [b, s_x, num_heads * head_dim] q = self.q_proj(x) @@ -243,24 +243,13 @@ def forward( if self.q_norm is not None: q = self.q_norm(q) - def true_fn(y): - # if self.kv_cache is None: - # raise ValueError( - # "Must provide y input or use kv_cache to enable streaming decoding" - # ) - if self.kv_cache is not None: - k = self.kv_cache.k_cache - v = self.kv_cache.v_cache - else: - k = torch.zeros(b, s_y, self.num_heads, self.head_dim).transpose(1, 2) - v = torch.zeros(b, s_y, self.num_heads, self.head_dim).transpose(1, 2) - return k, v - - def false_fn(y): + def calculate_kv(original_y): # Update k and v shape, positional embeddings, and normalization # k has shape [b, s_y, num_kv_heads * head_dim] # v has shape [b, s_y, num_kv_heads * head_dim] + y = original_y.clone() + k = self.k_proj(y) v = self.v_proj(y) @@ -296,12 +285,32 @@ def false_fn(y): if self.k_norm is not None: k = self.k_norm(k) - # Update key-value cache - if self.kv_cache is not None: - k, v = self.kv_cache.update(k, v) return k, v - k, v = torch.cond(torch.isnan(y).all(), true_fn, false_fn, (y, )) + def true_fn(y): + kv_cache = self.kv_cache.clone() + return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos + + def false_fn(y): + k, v = calculate_kv(y) + kv_cache = self.kv_cache.clone() + kv_cache.update(k, v) + return kv_cache.k_cache, kv_cache.v_cache, kv_cache.cache_pos + + # If kv cache is None, we expect y to be provided + if self.kv_cache is None: + assert ( + y is not None and not torch.isnan(y).all() + ), "Must provide y input or use kv_cache to enable streaming decoding" + k, v = calculate_kv(y) + else: + # Expecting the k, v returning here to be the same size of self.kv_cache + k, v, cache_pos = torch.cond(torch.isnan(y).all(), true_fn, false_fn, (y,)) + + # Update kv cache + self.kv_cache.k_cache.copy_(k) + self.kv_cache.v_cache.copy_(v) + self.kv_cache.cache_pos.copy_(cache_pos) output = self._attention_call( q, diff --git a/torchtune/modules/kv_cache.py b/torchtune/modules/kv_cache.py index f784d7eb8c..b615d6742b 100644 --- a/torchtune/modules/kv_cache.py +++ b/torchtune/modules/kv_cache.py @@ -114,3 +114,17 @@ def update( self.cache_pos.add_(seq_len) return k_out, v_out + + def clone(self) -> "KVCache": + """Create a clone of the KVCache.""" + clone = KVCache( + batch_size=self.batch_size, + max_seq_len=self.k_cache.shape[2], + num_heads=self.k_cache.shape[1], + head_dim=self.k_cache.shape[3], + dtype=self.k_cache.dtype, + ) + clone.k_cache.copy_(self.k_cache) + clone.v_cache.copy_(self.v_cache) + clone.cache_pos.copy_(self.cache_pos) + return clone