1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/while_loop.h"
17#include "tensorflow/cc/client/client_session.h"
18#include "tensorflow/cc/ops/standard_ops.h"
19#include "tensorflow/core/framework/tensor_testutil.h"
20#include "tensorflow/core/graph/while_context.h"
21#include "tensorflow/core/lib/core/status_test_util.h"
22#include "tensorflow/core/platform/test.h"
23
24namespace tensorflow {
25
26namespace {
27
28class WhileLoopTest : public ::testing::Test {
29 protected:
30  WhileLoopTest() : scope_(Scope::NewRootScope()) {}
31
32  void Init(int num_inputs, DataType dtype = DT_INT32) {
33    for (int i = 0; i < num_inputs; ++i) {
34      inputs_.push_back(ops::Placeholder(scope_, dtype));
35    }
36  }
37
38  void CreateLoop(const ops::CondGraphBuilderFn& cond,
39                  const ops::BodyGraphBuilderFn& body,
40                  error::Code error_code = error::OK,
41                  const string& error_msg = "") {
42    Status s =
43        ops::BuildWhileLoop(scope_, inputs_, cond, body, kFrameName, &outputs_);
44    EXPECT_EQ(s.code(), error_code);
45    EXPECT_EQ(s.error_message(), error_msg);
46  }
47
48  template <typename T>
49  void Run(const std::vector<Input::Initializer>& input_values,
50           const std::vector<T>& expected_output_values) {
51    ClientSession session(scope_);
52
53    DCHECK_EQ(input_values.size(), inputs_.size());
54    ClientSession::FeedType feeds;
55    for (int i = 0; i < inputs_.size(); ++i) {
56      feeds.emplace(inputs_[i], input_values[i]);
57    }
58
59    std::vector<Tensor> out_tensors;
60    TF_ASSERT_OK(session.Run(feeds, outputs_, &out_tensors));
61    ASSERT_EQ(out_tensors.size(), outputs_.size());
62
63    DCHECK_EQ(expected_output_values.size(), out_tensors.size());
64    for (int i = 0; i < out_tensors.size(); ++i) {
65      test::ExpectTensorEqual<T>(
66          out_tensors[i], test::AsTensor<T>({expected_output_values[i]}, {}));
67    }
68  }
69
70  Scope scope_;
71  std::vector<Output> inputs_;
72  std::vector<Output> outputs_;
73
74  static const char* const kFrameName;
75};
76
77const char* const WhileLoopTest::kFrameName = "test_loop";
78
79Status LessThanTenCond(const Scope& s, const std::vector<Output>& inputs,
80                       Output* output) {
81  *output = ops::Less(s, inputs[0], 10);
82  return s.status();
83}
84
85Status AddOneBody(const Scope& s, const std::vector<Output>& inputs,
86                  std::vector<Output>* outputs) {
87  outputs->push_back(ops::Add(s, inputs[0], 1));
88  return s.status();
89}
90
91TEST_F(WhileLoopTest, Basic) {
92  // Create loop: while (i < 10) i += 1
93  Init(1);
94  CreateLoop(LessThanTenCond, AddOneBody);
95
96  // Verify some output invariants
97  WhileContext* while_ctx;
98  for (int i = 0; i < outputs_.size(); ++i) {
99    Node* node = outputs_[i].node();
100    ASSERT_TRUE(node->IsExit()) << "Output node " << i << ":\n"
101                                << node->DebugString();
102    ASSERT_TRUE(node->while_ctx() != nullptr) << i;
103    if (i == 0) {
104      while_ctx = node->while_ctx();
105      EXPECT_EQ(while_ctx->frame_name(), kFrameName);
106    } else {
107      EXPECT_EQ(node->while_ctx(), while_ctx) << i;
108    }
109  }
110
111  // Run the loop and test we get the expected results
112  Run<int>({1}, {10});
113  Run<int>({11}, {11});
114}
115
116TEST_F(WhileLoopTest, WrongCondOutputType) {
117  Init(1);
118  CreateLoop(
119      [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
120        *output = ops::Placeholder(s, DT_FLOAT);
121        return s.status();
122      },
123      AddOneBody, error::INVALID_ARGUMENT,
124      "BuildWhileLoop: 'cond' argument must return a boolean output, got "
125      "float");
126}
127
128// TODO(skyewm): test bad cond output shape
129
130TEST_F(WhileLoopTest, NullCondOutputNode) {
131  Init(1);
132  // TODO(skyewm): improve error message
133  CreateLoop(
134      [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
135        *output = {nullptr, 0};
136        return s.status();
137      },
138      AddOneBody, error::INVALID_ARGUMENT, "Node is null");
139}
140
141TEST_F(WhileLoopTest, InvalidCondOutputIndex) {
142  Init(1);
143  CreateLoop(
144      [](const Scope& s, const std::vector<Output>& inputs, Output* output) {
145        auto less = ops::Less(s, inputs[0], 10);
146        *output = {less.node(), 100};
147        return s.status();
148      },
149      AddOneBody, error::OUT_OF_RANGE,
150      "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output "
151      "100");
152}
153
154TEST_F(WhileLoopTest, UnsetCondOutput) {
155  Init(1);
156  CreateLoop([](const Scope& s, const std::vector<Output>& inputs,
157                Output* output) { return s.status(); },
158             AddOneBody, error::INVALID_ARGUMENT, "Node is null");
159}
160
161// TODO(skyewm): test bad body output type
162// TODO(skyewm): test bad body output shape
163
164TEST_F(WhileLoopTest, NullBodyOutputNode) {
165  Init(1);
166  // TODO(skyewm): improve error message
167  CreateLoop(LessThanTenCond,
168             [](const Scope& s, const std::vector<Output>& inputs,
169                std::vector<Output>* outputs) {
170               outputs->push_back({nullptr, 0});
171               return s.status();
172             },
173             error::INVALID_ARGUMENT, "Node is null");
174}
175
176TEST_F(WhileLoopTest, InvalidBodyOutputIndex) {
177  Init(1);
178  CreateLoop(LessThanTenCond,
179             [](const Scope& s, const std::vector<Output>& inputs,
180                std::vector<Output>* outputs) {
181               auto add = ops::Add(s, inputs[0], 1);
182               outputs->emplace_back(add.node(), 100);
183               return s.status();
184             },
185             error::OUT_OF_RANGE,
186             "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have "
187             "output 100");
188}
189
190TEST_F(WhileLoopTest, UnsetBodyOutputs) {
191  Init(1);
192  CreateLoop(
193      LessThanTenCond,
194      [](const Scope& s, const std::vector<Output>& inputs,
195         std::vector<Output>* outputs) { return s.status(); },
196      error::INVALID_ARGUMENT,
197      "BuildWhileLoop: 'body' argument expected to return 1 output(s), got 0");
198}
199
200}  // namespace
201}  // namespace tensorflow
202