-
Notifications
You must be signed in to change notification settings - Fork 6.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: Argument #4: Padding size should be less than the corresponding input dimension for v2 transforms #8622
Comments
Not sure of the reason to combine v1 and v2 together in
Below code works (tested in google colab) . Please try this.
|
It works, but following the docs, it seems that the standard steps should include This is what a typical transform pipeline could look like:
|
Below is my understanding, others can chime in as needed :) Yeah, that is a good point. In My opinion, May be that doc needs to be clear to specify the difference in padding operation done on pillow image and on a tensor.
If we look at other docs for padding, they have used pillow images. https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_illustrations.html#sphx-glr-auto-examples-transforms-plot-transforms-illustrations-py Anyways, this is my understanding.
Root Cause Analysis : Padding on pillow images uses pillow functions and numpy functions and do not do any checking on dimensions. https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_pil.py#L144-L220 padding on tensor uses pytorch code and does strict type checking for dimensions. My understanding is that PyTorch does these internal checks to prevent padding operations from exceeding the dimensions of a tensor, ensuring that all computations stay within the allocated memory bounds to avoid errors like crashes or data corruption. Scenario 1 . Extra padding (padding size greater than image size) works on a pillow image.
Scenario 2 :
Scenario 3:
|
Many thanks, very clear explanations and instructions! |
Thanks for the report @lxr2 , and @venkatram-dev for the help. Just to summarize: this isn't a v1 vs v2 issue. This is a difference in behavior between the PIL backend and the tensor backend (and this difference can be observed on both v1 and v2). PIL supports the padding size to be larger than the image dimsions, while torchvision / pytorch doesn't. simple reproducer: import torch
from torchvision.transforms import functional as F
from torchvision.transforms.v2 import functional as F2
t = torch.rand(3,32,32)
pil_img = F.to_pil_image(t)
padding = 31 # fails for 32+ on tensors
trans_img = F.pad(pil_img, padding=padding, padding_mode='reflect')
print(trans_img.size)
trans_img = F2.pad(pil_img, padding=padding, padding_mode='reflect')
print(trans_img.size)
trans_img = F.pad(t, padding=padding, padding_mode='reflect')
print(trans_img.shape)
trans_img = F2.pad(t, padding=padding, padding_mode='reflect')
print(trans_img.shape) Unfortunately, this isn't something we can directly address in torchvision, because the behavior is dictated by torch's pad. Note that there are similar discussions in pytorch/pytorch#18413 but at the time, it was suggested that the existing torch behavior is expected. |
🐛 Describe the bug
It seems that
v2.Pad
does not support cases where the padding size is greater than the image size, butv1.Pad
does support this. I hope that v2.Pad will allow this in the future as well.Versions
The text was updated successfully, but these errors were encountered: