1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/cc/ops/while_loop.h" 17#include "tensorflow/cc/client/client_session.h" 18#include "tensorflow/cc/ops/standard_ops.h" 19#include "tensorflow/core/framework/tensor_testutil.h" 20#include "tensorflow/core/graph/while_context.h" 21#include "tensorflow/core/lib/core/status_test_util.h" 22#include "tensorflow/core/platform/test.h" 23 24namespace tensorflow { 25 26namespace { 27 28class WhileLoopTest : public ::testing::Test { 29 protected: 30 WhileLoopTest() : scope_(Scope::NewRootScope()) {} 31 32 void Init(int num_inputs, DataType dtype = DT_INT32) { 33 for (int i = 0; i < num_inputs; ++i) { 34 inputs_.push_back(ops::Placeholder(scope_, dtype)); 35 } 36 } 37 38 void CreateLoop(const ops::CondGraphBuilderFn& cond, 39 const ops::BodyGraphBuilderFn& body, 40 error::Code error_code = error::OK, 41 const string& error_msg = "") { 42 Status s = 43 ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_); 44 EXPECT_EQ(s.code(), error_code); 45 EXPECT_EQ(s.error_message(), error_msg); 46 } 47 48 template <typename T> 49 void Run(const std::vector<Input::Initializer>& input_values, 50 const std::vector<T>& expected_output_values) { 51 ClientSession session(scope_); 52 53 DCHECK_EQ(input_values.size(), inputs_.size()); 54 ClientSession::FeedType feeds; 55 for (int i = 0; i < inputs_.size(); ++i) { 56 feeds.emplace(inputs_[i], input_values[i]); 57 } 58 59 std::vector<Tensor> out_tensors; 60 TF_ASSERT_OK(session.Run(feeds, outputs_, &out_tensors)); 61 ASSERT_EQ(out_tensors.size(), outputs_.size()); 62 63 DCHECK_EQ(expected_output_values.size(), out_tensors.size()); 64 for (int i = 0; i < out_tensors.size(); ++i) { 65 test::ExpectTensorEqual<T>( 66 out_tensors[i], test::AsTensor<T>({expected_output_values[i]}, {})); 67 } 68 } 69 70 Scope scope_; 71 std::vector<Output> inputs_; 72 std::vector<Output> outputs_; 73 74 static const char* const kFrameName; 75}; 76 77const char* const WhileLoopTest::kFrameName = "test_loop"; 78 79Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs, 80 Output* output) { 81 *output = ops::Less(s, inputs[0], 10); 82 return s.status(); 83} 84 85Status AddOneBody(const Scope& s, const std::vector<Output>& inputs, 86 std::vector<Output>* outputs) { 87 outputs->push_back(ops::Add(s, inputs[0], 1)); 88 return s.status(); 89} 90 91TEST_F(WhileLoopTest, Basic) { 92 // Create loop: while (i < 10) i += 1 93 Init(1); 94 CreateLoop(LessThanTenCond, AddOneBody); 95 96 // Verify some output invariants 97 WhileContext* while_ctx; 98 for (int i = 0; i < outputs_.size(); ++i) { 99 Node* node = outputs_[i].node(); 100 ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n" 101 << node->DebugString(); 102 ASSERT_TRUE(node->while_ctx() != nullptr) << i; 103 if (i == 0) { 104 while_ctx = node->while_ctx(); 105 EXPECT_EQ(while_ctx->frame_name(), kFrameName); 106 } else { 107 EXPECT_EQ(node->while_ctx(), while_ctx) << i; 108 } 109 } 110 111 // Run the loop and test we get the expected results 112 Run<int>({1}, {10}); 113 Run<int>({11}, {11}); 114} 115 116TEST_F(WhileLoopTest, WrongCondOutputType) { 117 Init(1); 118 CreateLoop( 119 [](const Scope& s, const std::vector<Output>& inputs, Output* output) { 120 *output = ops::Placeholder(s, DT_FLOAT); 121 return s.status(); 122 }, 123 AddOneBody, error::INVALID_ARGUMENT, 124 "BuildWhileLoop: 'cond' argument must return a boolean output, got " 125 "float"); 126} 127 128// TODO(skyewm): test bad cond output shape 129 130TEST_F(WhileLoopTest, NullCondOutputNode) { 131 Init(1); 132 // TODO(skyewm): improve error message 133 CreateLoop( 134 [](const Scope& s, const std::vector<Output>& inputs, Output* output) { 135 *output = {nullptr, 0}; 136 return s.status(); 137 }, 138 AddOneBody, error::INVALID_ARGUMENT, "Node is null"); 139} 140 141TEST_F(WhileLoopTest, InvalidCondOutputIndex) { 142 Init(1); 143 CreateLoop( 144 [](const Scope& s, const std::vector<Output>& inputs, Output* output) { 145 auto less = ops::Less(s, inputs[0], 10); 146 *output = {less.node(), 100}; 147 return s.status(); 148 }, 149 AddOneBody, error::OUT_OF_RANGE, 150 "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output " 151 "100"); 152} 153 154TEST_F(WhileLoopTest, UnsetCondOutput) { 155 Init(1); 156 CreateLoop([](const Scope& s, const std::vector<Output>& inputs, 157 Output* output) { return s.status(); }, 158 AddOneBody, error::INVALID_ARGUMENT, "Node is null"); 159} 160 161// TODO(skyewm): test bad body output type 162// TODO(skyewm): test bad body output shape 163 164TEST_F(WhileLoopTest, NullBodyOutputNode) { 165 Init(1); 166 // TODO(skyewm): improve error message 167 CreateLoop(LessThanTenCond, 168 [](const Scope& s, const std::vector<Output>& inputs, 169 std::vector<Output>* outputs) { 170 outputs->push_back({nullptr, 0}); 171 return s.status(); 172 }, 173 error::INVALID_ARGUMENT, "Node is null"); 174} 175 176TEST_F(WhileLoopTest, InvalidBodyOutputIndex) { 177 Init(1); 178 CreateLoop(LessThanTenCond, 179 [](const Scope& s, const std::vector<Output>& inputs, 180 std::vector<Output>* outputs) { 181 auto add = ops::Add(s, inputs[0], 1); 182 outputs->emplace_back(add.node(), 100); 183 return s.status(); 184 }, 185 error::OUT_OF_RANGE, 186 "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have " 187 "output 100"); 188} 189 190TEST_F(WhileLoopTest, UnsetBodyOutputs) { 191 Init(1); 192 CreateLoop( 193 LessThanTenCond, 194 [](const Scope& s, const std::vector<Output>& inputs, 195 std::vector<Output>* outputs) { return s.status(); }, 196 error::INVALID_ARGUMENT, 197 "BuildWhileLoop: 'body' argument expected to return 1 output(s), got 0"); 198} 199 200} // namespace 201} // namespace tensorflow 202