1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// A minimal but useful C++ example showing how to load an Imagenet-style object
17// recognition TensorFlow model, prepare input images for it, run them through
18// the graph, and interpret the results.
19//
20// It's designed to have as few dependencies and be as clear as possible, so
21// it's more verbose than it could be in production code. In particular, using
22// auto for the types of a lot of the returned values from TensorFlow calls can
23// remove a lot of boilerplate, but I find the explicit types useful in sample
24// code to make it simple to look up the classes involved.
25//
26// To use it, compile and then run in a working directory with the
27// learning/brain/tutorials/label_image/data/ folder below it, and you should
28// see the top five labels for the example Lena image output. You can then
29// customize it to use your own models or images by changing the file names at
30// the top of the main() function.
31//
32// The googlenet_graph.pb file included by default is created from Inception.
33//
34// Note that, for GIF inputs, to reuse existing code, only single-frame ones
35// are supported.
36
37#include <fstream>
38#include <utility>
39#include <vector>
40
41#include "tensorflow/cc/ops/const_op.h"
42#include "tensorflow/cc/ops/image_ops.h"
43#include "tensorflow/cc/ops/standard_ops.h"
44#include "tensorflow/core/framework/graph.pb.h"
45#include "tensorflow/core/framework/tensor.h"
46#include "tensorflow/core/graph/default_device.h"
47#include "tensorflow/core/graph/graph_def_builder.h"
48#include "tensorflow/core/lib/core/errors.h"
49#include "tensorflow/core/lib/core/stringpiece.h"
50#include "tensorflow/core/lib/core/threadpool.h"
51#include "tensorflow/core/lib/io/path.h"
52#include "tensorflow/core/lib/strings/stringprintf.h"
53#include "tensorflow/core/platform/env.h"
54#include "tensorflow/core/platform/init_main.h"
55#include "tensorflow/core/platform/logging.h"
56#include "tensorflow/core/platform/types.h"
57#include "tensorflow/core/public/session.h"
58#include "tensorflow/core/util/command_line_flags.h"
59
60// These are all common classes it's handy to reference with no namespace.
61using tensorflow::Flag;
62using tensorflow::Tensor;
63using tensorflow::Status;
64using tensorflow::string;
65using tensorflow::int32;
66
67// Takes a file name, and loads a list of labels from it, one per line, and
68// returns a vector of the strings. It pads with empty strings so the length
69// of the result is a multiple of 16, because our model expects that.
70Status ReadLabelsFile(const string& file_name, std::vector<string>* result,
71                      size_t* found_label_count) {
72  std::ifstream file(file_name);
73  if (!file) {
74    return tensorflow::errors::NotFound("Labels file ", file_name,
75                                        " not found.");
76  }
77  result->clear();
78  string line;
79  while (std::getline(file, line)) {
80    result->push_back(line);
81  }
82  *found_label_count = result->size();
83  const int padding = 16;
84  while (result->size() % padding) {
85    result->emplace_back();
86  }
87  return Status::OK();
88}
89
90static Status ReadEntireFile(tensorflow::Env* env, const string& filename,
91                             Tensor* output) {
92  tensorflow::uint64 file_size = 0;
93  TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size));
94
95  string contents;
96  contents.resize(file_size);
97
98  std::unique_ptr<tensorflow::RandomAccessFile> file;
99  TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
100
101  tensorflow::StringPiece data;
102  TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0]));
103  if (data.size() != file_size) {
104    return tensorflow::errors::DataLoss("Truncated read of '", filename,
105                                        "' expected ", file_size, " got ",
106                                        data.size());
107  }
108  output->scalar<string>()() = data.ToString();
109  return Status::OK();
110}
111
112// Given an image file name, read in the data, try to decode it as an image,
113// resize it to the requested size, and then scale the values as desired.
114Status ReadTensorFromImageFile(const string& file_name, const int input_height,
115                               const int input_width, const float input_mean,
116                               const float input_std,
117                               std::vector<Tensor>* out_tensors) {
118  auto root = tensorflow::Scope::NewRootScope();
119  using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
120
121  string input_name = "file_reader";
122  string output_name = "normalized";
123
124  // read file_name into a tensor named input
125  Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape());
126  TF_RETURN_IF_ERROR(
127      ReadEntireFile(tensorflow::Env::Default(), file_name, &input));
128
129  // use a placeholder to read input data
130  auto file_reader =
131      Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING);
132
133  std::vector<std::pair<string, tensorflow::Tensor>> inputs = {
134      {"input", input},
135  };
136
137  // Now try to figure out what kind of file it is and decode it.
138  const int wanted_channels = 3;
139  tensorflow::Output image_reader;
140  if (tensorflow::StringPiece(file_name).ends_with(".png")) {
141    image_reader = DecodePng(root.WithOpName("png_reader"), file_reader,
142                             DecodePng::Channels(wanted_channels));
143  } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) {
144    // gif decoder returns 4-D tensor, remove the first dim
145    image_reader =
146        Squeeze(root.WithOpName("squeeze_first_dim"),
147                DecodeGif(root.WithOpName("gif_reader"), file_reader));
148  } else if (tensorflow::StringPiece(file_name).ends_with(".bmp")) {
149    image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader);
150  } else {
151    // Assume if it's neither a PNG nor a GIF then it must be a JPEG.
152    image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader,
153                              DecodeJpeg::Channels(wanted_channels));
154  }
155  // Now cast the image data to float so we can do normal math on it.
156  auto float_caster =
157      Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT);
158  // The convention for image ops in TensorFlow is that all images are expected
159  // to be in batches, so that they're four-dimensional arrays with indices of
160  // [batch, height, width, channel]. Because we only have a single image, we
161  // have to add a batch dimension of 1 to the start with ExpandDims().
162  auto dims_expander = ExpandDims(root, float_caster, 0);
163  // Bilinearly resize the image to fit the required dimensions.
164  auto resized = ResizeBilinear(
165      root, dims_expander,
166      Const(root.WithOpName("size"), {input_height, input_width}));
167  // Subtract the mean and divide by the scale.
168  Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}),
169      {input_std});
170
171  // This runs the GraphDef network definition that we've just constructed, and
172  // returns the results in the output tensor.
173  tensorflow::GraphDef graph;
174  TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
175
176  std::unique_ptr<tensorflow::Session> session(
177      tensorflow::NewSession(tensorflow::SessionOptions()));
178  TF_RETURN_IF_ERROR(session->Create(graph));
179  TF_RETURN_IF_ERROR(session->Run({inputs}, {output_name}, {}, out_tensors));
180  return Status::OK();
181}
182
183// Reads a model graph definition from disk, and creates a session object you
184// can use to run it.
185Status LoadGraph(const string& graph_file_name,
186                 std::unique_ptr<tensorflow::Session>* session) {
187  tensorflow::GraphDef graph_def;
188  Status load_graph_status =
189      ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def);
190  if (!load_graph_status.ok()) {
191    return tensorflow::errors::NotFound("Failed to load compute graph at '",
192                                        graph_file_name, "'");
193  }
194  session->reset(tensorflow::NewSession(tensorflow::SessionOptions()));
195  Status session_create_status = (*session)->Create(graph_def);
196  if (!session_create_status.ok()) {
197    return session_create_status;
198  }
199  return Status::OK();
200}
201
202// Analyzes the output of the Inception graph to retrieve the highest scores and
203// their positions in the tensor, which correspond to categories.
204Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels,
205                    Tensor* indices, Tensor* scores) {
206  auto root = tensorflow::Scope::NewRootScope();
207  using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
208
209  string output_name = "top_k";
210  TopK(root.WithOpName(output_name), outputs[0], how_many_labels);
211  // This runs the GraphDef network definition that we've just constructed, and
212  // returns the results in the output tensors.
213  tensorflow::GraphDef graph;
214  TF_RETURN_IF_ERROR(root.ToGraphDef(&graph));
215
216  std::unique_ptr<tensorflow::Session> session(
217      tensorflow::NewSession(tensorflow::SessionOptions()));
218  TF_RETURN_IF_ERROR(session->Create(graph));
219  // The TopK node returns two outputs, the scores and their original indices,
220  // so we have to append :0 and :1 to specify them both.
221  std::vector<Tensor> out_tensors;
222  TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"},
223                                  {}, &out_tensors));
224  *scores = out_tensors[0];
225  *indices = out_tensors[1];
226  return Status::OK();
227}
228
229// Given the output of a model run, and the name of a file containing the labels
230// this prints out the top five highest-scoring values.
231Status PrintTopLabels(const std::vector<Tensor>& outputs,
232                      const string& labels_file_name) {
233  std::vector<string> labels;
234  size_t label_count;
235  Status read_labels_status =
236      ReadLabelsFile(labels_file_name, &labels, &label_count);
237  if (!read_labels_status.ok()) {
238    LOG(ERROR) << read_labels_status;
239    return read_labels_status;
240  }
241  const int how_many_labels = std::min(5, static_cast<int>(label_count));
242  Tensor indices;
243  Tensor scores;
244  TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
245  tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>();
246  tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
247  for (int pos = 0; pos < how_many_labels; ++pos) {
248    const int label_index = indices_flat(pos);
249    const float score = scores_flat(pos);
250    LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score;
251  }
252  return Status::OK();
253}
254
255// This is a testing function that returns whether the top label index is the
256// one that's expected.
257Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected,
258                     bool* is_expected) {
259  *is_expected = false;
260  Tensor indices;
261  Tensor scores;
262  const int how_many_labels = 1;
263  TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores));
264  tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>();
265  if (indices_flat(0) != expected) {
266    LOG(ERROR) << "Expected label #" << expected << " but got #"
267               << indices_flat(0);
268    *is_expected = false;
269  } else {
270    *is_expected = true;
271  }
272  return Status::OK();
273}
274
275int main(int argc, char* argv[]) {
276  // These are the command-line flags the program can understand.
277  // They define where the graph and input data is located, and what kind of
278  // input the model expects. If you train your own model, or use something
279  // other than inception_v3, then you'll need to update these.
280  string image = "tensorflow/examples/label_image/data/grace_hopper.jpg";
281  string graph =
282      "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb";
283  string labels =
284      "tensorflow/examples/label_image/data/imagenet_slim_labels.txt";
285  int32 input_width = 299;
286  int32 input_height = 299;
287  float input_mean = 0;
288  float input_std = 255;
289  string input_layer = "input";
290  string output_layer = "InceptionV3/Predictions/Reshape_1";
291  bool self_test = false;
292  string root_dir = "";
293  std::vector<Flag> flag_list = {
294      Flag("image", &image, "image to be processed"),
295      Flag("graph", &graph, "graph to be executed"),
296      Flag("labels", &labels, "name of file containing labels"),
297      Flag("input_width", &input_width, "resize image to this width in pixels"),
298      Flag("input_height", &input_height,
299           "resize image to this height in pixels"),
300      Flag("input_mean", &input_mean, "scale pixel values to this mean"),
301      Flag("input_std", &input_std, "scale pixel values to this std deviation"),
302      Flag("input_layer", &input_layer, "name of input layer"),
303      Flag("output_layer", &output_layer, "name of output layer"),
304      Flag("self_test", &self_test, "run a self test"),
305      Flag("root_dir", &root_dir,
306           "interpret image and graph file names relative to this directory"),
307  };
308  string usage = tensorflow::Flags::Usage(argv[0], flag_list);
309  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
310  if (!parse_result) {
311    LOG(ERROR) << usage;
312    return -1;
313  }
314
315  // We need to call this to set up global state for TensorFlow.
316  tensorflow::port::InitMain(argv[0], &argc, &argv);
317  if (argc > 1) {
318    LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
319    return -1;
320  }
321
322  // First we load and initialize the model.
323  std::unique_ptr<tensorflow::Session> session;
324  string graph_path = tensorflow::io::JoinPath(root_dir, graph);
325  Status load_graph_status = LoadGraph(graph_path, &session);
326  if (!load_graph_status.ok()) {
327    LOG(ERROR) << load_graph_status;
328    return -1;
329  }
330
331  // Get the image from disk as a float array of numbers, resized and normalized
332  // to the specifications the main graph expects.
333  std::vector<Tensor> resized_tensors;
334  string image_path = tensorflow::io::JoinPath(root_dir, image);
335  Status read_tensor_status =
336      ReadTensorFromImageFile(image_path, input_height, input_width, input_mean,
337                              input_std, &resized_tensors);
338  if (!read_tensor_status.ok()) {
339    LOG(ERROR) << read_tensor_status;
340    return -1;
341  }
342  const Tensor& resized_tensor = resized_tensors[0];
343
344  // Actually run the image through the model.
345  std::vector<Tensor> outputs;
346  Status run_status = session->Run({{input_layer, resized_tensor}},
347                                   {output_layer}, {}, &outputs);
348  if (!run_status.ok()) {
349    LOG(ERROR) << "Running model failed: " << run_status;
350    return -1;
351  }
352
353  // This is for automated testing to make sure we get the expected result with
354  // the default settings. We know that label 653 (military uniform) should be
355  // the top label for the Admiral Hopper image.
356  if (self_test) {
357    bool expected_matches;
358    Status check_status = CheckTopLabel(outputs, 653, &expected_matches);
359    if (!check_status.ok()) {
360      LOG(ERROR) << "Running check failed: " << check_status;
361      return -1;
362    }
363    if (!expected_matches) {
364      LOG(ERROR) << "Self-test failed!";
365      return -1;
366    }
367  }
368
369  // Do something interesting with the results we've generated.
370  Status print_status = PrintTopLabels(outputs, labels);
371  if (!print_status.ok()) {
372    LOG(ERROR) << "Running print failed: " << print_status;
373    return -1;
374  }
375
376  return 0;
377}
378