Skip to content

Commit

Permalink
Adding test for quantized ESPCN
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Aug 1, 2023
1 parent 93b8f09 commit 44e1b9a
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/transformation/test_subpixel_to_deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
import pytest

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as nph
from onnx import TensorProto
from onnx.checker import check_model
from pkgutil import get_data
Expand Down Expand Up @@ -62,6 +64,36 @@ def test_subpixel_to_deconv_float_espcn():
assert np.isclose(expected, produced, atol=1e-4).all(), "Error: expected output does not match the produced output."


def test_subpixel_to_deconv_quant_espcn():
# get raw quantized model with reference input
raw_i = get_data("qonnx.data", "onnx/bsd300x3-espcn/test_data/input_0.pb")
raw_m = get_data("qonnx.data", "onnx/bsd300x3-espcn/quant_model.onnx")
# create model from the onnx file and infer the shapes
model = ModelWrapper(raw_m)
model = model.transform(InferShapes())
iname = model.graph.input[0].name
oname = model.graph.output[0].name
ishape = model.get_tensor_shape(iname)
# load the reference input tensor
input_tensor = onnx.load_tensor_from_string(raw_i)
input_tensor = nph.to_array(input_tensor)
assert list(input_tensor.shape) == ishape, "Error: reference input doesn't match loaded model."
input_dict = {iname: input_tensor}
# get the output from the sub-pixel convolution model
output_subpixel_conv = oxe.execute_onnx(model, input_dict)[oname]
# translate the sub-pixel convolution to the deconvolution
new_model = model.transform(SubPixelToDeconvolution())
new_model = new_model.transform(InferShapes())
# check that there are no DepthToSpace ops left
op_types = list(map(lambda x: x.op_type, new_model.graph.node))
assert "DepthToSpace" not in op_types, "Error: the DepthToSpace nodes would be removed."
# get the output from the deconvolution model
output_deconv = oxe.execute_onnx(new_model, input_dict)[oname]
assert np.isclose(
output_deconv, output_subpixel_conv, atol=1 / 255.0, rtol=1 / 255.0
).all(), "Error: expected output does not match the produced output."


def create_subpixel_conv_model(
in_channels: int, out_channels: int, input_dim: int, kernel_size: int, upscale_factor: int, bias: bool = False
):
Expand Down

0 comments on commit 44e1b9a

Please sign in to comment.