1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// A minimal but useful C++ example showing how to load an Imagenet-style object 17// recognition TensorFlow model, prepare input images for it, run them through 18// the graph, and interpret the results. 19// 20// It's designed to have as few dependencies and be as clear as possible, so 21// it's more verbose than it could be in production code. In particular, using 22// auto for the types of a lot of the returned values from TensorFlow calls can 23// remove a lot of boilerplate, but I find the explicit types useful in sample 24// code to make it simple to look up the classes involved. 25// 26// To use it, compile and then run in a working directory with the 27// learning/brain/tutorials/label_image/data/ folder below it, and you should 28// see the top five labels for the example Lena image output. You can then 29// customize it to use your own models or images by changing the file names at 30// the top of the main() function. 31// 32// The googlenet_graph.pb file included by default is created from Inception. 33// 34// Note that, for GIF inputs, to reuse existing code, only single-frame ones 35// are supported. 36 37#include <fstream> 38#include <utility> 39#include <vector> 40 41#include "tensorflow/cc/ops/const_op.h" 42#include "tensorflow/cc/ops/image_ops.h" 43#include "tensorflow/cc/ops/standard_ops.h" 44#include "tensorflow/core/framework/graph.pb.h" 45#include "tensorflow/core/framework/tensor.h" 46#include "tensorflow/core/graph/default_device.h" 47#include "tensorflow/core/graph/graph_def_builder.h" 48#include "tensorflow/core/lib/core/errors.h" 49#include "tensorflow/core/lib/core/stringpiece.h" 50#include "tensorflow/core/lib/core/threadpool.h" 51#include "tensorflow/core/lib/io/path.h" 52#include "tensorflow/core/lib/strings/stringprintf.h" 53#include "tensorflow/core/platform/env.h" 54#include "tensorflow/core/platform/init_main.h" 55#include "tensorflow/core/platform/logging.h" 56#include "tensorflow/core/platform/types.h" 57#include "tensorflow/core/public/session.h" 58#include "tensorflow/core/util/command_line_flags.h" 59 60// These are all common classes it's handy to reference with no namespace. 61using tensorflow::Flag; 62using tensorflow::Tensor; 63using tensorflow::Status; 64using tensorflow::string; 65using tensorflow::int32; 66 67// Takes a file name, and loads a list of labels from it, one per line, and 68// returns a vector of the strings. It pads with empty strings so the length 69// of the result is a multiple of 16, because our model expects that. 70Status ReadLabelsFile(const string& file_name, std::vector<string>* result, 71 size_t* found_label_count) { 72 std::ifstream file(file_name); 73 if (!file) { 74 return tensorflow::errors::NotFound("Labels file ", file_name, 75 " not found."); 76 } 77 result->clear(); 78 string line; 79 while (std::getline(file, line)) { 80 result->push_back(line); 81 } 82 *found_label_count = result->size(); 83 const int padding = 16; 84 while (result->size() % padding) { 85 result->emplace_back(); 86 } 87 return Status::OK(); 88} 89 90static Status ReadEntireFile(tensorflow::Env* env, const string& filename, 91 Tensor* output) { 92 tensorflow::uint64 file_size = 0; 93 TF_RETURN_IF_ERROR(env->GetFileSize(filename, &file_size)); 94 95 string contents; 96 contents.resize(file_size); 97 98 std::unique_ptr<tensorflow::RandomAccessFile> file; 99 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file)); 100 101 tensorflow::StringPiece data; 102 TF_RETURN_IF_ERROR(file->Read(0, file_size, &data, &(contents)[0])); 103 if (data.size() != file_size) { 104 return tensorflow::errors::DataLoss("Truncated read of '", filename, 105 "' expected ", file_size, " got ", 106 data.size()); 107 } 108 output->scalar<string>()() = data.ToString(); 109 return Status::OK(); 110} 111 112// Given an image file name, read in the data, try to decode it as an image, 113// resize it to the requested size, and then scale the values as desired. 114Status ReadTensorFromImageFile(const string& file_name, const int input_height, 115 const int input_width, const float input_mean, 116 const float input_std, 117 std::vector<Tensor>* out_tensors) { 118 auto root = tensorflow::Scope::NewRootScope(); 119 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 120 121 string input_name = "file_reader"; 122 string output_name = "normalized"; 123 124 // read file_name into a tensor named input 125 Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape()); 126 TF_RETURN_IF_ERROR( 127 ReadEntireFile(tensorflow::Env::Default(), file_name, &input)); 128 129 // use a placeholder to read input data 130 auto file_reader = 131 Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING); 132 133 std::vector<std::pair<string, tensorflow::Tensor>> inputs = { 134 {"input", input}, 135 }; 136 137 // Now try to figure out what kind of file it is and decode it. 138 const int wanted_channels = 3; 139 tensorflow::Output image_reader; 140 if (tensorflow::StringPiece(file_name).ends_with(".png")) { 141 image_reader = DecodePng(root.WithOpName("png_reader"), file_reader, 142 DecodePng::Channels(wanted_channels)); 143 } else if (tensorflow::StringPiece(file_name).ends_with(".gif")) { 144 // gif decoder returns 4-D tensor, remove the first dim 145 image_reader = 146 Squeeze(root.WithOpName("squeeze_first_dim"), 147 DecodeGif(root.WithOpName("gif_reader"), file_reader)); 148 } else if (tensorflow::StringPiece(file_name).ends_with(".bmp")) { 149 image_reader = DecodeBmp(root.WithOpName("bmp_reader"), file_reader); 150 } else { 151 // Assume if it's neither a PNG nor a GIF then it must be a JPEG. 152 image_reader = DecodeJpeg(root.WithOpName("jpeg_reader"), file_reader, 153 DecodeJpeg::Channels(wanted_channels)); 154 } 155 // Now cast the image data to float so we can do normal math on it. 156 auto float_caster = 157 Cast(root.WithOpName("float_caster"), image_reader, tensorflow::DT_FLOAT); 158 // The convention for image ops in TensorFlow is that all images are expected 159 // to be in batches, so that they're four-dimensional arrays with indices of 160 // [batch, height, width, channel]. Because we only have a single image, we 161 // have to add a batch dimension of 1 to the start with ExpandDims(). 162 auto dims_expander = ExpandDims(root, float_caster, 0); 163 // Bilinearly resize the image to fit the required dimensions. 164 auto resized = ResizeBilinear( 165 root, dims_expander, 166 Const(root.WithOpName("size"), {input_height, input_width})); 167 // Subtract the mean and divide by the scale. 168 Div(root.WithOpName(output_name), Sub(root, resized, {input_mean}), 169 {input_std}); 170 171 // This runs the GraphDef network definition that we've just constructed, and 172 // returns the results in the output tensor. 173 tensorflow::GraphDef graph; 174 TF_RETURN_IF_ERROR(root.ToGraphDef(&graph)); 175 176 std::unique_ptr<tensorflow::Session> session( 177 tensorflow::NewSession(tensorflow::SessionOptions())); 178 TF_RETURN_IF_ERROR(session->Create(graph)); 179 TF_RETURN_IF_ERROR(session->Run({inputs}, {output_name}, {}, out_tensors)); 180 return Status::OK(); 181} 182 183// Reads a model graph definition from disk, and creates a session object you 184// can use to run it. 185Status LoadGraph(const string& graph_file_name, 186 std::unique_ptr<tensorflow::Session>* session) { 187 tensorflow::GraphDef graph_def; 188 Status load_graph_status = 189 ReadBinaryProto(tensorflow::Env::Default(), graph_file_name, &graph_def); 190 if (!load_graph_status.ok()) { 191 return tensorflow::errors::NotFound("Failed to load compute graph at '", 192 graph_file_name, "'"); 193 } 194 session->reset(tensorflow::NewSession(tensorflow::SessionOptions())); 195 Status session_create_status = (*session)->Create(graph_def); 196 if (!session_create_status.ok()) { 197 return session_create_status; 198 } 199 return Status::OK(); 200} 201 202// Analyzes the output of the Inception graph to retrieve the highest scores and 203// their positions in the tensor, which correspond to categories. 204Status GetTopLabels(const std::vector<Tensor>& outputs, int how_many_labels, 205 Tensor* indices, Tensor* scores) { 206 auto root = tensorflow::Scope::NewRootScope(); 207 using namespace ::tensorflow::ops; // NOLINT(build/namespaces) 208 209 string output_name = "top_k"; 210 TopK(root.WithOpName(output_name), outputs[0], how_many_labels); 211 // This runs the GraphDef network definition that we've just constructed, and 212 // returns the results in the output tensors. 213 tensorflow::GraphDef graph; 214 TF_RETURN_IF_ERROR(root.ToGraphDef(&graph)); 215 216 std::unique_ptr<tensorflow::Session> session( 217 tensorflow::NewSession(tensorflow::SessionOptions())); 218 TF_RETURN_IF_ERROR(session->Create(graph)); 219 // The TopK node returns two outputs, the scores and their original indices, 220 // so we have to append :0 and :1 to specify them both. 221 std::vector<Tensor> out_tensors; 222 TF_RETURN_IF_ERROR(session->Run({}, {output_name + ":0", output_name + ":1"}, 223 {}, &out_tensors)); 224 *scores = out_tensors[0]; 225 *indices = out_tensors[1]; 226 return Status::OK(); 227} 228 229// Given the output of a model run, and the name of a file containing the labels 230// this prints out the top five highest-scoring values. 231Status PrintTopLabels(const std::vector<Tensor>& outputs, 232 const string& labels_file_name) { 233 std::vector<string> labels; 234 size_t label_count; 235 Status read_labels_status = 236 ReadLabelsFile(labels_file_name, &labels, &label_count); 237 if (!read_labels_status.ok()) { 238 LOG(ERROR) << read_labels_status; 239 return read_labels_status; 240 } 241 const int how_many_labels = std::min(5, static_cast<int>(label_count)); 242 Tensor indices; 243 Tensor scores; 244 TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores)); 245 tensorflow::TTypes<float>::Flat scores_flat = scores.flat<float>(); 246 tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>(); 247 for (int pos = 0; pos < how_many_labels; ++pos) { 248 const int label_index = indices_flat(pos); 249 const float score = scores_flat(pos); 250 LOG(INFO) << labels[label_index] << " (" << label_index << "): " << score; 251 } 252 return Status::OK(); 253} 254 255// This is a testing function that returns whether the top label index is the 256// one that's expected. 257Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected, 258 bool* is_expected) { 259 *is_expected = false; 260 Tensor indices; 261 Tensor scores; 262 const int how_many_labels = 1; 263 TF_RETURN_IF_ERROR(GetTopLabels(outputs, how_many_labels, &indices, &scores)); 264 tensorflow::TTypes<int32>::Flat indices_flat = indices.flat<int32>(); 265 if (indices_flat(0) != expected) { 266 LOG(ERROR) << "Expected label #" << expected << " but got #" 267 << indices_flat(0); 268 *is_expected = false; 269 } else { 270 *is_expected = true; 271 } 272 return Status::OK(); 273} 274 275int main(int argc, char* argv[]) { 276 // These are the command-line flags the program can understand. 277 // They define where the graph and input data is located, and what kind of 278 // input the model expects. If you train your own model, or use something 279 // other than inception_v3, then you'll need to update these. 280 string image = "tensorflow/examples/label_image/data/grace_hopper.jpg"; 281 string graph = 282 "tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb"; 283 string labels = 284 "tensorflow/examples/label_image/data/imagenet_slim_labels.txt"; 285 int32 input_width = 299; 286 int32 input_height = 299; 287 float input_mean = 0; 288 float input_std = 255; 289 string input_layer = "input"; 290 string output_layer = "InceptionV3/Predictions/Reshape_1"; 291 bool self_test = false; 292 string root_dir = ""; 293 std::vector<Flag> flag_list = { 294 Flag("image", &image, "image to be processed"), 295 Flag("graph", &graph, "graph to be executed"), 296 Flag("labels", &labels, "name of file containing labels"), 297 Flag("input_width", &input_width, "resize image to this width in pixels"), 298 Flag("input_height", &input_height, 299 "resize image to this height in pixels"), 300 Flag("input_mean", &input_mean, "scale pixel values to this mean"), 301 Flag("input_std", &input_std, "scale pixel values to this std deviation"), 302 Flag("input_layer", &input_layer, "name of input layer"), 303 Flag("output_layer", &output_layer, "name of output layer"), 304 Flag("self_test", &self_test, "run a self test"), 305 Flag("root_dir", &root_dir, 306 "interpret image and graph file names relative to this directory"), 307 }; 308 string usage = tensorflow::Flags::Usage(argv[0], flag_list); 309 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); 310 if (!parse_result) { 311 LOG(ERROR) << usage; 312 return -1; 313 } 314 315 // We need to call this to set up global state for TensorFlow. 316 tensorflow::port::InitMain(argv[0], &argc, &argv); 317 if (argc > 1) { 318 LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; 319 return -1; 320 } 321 322 // First we load and initialize the model. 323 std::unique_ptr<tensorflow::Session> session; 324 string graph_path = tensorflow::io::JoinPath(root_dir, graph); 325 Status load_graph_status = LoadGraph(graph_path, &session); 326 if (!load_graph_status.ok()) { 327 LOG(ERROR) << load_graph_status; 328 return -1; 329 } 330 331 // Get the image from disk as a float array of numbers, resized and normalized 332 // to the specifications the main graph expects. 333 std::vector<Tensor> resized_tensors; 334 string image_path = tensorflow::io::JoinPath(root_dir, image); 335 Status read_tensor_status = 336 ReadTensorFromImageFile(image_path, input_height, input_width, input_mean, 337 input_std, &resized_tensors); 338 if (!read_tensor_status.ok()) { 339 LOG(ERROR) << read_tensor_status; 340 return -1; 341 } 342 const Tensor& resized_tensor = resized_tensors[0]; 343 344 // Actually run the image through the model. 345 std::vector<Tensor> outputs; 346 Status run_status = session->Run({{input_layer, resized_tensor}}, 347 {output_layer}, {}, &outputs); 348 if (!run_status.ok()) { 349 LOG(ERROR) << "Running model failed: " << run_status; 350 return -1; 351 } 352 353 // This is for automated testing to make sure we get the expected result with 354 // the default settings. We know that label 653 (military uniform) should be 355 // the top label for the Admiral Hopper image. 356 if (self_test) { 357 bool expected_matches; 358 Status check_status = CheckTopLabel(outputs, 653, &expected_matches); 359 if (!check_status.ok()) { 360 LOG(ERROR) << "Running check failed: " << check_status; 361 return -1; 362 } 363 if (!expected_matches) { 364 LOG(ERROR) << "Self-test failed!"; 365 return -1; 366 } 367 } 368 369 // Do something interesting with the results we've generated. 370 Status print_status = PrintTopLabels(outputs, labels); 371 if (!print_status.ok()) { 372 LOG(ERROR) << "Running print failed: " << print_status; 373 return -1; 374 } 375 376 return 0; 377} 378