Skip to content

Commit

Permalink
Merge pull request #31 from h2oai/detached
Browse files Browse the repository at this point in the history
Update dbresnet_50 detection model
  • Loading branch information
smg478 authored Mar 19, 2024
2 parents 2965b9d + a201a55 commit ce9937b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 47 deletions.
125 changes: 79 additions & 46 deletions doctr/models/detection/_utils/export_det_onnx.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,95 @@
# Import required libraries
import time

import torch
import numpy as np
import torch.onnx

import onnxruntime
from doctr.models import ocr_predictor
from openvino.runtime import Core

model = ocr_predictor(det_arch ='db_resnet50_rotation', pretrained=True)
model.det_predictor.model = model.det_predictor.model.eval()

input = torch.randn(1, 3, 1024, 1024)
input2 = torch.randn(1, 3, 1536, 1536)
start = time.time()
pred = model.det_predictor.model(input)
print("pytorch time", time.time() - start)
torch.onnx.export(model.det_predictor.model,
input,
'db_resnet50_rotation.onnx',
export_params = True,
start_load_time = time.time()
device = torch.device('cpu')
model = ocr_predictor(det_arch='db_resnet50', pretrained=True).det_predictor.model
model.to(device).eval()
model_load_time = time.time() - start_load_time
print(f"PyTorch Model Load Time: {model_load_time} seconds")


# Define a function for PyTorch inference and benchmarking
def pytorch_inference(model, input_tensor):
with torch.no_grad():
return model(input_tensor).detach().cpu().numpy()


# Define a function to benchmark ONNX inference and verify accuracy
def benchmark_onnx_inference_and_verify(model_path, input_tensor, pytorch_output):
session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
start_time = time.time()
onnx_output = session.run(None, {"input": input_tensor.numpy()})
inference_time = time.time() - start_time

# Verify accuracy
np.testing.assert_allclose(pytorch_output, onnx_output[0], rtol=1e-3, atol=1e-5)
print("ONNX Runtime verification passed")
return inference_time


# Define a function to benchmark OpenVINO inference and verify accuracy
def benchmark_openvino_inference_and_verify(model_path, input_tensor, pytorch_output):
ie = Core()
model_onnx = ie.read_model(model=model_path)
compiled_model = ie.compile_model(model=model_onnx, device_name="CPU")
output_layer = compiled_model.output(0)

start_time = time.time()
openvino_output = compiled_model([input_tensor.numpy()])[output_layer]
inference_time = time.time() - start_time

# Verify accuracy
np.testing.assert_allclose(pytorch_output, openvino_output, rtol=1e-3, atol=1e-5)
print("OpenVINO Runtime verification passed")
return inference_time


torch.onnx.export(model,
torch.randn(1, 3, 1024, 1024),
'db_resnet50.onnx',
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names = ["input"],
output_names = ["output"],
dynamic_axes = {"input":{0:"batch_size", 2:"x_axis", 3:"y_axis"},
"output":{0:"batch_size", 2:"x_axis", 3:"y_axis"}})

import onnxruntime
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size", 2: "height", 3: "width"}})

ort_session = onnxruntime.InferenceSession('db_resnet50_rotation.onnx', providers = ["CPUExecutionProvider"])
# Example input tensor
input_tensor_1024 = torch.randn(1, 3, 1024, 1024)
input_tensor_1536 = torch.randn(1, 3, 1536, 1536)

ort_inputs = {"input":input.numpy()}
# Perform PyTorch inference and capture output
start = time.time()
ort_outs = ort_session.run(None, ort_inputs)
print("onnx time", time.time() - start)
print(np.testing.assert_allclose(pred.detach().cpu().numpy(), ort_outs[0], rtol=1e-3, atol=1e-5))
pytorch_output_1024 = pytorch_inference(model, input_tensor_1024)
print(f"PyTorch Inference Time (1024x1024): {time.time() - start}")

ort_inputs = {"input":input2.numpy()}
start = time.time()
ort_outs = ort_session.run(None, ort_inputs)
print("onnx time", time.time() - start)
start = time.time()
pred = model.det_predictor.model(input2)
print("pytorch time", time.time() - start)
print(np.testing.assert_allclose(pred.detach().cpu().numpy(), ort_outs[0], rtol=1e-3, atol=1e-5))
pytorch_output_1536 = pytorch_inference(model, input_tensor_1536)
print(f"PyTorch Inference Time (1536x1536): {time.time() - start}")


from openvino.runtime import Core
# Benchmark and verify ONNX Runtime
time_onnx_1024 = benchmark_onnx_inference_and_verify('db_resnet50.onnx', input_tensor_1024,
pytorch_output_1024)
print(f"ONNX Runtime Inference Time (1024x1024): {time_onnx_1024}")

ie = Core()
model_onnx = ie.read_model(model='db_resnet50_rotation.onnx')
start = time.time()
compiled_model_onnx = ie.compile_model(model=model_onnx, device_name="CPU")
time_onnx_1536 = benchmark_onnx_inference_and_verify('db_resnet50.onnx', input_tensor_1536,
pytorch_output_1536)
print(f"ONNX Runtime Inference Time (1536x1536): {time_onnx_1536}")

output_layer_onnx = compiled_model_onnx.output(0)
print("model compilation time", time.time() - start)
start = time.time()
# Run inference on the input image.
print(input2.numpy().shape, input2.dtype)
res_onnx = compiled_model_onnx([input2.numpy()])[output_layer_onnx]
print(res_onnx.shape)
print("openvino time", time.time() - start)
print(np.testing.assert_allclose(pred.detach().cpu().numpy(), res_onnx, rtol=1e-3, atol=1e-5))
# Benchmark and verify OpenVINO
time_openvino_1024 = benchmark_openvino_inference_and_verify('db_resnet50.onnx', input_tensor_1024,
pytorch_output_1024)
print(f"OpenVINO Inference Time (1024x1024): {time_openvino_1024}")

time_openvino_1536 = benchmark_openvino_inference_and_verify('db_resnet50.onnx', input_tensor_1536,
pytorch_output_1536)
print(f"OpenVINO Inference Time (1536x1536): {time_openvino_1536}")
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
"input_shape": (3, 1024, 1024),
"mean": (0.798, 0.785, 0.772),
"std": (0.264, 0.2749, 0.287),
"url": "https://doctr-static.mindee.com/models?id=v0.3.1/db_resnet50-ac60cadc.pt&src=0",
# "url": "https://doctr-static.mindee.com/models?id=v0.3.1/db_resnet50-ac60cadc.pt&src=0",
"url": "https://doctr-static.mindee.com/models?id=v0.7.0/db_resnet50-79bd7d70.pt&src=0", # New URL for v0.7.0
},
'db_resnet50_onnx': {
'input_shape': (3, 1024, 1024),
Expand Down Expand Up @@ -300,6 +301,8 @@ def _dbnet(
# Load pretrained parameters
if pretrained:
load_pretrained_params(model, default_cfgs[arch]["url"])
print(f"Loaded pretrained parameters for {arch}")
print(f"Pretrained parameters loaded from {default_cfgs[arch]['url']}")

return model

Expand Down

0 comments on commit ce9937b

Please sign in to comment.