Skip to content

Question about implement external conv2d in assignment 4 #33

Answered by cblmemo
cblmemo asked this question in Q&A
Discussion options

You must be logged in to vote

Update: inspired by hbsun's answer I tried this implementation:

@tvm.register_func("env.conv2d", override=True)
def torch_conv2d(
    x: tvm.nd.NDArray,
    w: tvm.nd.NDArray,
    b: tvm.nd.NDArray,
    o: tvm.nd.NDArray
):
    x_torch = torch.from_dlpack(x)
    w_torch = torch.from_dlpack(w)
    b_torch = torch.from_dlpack(b)
    o_torch = torch.from_dlpack(o)

    out_temp = torch.nn.functional.conv2d(x_torch, w_torch, b_torch)
    torch.add(out_temp, 0, out=o_torch)

and it works as expected. So it seems that the problem is using out=o_torch will actually rewrite memory of o and o_torch = foo will only let identifier o_torch reference to variable foo and memory of o_torch will not chang…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by cblmemo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant