1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <setjmp.h>
17#include <stdio.h>
18#include <string.h>
19#include <fstream>
20#include <vector>
21
22#include "tensorflow/cc/ops/const_op.h"
23#include "tensorflow/cc/ops/image_ops.h"
24#include "tensorflow/cc/ops/standard_ops.h"
25#include "tensorflow/core/framework/graph.pb.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/graph/default_device.h"
28#include "tensorflow/core/graph/graph_def_builder.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/lib/core/stringpiece.h"
31#include "tensorflow/core/lib/core/threadpool.h"
32#include "tensorflow/core/lib/io/path.h"
33#include "tensorflow/core/lib/strings/stringprintf.h"
34#include "tensorflow/core/platform/init_main.h"
35#include "tensorflow/core/platform/logging.h"
36#include "tensorflow/core/platform/types.h"
37#include "tensorflow/core/public/session.h"
38#include "tensorflow/core/util/command_line_flags.h"
39
40// These are all common classes it's handy to reference with no namespace.
41using tensorflow::Flag;
42using tensorflow::Tensor;
43using tensorflow::Status;
44using tensorflow::string;
45using tensorflow::int32;
46using tensorflow::uint8;
47
48// Takes a file name, and loads a list of comma-separated box priors from it,
49// one per line, and returns a vector of the values.
50Status ReadLocationsFile(const string& file_name, std::vector<float>* result,
51                         size_t* found_label_count) {
52  std::ifstream file(file_name);
53  if (!file) {
54    return tensorflow::errors::NotFound("Labels file ", file_name,
55                                        " not found.");
56  }
57  result->clear();
58  string line;
59  while (std::getline(file, line)) {
60    std::vector<float> tokens;
61    CHECK(tensorflow::str_util::SplitAndParseAsFloats(line, ',', &tokens));
62    for (auto number : tokens) {
63      result->push_back(number);
64    }
65  }
66  *found_label_count = result->size();
67  return Status::OK();
68}
69
70// Given an image file name, read in the data, try to decode it as an image,
71// resize it to the requested size, and then scale the values as desired.
72Status ReadTensorFromImageFile(const string& file_name, const int input_height,
73                               const int input_width, const float input_mean,
74                               const float input_std,
75                               std::vector<Tensor>* out_tensors) {
76  auto root = tensorflow::Scope::NewRootScope();
77  using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
78
79  string input_name = "file_reader";
80  string original_name = "identity";
81  string output_name = "normalized";
82  auto file_reader =
83      tensorflow::ops::ReadFile(root.WithOpName(input_name), file_name);
84  // Now try to figure out what kind of file it is and decode it.
85  const int wanted_channels = 3;
86  tensorflow::Output image_reader;
87  if (tensorflow::StringPiece(file_name).ends_with(".png")) {
88    image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
89                             DecodePng::Channels(wanted_channels));
90  } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) {
91    image_reader = DecodeGif(root.WithOpName("gif_reader"), file_reader);
92  } else {
93    // Assume if it's neither a PNG nor a GIF then it must be a JPEG.
94    image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
95                              DecodeJpeg::Channels(wanted_channels));
96  }
97
98  // Also return identity so that we can know the original dimensions and
99  // optionally save the image out with bounding boxes overlaid.
100  auto original_image = Identity(root.WithOpName(original_name), image_reader);
101
102  // Now cast the image data to float so we can do normal math on it.
103  auto float_caster = Cast(root.WithOpName("float_caster"), original_image,
104                           tensorflow::DT_FLOAT);
105  // The convention for image ops in TensorFlow is that all images are expected
106  // to be in batches, so that they're four-dimensional arrays with indices of
107  // [batch, height, width, channel]. Because we only have a single image, we
108  // have to add a batch dimension of 1 to the start with ExpandDims().
109  auto dims_expander = ExpandDims(root, float_caster, 0);
110
111  // Bilinearly resize the image to fit the required dimensions.
112  auto resized = ResizeBilinear(
113      root, dims_expander,
114      Const(root.WithOpName("size"), {input_height, input_width}));
115  // Subtract the mean and divide by the scale.
116  Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
117      {input_std});
118
119  // This runs the GraphDef network definition that we've just constructed, and
120  // returns the results in the output tensor.
121  tensorflow::GraphDef graph;
122  TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
123
124  std::unique_ptr<tensorflow::Session> session(
125      tensorflow::NewSession(tensorflow::SessionOptions()));
126  TF_RETURN_IF_ERROR(session->Create(graph));
127  TF_RETURN_IF_ERROR(
128      session->Run({}, {output_name, original_name}, {}, out_tensors));
129  return Status::OK();
130}
131
132Status SaveImage(const Tensor& tensor, const string& file_path) {
133  LOG(INFO) << "Saving image to " << file_path;
134  CHECK(tensorflow::StringPiece(file_path).ends_with(".png"))
135      << "Only saving of png files is supported.";
136
137  auto root = tensorflow::Scope::NewRootScope();
138  using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
139
140  string encoder_name = "encode";
141  string output_name = "file_writer";
142
143  tensorflow::Output image_encoder =
144      EncodePng(root.WithOpName(encoder_name), tensor);
145  tensorflow::ops::WriteFile file_saver = tensorflow::ops::WriteFile(
146      root.WithOpName(output_name), file_path, image_encoder);
147
148  tensorflow::GraphDef graph;
149  TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
150
151  std::unique_ptr<tensorflow::Session> session(
152      tensorflow::NewSession(tensorflow::SessionOptions()));
153  TF_RETURN_IF_ERROR(session->Create(graph));
154  std::vector<Tensor> outputs;
155  TF_RETURN_IF_ERROR(session->Run({}, {}, {output_name}, &outputs));
156
157  return Status::OK();
158}
159
160// Reads a model graph definition from disk, and creates a session object you
161// can use to run it.
162Status LoadGraph(const string& graph_file_name,
163                 std::unique_ptr<tensorflow::Session>* session) {
164  tensorflow::GraphDef graph_def;
165  Status load_graph_status =
166      ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
167  if (!load_graph_status.ok()) {
168    return tensorflow::errors::NotFound("Failed to load compute graph at '",
169                                        graph_file_name, "'");
170  }
171  session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
172  Status session_create_status = (*session)->Create(graph_def);
173  if (!session_create_status.ok()) {
174    return session_create_status;
175  }
176  return Status::OK();
177}
178
179// Analyzes the output of the MultiBox graph to retrieve the highest scores and
180// their positions in the tensor, which correspond to individual box detections.
181Status GetTopDetections(const std::vector<Tensor>& outputs, int how_many_labels,
182                        Tensor* indices, Tensor* scores) {
183  auto root = tensorflow::Scope::NewRootScope();
184  using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
185
186  string output_name = "top_k";
187  TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
188  // This runs the GraphDef network definition that we've just constructed, and
189  // returns the results in the output tensors.
190  tensorflow::GraphDef graph;
191  TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
192
193  std::unique_ptr<tensorflow::Session> session(
194      tensorflow::NewSession(tensorflow::SessionOptions()));
195  TF_RETURN_IF_ERROR(session->Create(graph));
196  // The TopK node returns two outputs, the scores and their original indices,
197  // so we have to append :0 and :1 to specify them both.
198  std::vector<Tensor> out_tensors;
199  TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
200                                  {}, &out_tensors));
201  *scores = out_tensors[0];
202  *indices = out_tensors[1];
203  return Status::OK();
204}
205
206// Converts an encoded location to an actual box placement with the provided
207// box priors.
208void DecodeLocation(const float* encoded_location, const float* box_priors,
209                    float* decoded_location) {
210  bool non_zero = false;
211  for (int i = 0; i < 4; ++i) {
212    const float curr_encoding = encoded_location[i];
213    non_zero = non_zero || curr_encoding != 0.0f;
214
215    const float mean = box_priors[i * 2];
216    const float std_dev = box_priors[i * 2 + 1];
217
218    float currentLocation = curr_encoding * std_dev + mean;
219
220    currentLocation = std::max(currentLocation, 0.0f);
221    currentLocation = std::min(currentLocation, 1.0f);
222    decoded_location[i] = currentLocation;
223  }
224
225  if (!non_zero) {
226    LOG(WARNING) << "No non-zero encodings; check log for inference errors.";
227  }
228}
229
230float DecodeScore(float encoded_score) { return 1 / (1 + exp(-encoded_score)); }
231
232void DrawBox(const int image_width, const int image_height, int left, int top,
233             int right, int bottom, tensorflow::TTypes<uint8>::Flat* image) {
234  tensorflow::TTypes<uint8>::Flat image_ref = *image;
235
236  top = std::max(0, std::min(image_height - 1, top));
237  bottom = std::max(0, std::min(image_height - 1, bottom));
238
239  left = std::max(0, std::min(image_width - 1, left));
240  right = std::max(0, std::min(image_width - 1, right));
241
242  for (int i = 0; i < 3; ++i) {
243    uint8 val = i == 2 ? 255 : 0;
244    for (int x = left; x <= right; ++x) {
245      image_ref((top * image_width + x) * 3 + i) = val;
246      image_ref((bottom * image_width + x) * 3 + i) = val;
247    }
248    for (int y = top; y <= bottom; ++y) {
249      image_ref((y * image_width + left) * 3 + i) = val;
250      image_ref((y * image_width + right) * 3 + i) = val;
251    }
252  }
253}
254
255// Given the output of a model run, and the name of a file containing the labels
256// this prints out the top five highest-scoring values.
257Status PrintTopDetections(const std::vector<Tensor>& outputs,
258                          const string& labels_file_name,
259                          const int num_boxes,
260                          const int num_detections,
261                          const string& image_file_name,
262                          Tensor* original_tensor) {
263  std::vector<float> locations;
264  size_t label_count;
265  Status read_labels_status =
266      ReadLocationsFile(labels_file_name, &locations, &label_count);
267  if (!read_labels_status.ok()) {
268    LOG(ERROR) << read_labels_status;
269    return read_labels_status;
270  }
271  CHECK_EQ(label_count, num_boxes * 8);
272
273  const int how_many_labels =
274      std::min(num_detections, static_cast<int>(label_count));
275  Tensor indices;
276  Tensor scores;
277  TF_RETURN_IF_ERROR(
278      GetTopDetections(outputs, how_many_labels, &indices, &scores));
279
280  tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
281
282  tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
283
284  const Tensor& encoded_locations = outputs[1];
285  auto locations_encoded = encoded_locations.flat<float>();
286
287  LOG(INFO) << original_tensor->DebugString();
288  const int image_width = original_tensor->shape().dim_size(1);
289  const int image_height = original_tensor->shape().dim_size(0);
290
291  tensorflow::TTypes<uint8>::Flat image_flat = original_tensor->flat<uint8>();
292
293  LOG(INFO) << "===== Top " << how_many_labels << " Detections ======";
294  for (int pos = 0; pos < how_many_labels; ++pos) {
295    const int label_index = indices_flat(pos);
296    const float score = scores_flat(pos);
297
298    float decoded_location[4];
299    DecodeLocation(&locations_encoded(label_index * 4),
300                   &locations[label_index * 8], decoded_location);
301
302    float left = decoded_location[0] * image_width;
303    float top = decoded_location[1] * image_height;
304    float right = decoded_location[2] * image_width;
305    float bottom = decoded_location[3] * image_height;
306
307    LOG(INFO) << "Detection " << pos << ": "
308              << "L:" << left << " "
309              << "T:" << top << " "
310              << "R:" << right << " "
311              << "B:" << bottom << " "
312              << "(" << label_index << ") score: " << DecodeScore(score);
313
314    DrawBox(image_width, image_height, left, top, right, bottom, &image_flat);
315  }
316
317  if (!image_file_name.empty()) {
318    return SaveImage(*original_tensor, image_file_name);
319  }
320  return Status::OK();
321}
322
323int main(int argc, char* argv[]) {
324  // These are the command-line flags the program can understand.
325  // They define where the graph and input data is located, and what kind of
326  // input the model expects. If you train your own model, or use something
327  // other than multibox_model you'll need to update these.
328  string image =
329      "tensorflow/examples/multibox_detector/data/surfers.jpg";
330  string graph =
331      "tensorflow/examples/multibox_detector/data/"
332      "multibox_model.pb";
333  string box_priors =
334      "tensorflow/examples/multibox_detector/data/"
335      "multibox_location_priors.txt";
336  int32 input_width = 224;
337  int32 input_height = 224;
338  int32 input_mean = 128;
339  int32 input_std = 128;
340  int32 num_detections = 5;
341  int32 num_boxes = 784;
342  string input_layer = "ResizeBilinear";
343  string output_location_layer = "output_locations/Reshape";
344  string output_score_layer = "output_scores/Reshape";
345  string root_dir = "";
346  string image_out = "";
347
348  std::vector<Flag> flag_list = {
349      Flag("image", &image, "image to be processed"),
350      Flag("image_out", &image_out,
351           "location to save output image, if desired"),
352      Flag("graph", &graph, "graph to be executed"),
353      Flag("box_priors", &box_priors, "name of file containing box priors"),
354      Flag("input_width", &input_width, "resize image to this width in pixels"),
355      Flag("input_height", &input_height,
356           "resize image to this height in pixels"),
357      Flag("input_mean", &input_mean, "scale pixel values to this mean"),
358      Flag("input_std", &input_std, "scale pixel values to this std deviation"),
359      Flag("num_detections", &num_detections,
360           "number of top detections to return"),
361      Flag("num_boxes", &num_boxes,
362           "number of boxes defined by the location file"),
363      Flag("input_layer", &input_layer, "name of input layer"),
364      Flag("output_location_layer", &output_location_layer,
365           "name of location output layer"),
366      Flag("output_score_layer", &output_score_layer,
367           "name of score output layer"),
368      Flag("root_dir", &root_dir,
369           "interpret image and graph file names relative to this directory"),
370  };
371
372  string usage = tensorflow::Flags::Usage(argv[0], flag_list);
373  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
374  if (!parse_result) {
375    LOG(ERROR) << usage;
376    return -1;
377  }
378
379  // We need to call this to set up global state for TensorFlow.
380  tensorflow::port::InitMain(argv[0], &argc, &argv);
381  if (argc > 1) {
382    LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
383    return -1;
384  }
385
386  // First we load and initialize the model.
387  std::unique_ptr<tensorflow::Session> session;
388  string graph_path = tensorflow::io::JoinPath(root_dir, graph);
389  Status load_graph_status = LoadGraph(graph_path, &session);
390  if (!load_graph_status.ok()) {
391    LOG(ERROR) << load_graph_status;
392    return -1;
393  }
394
395  // Get the image from disk as a float array of numbers, resized and normalized
396  // to the specifications the main graph expects.
397  std::vector<Tensor> image_tensors;
398  string image_path = tensorflow::io::JoinPath(root_dir, image);
399
400  Status read_tensor_status =
401      ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
402                              input_std, &image_tensors);
403  if (!read_tensor_status.ok()) {
404    LOG(ERROR) << read_tensor_status;
405    return -1;
406  }
407  const Tensor& resized_tensor = image_tensors[0];
408
409  // Actually run the image through the model.
410  std::vector<Tensor> outputs;
411  Status run_status =
412      session->Run({{input_layer, resized_tensor}},
413                   {output_score_layer, output_location_layer}, {}, &outputs);
414  if (!run_status.ok()) {
415    LOG(ERROR) << "Running model failed: " << run_status;
416    return -1;
417  }
418
419  Status print_status = PrintTopDetections(outputs, box_priors, num_boxes,
420                                           num_detections, image_out,
421                                           &image_tensors[1]);
422
423  if (!print_status.ok()) {
424    LOG(ERROR) << "Running print failed: " << print_status;
425    return -1;
426  }
427  return 0;
428}
429