1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/contrib/session_bundle/bundle_shim.h" 17 18#include "google/protobuf/any.pb.h" 19#include "tensorflow/cc/saved_model/signature_constants.h" 20#include "tensorflow/cc/saved_model/tag_constants.h" 21#include "tensorflow/contrib/session_bundle/test_util.h" 22#include "tensorflow/core/example/example.pb.h" 23#include "tensorflow/core/example/feature.pb.h" 24#include "tensorflow/core/framework/tensor_testutil.h" 25#include "tensorflow/core/lib/core/status_test_util.h" 26#include "tensorflow/core/lib/io/path.h" 27#include "tensorflow/core/protobuf/meta_graph.pb.h" 28 29namespace tensorflow { 30namespace serving { 31namespace internal { 32namespace { 33 34constexpr char kSessionBundlePath[] = 35 "session_bundle/testdata/half_plus_two/00000123"; 36constexpr char kSavedModelBundlePath[] = 37 "cc/saved_model/testdata/half_plus_two/00000123"; 38 39string MakeSerializedExample(float x) { 40 tensorflow::Example example; 41 auto* feature_map = example.mutable_features()->mutable_feature(); 42 (*feature_map)["x"].mutable_float_list()->add_value(x); 43 return example.SerializeAsString(); 44} 45 46void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle, 47 const string& input_tensor_name, 48 const string& output_tensor_name) { 49 // Validate the half plus two behavior. 50 std::vector<string> serialized_examples; 51 for (float x : {0, 1, 2, 3}) { 52 serialized_examples.push_back(MakeSerializedExample(x)); 53 } 54 Tensor input = test::AsTensor<string>(serialized_examples, TensorShape({4})); 55 56 std::vector<Tensor> outputs; 57 TF_ASSERT_OK(saved_model_bundle.session->Run( 58 {{input_tensor_name, input}}, {output_tensor_name}, {}, &outputs)); 59 ASSERT_EQ(outputs.size(), 1); 60 test::ExpectTensorEqual<float>( 61 outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1}))); 62} 63 64void LoadAndValidateSavedModelBundle(const string& export_dir, 65 const std::unordered_set<string>& tags, 66 const string& signature_def_key) { 67 SessionOptions session_options; 68 RunOptions run_options; 69 SavedModelBundle saved_model_bundle; 70 TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle( 71 session_options, run_options, export_dir, tags, &saved_model_bundle)); 72 const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def; 73 const auto& signature_def_map = meta_graph_def.signature_def(); 74 75 const auto& regression_entry = signature_def_map.find(signature_def_key); 76 ASSERT_FALSE(regression_entry == signature_def_map.end()); 77 SignatureDef regression_signature_def = regression_entry->second; 78 79 EXPECT_EQ(1, regression_signature_def.inputs_size()); 80 ASSERT_FALSE(regression_signature_def.inputs().find(kRegressInputs) == 81 regression_signature_def.inputs().end()); 82 TensorInfo input_tensor_info = 83 regression_signature_def.inputs().find(kRegressInputs)->second; 84 EXPECT_EQ(1, regression_signature_def.outputs_size()); 85 // Ensure the TensorInfo has dtype populated. 86 EXPECT_EQ(DT_STRING, input_tensor_info.dtype()); 87 88 ASSERT_FALSE(regression_signature_def.outputs().find(kRegressOutputs) == 89 regression_signature_def.outputs().end()); 90 TensorInfo output_tensor_info = 91 regression_signature_def.outputs().find(kRegressOutputs)->second; 92 // Ensure the TensorInfo has dtype populated. 93 EXPECT_EQ(DT_FLOAT, output_tensor_info.dtype()); 94 ValidateHalfPlusTwo(saved_model_bundle, input_tensor_info.name(), 95 output_tensor_info.name()); 96} 97 98// Helper function to validate that the SignatureDef found in the MetaGraphDef 99// with the provided key has the expected string representation. 100void ValidateSignatureDef(const MetaGraphDef& meta_graph_def, const string& key, 101 const string& expected_string_signature_def) { 102 tensorflow::SignatureDef expected_signature; 103 CHECK(protobuf::TextFormat::ParseFromString(expected_string_signature_def, 104 &expected_signature)); 105 auto iter = meta_graph_def.signature_def().find(key); 106 ASSERT_TRUE(iter != meta_graph_def.signature_def().end()); 107 EXPECT_EQ(expected_signature.DebugString(), iter->second.DebugString()); 108} 109 110// Checks that the input map in a signature def is populated correctly. 111TEST(BundleShimTest, AddInputToSignatureDef) { 112 SignatureDef signature_def; 113 const string tensor_name = "foo_tensor"; 114 const string map_key = "foo_key"; 115 116 // Build a map of tensor-name to dtype, for the unit-test. 117 std::unordered_map<string, DataType> tensor_name_to_dtype; 118 tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING; 119 120 AddInputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key, 121 &signature_def); 122 EXPECT_EQ(1, signature_def.inputs_size()); 123 EXPECT_EQ(tensor_name, signature_def.inputs().find(map_key)->second.name()); 124} 125 126// Checks that the output map in a signature def is populated correctly. 127TEST(BundleShimTest, AddOutputToSignatureDef) { 128 SignatureDef signature_def; 129 const string tensor_name = "foo_tensor"; 130 const string map_key = "foo_key"; 131 132 // Build a map of tensor-name to dtype, for the unit-test. 133 std::unordered_map<string, DataType> tensor_name_to_dtype; 134 tensor_name_to_dtype[tensor_name] = tensorflow::DT_STRING; 135 136 AddOutputToSignatureDef(tensor_name, tensor_name_to_dtype, map_key, 137 &signature_def); 138 EXPECT_EQ(1, signature_def.outputs_size()); 139 EXPECT_EQ(tensor_name, signature_def.outputs().find(map_key)->second.name()); 140} 141 142// Checks that no signature defs are added if the default signature is missing. 143TEST(BundleShimTest, DefaultSignatureMissing) { 144 MetaGraphDef meta_graph_def; 145 // Signatures signatures; 146 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 147 EXPECT_EQ(0, meta_graph_def.signature_def_size()); 148} 149 150// Checks that no signature defs are added if the default signature is empty. 151TEST(BundleShimTest, DefaultSignatureEmpty) { 152 Signatures signatures; 153 signatures.mutable_default_signature(); 154 155 MetaGraphDef meta_graph_def; 156 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 157 .mutable_any_list() 158 ->add_value() 159 ->PackFrom(signatures); 160 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 161 EXPECT_EQ(0, meta_graph_def.signature_def_size()); 162} 163 164// Checks the conversion to signature def for a regression default signature. 165TEST(BundleShimTest, DefaultSignatureRegression) { 166 Signatures signatures; 167 RegressionSignature* regression_signature = 168 signatures.mutable_default_signature()->mutable_regression_signature(); 169 regression_signature->mutable_input()->set_tensor_name("foo-input"); 170 regression_signature->mutable_output()->set_tensor_name("foo-output"); 171 MetaGraphDef meta_graph_def; 172 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 173 .mutable_any_list() 174 ->add_value() 175 ->PackFrom(signatures); 176 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 177 EXPECT_EQ(1, meta_graph_def.signature_def_size()); 178 const auto actual_signature_def = 179 meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey); 180 EXPECT_EQ("foo-input", actual_signature_def->second.inputs() 181 .find(kRegressInputs) 182 ->second.name()); 183 EXPECT_EQ("foo-output", actual_signature_def->second.outputs() 184 .find(kRegressOutputs) 185 ->second.name()); 186 EXPECT_EQ(kRegressMethodName, actual_signature_def->second.method_name()); 187} 188 189// Checks the conversion to signature def for a classification default 190// signature. 191TEST(BundleShimTest, DefaultSignatureClassification) { 192 Signatures signatures; 193 ClassificationSignature* classification_signature = 194 signatures.mutable_default_signature() 195 ->mutable_classification_signature(); 196 classification_signature->mutable_input()->set_tensor_name("foo-input"); 197 classification_signature->mutable_classes()->set_tensor_name("foo-classes"); 198 classification_signature->mutable_scores()->set_tensor_name("foo-scores"); 199 MetaGraphDef meta_graph_def; 200 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 201 .mutable_any_list() 202 ->add_value() 203 ->PackFrom(signatures); 204 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 205 EXPECT_EQ(1, meta_graph_def.signature_def_size()); 206 const auto actual_signature_def = 207 meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey); 208 EXPECT_EQ("foo-input", actual_signature_def->second.inputs() 209 .find(kClassifyInputs) 210 ->second.name()); 211 EXPECT_EQ("foo-classes", actual_signature_def->second.outputs() 212 .find(kClassifyOutputClasses) 213 ->second.name()); 214 EXPECT_EQ("foo-scores", actual_signature_def->second.outputs() 215 .find(kClassifyOutputScores) 216 ->second.name()); 217 EXPECT_EQ(kClassifyMethodName, actual_signature_def->second.method_name()); 218} 219 220// Checks that generic default signatures are not up converted. 221TEST(BundleShimTest, DefaultSignatureGeneric) { 222 TensorBinding input_binding; 223 input_binding.set_tensor_name("foo-input"); 224 225 TensorBinding output_binding; 226 output_binding.set_tensor_name("foo-output"); 227 228 Signatures signatures; 229 GenericSignature* generic_signature = 230 signatures.mutable_default_signature()->mutable_generic_signature(); 231 generic_signature->mutable_map()->insert({kPredictInputs, input_binding}); 232 generic_signature->mutable_map()->insert({kPredictOutputs, output_binding}); 233 234 MetaGraphDef meta_graph_def; 235 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 236 .mutable_any_list() 237 ->add_value() 238 ->PackFrom(signatures); 239 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 240 EXPECT_EQ(0, meta_graph_def.signature_def_size()); 241} 242 243TEST(BundleShimTest, NamedRegressionSignatures) { 244 Signatures signatures; 245 246 RegressionSignature* foo_regression_signature = 247 (*signatures.mutable_named_signatures())["foo"] 248 .mutable_regression_signature(); 249 foo_regression_signature->mutable_input()->set_tensor_name("foo-input"); 250 foo_regression_signature->mutable_output()->set_tensor_name("foo-output"); 251 252 RegressionSignature* bar_regression_signature = 253 (*signatures.mutable_named_signatures())["bar"] 254 .mutable_regression_signature(); 255 bar_regression_signature->mutable_input()->set_tensor_name("bar-input"); 256 bar_regression_signature->mutable_output()->set_tensor_name("bar-output"); 257 258 MetaGraphDef meta_graph_def; 259 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 260 .mutable_any_list() 261 ->add_value() 262 ->PackFrom(signatures); 263 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 264 ASSERT_EQ(2, meta_graph_def.signature_def_size()); 265 266 ValidateSignatureDef(meta_graph_def, "foo", 267 "inputs { " 268 " key: \"inputs\" " 269 " value { " 270 "name: \"foo-input\" " 271 " } " 272 "} " 273 "outputs { " 274 " key: \"outputs\" " 275 " value { " 276 " name: \"foo-output\" " 277 " } " 278 "} " 279 "method_name: \"tensorflow/serving/regress\" "); 280 ValidateSignatureDef(meta_graph_def, "bar", 281 "inputs { " 282 " key: \"inputs\" " 283 " value { " 284 "name: \"bar-input\" " 285 " } " 286 "} " 287 "outputs { " 288 " key: \"outputs\" " 289 " value { " 290 " name: \"bar-output\" " 291 " } " 292 "} " 293 "method_name: \"tensorflow/serving/regress\" "); 294} 295 296TEST(BundleShimTest, NamedClassificationSignatures) { 297 Signatures signatures; 298 299 ClassificationSignature* foo_classification_signature = 300 (*signatures.mutable_named_signatures())["foo"] 301 .mutable_classification_signature(); 302 foo_classification_signature->mutable_input()->set_tensor_name("foo-input"); 303 foo_classification_signature->mutable_classes()->set_tensor_name( 304 "foo-classes"); 305 306 ClassificationSignature* bar_classification_signature = 307 (*signatures.mutable_named_signatures())["bar"] 308 .mutable_classification_signature(); 309 bar_classification_signature->mutable_input()->set_tensor_name("bar-input"); 310 bar_classification_signature->mutable_scores()->set_tensor_name("bar-scores"); 311 312 MetaGraphDef meta_graph_def; 313 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 314 .mutable_any_list() 315 ->add_value() 316 ->PackFrom(signatures); 317 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 318 ASSERT_EQ(2, meta_graph_def.signature_def_size()); 319 320 ValidateSignatureDef(meta_graph_def, "foo", 321 "inputs { " 322 " key: \"inputs\" " 323 " value { " 324 "name: \"foo-input\" " 325 " } " 326 "} " 327 "outputs { " 328 " key: \"classes\" " 329 " value { " 330 " name: \"foo-classes\" " 331 " } " 332 "} " 333 "method_name: \"tensorflow/serving/classify\" "); 334 ValidateSignatureDef(meta_graph_def, "bar", 335 "inputs { " 336 " key: \"inputs\" " 337 " value { " 338 "name: \"bar-input\" " 339 " } " 340 "} " 341 "outputs { " 342 " key: \"scores\" " 343 " value { " 344 " name: \"bar-scores\" " 345 " } " 346 "} " 347 "method_name: \"tensorflow/serving/classify\" "); 348} 349 350// Checks the Predict SignatureDef created when the named signatures have 351// `inputs` and `outputs`. 352TEST(BundleShimTest, NamedSignatureGenericInputsAndOutputs) { 353 TensorBinding input_binding; 354 input_binding.set_tensor_name("foo-input"); 355 356 TensorBinding output_binding; 357 output_binding.set_tensor_name("foo-output"); 358 359 Signatures signatures; 360 GenericSignature* input_generic_signature = 361 (*signatures.mutable_named_signatures())[kPredictInputs] 362 .mutable_generic_signature(); 363 input_generic_signature->mutable_map()->insert({"foo-input", input_binding}); 364 365 GenericSignature* output_generic_signature = 366 (*signatures.mutable_named_signatures())[kPredictOutputs] 367 .mutable_generic_signature(); 368 output_generic_signature->mutable_map()->insert( 369 {"foo-output", output_binding}); 370 371 MetaGraphDef meta_graph_def; 372 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 373 .mutable_any_list() 374 ->add_value() 375 ->PackFrom(signatures); 376 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 377 EXPECT_EQ(1, meta_graph_def.signature_def_size()); 378 const auto actual_signature_def = 379 meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey); 380 ASSERT_FALSE(actual_signature_def == meta_graph_def.signature_def().end()); 381 ASSERT_FALSE(actual_signature_def->second.inputs().find("foo-input") == 382 actual_signature_def->second.inputs().end()); 383 EXPECT_EQ( 384 "foo-input", 385 actual_signature_def->second.inputs().find("foo-input")->second.name()); 386 ASSERT_FALSE(actual_signature_def->second.outputs().find("foo-output") == 387 actual_signature_def->second.outputs().end()); 388 EXPECT_EQ( 389 "foo-output", 390 actual_signature_def->second.outputs().find("foo-output")->second.name()); 391 EXPECT_EQ(kPredictMethodName, actual_signature_def->second.method_name()); 392} 393 394// Checks that a signature def is not added if the named signatures is generic 395// but does not have `inputs` and `outputs`. 396TEST(BundleShimTest, NamedSignatureGenericNoInputsOrOutputs) { 397 TensorBinding input_binding; 398 input_binding.set_tensor_name("foo-input"); 399 400 TensorBinding output_binding; 401 output_binding.set_tensor_name("foo-output"); 402 403 Signatures signatures; 404 GenericSignature* generic_signature = 405 (*signatures.mutable_named_signatures())["unknown"] 406 .mutable_generic_signature(); 407 generic_signature->mutable_map()->insert({kPredictInputs, input_binding}); 408 generic_signature->mutable_map()->insert({kPredictOutputs, output_binding}); 409 410 MetaGraphDef meta_graph_def; 411 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 412 .mutable_any_list() 413 ->add_value() 414 ->PackFrom(signatures); 415 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 416 EXPECT_EQ(0, meta_graph_def.signature_def_size()); 417} 418 419// Checks that a signature def is not added when the named signatures have only 420// one of `inputs` and `outputs`. 421TEST(BundleShimTest, NamedSignatureGenericOnlyInput) { 422 TensorBinding input_binding; 423 input_binding.set_tensor_name("foo-input"); 424 425 Signatures signatures; 426 GenericSignature* input_generic_signature = 427 (*signatures.mutable_named_signatures())[kPredictInputs] 428 .mutable_generic_signature(); 429 input_generic_signature->mutable_map()->insert({"foo-input", input_binding}); 430 431 MetaGraphDef meta_graph_def; 432 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 433 .mutable_any_list() 434 ->add_value() 435 ->PackFrom(signatures); 436 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 437 EXPECT_EQ(0, meta_graph_def.signature_def_size()); 438} 439 440// Tests up-conversion of Signatures to SignatureDefs when both `default` and 441// `named` signatures are present. 442TEST(BundleShimTest, DefaultAndNamedSignatureWithPredict) { 443 Signatures signatures; 444 445 // Build a generic signature corresponding to `inputs` and add it to the 446 // Signatures to up-convert. 447 TensorBinding input_binding; 448 input_binding.set_tensor_name("foo-input"); 449 GenericSignature* input_generic_signature = 450 (*signatures.mutable_named_signatures())[kPredictInputs] 451 .mutable_generic_signature(); 452 input_generic_signature->mutable_map()->insert({"foo-input", input_binding}); 453 454 // Build a generic signature corresponding to `outputs` and add it to the 455 // Signatures to up-convert. 456 TensorBinding output_binding; 457 output_binding.set_tensor_name("foo-output"); 458 GenericSignature* output_generic_signature = 459 (*signatures.mutable_named_signatures())[kPredictOutputs] 460 .mutable_generic_signature(); 461 output_generic_signature->mutable_map()->insert( 462 {"foo-output", output_binding}); 463 464 // Build a regression signature and set it as the default signature. 465 RegressionSignature* inputs_regression_signature = 466 (*signatures.mutable_default_signature()).mutable_regression_signature(); 467 inputs_regression_signature->mutable_input()->set_tensor_name("bar-input"); 468 469 // Up-convert the available signatures to SignatureDefs. 470 MetaGraphDef meta_graph_def; 471 (*meta_graph_def.mutable_collection_def())[kSignaturesKey] 472 .mutable_any_list() 473 ->add_value() 474 ->PackFrom(signatures); 475 TF_EXPECT_OK(ConvertSignaturesToSignatureDefs(&meta_graph_def)); 476 EXPECT_EQ(2, meta_graph_def.signature_def_size()); 477 478 // Verify that the default regression signature is converted to a 479 // SignatureDef that corresponds to the kDefaultServingSignatureDefKey. 480 const auto actual_signature_def_regress = 481 meta_graph_def.signature_def().find(kDefaultServingSignatureDefKey); 482 ASSERT_FALSE(actual_signature_def_regress == 483 meta_graph_def.signature_def().end()); 484 ASSERT_FALSE( 485 actual_signature_def_regress->second.inputs().find(kRegressInputs) == 486 actual_signature_def_regress->second.inputs().end()); 487 488 // Verify that the `Predict` SignatureDef is created under a different key. 489 const auto actual_signature_def_predict = meta_graph_def.signature_def().find( 490 strings::StrCat(kDefaultServingSignatureDefKey, "_from_named")); 491 ASSERT_FALSE(actual_signature_def_predict == 492 meta_graph_def.signature_def().end()); 493 ASSERT_FALSE( 494 actual_signature_def_predict->second.inputs().find("foo-input") == 495 actual_signature_def_predict->second.inputs().end()); 496 EXPECT_EQ("foo-input", actual_signature_def_predict->second.inputs() 497 .find("foo-input") 498 ->second.name()); 499 ASSERT_FALSE( 500 actual_signature_def_predict->second.outputs().find("foo-output") == 501 actual_signature_def_predict->second.outputs().end()); 502 EXPECT_EQ("foo-output", actual_signature_def_predict->second.outputs() 503 .find("foo-output") 504 ->second.name()); 505 EXPECT_EQ(kPredictMethodName, 506 actual_signature_def_predict->second.method_name()); 507} 508 509// Checks a basic up conversion for half plus two for SessionBundle. 510TEST(BundleShimTest, BasicExportSessionBundle) { 511 const std::unordered_set<string> tags = {"tag"}; 512 const string session_bundle_export_dir = 513 test_util::TestSrcDirPath(kSessionBundlePath); 514 LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags, 515 kDefaultServingSignatureDefKey); 516 517 // Verify that the named signature is also present. 518 SessionOptions session_options; 519 RunOptions run_options; 520 SavedModelBundle saved_model_bundle; 521 TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(session_options, run_options, 522 session_bundle_export_dir, 523 tags, &saved_model_bundle)); 524 const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def; 525 const auto& signature_def_map = meta_graph_def.signature_def(); 526 bool found_named_signature = false; 527 for (const auto& entry : signature_def_map) { 528 const string& key = entry.first; 529 const SignatureDef& signature_def = entry.second; 530 531 // We're looking for the key that is *not* kDefaultServingSignatureDefKey. 532 if (key == kDefaultServingSignatureDefKey) { 533 continue; 534 } 535 found_named_signature = true; 536 537 EXPECT_EQ(1, signature_def.inputs_size()); 538 const auto it_inputs_x = signature_def.inputs().find("x"); 539 EXPECT_FALSE(it_inputs_x == signature_def.inputs().end()); 540 // Ensure the TensorInfo has name and dtype populated. 541 const TensorInfo& tensor_info_x = it_inputs_x->second; 542 EXPECT_EQ("x:0", tensor_info_x.name()); 543 EXPECT_EQ(DT_FLOAT, tensor_info_x.dtype()); 544 545 EXPECT_EQ(1, signature_def.outputs_size()); 546 const auto it_outputs_y = signature_def.outputs().find("y"); 547 EXPECT_FALSE(it_outputs_y == signature_def.outputs().end()); 548 // Ensure the TensorInfo has name and dtype populated. 549 const TensorInfo& tensor_info_y = it_outputs_y->second; 550 EXPECT_EQ("y:0", tensor_info_y.name()); 551 EXPECT_EQ(DT_FLOAT, tensor_info_y.dtype()); 552 } 553 EXPECT_TRUE(found_named_signature); 554} 555 556// Checks a basic load for half plus two for SavedModelBundle. 557TEST(BundleShimTest, BasicExportSavedModel) { 558 const string saved_model_bundle_export_dir = 559 io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath); 560 LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir, 561 {kSavedModelTagServe}, "regress_x_to_y"); 562} 563 564// Checks a basic load fails with an invalid export path. 565TEST(BundleShimTest, InvalidPath) { 566 const string invalid_export_dir = testing::TensorFlowSrcRoot(); 567 SessionOptions session_options; 568 RunOptions run_options; 569 SavedModelBundle saved_model_bundle; 570 Status status = LoadSessionBundleOrSavedModelBundle( 571 session_options, run_options, invalid_export_dir, {kSavedModelTagServe}, 572 &saved_model_bundle); 573 EXPECT_EQ(error::Code::NOT_FOUND, status.code()); 574} 575 576// Checks that if loading a session bundle fails, the error is propagated to 577// LoadSessionBundleOrSavedModelBundle(). 578TEST(BundleShimTest, LoadSessionBundleError) { 579 const string session_bundle_export_dir = 580 test_util::TestSrcDirPath(kSessionBundlePath); 581 SessionOptions session_options; 582 RunOptions run_options; 583 // Invalid threadpool index to use for session-run calls. 584 run_options.set_inter_op_thread_pool(100); 585 SavedModelBundle saved_model_bundle; 586 EXPECT_FALSE(LoadSessionBundleOrSavedModelBundle(session_options, run_options, 587 session_bundle_export_dir, 588 {"tag"}, &saved_model_bundle) 589 .ok()); 590} 591 592} // namespace 593} // namespace internal 594} // namespace serving 595} // namespace tensorflow 596