1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
17#include "tensorflow/cc/ops/standard_ops.h"
18#include "tensorflow/core/framework/node_def.pb.h"
19#include "tensorflow/core/framework/tensor_testutil.h"
20#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
21#include "tensorflow/core/grappler/grappler_item.h"
22#include "tensorflow/core/grappler/utils.h"
23#include "tensorflow/core/lib/core/status_test_util.h"
24#include "tensorflow/core/lib/strings/strcat.h"
25#include "tensorflow/core/platform/test.h"
26#include "tensorflow/core/protobuf/device_properties.pb.h"
27
28namespace tensorflow {
29namespace grappler {
30namespace {
31
32class LayoutOptimizerTest : public ::testing::Test {
33 protected:
34  void SetUp() override {
35    DeviceProperties device_properties;
36    device_properties.set_type("GPU");
37    device_properties.mutable_environment()->insert({"architecture", "6"});
38    virtual_cluster_.reset(new VirtualCluster({{"/GPU:0", device_properties}}));
39  }
40
41  Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
42                      const string& padding) {
43    return SimpleConv2D(s, input_size, filter_size, padding, "");
44  }
45
46  Output SimpleConv2D(tensorflow::Scope* s, int input_size, int filter_size,
47                      const string& padding, const string& device) {
48    int batch_size = 8;
49    int input_height = input_size;
50    int input_width = input_size;
51    int input_depth = 3;
52    int filter_count = 2;
53    int stride = 1;
54    TensorShape input_shape(
55        {batch_size, input_height, input_width, input_depth});
56    Tensor input_data(DT_FLOAT, input_shape);
57    test::FillIota<float>(&input_data, 1.0f);
58    Output input =
59        ops::Const(s->WithOpName("Input"), Input::Initializer(input_data));
60
61    TensorShape filter_shape(
62        {filter_size, filter_size, input_depth, filter_count});
63    Tensor filter_data(DT_FLOAT, filter_shape);
64    test::FillIota<float>(&filter_data, 1.0f);
65    Output filter =
66        ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
67
68    Output conv = ops::Conv2D(s->WithOpName("Conv2D").WithDevice(device), input,
69                              filter, {1, stride, stride, 1}, padding);
70    return conv;
71  }
72
73  Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
74                                   int filter_size, const string& padding) {
75    return SimpleConv2DBackpropInput(s, input_size, filter_size, padding, true);
76  }
77
78  Output SimpleConv2DBackpropInput(tensorflow::Scope* s, int input_size,
79                                   int filter_size, const string& padding,
80                                   bool const_input_size) {
81    int batch_size = 128;
82    int input_height = input_size;
83    int input_width = input_size;
84    int input_depth = 3;
85    int filter_count = 2;
86    int stride = 1;
87    TensorShape input_sizes_shape({4});
88    Tensor input_data(DT_INT32, input_sizes_shape);
89    test::FillValues<int>(&input_data,
90                          {batch_size, input_height, input_width, input_depth});
91    Output input_sizes =
92        ops::Const(s->WithOpName("InputSizes"), Input::Initializer(input_data));
93
94    TensorShape filter_shape(
95        {filter_size, filter_size, input_depth, filter_count});
96    Tensor filter_data(DT_FLOAT, filter_shape);
97    test::FillIota<float>(&filter_data, 1.0f);
98    Output filter =
99        ops::Const(s->WithOpName("Filter"), Input::Initializer(filter_data));
100
101    int output_height = input_height;
102    int output_width = input_width;
103    TensorShape output_shape(
104        {batch_size, output_height, output_width, filter_count});
105    Tensor output_data(DT_FLOAT, output_shape);
106    test::FillIota<float>(&output_data, 1.0f);
107    Output output =
108        ops::Const(s->WithOpName("Output"), Input::Initializer(output_data));
109
110    Output conv_backprop_input;
111    Output input_sizes_i =
112        ops::Identity(s->WithOpName("InputSizesIdentity"), input_sizes);
113    if (const_input_size) {
114      conv_backprop_input = ops::Conv2DBackpropInput(
115          s->WithOpName("Conv2DBackpropInput"), input_sizes, filter, output,
116          {1, stride, stride, 1}, padding);
117    } else {
118      conv_backprop_input = ops::Conv2DBackpropInput(
119          s->WithOpName("Conv2DBackpropInput"), input_sizes_i, filter, output,
120          {1, stride, stride, 1}, padding);
121    }
122    return conv_backprop_input;
123  }
124
125  Tensor GetAttrValue(const NodeDef& node) {
126    Tensor tensor;
127    CHECK(tensor.FromProto(node.attr().at({"value"}).tensor()));
128    return tensor;
129  }
130
131  Output SimpleFusedBatchNormGrad(tensorflow::Scope* s, bool is_training) {
132    int batch_size = 16;
133    int input_height = 8;
134    int input_width = 8;
135    int input_channels = 3;
136    TensorShape shape({batch_size, input_height, input_width, input_channels});
137    Tensor data(DT_FLOAT, shape);
138    test::FillIota<float>(&data, 1.0f);
139    Output x = ops::Const(s->WithOpName("Input"), Input::Initializer(data));
140    Output y_backprop =
141        ops::Const(s->WithOpName("YBackprop"), Input::Initializer(data));
142
143    TensorShape shape_vector({input_channels});
144    Tensor data_vector(DT_FLOAT, shape_vector);
145    test::FillIota<float>(&data_vector, 2.0f);
146    Output scale =
147        ops::Const(s->WithOpName("Scale"), Input::Initializer(data_vector));
148    Output reserve1 =
149        ops::Const(s->WithOpName("Reserve1"), Input::Initializer(data_vector));
150    Output reserve2 =
151        ops::Const(s->WithOpName("Reserve2"), Input::Initializer(data_vector));
152
153    ops::FusedBatchNormGrad::Attrs attrs;
154    attrs.is_training_ = is_training;
155    auto output =
156        ops::FusedBatchNormGrad(s->WithOpName("FusedBatchNormGrad"), y_backprop,
157                                x, scale, reserve1, reserve2, attrs);
158    return output.x_backprop;
159  }
160
161  std::unique_ptr<VirtualCluster> virtual_cluster_;
162};
163
164TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
165  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
166  auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME");
167  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
168  GrapplerItem item;
169  TF_CHECK_OK(s.ToGraphDef(&item.graph));
170  LayoutOptimizer optimizer;
171  GraphDef output;
172
173  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
174  NodeMap node_map(&output);
175  string input_name = "Conv2DBackpropInput-0-LayoutOptimizer";
176  auto input_sizes_node = node_map.GetNode(input_name);
177  CHECK(input_sizes_node);
178  auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
179  CHECK(conv2d_backprop_node);
180  EXPECT_EQ(input_name, conv2d_backprop_node->input(0));
181  auto input_sizes = GetAttrValue(*input_sizes_node);
182  Tensor input_sizes_expected(DT_INT32, {4});
183  test::FillValues<int>(&input_sizes_expected, {128, 3, 7, 7});
184  test::ExpectTensorEqual<int>(input_sizes_expected, input_sizes);
185}
186
187TEST_F(LayoutOptimizerTest, Conv2DBackpropInputNonConstInputSizes) {
188  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
189  auto conv = SimpleConv2DBackpropInput(&s, 7, 2, "SAME", false);
190  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
191  GrapplerItem item;
192  TF_CHECK_OK(s.ToGraphDef(&item.graph));
193  LayoutOptimizer optimizer;
194  GraphDef output;
195
196  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
197  NodeMap node_map(&output);
198  auto conv2d_backprop_node = node_map.GetNode("Conv2DBackpropInput");
199  CHECK(conv2d_backprop_node);
200  EXPECT_EQ(conv2d_backprop_node->input(0),
201            "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
202  auto input_sizes_node = node_map.GetNode(
203      "Conv2DBackpropInput-0-VecPermuteNHWCToNCHW-LayoutOptimizer");
204  CHECK(input_sizes_node);
205  EXPECT_EQ(input_sizes_node->input(0), "InputSizesIdentity");
206  EXPECT_EQ(input_sizes_node->op(), "DataFormatVecPermute");
207}
208
209TEST_F(LayoutOptimizerTest, FilterSizeIsOne) {
210  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
211  auto conv = SimpleConv2D(&s, 2, 1, "SAME");
212  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
213  GrapplerItem item;
214  TF_CHECK_OK(s.ToGraphDef(&item.graph));
215  LayoutOptimizer optimizer;
216  GraphDef output;
217  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
218  NodeMap node_map(&output);
219  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
220}
221
222TEST_F(LayoutOptimizerTest, FilterSizeNotOne) {
223  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
224  auto conv = SimpleConv2D(&s, 2, 1, "SAME");
225  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
226  GrapplerItem item;
227  TF_CHECK_OK(s.ToGraphDef(&item.graph));
228  LayoutOptimizer optimizer;
229  GraphDef output;
230  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
231  NodeMap node_map(&output);
232  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
233}
234
235TEST_F(LayoutOptimizerTest, EqualSizeWithValidPadding) {
236  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
237  auto conv = SimpleConv2D(&s, 2, 2, "VALID");
238  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
239  GrapplerItem item;
240  TF_CHECK_OK(s.ToGraphDef(&item.graph));
241  LayoutOptimizer optimizer;
242  GraphDef output;
243  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
244  NodeMap node_map(&output);
245  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
246}
247
248TEST_F(LayoutOptimizerTest, EqualSizeWithSamePadding) {
249  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
250  auto conv = SimpleConv2D(&s, 2, 2, "SAME");
251  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
252  GrapplerItem item;
253  TF_CHECK_OK(s.ToGraphDef(&item.graph));
254  LayoutOptimizer optimizer;
255  GraphDef output;
256  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
257  NodeMap node_map(&output);
258  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
259}
260
261TEST_F(LayoutOptimizerTest, NotEqualSizeWithValidPadding) {
262  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
263  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
264  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
265  GrapplerItem item;
266  TF_CHECK_OK(s.ToGraphDef(&item.graph));
267  LayoutOptimizer optimizer;
268  GraphDef output;
269  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
270  NodeMap node_map(&output);
271  EXPECT_TRUE(node_map.GetNode("Conv2D-0-TransposeNHWCToNCHW-LayoutOptimizer"));
272}
273
274TEST_F(LayoutOptimizerTest, Pad) {
275  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
276  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
277  auto c = ops::Const(s.WithOpName("c"), {1, 2, 3, 4, 5, 6, 7, 8}, {4, 2});
278  auto p = ops::Pad(s.WithOpName("p"), conv, c);
279  auto o = ops::Identity(s.WithOpName("o"), p);
280  GrapplerItem item;
281  TF_CHECK_OK(s.ToGraphDef(&item.graph));
282  LayoutOptimizer optimizer;
283  GraphDef output;
284  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
285  NodeMap node_map(&output);
286
287  auto pad = node_map.GetNode("p");
288  EXPECT_EQ(pad->input(0), "Conv2D");
289
290  auto pad_const = node_map.GetNode("p-1-LayoutOptimizer");
291  EXPECT_TRUE(pad_const);
292  EXPECT_TRUE(pad_const->attr().find("value") != pad_const->attr().end());
293  Tensor tensor;
294  EXPECT_TRUE(
295      tensor.FromProto(pad_const->mutable_attr()->at({"value"}).tensor()));
296  Tensor tensor_expected(DT_INT32, {4, 2});
297  test::FillValues<int>(&tensor_expected, {1, 2, 7, 8, 3, 4, 5, 6});
298  test::ExpectTensorEqual<int>(tensor_expected, tensor);
299}
300
301TEST_F(LayoutOptimizerTest, Connectivity) {
302  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
303  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
304  auto i1 = ops::Identity(s.WithOpName("i1"), conv);
305  auto i2 = ops::Identity(s.WithOpName("i2"), i1);
306  auto i3 = ops::Identity(s.WithOpName("i3"), i2);
307  GrapplerItem item;
308  TF_CHECK_OK(s.ToGraphDef(&item.graph));
309  // Make the graph not in topological order to test the handling of multi-hop
310  // connectivity (here we say two nodes are connected if all nodes in the
311  // middle are layout agnostic). If the graph is already in topological order,
312  // the problem is easier, where layout optimizer only needs to check
313  // single-hop connectivity.
314  NodeMap node_map_original(&item.graph);
315  auto node_i1 = node_map_original.GetNode("i1");
316  auto node_i2 = node_map_original.GetNode("i2");
317  node_i2->Swap(node_i1);
318  LayoutOptimizer optimizer;
319  GraphDef output;
320  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
321  NodeMap node_map_output(&output);
322  auto node_i2_output = node_map_output.GetNode("i2");
323  // Layout optimizer should process i2, as it detects i2 is connected with the
324  // Conv2D node two hops away. Similarly i1 is processed as well, as i1 is
325  // directly connected to the Conv2D node. The two added transposes between
326  // i1 and i2 should cancel each other, and as a result i2 is directly
327  // connected to i1.
328  EXPECT_EQ(node_i2_output->input(0), "i1");
329}
330
331TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) {
332  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
333  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
334  auto i1 = ops::Identity(s.WithOpName("i1"), conv);
335  auto i2 = ops::Identity(s.WithOpName("i2"), i1);
336  auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {});
337  auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2);
338  auto i3 = ops::Identity(s.WithOpName("i3"), sub);
339  auto i4 = ops::Identity(s.WithOpName("i4"), i3);
340  auto i5 = ops::Identity(s.WithOpName("i5"), i4);
341  auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {});
342  auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5);
343  auto i6 = ops::Identity(s.WithOpName("i6"), mul);
344  GrapplerItem item;
345  TF_CHECK_OK(s.ToGraphDef(&item.graph));
346  // Make the graph not in topological order to test the handling of multi-hop
347  // connectivity (here we say two nodes are connected if all nodes in the
348  // middle are layout agnostic). If the graph is already in topological order,
349  // the problem is easier, where layout optimizer only needs to check
350  // single-hop connectivity.
351  NodeMap node_map_original(&item.graph);
352  auto node_i1 = node_map_original.GetNode("i1");
353  auto node_mul = node_map_original.GetNode("mul");
354  node_mul->Swap(node_i1);
355  LayoutOptimizer optimizer;
356  GraphDef output;
357  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
358  NodeMap node_map_output(&output);
359  auto mul_node = node_map_output.GetNode("mul");
360  EXPECT_EQ(mul_node->input(0), "scalar_mul");
361  EXPECT_EQ(mul_node->input(1), "i5");
362}
363
364TEST_F(LayoutOptimizerTest, PreserveFetch) {
365  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
366  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
367  auto i = ops::Identity(s.WithOpName("i"), conv);
368  GrapplerItem item;
369  item.fetch.push_back("Conv2D");
370  TF_CHECK_OK(s.ToGraphDef(&item.graph));
371  LayoutOptimizer optimizer;
372  GraphDef output;
373  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
374  NodeMap node_map(&output);
375  auto conv_node = node_map.GetNode("Conv2D");
376  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
377}
378
379TEST_F(LayoutOptimizerTest, EmptyDevice) {
380  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
381  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
382  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
383  GrapplerItem item;
384  TF_CHECK_OK(s.ToGraphDef(&item.graph));
385  LayoutOptimizer optimizer;
386  GraphDef output;
387  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
388  NodeMap node_map(&output);
389  auto conv_node = node_map.GetNode("Conv2D");
390  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
391}
392
393TEST_F(LayoutOptimizerTest, GPUDevice) {
394  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
395  auto conv =
396      SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:gpu:0");
397  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
398  GrapplerItem item;
399  TF_CHECK_OK(s.ToGraphDef(&item.graph));
400  LayoutOptimizer optimizer;
401  GraphDef output;
402  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
403  NodeMap node_map(&output);
404  auto conv_node = node_map.GetNode("Conv2D");
405  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
406}
407
408TEST_F(LayoutOptimizerTest, CPUDeviceLowercase) {
409  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
410  auto conv =
411      SimpleConv2D(&s, 4, 2, "VALID", "/job:w/replica:0/task:0/device:cpu:0");
412  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
413  GrapplerItem item;
414  TF_CHECK_OK(s.ToGraphDef(&item.graph));
415  LayoutOptimizer optimizer;
416  GraphDef output;
417  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
418  NodeMap node_map(&output);
419  auto conv_node = node_map.GetNode("Conv2D");
420  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
421}
422
423TEST_F(LayoutOptimizerTest, CPUDeviceUppercase) {
424  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
425  auto conv = SimpleConv2D(&s, 4, 2, "VALID", "/CPU:0");
426  Output fetch = ops::Identity(s.WithOpName("Fetch"), {conv});
427  GrapplerItem item;
428  TF_CHECK_OK(s.ToGraphDef(&item.graph));
429  LayoutOptimizer optimizer;
430  GraphDef output;
431  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
432  NodeMap node_map(&output);
433  auto conv_node = node_map.GetNode("Conv2D");
434  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
435}
436
437TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingTrue) {
438  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
439  auto x_backprop = SimpleFusedBatchNormGrad(&s, true);
440  Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
441  GrapplerItem item;
442  TF_CHECK_OK(s.ToGraphDef(&item.graph));
443  LayoutOptimizer optimizer;
444  GraphDef output;
445  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
446  NodeMap node_map(&output);
447  auto conv_node = node_map.GetNode("FusedBatchNormGrad");
448  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NCHW");
449}
450
451TEST_F(LayoutOptimizerTest, FusedBatchNormGradTrainingFalse) {
452  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
453  auto x_backprop = SimpleFusedBatchNormGrad(&s, false);
454  Output fetch = ops::Identity(s.WithOpName("Fetch"), {x_backprop});
455  GrapplerItem item;
456  TF_CHECK_OK(s.ToGraphDef(&item.graph));
457  LayoutOptimizer optimizer;
458  GraphDef output;
459  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
460  NodeMap node_map(&output);
461  auto conv_node = node_map.GetNode("FusedBatchNormGrad");
462  EXPECT_EQ(conv_node->attr().at({"data_format"}).s(), "NHWC");
463}
464
465TEST_F(LayoutOptimizerTest, SplitDimC) {
466  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
467  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
468  auto c = ops::Const(s.WithOpName("c"), 3, {});
469  auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
470  auto i = ops::Identity(s.WithOpName("i"), split[0]);
471  GrapplerItem item;
472  TF_CHECK_OK(s.ToGraphDef(&item.graph));
473  LayoutOptimizer optimizer;
474  GraphDef output;
475  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
476  NodeMap node_map(&output);
477  auto split_node = node_map.GetNode("split");
478  EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
479  EXPECT_EQ(split_node->input(1), "Conv2D");
480  auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
481  EXPECT_EQ(split_const->op(), "Const");
482  EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 1);
483}
484
485TEST_F(LayoutOptimizerTest, SplitDimH) {
486  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
487  auto conv = SimpleConv2D(&s, 6, 2, "SAME");
488  auto c = ops::Const(s.WithOpName("c"), 1, {});
489  auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
490  auto i = ops::Identity(s.WithOpName("i"), split[0]);
491  GrapplerItem item;
492  TF_CHECK_OK(s.ToGraphDef(&item.graph));
493  LayoutOptimizer optimizer;
494  GraphDef output;
495  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
496  NodeMap node_map(&output);
497  auto split_node = node_map.GetNode("split");
498  EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
499  EXPECT_EQ(split_node->input(1), "Conv2D");
500  auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
501  EXPECT_EQ(split_const->op(), "Const");
502  EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 2);
503}
504
505TEST_F(LayoutOptimizerTest, SplitDimW) {
506  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
507  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
508  auto c = ops::Const(s.WithOpName("c"), 2, {});
509  auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
510  auto i = ops::Identity(s.WithOpName("i"), split[0]);
511  GrapplerItem item;
512  TF_CHECK_OK(s.ToGraphDef(&item.graph));
513  LayoutOptimizer optimizer;
514  GraphDef output;
515  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
516  NodeMap node_map(&output);
517  auto split_node = node_map.GetNode("split");
518  EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
519  EXPECT_EQ(split_node->input(1), "Conv2D");
520  auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
521  EXPECT_EQ(split_const->op(), "Const");
522  EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 3);
523}
524
525TEST_F(LayoutOptimizerTest, SplitDimN) {
526  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
527  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
528  auto c = ops::Const(s.WithOpName("c"), 0, {});
529  auto split = ops::Split(s.WithOpName("split"), c, conv, 2);
530  auto i = ops::Identity(s.WithOpName("i"), split[0]);
531  GrapplerItem item;
532  TF_CHECK_OK(s.ToGraphDef(&item.graph));
533  LayoutOptimizer optimizer;
534  GraphDef output;
535  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
536  NodeMap node_map(&output);
537  auto split_node = node_map.GetNode("split");
538  EXPECT_EQ(split_node->input(0), "split-0-LayoutOptimizer");
539  EXPECT_EQ(split_node->input(1), "Conv2D");
540  auto split_const = node_map.GetNode("split-0-LayoutOptimizer");
541  EXPECT_EQ(split_const->op(), "Const");
542  EXPECT_EQ(split_const->attr().at({"value"}).tensor().int_val(0), 0);
543}
544
545TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
546  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
547  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
548  auto c = ops::Const(s.WithOpName("c"), 0, {});
549  auto i1 = ops::Identity(s.WithOpName("i1"), c);
550  auto split = ops::Split(s.WithOpName("split"), i1, conv, 2);
551  auto i2 = ops::Identity(s.WithOpName("i"), split[0]);
552  GrapplerItem item;
553  TF_CHECK_OK(s.ToGraphDef(&item.graph));
554  LayoutOptimizer optimizer;
555  GraphDef output;
556  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
557  NodeMap node_map(&output);
558  auto split_node = node_map.GetNode("split");
559  EXPECT_EQ(split_node->input(0), "split-0-DimMapNHWCToNCHW-LayoutOptimizer");
560  EXPECT_EQ(split_node->input(1), "Conv2D");
561  auto map_node = node_map.GetNode("split-0-DimMapNHWCToNCHW-LayoutOptimizer");
562  EXPECT_EQ(map_node->op(), "DataFormatDimMap");
563  EXPECT_EQ(map_node->input(0), "i1");
564}
565
566TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
567  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
568  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
569  auto axis = ops::Const(s.WithOpName("axis"), 3);
570  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
571  auto concat =
572      ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
573  auto o = ops::Identity(s.WithOpName("o"), concat);
574  GrapplerItem item;
575  TF_CHECK_OK(s.ToGraphDef(&item.graph));
576  LayoutOptimizer optimizer;
577  GraphDef output;
578  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
579  NodeMap node_map(&output);
580  auto concat_node = node_map.GetNode("concat");
581  EXPECT_EQ(concat_node->input(0), "split:1");
582  EXPECT_EQ(concat_node->input(1), "split:1");
583  EXPECT_EQ(concat_node->input(2), "split:1");
584  EXPECT_EQ(concat_node->input(3), "concat-3-LayoutOptimizer");
585  auto concat_dim = node_map.GetNode("concat-3-LayoutOptimizer");
586  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
587}
588
589TEST_F(LayoutOptimizerTest, ConcatDimH) {
590  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
591  auto conv = SimpleConv2D(&s, 4, 2, "SAME");
592  auto axis = ops::Const(s.WithOpName("axis"), 1);
593  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
594  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
595  auto o = ops::Identity(s.WithOpName("o"), concat);
596  GrapplerItem item;
597  TF_CHECK_OK(s.ToGraphDef(&item.graph));
598  LayoutOptimizer optimizer;
599  GraphDef output;
600  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
601  NodeMap node_map(&output);
602  auto concat_node = node_map.GetNode("concat");
603  EXPECT_EQ(concat_node->input(0), "split");
604  EXPECT_EQ(concat_node->input(1), "split:1");
605  EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
606  auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
607  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 2);
608}
609
610TEST_F(LayoutOptimizerTest, ConcatNonConst) {
611  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
612  auto conv = SimpleConv2D(&s, 4, 2, "SAME");
613  auto axis = ops::Const(s.WithOpName("axis"), 1);
614  auto i = ops::Identity(s.WithOpName("i"), axis);
615  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
616  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, i);
617  auto o = ops::Identity(s.WithOpName("o"), concat);
618  GrapplerItem item;
619  TF_CHECK_OK(s.ToGraphDef(&item.graph));
620  LayoutOptimizer optimizer;
621  GraphDef output;
622  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
623  NodeMap node_map(&output);
624  auto concat_node = node_map.GetNode("concat");
625  EXPECT_EQ(concat_node->input(0), "split");
626  EXPECT_EQ(concat_node->input(1), "split:1");
627  EXPECT_EQ(concat_node->input(2), "concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
628  auto concat_dim =
629      node_map.GetNode("concat-2-DimMapNHWCToNCHW-LayoutOptimizer");
630  EXPECT_EQ(concat_dim->op(), "DataFormatDimMap");
631  EXPECT_EQ(concat_dim->input(0), "i");
632}
633
634TEST_F(LayoutOptimizerTest, ConcatDimW) {
635  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
636  auto conv = SimpleConv2D(&s, 4, 2, "SAME");
637  auto axis = ops::Const(s.WithOpName("axis"), 2);
638  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
639  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
640  auto o = ops::Identity(s.WithOpName("o"), concat);
641  GrapplerItem item;
642  TF_CHECK_OK(s.ToGraphDef(&item.graph));
643  LayoutOptimizer optimizer;
644  GraphDef output;
645  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
646  NodeMap node_map(&output);
647  auto concat_node = node_map.GetNode("concat");
648  EXPECT_EQ(concat_node->input(0), "split");
649  EXPECT_EQ(concat_node->input(1), "split:1");
650  EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
651  auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
652  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 3);
653}
654
655TEST_F(LayoutOptimizerTest, ConcatDimN) {
656  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
657  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
658  auto axis = ops::Const(s.WithOpName("axis"), 0);
659  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
660  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
661  auto o = ops::Identity(s.WithOpName("o"), concat);
662  GrapplerItem item;
663  TF_CHECK_OK(s.ToGraphDef(&item.graph));
664  LayoutOptimizer optimizer;
665  GraphDef output;
666  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
667  NodeMap node_map(&output);
668  auto concat_node = node_map.GetNode("concat");
669  EXPECT_EQ(concat_node->input(0), "split");
670  EXPECT_EQ(concat_node->input(1), "split:1");
671  EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
672  auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
673  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 0);
674}
675
676TEST_F(LayoutOptimizerTest, ConcatDimC) {
677  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
678  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
679  auto axis = ops::Const(s.WithOpName("axis"), 3);
680  auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
681  auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
682  auto o = ops::Identity(s.WithOpName("o"), concat);
683  GrapplerItem item;
684  TF_CHECK_OK(s.ToGraphDef(&item.graph));
685  LayoutOptimizer optimizer;
686  GraphDef output;
687  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
688  NodeMap node_map(&output);
689  auto concat_node = node_map.GetNode("concat");
690  EXPECT_EQ(concat_node->input(0), "split");
691  EXPECT_EQ(concat_node->input(1), "split:1");
692  EXPECT_EQ(concat_node->input(2), "concat-2-LayoutOptimizer");
693  auto concat_dim = node_map.GetNode("concat-2-LayoutOptimizer");
694  EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
695}
696
697TEST_F(LayoutOptimizerTest, Sum) {
698  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
699  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
700  auto reduction_indices =
701      ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
702  auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
703  auto o = ops::Identity(s.WithOpName("o"), sum);
704  GrapplerItem item;
705  TF_CHECK_OK(s.ToGraphDef(&item.graph));
706  LayoutOptimizer optimizer;
707  GraphDef output;
708  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
709  // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
710  // because of the worse performance in some cases.
711  /*
712  NodeMap node_map(&output);
713  auto sum_node = node_map.GetNode("sum");
714  EXPECT_EQ(sum_node->input(0), "Conv2D");
715  EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
716  auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
717  Tensor tensor;
718  EXPECT_TRUE(
719      tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
720  Tensor tensor_expected(DT_INT32, {3});
721  test::FillValues<int>(&tensor_expected, {0, 2, 3});
722  test::ExpectTensorEqual<int>(tensor_expected, tensor);
723  */
724}
725
726TEST_F(LayoutOptimizerTest, MulScalarAnd4D) {
727  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
728  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
729  auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
730  auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv);
731  auto o = ops::Identity(s.WithOpName("o"), mul);
732  GrapplerItem item;
733  TF_CHECK_OK(s.ToGraphDef(&item.graph));
734  LayoutOptimizer optimizer;
735  GraphDef output;
736  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
737  NodeMap node_map(&output);
738  auto mul_node = node_map.GetNode("mul");
739  EXPECT_EQ(mul_node->input(0), "scalar");
740  EXPECT_EQ(mul_node->input(1), "Conv2D");
741}
742
743TEST_F(LayoutOptimizerTest, Mul4DAndScalar) {
744  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
745  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
746  auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
747  auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar);
748  auto o = ops::Identity(s.WithOpName("o"), mul);
749  GrapplerItem item;
750  TF_CHECK_OK(s.ToGraphDef(&item.graph));
751  LayoutOptimizer optimizer;
752  GraphDef output;
753  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
754  NodeMap node_map(&output);
755  auto mul_node = node_map.GetNode("mul");
756  EXPECT_EQ(mul_node->input(0), "Conv2D");
757  EXPECT_EQ(mul_node->input(1), "scalar");
758}
759
760TEST_F(LayoutOptimizerTest, Mul4DAndUnknownRank) {
761  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
762  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
763  auto unknown_rank =
764      ops::Placeholder(s.WithOpName("unknown"), DT_FLOAT,
765                       ops::Placeholder::Shape(PartialTensorShape()));
766  Output c = ops::Const(s.WithOpName("c"), 3.0f, {8, 2, 2, 2});
767  Output mul = ops::Mul(s.WithOpName("mul"), conv, unknown_rank);
768  auto o = ops::AddN(s.WithOpName("o"), {mul, c});
769  GrapplerItem item;
770  TF_CHECK_OK(s.ToGraphDef(&item.graph));
771  LayoutOptimizer optimizer;
772  GraphDef output;
773  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
774  NodeMap node_map(&output);
775  auto mul_node = node_map.GetNode("mul");
776  // Node mul should not be processed by layout optimizer, because one of its
777  // inputs is of unknown rank.
778  EXPECT_EQ(mul_node->input(0),
779            "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
780  EXPECT_EQ(mul_node->input(1), "unknown");
781}
782
783TEST_F(LayoutOptimizerTest, Mul4DAnd4D) {
784  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
785  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
786  auto i = ops::Identity(s.WithOpName("i"), conv);
787  auto mul = ops::Mul(s.WithOpName("mul"), conv, i);
788  auto o = ops::Identity(s.WithOpName("o"), mul);
789  GrapplerItem item;
790  TF_CHECK_OK(s.ToGraphDef(&item.graph));
791  LayoutOptimizer optimizer;
792  GraphDef output;
793  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
794  NodeMap node_map(&output);
795  auto mul_node = node_map.GetNode("mul");
796  EXPECT_EQ(mul_node->input(0), "Conv2D");
797  EXPECT_EQ(mul_node->input(1), "i");
798}
799
800TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
801  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
802  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
803  auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
804  auto mul = ops::Mul(s.WithOpName("mul"), conv, vector);
805  auto o = ops::Identity(s.WithOpName("o"), mul);
806  GrapplerItem item;
807  TF_CHECK_OK(s.ToGraphDef(&item.graph));
808  LayoutOptimizer optimizer;
809  GraphDef output;
810  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
811  NodeMap node_map(&output);
812  auto mul_node = node_map.GetNode("mul");
813  EXPECT_EQ(mul_node->input(0), "Conv2D");
814  EXPECT_EQ(mul_node->input(1), "mul-1-ReshapeNHWCToNCHW-LayoutOptimizer");
815  auto mul_const = node_map.GetNode("mul-1-ReshapeConst-LayoutOptimizer");
816  Tensor tensor;
817  EXPECT_TRUE(
818      tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
819  Tensor tensor_expected(DT_INT32, {4});
820  test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
821  test::ExpectTensorEqual<int>(tensor_expected, tensor);
822}
823
824TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
825  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
826  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
827  auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
828  auto mul = ops::Mul(s.WithOpName("mul"), vector, conv);
829  auto o = ops::Identity(s.WithOpName("o"), mul);
830  GrapplerItem item;
831  TF_CHECK_OK(s.ToGraphDef(&item.graph));
832  LayoutOptimizer optimizer;
833  GraphDef output;
834  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
835  NodeMap node_map(&output);
836  auto mul_node = node_map.GetNode("mul");
837  EXPECT_EQ(mul_node->input(0), "mul-0-ReshapeNHWCToNCHW-LayoutOptimizer");
838  EXPECT_EQ(mul_node->input(1), "Conv2D");
839  auto mul_const = node_map.GetNode("mul-0-ReshapeConst-LayoutOptimizer");
840  Tensor tensor;
841  EXPECT_TRUE(
842      tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
843  Tensor tensor_expected(DT_INT32, {4});
844  test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
845  test::ExpectTensorEqual<int>(tensor_expected, tensor);
846}
847
848TEST_F(LayoutOptimizerTest, SliceConst) {
849  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
850  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
851  auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
852  auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
853  auto slice = ops::Slice(s.WithOpName("slice"), conv, begin, size);
854  auto o = ops::Identity(s.WithOpName("o"), slice);
855  GrapplerItem item;
856  TF_CHECK_OK(s.ToGraphDef(&item.graph));
857  LayoutOptimizer optimizer;
858  GraphDef output;
859  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
860  NodeMap node_map(&output);
861  auto slice_node = node_map.GetNode("slice");
862  EXPECT_EQ(slice_node->input(0), "Conv2D");
863  EXPECT_EQ(slice_node->input(1), "slice-1-LayoutOptimizer");
864  EXPECT_EQ(slice_node->input(2), "slice-2-LayoutOptimizer");
865
866  auto begin_const = node_map.GetNode("slice-1-LayoutOptimizer");
867  Tensor begin_tensor;
868  EXPECT_TRUE(begin_tensor.FromProto(
869      begin_const->mutable_attr()->at({"value"}).tensor()));
870  Tensor begin_tensor_expected(DT_INT32, {4});
871  test::FillValues<int>(&begin_tensor_expected, {0, 1, 2, 3});
872  test::ExpectTensorEqual<int>(begin_tensor_expected, begin_tensor);
873
874  auto size_const = node_map.GetNode("slice-2-LayoutOptimizer");
875  Tensor size_tensor;
876  EXPECT_TRUE(size_tensor.FromProto(
877      size_const->mutable_attr()->at({"value"}).tensor()));
878  Tensor size_tensor_expected(DT_INT32, {4});
879  test::FillValues<int>(&size_tensor_expected, {4, 4, 1, 2});
880  test::ExpectTensorEqual<int>(size_tensor_expected, size_tensor);
881}
882
883TEST_F(LayoutOptimizerTest, SliceNonConst) {
884  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
885  auto conv = SimpleConv2D(&s, 5, 2, "VALID");
886  auto begin = ops::Const(s.WithOpName("begin"), {0, 2, 3, 1}, {4});
887  auto ibegin = ops::Identity(s.WithOpName("ibegin"), begin);
888  auto size = ops::Const(s.WithOpName("size"), {4, 1, 2, 4}, {4});
889  auto isize = ops::Identity(s.WithOpName("isize"), size);
890  auto slice = ops::Slice(s.WithOpName("slice"), conv, ibegin, isize);
891  auto o = ops::Identity(s.WithOpName("o"), slice);
892  GrapplerItem item;
893  TF_CHECK_OK(s.ToGraphDef(&item.graph));
894  LayoutOptimizer optimizer;
895  GraphDef output;
896  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
897  NodeMap node_map(&output);
898  auto slice_node = node_map.GetNode("slice");
899  EXPECT_EQ(slice_node->input(0), "Conv2D");
900  EXPECT_EQ(slice_node->input(1),
901            "slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
902  EXPECT_EQ(slice_node->input(2),
903            "slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
904  auto perm1 = node_map.GetNode("slice-1-VecPermuteNHWCToNCHW-LayoutOptimizer");
905  EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
906  EXPECT_EQ(perm1->input(0), "ibegin");
907  auto perm2 = node_map.GetNode("slice-2-VecPermuteNHWCToNCHW-LayoutOptimizer");
908  EXPECT_EQ(perm1->op(), "DataFormatVecPermute");
909  EXPECT_EQ(perm2->input(0), "isize");
910}
911
912TEST_F(LayoutOptimizerTest, DoNotApplyOptimizerTwice) {
913  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
914  auto scalar =
915      ops::Const(s.WithOpName("AlreadyApplied-LayoutOptimizer"), 3.0f, {});
916  auto mul = ops::Mul(s.WithOpName("mul"), scalar, scalar);
917  auto o = ops::Identity(s.WithOpName("o"), mul);
918  GrapplerItem item;
919  TF_CHECK_OK(s.ToGraphDef(&item.graph));
920  LayoutOptimizer optimizer;
921  GraphDef output;
922  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
923  EXPECT_TRUE(errors::IsInvalidArgument(status));
924}
925
926TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAnd4D) {
927  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
928  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
929  auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, conv});
930  auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
931  GrapplerItem item;
932  TF_CHECK_OK(s.ToGraphDef(&item.graph));
933  LayoutOptimizer optimizer;
934  GraphDef output;
935  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
936  NodeMap node_map(&output);
937  auto shapen_node = node_map.GetNode("shapen");
938  EXPECT_EQ(shapen_node->input(0), "Conv2D");
939  EXPECT_EQ(shapen_node->input(1), "Conv2D");
940  auto add_node = node_map.GetNode("add");
941  EXPECT_EQ(add_node->input(0),
942            "shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
943  EXPECT_EQ(add_node->input(1),
944            "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
945  auto vec_permute1 =
946      node_map.GetNode("shapen-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
947  EXPECT_EQ(vec_permute1->input(0), "shapen");
948  EXPECT_EQ(vec_permute1->op(), "DataFormatVecPermute");
949  auto vec_permute2 =
950      node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
951  EXPECT_EQ(vec_permute2->input(0), "shapen:1");
952  EXPECT_EQ(vec_permute2->op(), "DataFormatVecPermute");
953}
954
955TEST_F(LayoutOptimizerTest, ShapeNWithInputsVectorAnd4D) {
956  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
957  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
958  auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {7});
959  auto shapen = ops::ShapeN(s.WithOpName("shapen"), {vector, conv});
960  auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
961  GrapplerItem item;
962  TF_CHECK_OK(s.ToGraphDef(&item.graph));
963  LayoutOptimizer optimizer;
964  GraphDef output;
965  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
966  NodeMap node_map(&output);
967  auto shapen_node = node_map.GetNode("shapen");
968  EXPECT_EQ(shapen_node->input(0), "vector");
969  EXPECT_EQ(shapen_node->input(1), "Conv2D");
970  auto add_node = node_map.GetNode("add");
971  EXPECT_EQ(add_node->input(0), "shapen");
972  EXPECT_EQ(add_node->input(1),
973            "shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
974  auto vec_permute =
975      node_map.GetNode("shapen-0-1-VecPermuteNCHWToNHWC-LayoutOptimizer");
976  EXPECT_EQ(vec_permute->input(0), "shapen:1");
977  EXPECT_EQ(vec_permute->op(), "DataFormatVecPermute");
978}
979
980TEST_F(LayoutOptimizerTest, ShapeNWithInputs4DAndNoNeedToTransform4D) {
981  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
982  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
983  auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
984  auto i1 = ops::Identity(s.WithOpName("i1"), tensor_4d);
985  Output i2 = ops::Identity(s.WithOpName("i2"), i1);
986  auto shapen = ops::ShapeN(s.WithOpName("shapen"), {conv, i2});
987  auto add = ops::Add(s.WithOpName("add"), shapen[0], shapen[1]);
988  GrapplerItem item;
989  TF_CHECK_OK(s.ToGraphDef(&item.graph));
990  LayoutOptimizer optimizer;
991  GraphDef output;
992  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
993  NodeMap node_map(&output);
994  auto shapen_node = node_map.GetNode("shapen");
995  EXPECT_EQ(shapen_node->input(0), "Conv2D");
996  EXPECT_EQ(shapen_node->input(1), "i2");
997}
998
999TEST_F(LayoutOptimizerTest, Switch) {
1000  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1001  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1002  ops::Variable ctrl(s.WithOpName("ctrl"), {}, DT_BOOL);
1003  auto sw = ops::Switch(s.WithOpName("switch"), conv, ctrl);
1004  auto i1 = ops::Identity(s.WithOpName("i1"), sw.output_true);
1005  auto i2 = ops::Identity(s.WithOpName("i2"), sw.output_false);
1006  GrapplerItem item;
1007  TF_CHECK_OK(s.ToGraphDef(&item.graph));
1008  LayoutOptimizer optimizer;
1009  GraphDef output;
1010  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1011  NodeMap node_map(&output);
1012  auto switch_node = node_map.GetNode("switch");
1013  EXPECT_EQ(switch_node->input(0), "Conv2D");
1014  EXPECT_EQ(switch_node->input(1), "ctrl");
1015  auto i1_node = node_map.GetNode("i1");
1016  auto i2_node = node_map.GetNode("i2");
1017  auto trans1 = node_map.GetNode(i1_node->input(0));
1018  EXPECT_EQ(trans1->input(0), "switch:1");
1019  auto trans2 = node_map.GetNode(i2_node->input(0));
1020  EXPECT_EQ(trans2->input(0), "switch");
1021}
1022
1023TEST_F(LayoutOptimizerTest, MergeBothInputsConvertible) {
1024  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1025  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1026  Output i1 = ops::Identity(s.WithOpName("i1"), conv);
1027  auto merge = ops::Merge(s.WithOpName("merge"), {conv, i1});
1028  auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1029  GrapplerItem item;
1030  TF_CHECK_OK(s.ToGraphDef(&item.graph));
1031  LayoutOptimizer optimizer;
1032  GraphDef output;
1033  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1034  NodeMap node_map(&output);
1035  auto merge_node = node_map.GetNode("merge");
1036  EXPECT_EQ(merge_node->input(0), "Conv2D");
1037  EXPECT_EQ(merge_node->input(1), "i1");
1038  auto i2_node = node_map.GetNode("i2");
1039  EXPECT_EQ(i2_node->input(0), "merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1040  auto transpose =
1041      node_map.GetNode("merge-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1042  EXPECT_EQ(transpose->input(0), "merge");
1043}
1044
1045TEST_F(LayoutOptimizerTest, MergeOneInputNotConvertible) {
1046  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1047  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1048  auto tensor_4d = ops::Const(s.WithOpName("tensor_4d"), 3.0f, {1, 1, 1, 3});
1049  auto merge = ops::Merge(s.WithOpName("merge"), {tensor_4d, conv});
1050  auto i2 = ops::Identity(s.WithOpName("i2"), merge.output);
1051  GrapplerItem item;
1052  TF_CHECK_OK(s.ToGraphDef(&item.graph));
1053  LayoutOptimizer optimizer;
1054  GraphDef output;
1055  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1056  NodeMap node_map(&output);
1057  auto merge_node = node_map.GetNode("merge");
1058  EXPECT_EQ(merge_node->input(0), "tensor_4d");
1059  EXPECT_EQ(merge_node->input(1),
1060            "Conv2D-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1061}
1062
1063TEST_F(LayoutOptimizerTest, Complex) {
1064  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1065  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1066  auto comp = ops::Complex(s.WithOpName("complex"), conv, conv);
1067  auto i = ops::Identity(s.WithOpName("i"), comp);
1068  GrapplerItem item;
1069  TF_CHECK_OK(s.ToGraphDef(&item.graph));
1070  LayoutOptimizer optimizer;
1071  GraphDef output;
1072  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1073  NodeMap node_map(&output);
1074  auto merge_node = node_map.GetNode("complex");
1075  EXPECT_EQ(merge_node->input(0), "Conv2D");
1076  EXPECT_EQ(merge_node->input(1), "Conv2D");
1077  auto trans =
1078      node_map.GetNode("complex-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
1079  EXPECT_EQ(trans->attr().at("T").type(), DT_COMPLEX64);
1080}
1081
1082TEST_F(LayoutOptimizerTest, IdentityNWithInputsVectorAnd4D) {
1083  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
1084  auto conv = SimpleConv2D(&s, 4, 2, "VALID");
1085  auto vector = ops::Const(s.WithOpName("vector"), 3.0f, {2});
1086  auto identity_n = ops::IdentityN(s.WithOpName("identity_n"), {vector, conv});
1087  auto add = ops::Add(s.WithOpName("add"), identity_n[0], identity_n[1]);
1088  GrapplerItem item;
1089  TF_CHECK_OK(s.ToGraphDef(&item.graph));
1090  LayoutOptimizer optimizer;
1091  GraphDef output;
1092  Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
1093  NodeMap node_map(&output);
1094  auto i = node_map.GetNode("identity_n");
1095  EXPECT_EQ(i->input(0), "vector");
1096  EXPECT_EQ(i->input(1), "Conv2D");
1097  auto trans =
1098      node_map.GetNode("identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1099  EXPECT_EQ(trans->input(0), "identity_n:1");
1100  auto add_node = node_map.GetNode("add");
1101  EXPECT_EQ(add_node->input(0), "identity_n");
1102  EXPECT_EQ(add_node->input(1),
1103            "identity_n-0-1-TransposeNCHWToNHWC-LayoutOptimizer");
1104}
1105}  // namespace
1106}  // namespace grappler
1107}  // namespace tensorflow
1108