1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/contrib/session_bundle/session_bundle.h"
17
18#include <string>
19#include <utility>
20#include <vector>
21
22#include "google/protobuf/any.pb.h"
23#include "tensorflow/contrib/session_bundle/manifest.pb.h"
24#include "tensorflow/core/framework/graph.pb.h"
25#include "tensorflow/core/framework/graph_def_util.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/tensor_types.h"
29#include "tensorflow/core/framework/types.pb.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/core/status.h"
32#include "tensorflow/core/lib/io/path.h"
33#include "tensorflow/core/lib/monitoring/counter.h"
34#include "tensorflow/core/platform/env.h"
35#include "tensorflow/core/platform/protobuf_internal.h"
36#include "tensorflow/core/platform/types.h"
37#include "tensorflow/core/protobuf/meta_graph.pb.h"
38#include "tensorflow/core/protobuf/saver.pb.h"
39#include "tensorflow/core/public/session_options.h"
40#include "tensorflow/core/util/tensor_bundle/naming.h"
41
42namespace tensorflow {
43namespace serving {
44namespace {
45
46auto* load_attempt_count = monitoring::Counter<2>::New(
47    "/tensorflow/contrib/session_bundle/load_attempt_count",
48    "The number of times a SessionBundle was requested to be loaded.",
49    "model_path", "status");
50auto* load_latency = monitoring::Counter<1>::New(
51    "/tensorflow/contrib/session_bundle/load_latency",
52    "Latency in microseconds for SessionBundles that were successfully loaded.",
53    "model_path");
54constexpr char kLoadAttemptFail[] = "fail";
55constexpr char kLoadAttemptSuccess[] = "success";
56
57// Create a session using the given options and load the graph.
58Status CreateSessionFromGraphDef(const SessionOptions& options,
59                                 const GraphDef& graph,
60                                 std::unique_ptr<Session>* session) {
61  session->reset(NewSession(options));
62  return (*session)->Create(graph);
63}
64
65Status GetMetaGraphDefFromExport(const StringPiece export_dir,
66                                 MetaGraphDef* meta_graph_def) {
67  const string meta_graph_def_path =
68      io::JoinPath(export_dir, kMetaGraphDefFilename);
69  return ReadBinaryProto(Env::Default(), meta_graph_def_path, meta_graph_def);
70}
71
72// Creates a string tensor.
73Tensor CreateStringTensor(const string& value) {
74  Tensor tensor(DT_STRING, TensorShape({}));
75  tensor.scalar<string>()() = value;
76  return tensor;
77}
78
79// Adds Assets related tensors (assets_dir and asset files) to the inputs.
80void AddAssetsTensorsToInputs(const StringPiece export_dir,
81                              const std::vector<AssetFile>& asset_files,
82                              std::vector<std::pair<string, Tensor>>* inputs) {
83  if (asset_files.empty()) {
84    return;
85  }
86  for (auto& asset : asset_files) {
87    Tensor assets_file_tensor = CreateStringTensor(
88        io::JoinPath(export_dir, kAssetsDirectory, asset.filename()));
89    inputs->push_back(
90        {asset.tensor_binding().tensor_name(), assets_file_tensor});
91  }
92}
93
94// Historically, model exporter(exporter.py) takes only saver with sharded=True,
95// and therefore always exports checkpoint in pattern file names.  In practice,
96// instead of training from scratch and export directly, we usually want to
97// restore from existing checkpoints and then export directly.  To support such
98// case, model exporter now supports reusing saver object restored from existing
99// checkpoint, that may have sharded=False - it will then export checkpoint file
100// in plain file name.  This method is to support models exported by both types
101// of saver object.  The change is backward-compatible, therefore no changes are
102// needed for existing model exports.
103//
104// Checkpoint v2 support: Variables exported using tf-exporter in the checkpoint
105// v2 format will have export.index and export.data-?????-of-????? files as
106// opposed to just an export or export-?????-of-????? file. The V2 save/restore
107// code accepts a filename prefix and assumes both prefix.index and
108// prefix.data-* are present in the filesystem. So if we see export.index
109// present in the export_dir, we know the export is in V2 format and we return
110// <export_dir>/export as this prefix.
111string GetVariablesFilename(const StringPiece export_dir) {
112  const char kVariablesFilename[] = "export";
113  const string kVariablesIndexFilename = MetaFilename("export");  // V2 ckpts
114  const char kVariablesFilenamePattern[] = "export-\?\?\?\?\?-of-\?\?\?\?\?";
115  if (Env::Default()
116          ->FileExists(io::JoinPath(export_dir, kVariablesFilename))
117          .ok() ||
118      // This works for the case of V2 because the variables filename is taken
119      // as a prefix in the save/restore abstraction, and the index and actual
120      // variables are meant to be present as prefix.index and
121      // prefix.data-?????-of-?????.
122      Env::Default()
123          ->FileExists(io::JoinPath(export_dir, kVariablesIndexFilename))
124          .ok()) {
125    return io::JoinPath(export_dir, kVariablesFilename);
126  } else {
127    return io::JoinPath(export_dir, kVariablesFilenamePattern);
128  }
129}
130
131Status RunRestoreOp(const RunOptions& run_options, const StringPiece export_dir,
132                    const std::vector<AssetFile>& asset_files,
133                    const StringPiece restore_op_name,
134                    const StringPiece variables_filename_const_op_name,
135                    Session* session) {
136  LOG(INFO) << "Running restore op for SessionBundle: " << restore_op_name
137            << ", " << variables_filename_const_op_name;
138  Tensor variables_tensor =
139      CreateStringTensor(GetVariablesFilename(export_dir));
140  std::vector<std::pair<string, Tensor>> inputs = {
141      {variables_filename_const_op_name.ToString(), variables_tensor}};
142  AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
143  RunMetadata run_metadata;
144  return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
145                      nullptr /* outputs */, &run_metadata);
146}
147
148Status RunInitOp(const RunOptions& run_options, const StringPiece export_dir,
149                 const std::vector<AssetFile>& asset_files,
150                 const StringPiece init_op_name, Session* session) {
151  LOG(INFO) << "Running init op for SessionBundle";
152  std::vector<std::pair<string, Tensor>> inputs;
153  AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
154  RunMetadata run_metadata;
155  return session->Run(run_options, inputs, {}, {init_op_name.ToString()},
156                      nullptr /* outputs */, &run_metadata);
157}
158
159Status LoadSessionBundleFromPathUsingRunOptionsInternal(
160    const SessionOptions& options, const RunOptions& run_options,
161    const StringPiece export_dir, SessionBundle* const bundle) {
162  LOG(INFO) << "Attempting to load a SessionBundle from: " << export_dir;
163  LOG(INFO) << "Using RunOptions: " << DebugStringIfAvailable(run_options);
164  TF_RETURN_IF_ERROR(
165      GetMetaGraphDefFromExport(export_dir, &(bundle->meta_graph_def)));
166
167  // Deprecated SessionBundle models may fail to load because newly added
168  // attributes are not added to the Graph in the default Session initialization
169  // flow. Add an explicit call here when first loading the graph from disk.
170  TF_RETURN_IF_ERROR(
171      AddDefaultAttrsToGraphDef(bundle->meta_graph_def.mutable_graph_def(),
172                                *OpRegistry::Global(), 0 /* node_offset */));
173
174  const auto& collection_def_map = bundle->meta_graph_def.collection_def();
175  const auto graph_it = bundle->meta_graph_def.collection_def().find(kGraphKey);
176  if (graph_it != collection_def_map.end()) {
177    const CollectionDef& graph_collection_def = graph_it->second;
178    // Use serving graph_def in MetaGraphDef collection_def.
179    if (graph_collection_def.any_list().value_size() != 1) {
180      return errors::FailedPrecondition(
181          "Expected exactly one serving GraphDef in : ", export_dir);
182    }
183    const auto& any = graph_collection_def.any_list().value(0);
184    GraphDef graph_def;
185    TF_RETURN_IF_ERROR(ParseAny(any, &graph_def, "tensorflow.GraphDef"));
186    TF_RETURN_IF_ERROR(
187        CreateSessionFromGraphDef(options, graph_def, &bundle->session));
188  } else {
189    // Fallback to use the graph_def in the MetaGraphDef.
190    const GraphDef& graph_def = bundle->meta_graph_def.graph_def();
191    TF_RETURN_IF_ERROR(
192        CreateSessionFromGraphDef(options, graph_def, &bundle->session));
193  }
194
195  std::vector<AssetFile> asset_files;
196  const auto assets_it = collection_def_map.find(kAssetsKey);
197  if (assets_it != collection_def_map.end()) {
198    const auto& any_assets = assets_it->second.any_list().value();
199    for (const auto& any_asset : any_assets) {
200      AssetFile asset_file;
201      TF_RETURN_IF_ERROR(
202          ParseAny(any_asset, &asset_file, "tensorflow.serving.AssetFile"));
203      asset_files.push_back(asset_file);
204    }
205  }
206
207  TF_RETURN_IF_ERROR(
208      RunRestoreOp(run_options, export_dir, asset_files,
209                   bundle->meta_graph_def.saver_def().restore_op_name(),
210                   bundle->meta_graph_def.saver_def().filename_tensor_name(),
211                   bundle->session.get()));
212
213  const auto init_op_it = collection_def_map.find(kInitOpKey);
214  if (init_op_it != collection_def_map.end()) {
215    if (init_op_it->second.node_list().value_size() != 1) {
216      return errors::FailedPrecondition(strings::StrCat(
217          "Expected exactly one serving init op in : ", export_dir));
218    }
219    TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, asset_files,
220                                 init_op_it->second.node_list().value(0),
221                                 bundle->session.get()));
222  }
223
224  return Status::OK();
225}
226
227}  // namespace
228
229Status LoadSessionBundleFromPath(const SessionOptions& options,
230                                 const StringPiece export_dir,
231                                 SessionBundle* const bundle) {
232  TF_RETURN_IF_ERROR(LoadSessionBundleFromPathUsingRunOptions(
233      options, RunOptions(), export_dir, bundle));
234  return Status::OK();
235}
236
237Status LoadSessionBundleFromPathUsingRunOptions(const SessionOptions& options,
238                                                const RunOptions& run_options,
239                                                const StringPiece export_dir,
240                                                SessionBundle* const bundle) {
241  const uint64 start_microseconds = Env::Default()->NowMicros();
242  const Status status = LoadSessionBundleFromPathUsingRunOptionsInternal(
243      options, run_options, export_dir, bundle);
244
245  const uint64 load_latency_microsecs = [&]() -> uint64 {
246    const uint64 end_microseconds = Env::Default()->NowMicros();
247    // Avoid clock skew.
248    if (end_microseconds < start_microseconds) return 0;
249    return end_microseconds - start_microseconds;
250  }();
251  auto log_and_count = [&](const string& status_str) {
252    LOG(INFO) << "Loading SessionBundle: " << status_str << ". Took "
253              << load_latency_microsecs << " microseconds.";
254    load_attempt_count->GetCell(export_dir.ToString(), status_str)
255        ->IncrementBy(1);
256  };
257  if (status.ok()) {
258    log_and_count(kLoadAttemptSuccess);
259  } else {
260    log_and_count(kLoadAttemptFail);
261  }
262  load_latency->GetCell(export_dir.ToString())
263      ->IncrementBy(load_latency_microsecs);
264  return status;
265}
266
267bool IsPossibleExportDirectory(const StringPiece directory) {
268  const string meta_graph_def_path =
269      io::JoinPath(directory, kMetaGraphDefFilename);
270  return Env::Default()->FileExists(meta_graph_def_path).ok();
271}
272
273}  // namespace serving
274}  // namespace tensorflow
275