Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unpack_tuple() is no longer correct with timm v1.0.3 #34

Open
yukw777 opened this issue May 15, 2024 · 2 comments
Open

unpack_tuple() is no longer correct with timm v1.0.3 #34

yukw777 opened this issue May 15, 2024 · 2 comments

Comments

@yukw777
Copy link

yukw777 commented May 15, 2024

timm v1.0.3 was just released 2 hours ago (https://github.com/huggingface/pytorch-image-models/releases/tag/v1.0.3) and it seems like they've reworked the API for forward_intermediates() and it returns a list instead of a tuple. As a result, when I run scripts.generate.py with all the default settings and a simple question Is the coffee cup empty?, I get the following error:

Traceback (most recent call last):
  File "/home/peter/repos/prismatic-vlms/scripts/generate.py", line 133, in <module>
    generate()
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/draccus/argparsing.py", line 203, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/scripts/generate.py", line 116, in generate
    generated_text = vlm.generate(
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/prismatic/models/vlms/prismatic.py", line 553, in generate
    generated_ids = super().generate(
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1576, in generate
    result = self._greedy_search(
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2494, in _greedy_search
    outputs = self(
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/prismatic/models/vlms/prismatic.py", line 311, in forward
    patch_features = self.vision_backbone({k: pixel_values[k][multimodal_indices] for k in pixel_values})
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/peter/repos/prismatic-vlms/prismatic/models/backbones/vision/dinosiglip_vit.py", line 147, in forward
    return torch.cat([dino_patches, siglip_patches], dim=2)
TypeError: expected Tensor as element 0 in argument 0, but got list

The following diff fixes the issue:

diff --git a/prismatic/models/backbones/vision/base_vision.py b/prismatic/models/backbones/vision/base_vision.py
index e9ccade..cf67351 100644
--- a/prismatic/models/backbones/vision/base_vision.py
+++ b/prismatic/models/backbones/vision/base_vision.py
@@ -11,7 +11,7 @@ Transformer model for feature extraction.
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from functools import partial
-from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
+from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union, Sequence

 import timm
 import torch
@@ -27,7 +27,7 @@ from torchvision.transforms import Compose, Resize
 def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
     def wrapper(*args: Any, **kwargs: Any) -> Any:
         result = fn(*args, **kwargs)
-        return result[0] if isinstance(result, tuple) else result
+        return result[0] if isinstance(result, Sequence) else result

     return wrapper

I'm happy to submit a PR for this, but seeing that this is related to monkey patching for FSDP support, I wanted to discuss how to properly fix it before moving forward.

@siddk
Copy link
Collaborator

siddk commented May 16, 2024

I just pushed a commit to pin timm==0.9.10 for the time being to make sure this doesn't break things for other folks.

I'd love it if you could push a PR, maybe add a test that verifies that results with different versions return the same output? Based on your PR, I can then test FSDP functionality and make sure everything checks out!

@yukw777
Copy link
Author

yukw777 commented May 17, 2024

Great! A few questions for you:

  1. Should we now drop support for timm < 1.0.0 now that timm reached 1.0.0? It'll significantly lessen the ongoing maintenance effort by depending on a (supposedly) stable API.
  2. If we do decide to keep supporting timm < 1.0.0, it'd be a good idea to write regression tests, but how do you guys usually write tests? I haven't been able to find an example in the repo.
  3. Does my quick fix look good to you? I may also want to rename the function to unpack_seq(), and it'd support both pre-1.0 timm and post-1.0 timm. If we do decide to drop support for timm < 1.0.0, I may just check for list (and rename the function to unpack_list()) and bubble up the error instead of eating it up.

jmercat referenced this issue in jmercat/openvla Jul 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants