1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/contrib/session_bundle/bundle_shim.h"
17
18#include "google/protobuf/any.pb.h"
19#include "tensorflow/cc/saved_model/signature_constants.h"
20#include "tensorflow/cc/saved_model/tag_constants.h"
21#include "tensorflow/contrib/session_bundle/test_util.h"
22#include "tensorflow/core/example/example.pb.h"
23#include "tensorflow/core/example/feature.pb.h"
24#include "tensorflow/core/framework/tensor_testutil.h"
25#include "tensorflow/core/lib/core/status_test_util.h"
26#include "tensorflow/core/lib/io/path.h"
27#include "tensorflow/core/protobuf/meta_graph.pb.h"
28
29namespace tensorflow {
30namespace serving {
31namespace internal {
32namespace {
33
34constexpr char kSessionBundlePath[] =
35    "session_bundle/testdata/half_plus_two/00000123";
36constexpr char kSavedModelBundlePath[] =
37    "cc/saved_model/testdata/half_plus_two/00000123";
38
39string MakeSerializedExample(float x) {
40  tensorflow::Example example;
41  auto* feature_map = example.mutable_features()->mutable_feature();
42  (*feature_map)["x"].mutable_float_list()->add_value(x);
43  return example.SerializeAsString();
44}
45
46void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle,
47                         const string& input_tensor_name,
48                         const string& output_tensor_name) {
49  // Validate the half plus two behavior.
50  std::vector<string> serialized_examples;
51  for (float x : {0, 1, 2, 3}) {
52    serialized_examples.push_back(MakeSerializedExample(x));
53  }
54  Tensor input = test::AsTensor<string>(serialized_examples, TensorShape({4}));
55
56  std::vector<Tensor> outputs;
57  TF_ASSERT_OK(saved_model_bundle.session->Run(
58      {{input_tensor_name, input}}, {output_tensor_name}, {}, &outputs));
59  ASSERT_EQ(outputs.size(), 1);
60  test::ExpectTensorEqual<float>(
61      outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
62}
63
64void LoadAndValidateSavedModelBundle(const string& export_dir,
65                                     const std::unordered_set<string>& tags,
66                                     const string& signature_def_key) {
67  SessionOptions session_options;
68  RunOptions run_options;
69  SavedModelBundle saved_model_bundle;
70  TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
71      session_options, run_options, export_dir, tags, &saved_model_bundle));
72  const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
73  const auto& signature_def_map = meta_graph_def.signature_def();
74
75  const auto& regression_entry = signature_def_map.find(signature_def_key);
76  ASSERT_FALSE(regression_entry == signature_def_map.end());
77  SignatureDef regression_signature_def = regression_entry->second;
78
79  EXPECT_EQ(1, regression_signature_def.inputs_size());
80  ASSERT_FALSE(regression_signature_def.inputs().find(kRegressInputs) ==
81               regression_signature_def.inputs().end());
82  TensorInfo input_tensor_info =
83      regression_signature_def.inputs().find(kRegressInputs)->second;
84  EXPECT_EQ(1, regression_signature_def.outputs_size());
85  // Ensure the TensorInfo has dtype populated.
86  EXPECT_EQ(DT_STRING, input_tensor_info.dtype());
87
88  ASSERT_FALSE(regression_signature_def.outputs().find(kRegressOutputs) ==
89               regression_signature_def.outputs().end());
90  TensorInfo output_tensor_info =
91      regression_signature_def.outputs().find(kRegressOutputs)->second;
92  // Ensure the TensorInfo has dtype populated.
93  EXPECT_EQ(DT_FLOAT, output_tensor_info.dtype());
94  ValidateHalfPlusTwo(saved_model_bundle, input_tensor_info.name(),
95                      output_tensor_info.name());
96}
97
98// Helper function to validate that the SignatureDef found in the MetaGraphDef
99// with the provided key has the expected string representation.
100void ValidateSignatureDef(const MetaGraphDef& meta_graph_def, const string& key,
101                          const string& expected_string_signature_def) {
102  tensorflow::SignatureDef expected_signature;
103  CHECK(protobuf::TextFormat::ParseFromString(expected_string_signature_def,
104                                              &expected_signature));
105  auto iter = meta_graph_def.signature_def().find(key);
106  ASSERT_TRUE(iter != meta_graph_def.signature_def().end());
107  EXPECT_EQ(expected_signature.DebugString(), iter->second.DebugString());
108}
109
110// Checks that the input map in a signature def is populated correctly.
111TEST(BundleShimTest, AddInputToSignatureDef) {
112  SignatureDef signature_def;
113  const string tensor_name = "foo_tensor";
114  const string map_key = "foo_key";
115
116  // Build a map of tensor-name to dtype, for the unit-test.
117  std::unordered_map<string, DataType> tensor_name_to_dtype;
118  tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
119
120  AddInputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
121                         &signature_def);
122  EXPECT_EQ(1, signature_def.inputs_size());
123  EXPECT_EQ(tensor_name, signature_def.inputs().find(map_key)->second.name());
124}
125
126// Checks that the output map in a signature def is populated correctly.
127TEST(BundleShimTest, AddOutputToSignatureDef) {
128  SignatureDef signature_def;
129  const string tensor_name = "foo_tensor";
130  const string map_key = "foo_key";
131
132  // Build a map of tensor-name to dtype, for the unit-test.
133  std::unordered_map<string, DataType> tensor_name_to_dtype;
134  tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING;
135
136  AddOutputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key,
137                          &signature_def);
138  EXPECT_EQ(1, signature_def.outputs_size());
139  EXPECT_EQ(tensor_name, signature_def.outputs().find(map_key)->second.name());
140}
141
142// Checks that no signature defs are added if the default signature is missing.
143TEST(BundleShimTest, DefaultSignatureMissing) {
144  MetaGraphDef meta_graph_def;
145  // Signatures signatures;
146  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
147  EXPECT_EQ(0, meta_graph_def.signature_def_size());
148}
149
150// Checks that no signature defs are added if the default signature is empty.
151TEST(BundleShimTest, DefaultSignatureEmpty) {
152  Signatures signatures;
153  signatures.mutable_default_signature();
154
155  MetaGraphDef meta_graph_def;
156  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
157      .mutable_any_list()
158      ->add_value()
159      ->PackFrom(signatures);
160  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
161  EXPECT_EQ(0, meta_graph_def.signature_def_size());
162}
163
164// Checks the conversion to signature def for a regression default signature.
165TEST(BundleShimTest, DefaultSignatureRegression) {
166  Signatures signatures;
167  RegressionSignature* regression_signature =
168      signatures.mutable_default_signature()->mutable_regression_signature();
169  regression_signature->mutable_input()->set_tensor_name("foo-input");
170  regression_signature->mutable_output()->set_tensor_name("foo-output");
171  MetaGraphDef meta_graph_def;
172  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
173      .mutable_any_list()
174      ->add_value()
175      ->PackFrom(signatures);
176  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
177  EXPECT_EQ(1, meta_graph_def.signature_def_size());
178  const auto actual_signature_def =
179      meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
180  EXPECT_EQ("foo-input", actual_signature_def->second.inputs()
181                             .find(kRegressInputs)
182                             ->second.name());
183  EXPECT_EQ("foo-output", actual_signature_def->second.outputs()
184                              .find(kRegressOutputs)
185                              ->second.name());
186  EXPECT_EQ(kRegressMethodName, actual_signature_def->second.method_name());
187}
188
189// Checks the conversion to signature def for a classification default
190// signature.
191TEST(BundleShimTest, DefaultSignatureClassification) {
192  Signatures signatures;
193  ClassificationSignature* classification_signature =
194      signatures.mutable_default_signature()
195          ->mutable_classification_signature();
196  classification_signature->mutable_input()->set_tensor_name("foo-input");
197  classification_signature->mutable_classes()->set_tensor_name("foo-classes");
198  classification_signature->mutable_scores()->set_tensor_name("foo-scores");
199  MetaGraphDef meta_graph_def;
200  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
201      .mutable_any_list()
202      ->add_value()
203      ->PackFrom(signatures);
204  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
205  EXPECT_EQ(1, meta_graph_def.signature_def_size());
206  const auto actual_signature_def =
207      meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
208  EXPECT_EQ("foo-input", actual_signature_def->second.inputs()
209                             .find(kClassifyInputs)
210                             ->second.name());
211  EXPECT_EQ("foo-classes", actual_signature_def->second.outputs()
212                               .find(kClassifyOutputClasses)
213                               ->second.name());
214  EXPECT_EQ("foo-scores", actual_signature_def->second.outputs()
215                              .find(kClassifyOutputScores)
216                              ->second.name());
217  EXPECT_EQ(kClassifyMethodName, actual_signature_def->second.method_name());
218}
219
220// Checks that generic default signatures are not up converted.
221TEST(BundleShimTest, DefaultSignatureGeneric) {
222  TensorBinding input_binding;
223  input_binding.set_tensor_name("foo-input");
224
225  TensorBinding output_binding;
226  output_binding.set_tensor_name("foo-output");
227
228  Signatures signatures;
229  GenericSignature* generic_signature =
230      signatures.mutable_default_signature()->mutable_generic_signature();
231  generic_signature->mutable_map()->insert({kPredictInputs, input_binding});
232  generic_signature->mutable_map()->insert({kPredictOutputs, output_binding});
233
234  MetaGraphDef meta_graph_def;
235  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
236      .mutable_any_list()
237      ->add_value()
238      ->PackFrom(signatures);
239  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
240  EXPECT_EQ(0, meta_graph_def.signature_def_size());
241}
242
243TEST(BundleShimTest, NamedRegressionSignatures) {
244  Signatures signatures;
245
246  RegressionSignature* foo_regression_signature =
247      (*signatures.mutable_named_signatures())["foo"]
248          .mutable_regression_signature();
249  foo_regression_signature->mutable_input()->set_tensor_name("foo-input");
250  foo_regression_signature->mutable_output()->set_tensor_name("foo-output");
251
252  RegressionSignature* bar_regression_signature =
253      (*signatures.mutable_named_signatures())["bar"]
254          .mutable_regression_signature();
255  bar_regression_signature->mutable_input()->set_tensor_name("bar-input");
256  bar_regression_signature->mutable_output()->set_tensor_name("bar-output");
257
258  MetaGraphDef meta_graph_def;
259  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
260      .mutable_any_list()
261      ->add_value()
262      ->PackFrom(signatures);
263  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
264  ASSERT_EQ(2, meta_graph_def.signature_def_size());
265
266  ValidateSignatureDef(meta_graph_def, "foo",
267                       "inputs { "
268                       "  key: \"inputs\" "
269                       "  value { "
270                       "name: \"foo-input\" "
271                       "  } "
272                       "} "
273                       "outputs { "
274                       "  key: \"outputs\" "
275                       "  value { "
276                       "    name: \"foo-output\" "
277                       "  } "
278                       "} "
279                       "method_name: \"tensorflow/serving/regress\" ");
280  ValidateSignatureDef(meta_graph_def, "bar",
281                       "inputs { "
282                       "  key: \"inputs\" "
283                       "  value { "
284                       "name: \"bar-input\" "
285                       "  } "
286                       "} "
287                       "outputs { "
288                       "  key: \"outputs\" "
289                       "  value { "
290                       "    name: \"bar-output\" "
291                       "  } "
292                       "} "
293                       "method_name: \"tensorflow/serving/regress\" ");
294}
295
296TEST(BundleShimTest, NamedClassificationSignatures) {
297  Signatures signatures;
298
299  ClassificationSignature* foo_classification_signature =
300      (*signatures.mutable_named_signatures())["foo"]
301          .mutable_classification_signature();
302  foo_classification_signature->mutable_input()->set_tensor_name("foo-input");
303  foo_classification_signature->mutable_classes()->set_tensor_name(
304      "foo-classes");
305
306  ClassificationSignature* bar_classification_signature =
307      (*signatures.mutable_named_signatures())["bar"]
308          .mutable_classification_signature();
309  bar_classification_signature->mutable_input()->set_tensor_name("bar-input");
310  bar_classification_signature->mutable_scores()->set_tensor_name("bar-scores");
311
312  MetaGraphDef meta_graph_def;
313  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
314      .mutable_any_list()
315      ->add_value()
316      ->PackFrom(signatures);
317  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
318  ASSERT_EQ(2, meta_graph_def.signature_def_size());
319
320  ValidateSignatureDef(meta_graph_def, "foo",
321                       "inputs { "
322                       "  key: \"inputs\" "
323                       "  value { "
324                       "name: \"foo-input\" "
325                       "  } "
326                       "} "
327                       "outputs { "
328                       "  key: \"classes\" "
329                       "  value { "
330                       "    name: \"foo-classes\" "
331                       "  } "
332                       "} "
333                       "method_name: \"tensorflow/serving/classify\" ");
334  ValidateSignatureDef(meta_graph_def, "bar",
335                       "inputs { "
336                       "  key: \"inputs\" "
337                       "  value { "
338                       "name: \"bar-input\" "
339                       "  } "
340                       "} "
341                       "outputs { "
342                       "  key: \"scores\" "
343                       "  value { "
344                       "    name: \"bar-scores\" "
345                       "  } "
346                       "} "
347                       "method_name: \"tensorflow/serving/classify\" ");
348}
349
350// Checks the Predict SignatureDef created when the named signatures have
351// `inputs` and `outputs`.
352TEST(BundleShimTest, NamedSignatureGenericInputsAndOutputs) {
353  TensorBinding input_binding;
354  input_binding.set_tensor_name("foo-input");
355
356  TensorBinding output_binding;
357  output_binding.set_tensor_name("foo-output");
358
359  Signatures signatures;
360  GenericSignature* input_generic_signature =
361      (*signatures.mutable_named_signatures())[kPredictInputs]
362          .mutable_generic_signature();
363  input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
364
365  GenericSignature* output_generic_signature =
366      (*signatures.mutable_named_signatures())[kPredictOutputs]
367          .mutable_generic_signature();
368  output_generic_signature->mutable_map()->insert(
369      {"foo-output", output_binding});
370
371  MetaGraphDef meta_graph_def;
372  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
373      .mutable_any_list()
374      ->add_value()
375      ->PackFrom(signatures);
376  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
377  EXPECT_EQ(1, meta_graph_def.signature_def_size());
378  const auto actual_signature_def =
379      meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
380  ASSERT_FALSE(actual_signature_def == meta_graph_def.signature_def().end());
381  ASSERT_FALSE(actual_signature_def->second.inputs().find("foo-input") ==
382               actual_signature_def->second.inputs().end());
383  EXPECT_EQ(
384      "foo-input",
385      actual_signature_def->second.inputs().find("foo-input")->second.name());
386  ASSERT_FALSE(actual_signature_def->second.outputs().find("foo-output") ==
387               actual_signature_def->second.outputs().end());
388  EXPECT_EQ(
389      "foo-output",
390      actual_signature_def->second.outputs().find("foo-output")->second.name());
391  EXPECT_EQ(kPredictMethodName, actual_signature_def->second.method_name());
392}
393
394// Checks that a signature def is not added if the named signatures is generic
395// but does not have `inputs` and `outputs`.
396TEST(BundleShimTest, NamedSignatureGenericNoInputsOrOutputs) {
397  TensorBinding input_binding;
398  input_binding.set_tensor_name("foo-input");
399
400  TensorBinding output_binding;
401  output_binding.set_tensor_name("foo-output");
402
403  Signatures signatures;
404  GenericSignature* generic_signature =
405      (*signatures.mutable_named_signatures())["unknown"]
406          .mutable_generic_signature();
407  generic_signature->mutable_map()->insert({kPredictInputs, input_binding});
408  generic_signature->mutable_map()->insert({kPredictOutputs, output_binding});
409
410  MetaGraphDef meta_graph_def;
411  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
412      .mutable_any_list()
413      ->add_value()
414      ->PackFrom(signatures);
415  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
416  EXPECT_EQ(0, meta_graph_def.signature_def_size());
417}
418
419// Checks that a signature def is not added when the named signatures have only
420// one of `inputs` and `outputs`.
421TEST(BundleShimTest, NamedSignatureGenericOnlyInput) {
422  TensorBinding input_binding;
423  input_binding.set_tensor_name("foo-input");
424
425  Signatures signatures;
426  GenericSignature* input_generic_signature =
427      (*signatures.mutable_named_signatures())[kPredictInputs]
428          .mutable_generic_signature();
429  input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
430
431  MetaGraphDef meta_graph_def;
432  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
433      .mutable_any_list()
434      ->add_value()
435      ->PackFrom(signatures);
436  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
437  EXPECT_EQ(0, meta_graph_def.signature_def_size());
438}
439
440// Tests up-conversion of Signatures to SignatureDefs when both `default` and
441// `named` signatures are present.
442TEST(BundleShimTest, DefaultAndNamedSignatureWithPredict) {
443  Signatures signatures;
444
445  // Build a generic signature corresponding to `inputs` and add it to the
446  // Signatures to up-convert.
447  TensorBinding input_binding;
448  input_binding.set_tensor_name("foo-input");
449  GenericSignature* input_generic_signature =
450      (*signatures.mutable_named_signatures())[kPredictInputs]
451          .mutable_generic_signature();
452  input_generic_signature->mutable_map()->insert({"foo-input", input_binding});
453
454  // Build a generic signature corresponding to `outputs` and add it to the
455  // Signatures to up-convert.
456  TensorBinding output_binding;
457  output_binding.set_tensor_name("foo-output");
458  GenericSignature* output_generic_signature =
459      (*signatures.mutable_named_signatures())[kPredictOutputs]
460          .mutable_generic_signature();
461  output_generic_signature->mutable_map()->insert(
462      {"foo-output", output_binding});
463
464  // Build a regression signature and set it as the default signature.
465  RegressionSignature* inputs_regression_signature =
466      (*signatures.mutable_default_signature()).mutable_regression_signature();
467  inputs_regression_signature->mutable_input()->set_tensor_name("bar-input");
468
469  // Up-convert the available signatures to SignatureDefs.
470  MetaGraphDef meta_graph_def;
471  (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
472      .mutable_any_list()
473      ->add_value()
474      ->PackFrom(signatures);
475  TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def));
476  EXPECT_EQ(2, meta_graph_def.signature_def_size());
477
478  // Verify that the default regression signature is converted to a
479  // SignatureDef that corresponds to the kDefaultServingSignatureDefKey.
480  const auto actual_signature_def_regress =
481      meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey);
482  ASSERT_FALSE(actual_signature_def_regress ==
483               meta_graph_def.signature_def().end());
484  ASSERT_FALSE(
485      actual_signature_def_regress->second.inputs().find(kRegressInputs) ==
486      actual_signature_def_regress->second.inputs().end());
487
488  // Verify that the `Predict` SignatureDef is created under a different key.
489  const auto actual_signature_def_predict = meta_graph_def.signature_def().find(
490      strings::StrCat(kDefaultServingSignatureDefKey, "_from_named"));
491  ASSERT_FALSE(actual_signature_def_predict ==
492               meta_graph_def.signature_def().end());
493  ASSERT_FALSE(
494      actual_signature_def_predict->second.inputs().find("foo-input") ==
495      actual_signature_def_predict->second.inputs().end());
496  EXPECT_EQ("foo-input", actual_signature_def_predict->second.inputs()
497                             .find("foo-input")
498                             ->second.name());
499  ASSERT_FALSE(
500      actual_signature_def_predict->second.outputs().find("foo-output") ==
501      actual_signature_def_predict->second.outputs().end());
502  EXPECT_EQ("foo-output", actual_signature_def_predict->second.outputs()
503                              .find("foo-output")
504                              ->second.name());
505  EXPECT_EQ(kPredictMethodName,
506            actual_signature_def_predict->second.method_name());
507}
508
509// Checks a basic up conversion for half plus two for SessionBundle.
510TEST(BundleShimTest, BasicExportSessionBundle) {
511  const std::unordered_set<string> tags = {"tag"};
512  const string session_bundle_export_dir =
513      test_util::TestSrcDirPath(kSessionBundlePath);
514  LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
515                                  kDefaultServingSignatureDefKey);
516
517  // Verify that the named signature is also present.
518  SessionOptions session_options;
519  RunOptions run_options;
520  SavedModelBundle saved_model_bundle;
521  TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
522                                                   session_bundle_export_dir,
523                                                   tags, &saved_model_bundle));
524  const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
525  const auto& signature_def_map = meta_graph_def.signature_def();
526  bool found_named_signature = false;
527  for (const auto& entry : signature_def_map) {
528    const string& key = entry.first;
529    const SignatureDef& signature_def = entry.second;
530
531    // We're looking for the key that is *not* kDefaultServingSignatureDefKey.
532    if (key == kDefaultServingSignatureDefKey) {
533      continue;
534    }
535    found_named_signature = true;
536
537    EXPECT_EQ(1, signature_def.inputs_size());
538    const auto it_inputs_x = signature_def.inputs().find("x");
539    EXPECT_FALSE(it_inputs_x == signature_def.inputs().end());
540    // Ensure the TensorInfo has name and dtype populated.
541    const TensorInfo& tensor_info_x = it_inputs_x->second;
542    EXPECT_EQ("x:0", tensor_info_x.name());
543    EXPECT_EQ(DT_FLOAT, tensor_info_x.dtype());
544
545    EXPECT_EQ(1, signature_def.outputs_size());
546    const auto it_outputs_y = signature_def.outputs().find("y");
547    EXPECT_FALSE(it_outputs_y == signature_def.outputs().end());
548    // Ensure the TensorInfo has name and dtype populated.
549    const TensorInfo& tensor_info_y = it_outputs_y->second;
550    EXPECT_EQ("y:0", tensor_info_y.name());
551    EXPECT_EQ(DT_FLOAT, tensor_info_y.dtype());
552  }
553  EXPECT_TRUE(found_named_signature);
554}
555
556// Checks a basic load for half plus two for SavedModelBundle.
557TEST(BundleShimTest, BasicExportSavedModel) {
558  const string saved_model_bundle_export_dir =
559      io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
560  LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
561                                  {kSavedModelTagServe}, "regress_x_to_y");
562}
563
564// Checks a basic load fails with an invalid export path.
565TEST(BundleShimTest, InvalidPath) {
566  const string invalid_export_dir = testing::TensorFlowSrcRoot();
567  SessionOptions session_options;
568  RunOptions run_options;
569  SavedModelBundle saved_model_bundle;
570  Status status = LoadSessionBundleOrSavedModelBundle(
571      session_options, run_options, invalid_export_dir, {kSavedModelTagServe},
572      &saved_model_bundle);
573  EXPECT_EQ(error::Code::NOT_FOUND, status.code());
574}
575
576// Checks that if loading a session bundle fails, the error is propagated to
577// LoadSessionBundleOrSavedModelBundle().
578TEST(BundleShimTest, LoadSessionBundleError) {
579  const string session_bundle_export_dir =
580      test_util::TestSrcDirPath(kSessionBundlePath);
581  SessionOptions session_options;
582  RunOptions run_options;
583  // Invalid threadpool index to use for session-run calls.
584  run_options.set_inter_op_thread_pool(100);
585  SavedModelBundle saved_model_bundle;
586  EXPECT_FALSE(LoadSessionBundleOrSavedModelBundle(session_options, run_options,
587                                                   session_bundle_export_dir,
588                                                   {"tag"}, &saved_model_bundle)
589                   .ok());
590}
591
592}  // namespace
593}  // namespace internal
594}  // namespace serving
595}  // namespace tensorflow
596