diff --git a/Makefile b/Makefile index cefd685..49adedd 100644 --- a/Makefile +++ b/Makefile @@ -100,12 +100,14 @@ examples: //coral/examples:two_models_two_tpus_threaded \ //coral/examples:model_pipelining \ //coral/examples:classify_image \ + //coral/examples:detect_image \ //coral/examples:backprop_last_layer mkdir -p $(EXAMPLES_OUT_DIR) cp -f $(BAZEL_OUT_DIR)/coral/examples/two_models_one_tpu \ $(BAZEL_OUT_DIR)/coral/examples/two_models_two_tpus_threaded \ $(BAZEL_OUT_DIR)/coral/examples/model_pipelining \ $(BAZEL_OUT_DIR)/coral/examples/classify_image \ + $(BAZEL_OUT_DIR)/coral/examples/detect_image \ $(BAZEL_OUT_DIR)/coral/examples/backprop_last_layer \ $(EXAMPLES_OUT_DIR) diff --git a/coral/examples/BUILD b/coral/examples/BUILD index c959b3a..d95e9a6 100644 --- a/coral/examples/BUILD +++ b/coral/examples/BUILD @@ -75,6 +75,22 @@ cc_binary( ], ) +cc_binary( + name = "detect_image", + srcs = ["detect_image.cc"], + deps = [ + ":file_utils", + "//coral:tflite_utils", + "//coral/detection:adapter", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@glog", + "@libedgetpu//tflite/public:oss_edgetpu_direct_all", # buildcleaner: keep + "@org_tensorflow//tensorflow/lite:framework", + "@org_tensorflow//tensorflow/lite/c:common", + ], +) + cc_binary( name = "backprop_last_layer", srcs = ["backprop_last_layer.cc"], diff --git a/coral/examples/detect_image.cc b/coral/examples/detect_image.cc new file mode 100644 index 0000000..aea0a33 --- /dev/null +++ b/coral/examples/detect_image.cc @@ -0,0 +1,121 @@ +/* Copyright 2019-2021 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// An example to detect objects in an image. +// The input image size must match the input size of the model and be stored as +// RGB pixel array. +// In linux, with the imagemagick package installed, you may resize and convert an existing image to pixel array like: +// convert kite_and_cold.jpg -resize 300x300! kite_and_cold-300x300.rgb +#include +#include +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "coral/detection/adapter.h" +#include "coral/examples/file_utils.h" +#include "coral/tflite_utils.h" +#include "tensorflow/lite/interpreter.h" + +ABSL_FLAG(std::string, model_path, "ssd_mobilenet_v2_coco_quant_postprocess_edgetpu.tflite", + "Path to the tflite model."); +ABSL_FLAG(std::string, image_path, "cat.rgb", + "Path to the image to objects detected. The input image size must match " + "the input size of the model and the image must be stored as RGB " + "pixel array."); +ABSL_FLAG(std::string, labels_path, "coco_labels.txt", + "Path to the coco labels."); +ABSL_FLAG(float, input_mean, 128, "Mean value for input normalization."); +ABSL_FLAG(float, input_std, 128, "STD value for input normalization."); +ABSL_FLAG(float, threshold, 0.2f, "Score threshold for detected objects."); +ABSL_FLAG(int, top_k, 10, "The best number of matches to return."); + +int main(int argc, char* argv[]) { + absl::ParseCommandLine(argc, argv); + + // Load the model. + const auto model = coral::LoadModelOrDie(absl::GetFlag(FLAGS_model_path)); + auto edgetpu_context = coral::ContainsEdgeTpuCustomOp(*model) + ? coral::GetEdgeTpuContextOrDie() + : nullptr; + auto interpreter = coral::MakeEdgeTpuInterpreterOrDie(*model, edgetpu_context.get()); + CHECK_EQ(interpreter->AllocateTensors(), kTfLiteOk); + + // Check whether input data need to be preprocessed. + // Image data must go through two transforms before running inference: + // 1. normalization, f = (v - mean) / std + // 2. quantization, q = f / scale + zero_point + // Preprocessing combines the two steps: + // q = (f - mean) / (std * scale) + zero_point + // When std * scale equals 1, and mean - zero_point equals 0, the image data + // do not need any preprocessing. In practice, it is probably okay to skip + // preprocessing for better efficiency when the normalization and quantization + // parameters approximate, but do not exactly meet the above conditions. + CHECK_EQ(interpreter->inputs().size(), 1UL); + const auto* input_tensor = interpreter->input_tensor(0); + CHECK_EQ(input_tensor->type, kTfLiteUInt8) + << "Only support uint8 input type."; + const float scale = input_tensor->params.scale; + const float zero_point = input_tensor->params.zero_point; + const float mean = absl::GetFlag(FLAGS_input_mean); + const float std = absl::GetFlag(FLAGS_input_std); + auto input = coral::MutableTensorData(*input_tensor); + const int input_size = 300; + std::cout << "Expecting " << input_size << "x" << input_size << " input." << std::endl; + if (std::abs(scale * std - 1) < 1e-5 && std::abs(mean - zero_point) < 1e-5) { + // Read the image directly into input tensor as there is no preprocessing + // needed. + std::cout << "Input data does not require preprocessing." << std::endl; + coral::ReadFileToOrDie(absl::GetFlag(FLAGS_image_path), + reinterpret_cast(input.data()), input.size()); + } else { + std::cout << "Input data requires preprocessing." << std::endl; + std::vector image_data(input.size()); + coral::ReadFileToOrDie(absl::GetFlag(FLAGS_image_path), + reinterpret_cast(image_data.data()), + input.size()); + for (uint8_t i = 0; i < input.size(); ++i) { + const float tmp = (image_data[i] - mean) / (std * scale) + zero_point; + if (tmp > 255) { + input[i] = 255; + } else if (tmp < 0) { + input[i] = 0; + } else { + input[i] = static_cast(tmp); + } + } + } + + CHECK_EQ(interpreter->Invoke(), kTfLiteOk); + + // Read the label file. + auto labels = coral::ReadLabelFile(absl::GetFlag(FLAGS_labels_path)); + + float threshold = absl::GetFlag(FLAGS_threshold); + int top_k = absl::GetFlag(FLAGS_top_k); + for (coral::Object result : coral::GetDetectionResults(*interpreter, threshold, top_k)) { + std::cout << "---------------------------" << std::endl; + std::cout << labels[result.id] << std::endl; + std::cout << "Position: " << + "x=" << result.bbox.xmin * input_size << + ",y=" << result.bbox.ymin * input_size << + ",width=" << (result.bbox.xmax - result.bbox.xmin) * input_size << + ",height=" << (result.bbox.ymax - result.bbox.ymin) * input_size << std::endl; + std::cout << "Score: " << result.score << std::endl; + } + + return 0; +} +