Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback. #1104

Open
AlphaBetaGamma96 opened this issue Jan 17, 2023 · 0 comments

Comments

@AlphaBetaGamma96
Copy link

Hi All,

I've been trying to use memory_efficient_fusion to see if I can speed up a main bottleneck in my code, but I hit a RuntimeError. This issue continues from #1011.
The code to reproduce this is as follows.

import torch
from torch import nn

import functorch
from functorch import make_functional, vmap, jacrev, grad
from functorch.compile import memory_efficient_fusion

import time

_ = torch.manual_seed(1234)

#version info
print("PyTorch version:   ", torch.__version__)     #PyTorch version:    2.0.0.dev20230116
print("CUDA version:      ", torch.version.cuda)    #CUDA version:       11.6
print("FuncTorch version: ", functorch.__version__) #FuncTorch version:  2.0.0.dev20230116

#=============================================#

#time with torch synchronization
def sync_time() -> float:
  torch.cuda.synchronize()
  return time.perf_counter()

class model(nn.Module):

  def __init__(self, num_inputs, num_hidden):
    super(model, self).__init__()
    
    self.num_inputs=num_inputs
    self.func = nn.Tanh()
    
    self.fc1 = nn.Linear(2, num_hidden)
    self.fc2 = nn.Linear(num_hidden, num_inputs)
  
  def forward(self, x):
    """
    Takes x in [B,A] and maps it to sign/logabsdet value in Tuple([B,], [B,])
    """
    x=x.unsqueeze(-1)
    idx=len(x.shape)             #creates args for repeat if vmap is used or not
    rep=[1 for _ in range(idx)]
    rep[-2] = self.num_inputs
    g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
    f = torch.cat((x,g), dim=-1)

    h = self.func(self.fc1(f))
    
    mat = self.fc2(h)
    sgn, logabs = torch.linalg.slogdet(mat)
    return sgn, logabs

#=============================================#

B=4096 #batch
N=2    #input nodes
H=64   #number of hidden nodes
device = torch.device('cuda')

x = torch.randn(B, N, device=device) #input data

net = model(N, H) #our model
net=net.to(device)

sgn, logabs = net(x)

fnet, params = make_functional(net)

def calc_logabs(params, x):
  _, logabs = fnet(params, x)
  return logabs

def calc_dlogabs_dx(params, x):
  dlogabs_dx = jacrev(func=calc_logabs, argnums=1)(params, x)
  return dlogabs_dx, dlogabs_dx #return aux

def local_kinetic_from_log_vmap(params, x):
  d2logabs_dx2, dlogabs_dx = jacrev(func=calc_dlogabs_dx, argnums=1, has_aux=True)(params, x)
  _local_kinetic = -0.5*(d2logabs_dx2.diagonal(0,-2,-1).sum() + dlogabs_dx.pow(2).sum())
  return _local_kinetic 

#memory efficient fusion here
#with torch.jit.fuser("fuser2"): #is this needed (from functorch/issues/840)
ps_elocal = vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0))
ps_elocal_fusion = memory_efficient_fusion(vmap(grad(local_kinetic_from_log_vmap, argnums=0), in_dims=(None, 0)))


t1=sync_time()

out1 = ps_elocal(params, x)

t2=sync_time()

ps_elocal_fusion(params, x) #crashes here: aten::is_same_size no batching rule

t3=sync_time()

#Compare memory_efficient_fusion on the function's walltime
print("Laplacian (standard): %4.2e (s)",t2-t1)
print("Laplacian (fusion):   %4.2e (s)",t3-t2)

The traceback is as follows,

PyTorch version:    2.0.0.dev20230116
CUDA version:       11.6
FuncTorch version:  2.0.0.dev20230116
Failed to collect metadata on function, produced code may be suboptimal.  Known situations this can occur are inference mode only compilation involving resize_ or prims (!schema.hasAnyAliasInfo() INTERNAL ASSERT FAILED); if your situation looks different please file a bug to PyTorch.
Traceback (most recent call last):
  File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1368, in aot_wrapper_dedupe
    fw_metadata, _out = run_functionalized_fw_and_collect_metadata(flat_fn)(
  File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 569, in inner
    flat_f_outs = f(*flat_f_args)
    
    ... lot more errors
    
  File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/autograd/__init__.py", line 285, in grad
    grad_outputs_ = _make_grads(t_outputs, grad_outputs_, is_grads_batched=is_grads_batched)
  File "~/anaconda3/envs/nightly_20230117/lib/python3.10/site-packages/torch/autograd/__init__.py", line 53, in _make_grads
    if not torch.is_same_size(out, first_grad):
RuntimeError: Batching rule not implemented for aten::is_same_size. We could not generate a fallback.

This code was created with the latest nightly version.

PyTorch version:    2.0.0.dev20230116
CUDA version:       11.6
FuncTorch version:  2.0.0.dev20230116
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant