1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/contrib/session_bundle/bundle_shim.h"
17
18#include "tensorflow/cc/saved_model/loader.h"
19#include "tensorflow/cc/saved_model/signature_constants.h"
20#include "tensorflow/contrib/session_bundle/manifest.pb.h"
21#include "tensorflow/contrib/session_bundle/session_bundle.h"
22#include "tensorflow/contrib/session_bundle/signature.h"
23#include "tensorflow/core/graph/graph_constructor.h"
24#include "tensorflow/core/lib/core/errors.h"
25#include "tensorflow/core/lib/core/status.h"
26#include "tensorflow/core/lib/core/stringpiece.h"
27#include "tensorflow/core/protobuf/meta_graph.pb.h"
28#include "tensorflow/core/public/session.h"
29#include "tensorflow/core/public/session_options.h"
30
31namespace tensorflow {
32namespace serving {
33namespace {
34///////////////////////////////////////////////////////////////////////////////
35// Helper functions to check Signature type.
36
37bool IsClassificationSignature(const Signature& signature) {
38  return signature.type_case() == Signature::kClassificationSignature;
39}
40
41bool IsRegressionSignature(const Signature& signature) {
42  return signature.type_case() == Signature::kRegressionSignature;
43}
44
45///////////////////////////////////////////////////////////////////////////////
46// Helper functions to build `Classification`, `Regression` and `Predict`
47// SignatureDefs.
48
49SignatureDef BuildRegressionSignatureDef(
50    const RegressionSignature& regression_signature,
51    const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
52  SignatureDef signature_def;
53  signature_def.set_method_name(kRegressMethodName);
54  internal::AddInputToSignatureDef(regression_signature.input().tensor_name(),
55                                   tensor_name_to_dtype, kRegressInputs,
56                                   &signature_def);
57  internal::AddOutputToSignatureDef(regression_signature.output().tensor_name(),
58                                    tensor_name_to_dtype, kRegressOutputs,
59                                    &signature_def);
60  return signature_def;
61}
62
63SignatureDef BuildClassificationSignatureDef(
64    const ClassificationSignature& classification_signature,
65    const std::unordered_map<string, DataType>& tensor_name_to_dtype) {
66  SignatureDef signature_def;
67  signature_def.set_method_name(kClassifyMethodName);
68  internal::AddInputToSignatureDef(
69      classification_signature.input().tensor_name(), tensor_name_to_dtype,
70      kClassifyInputs, &signature_def);
71  internal::AddOutputToSignatureDef(
72      classification_signature.classes().tensor_name(), tensor_name_to_dtype,
73      kClassifyOutputClasses, &signature_def);
74  internal::AddOutputToSignatureDef(
75      classification_signature.scores().tensor_name(), tensor_name_to_dtype,
76      kClassifyOutputScores, &signature_def);
77  return signature_def;
78}
79
80Status MaybeBuildPredictSignatureDef(
81    const std::unordered_map<string, DataType>& tensor_name_to_dtype,
82    MetaGraphDef* meta_graph_def) {
83  Signature input_signature, output_signature;
84  // Ensure that named signatures corresponding to `inputs` and `outputs` keys
85  // exist.
86  if (!GetNamedSignature(kPredictInputs, *meta_graph_def, &input_signature)
87           .ok() ||
88      !GetNamedSignature(kPredictOutputs, *meta_graph_def, &output_signature)
89           .ok()) {
90    return Status(error::Code::INVALID_ARGUMENT,
91                  "Named signatures can only be up-converted if entries "
92                  "corresponding to both `inputs` and `outputs` exist.");
93  }
94  // Ensure the `inputs` and `outputs` named signatures are generic signatures.
95  if (input_signature.type_case() != Signature::TypeCase::kGenericSignature ||
96      output_signature.type_case() != Signature::TypeCase::kGenericSignature) {
97    return Status(error::Code::INVALID_ARGUMENT,
98                  "Named signatures corresponding to `inputs` and `outputs` "
99                  "can only be up-converted if they are GenericSignatures.");
100  }
101  SignatureDef signature_def;
102  signature_def.set_method_name(kPredictMethodName);
103  // Add map entries from the `inputs` generic signature to the input map in the
104  // signature def.
105  for (const auto& map_entry : input_signature.generic_signature().map()) {
106    internal::AddInputToSignatureDef(map_entry.second.tensor_name(),
107                                     tensor_name_to_dtype, map_entry.first,
108                                     &signature_def);
109  }
110  // Add map entries from the `outputs` generic signature to the output map in
111  // the signature def.
112  for (const auto& map_entry : output_signature.generic_signature().map()) {
113    internal::AddOutputToSignatureDef(map_entry.second.tensor_name(),
114                                      tensor_name_to_dtype, map_entry.first,
115                                      &signature_def);
116  }
117  // Add the constructed signature def to the signature def map of the meta
118  // graph def. Use the default key if it isn't already in use.
119  const bool already_has_default_signature =
120      meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) !=
121      meta_graph_def->signature_def().end();
122  const string signature_def_key =
123      already_has_default_signature
124          ? strings::StrCat(kDefaultServingSignatureDefKey, "_from_named")
125          : kDefaultServingSignatureDefKey;
126  (*meta_graph_def->mutable_signature_def())[signature_def_key] = signature_def;
127  return Status::OK();
128}
129
130Status LoadSavedModelFromLegacySessionBundlePath(
131    const SessionOptions& session_options, const RunOptions& run_options,
132    const StringPiece session_bundle_export_dir,
133    SavedModelBundle* saved_model_bundle) {
134  if (session_bundle_export_dir.empty()) {
135    return Status(error::Code::NOT_FOUND, "Export directory path is empty.");
136  }
137  if (!IsPossibleExportDirectory(session_bundle_export_dir)) {
138    return Status(
139        error::Code::NOT_FOUND,
140        "Export directory does not contain a valid SessionBundle export.");
141  }
142
143  // Build the session-bundle.
144  SessionBundle session_bundle;
145  TF_RETURN_IF_ERROR(LoadSessionBundleFromPathUsingRunOptions(
146      session_options, run_options, session_bundle_export_dir,
147      &session_bundle));
148
149  // Convert the session-bundle to a saved-model-bundle.
150  return internal::ConvertSessionBundleToSavedModelBundle(session_bundle,
151                                                          saved_model_bundle);
152}
153
154///////////////////////////////////////////////////////////////////////////////
155// Helper functions to convert `Default` and `Named` signatures to
156// SignatureDefs.
157
158// Up-conversion of default signatures is supported for classification and
159// regression.
160Status ConvertDefaultSignatureToSignatureDef(
161    const Signatures& signatures,
162    const std::unordered_map<string, DataType>& tensor_name_to_dtype,
163    MetaGraphDef* meta_graph_def) {
164  if (!signatures.has_default_signature()) {
165    return Status::OK();
166  }
167  const bool already_has_default_signature =
168      meta_graph_def->signature_def().find(kDefaultServingSignatureDefKey) !=
169      meta_graph_def->signature_def().end();
170  if (already_has_default_signature) {
171    return Status(error::Code::ALREADY_EXISTS,
172                  strings::StrCat(
173                      "Default signature cannot be up-converted since ",
174                      kDefaultServingSignatureDefKey, " key already exists."));
175  }
176  const Signature& signature = signatures.default_signature();
177  if (IsRegressionSignature(signature)) {
178    (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
179        BuildRegressionSignatureDef(signature.regression_signature(),
180                                    tensor_name_to_dtype);
181  } else if (IsClassificationSignature(signature)) {
182    (*meta_graph_def->mutable_signature_def())[kDefaultServingSignatureDefKey] =
183        BuildClassificationSignatureDef(signature.classification_signature(),
184                                        tensor_name_to_dtype);
185  } else {
186    LOG(WARNING) << "Default signature up-conversion to SignatureDef is only "
187                    "supported for `Classification` and `Regression`. Could "
188                    "not up-convert signature: "
189                 << signature.DebugString()
190                 << ". (If using SessionRun with the SessionBundle export "
191                    "format please ignore this warning.)";
192  }
193  return Status::OK();
194}
195
196Status ConvertNamedSignaturesToSignatureDef(
197    const Signatures& signatures,
198    const std::unordered_map<string, DataType>& tensor_name_to_dtype,
199    MetaGraphDef* meta_graph_def) {
200  if (signatures.named_signatures().empty()) {
201    return Status::OK();
202  }
203  // Check for a Predict signature for up-conversion.
204  Status predict_signature_def_status =
205      MaybeBuildPredictSignatureDef(tensor_name_to_dtype, meta_graph_def);
206  for (const auto& it_named_signature : signatures.named_signatures()) {
207    const string key = it_named_signature.first;
208    // If a Predict SignatureDef was successfully constructed, skip the entries
209    // corresponding to `inputs` and `outputs`.
210    if (predict_signature_def_status.ok()) {
211      if (key == kPredictInputs || key == kPredictOutputs) {
212        continue;
213      }
214    }
215    const Signature signature = it_named_signature.second;
216    if (IsRegressionSignature(signature)) {
217      (*meta_graph_def->mutable_signature_def())[key] =
218          BuildRegressionSignatureDef(signature.regression_signature(),
219                                      tensor_name_to_dtype);
220    } else if (IsClassificationSignature(signature)) {
221      (*meta_graph_def->mutable_signature_def())[key] =
222          BuildClassificationSignatureDef(signature.classification_signature(),
223                                          tensor_name_to_dtype);
224    } else {
225      LOG(WARNING)
226          << "Named signature up-conversion to SignatureDef is only supported "
227             "for `Classification`, `Regression` or if two `GenericSignatures` "
228             "signatures  called `inputs` and `outputs` exist, corresponding "
229             "to the `Prediction` API. Could not up-convert signature: "
230          << signature.DebugString();
231    }
232  }
233  return Status::OK();
234}
235
236}  // namespace
237
238namespace internal {
239///////////////////////////////////////////////////////////////////////////////
240// Helper functions to populate SignatureDef fields.
241
242// Adds an entry to the `inputs` map of the supplied SignatureDef.
243void AddInputToSignatureDef(
244    const string& tensor_name,
245    const std::unordered_map<string, DataType>& tensor_name_to_dtype,
246    const string& input_key, SignatureDef* signature_def) {
247  if (tensor_name.empty()) {
248    LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
249                    "SignatureDef inputs.";
250    return;
251  }
252  // Extract the tensor-name in case the supplied string is a tensor-reference.
253  // Example: Extract "x" from "x:0".
254  std::size_t pos = tensor_name.find(":");
255  const string key =
256      (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
257  const auto it_tensor_info = tensor_name_to_dtype.find(key);
258  TensorInfo tensor_info;
259  tensor_info.set_name(tensor_name);
260  if (it_tensor_info != tensor_name_to_dtype.end()) {
261    tensor_info.set_dtype(it_tensor_info->second);
262  } else {
263    LOG(WARNING)
264        << "No dtype found for tensor with name: " << tensor_name << ". "
265        << "Building TensorInfo with only name for SignatureDef inputs. "
266        << "Downstream functionality including validation may be "
267        << "impacted.";
268  }
269  (*signature_def->mutable_inputs())[input_key] = tensor_info;
270}
271
272// Adds an entry to the `outputs` map of the supplied SignatureDef.
273void AddOutputToSignatureDef(
274    const string& tensor_name,
275    const std::unordered_map<string, DataType>& tensor_name_to_dtype,
276    const string& output_key, SignatureDef* signature_def) {
277  if (tensor_name.empty()) {
278    LOG(WARNING) << "Tensor name not provided. Cannot add TensorInfo to "
279                    "SignatureDef outputs.";
280    return;
281  }
282  // Extract the tensor-name in case the supplied string is a tensor-reference.
283  // Example: Extract "x" from "x:0".
284  std::size_t pos = tensor_name.find(":");
285  const string key =
286      (pos != std::string::npos) ? tensor_name.substr(0, pos) : tensor_name;
287  const auto it_tensor_info = tensor_name_to_dtype.find(key);
288  TensorInfo tensor_info;
289  tensor_info.set_name(tensor_name);
290  if (it_tensor_info != tensor_name_to_dtype.end()) {
291    tensor_info.set_dtype(it_tensor_info->second);
292  } else {
293    LOG(WARNING)
294        << "No dtype found for tensor with name: " << tensor_name << ". "
295        << "Building TensorInfo with only name for SignatureDef outputs."
296        << " Downstream functionality including validation may be "
297        << "impacted.";
298  }
299  (*signature_def->mutable_outputs())[output_key] = tensor_info;
300}
301
302// Builds a map from tensor name to the corresponding datatype, by parsing the
303// MetaGraphDef.
304Status BuildTensorNameToDtypeMap(
305    const MetaGraphDef& meta_graph_def,
306    std::unordered_map<string, DataType>* tensor_name_to_dtype) {
307  GraphConstructorOptions opts;
308  Graph graph(OpRegistry::Global());
309  TF_RETURN_IF_ERROR(
310      ConvertGraphDefToGraph(opts, meta_graph_def.graph_def(), &graph));
311  for (Node* node : graph.nodes()) {
312    for (auto dt : node->output_types()) {
313      tensor_name_to_dtype->insert(std::make_pair(node->name(), dt));
314    }
315  }
316  return Status::OK();
317}
318
319// Converts SessionBundle signatures to SavedModel signature-defs.
320Status ConvertSignaturesToSignatureDefs(MetaGraphDef* meta_graph_def) {
321  Signatures signatures;
322  GetSignatures(*meta_graph_def, &signatures).IgnoreError();
323
324  // Build a map of tensor-names to the corresponding tensor-info with `name`
325  // and `dtype` fields.
326  std::unordered_map<string, DataType> tensor_name_to_dtype;
327  TF_RETURN_IF_ERROR(
328      BuildTensorNameToDtypeMap(*meta_graph_def, &tensor_name_to_dtype));
329
330  TF_RETURN_IF_ERROR(ConvertDefaultSignatureToSignatureDef(
331      signatures, tensor_name_to_dtype, meta_graph_def));
332  TF_RETURN_IF_ERROR(ConvertNamedSignaturesToSignatureDef(
333      signatures, tensor_name_to_dtype, meta_graph_def));
334  return Status::OK();
335}
336
337// Converts a SessionBundle to a SavedModelBundle.
338Status ConvertSessionBundleToSavedModelBundle(
339    SessionBundle& session_bundle, SavedModelBundle* saved_model_bundle) {
340  // Transfer ownership of the session from old to new.
341  saved_model_bundle->session = std::move(session_bundle.session);
342
343  // Copy the meta graph def from the SessionBundle to the SavedModelBundle.
344  saved_model_bundle->meta_graph_def = session_bundle.meta_graph_def;
345
346  // Convert signatures from session-bundle to signature-defs in
347  // saved-model-bundle.
348  return internal::ConvertSignaturesToSignatureDefs(
349      &saved_model_bundle->meta_graph_def);
350}
351
352}  // namespace internal
353
354Status LoadSessionBundleOrSavedModelBundle(
355    const SessionOptions& session_options, const RunOptions& run_options,
356    const string& export_dir,
357    const std::unordered_set<string>& saved_model_tags,
358    SavedModelBundle* saved_model_bundle) {
359  if (MaybeSavedModelDirectory(export_dir)) {
360    LOG(INFO)
361        << "Attempting to load native SavedModelBundle in bundle-shim from: "
362        << export_dir;
363    return LoadSavedModel(session_options, run_options, export_dir,
364                          saved_model_tags, saved_model_bundle);
365  } else if (IsPossibleExportDirectory(export_dir)) {
366    LOG(ERROR) << "Found possible SessionBundle in export directory. "
367                  "SessionBundle is deprecated. Use SavedModel instead.";
368    LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
369                 "in bundle-shim from: "
370              << export_dir;
371    return LoadSavedModelFromLegacySessionBundlePath(
372        session_options, run_options, export_dir, saved_model_bundle);
373  }
374  return Status(
375      error::Code::NOT_FOUND,
376      strings::StrCat(
377          "Specified file path does not appear to contain a:\n"
378          "- Session bundle (should have a file called `export.meta`)\n"
379          "- or, SavedModel bundle (should have a file called "
380          "`saved_model.pb`)\n"
381          "Specified file path: ",
382          export_dir));
383}
384
385}  // namespace serving
386}  // namespace tensorflow
387