1c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
3c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerLicensed under the Apache License, Version 2.0 (the "License");
4c247826219dd2541c6aba4578a03a171375d9290Benoit Steineryou may not use this file except in compliance with the License.
5c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerYou may obtain a copy of the License at
6c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
7c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner    http://www.apache.org/licenses/LICENSE-2.0
8c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
9c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerUnless required by applicable law or agreed to in writing, software
10c247826219dd2541c6aba4578a03a171375d9290Benoit Steinerdistributed under the License is distributed on an "AS IS" BASIS,
11c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerSee the License for the specific language governing permissions and
13c247826219dd2541c6aba4578a03a171375d9290Benoit Steinerlimitations under the License.
14c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner==============================================================================*/
15c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
16c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/cc/ops/standard_ops.h"
18c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/framework/node_def.pb.h"
19c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/grappler/grappler_item.h"
20c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
2133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu#include "tensorflow/core/grappler/optimizers/constant_folding.h"
22e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee#include "tensorflow/core/grappler/optimizers/model_pruner.h"
23c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/grappler/utils.h"
24c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/lib/core/status_test_util.h"
25c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/platform/test.h"
26c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
27c247826219dd2541c6aba4578a03a171375d9290Benoit Steinernamespace tensorflow {
28c247826219dd2541c6aba4578a03a171375d9290Benoit Steinernamespace grappler {
29c247826219dd2541c6aba4578a03a171375d9290Benoit Steinernamespace {
30c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
3198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlowerstring OptimizedName(const string& name) {
3298ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  return AddPrefixToNodeName(name, kArithmeticOptimizer);
3398ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower}
3498ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower
3519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlowervoid VerifyGraphsMatch(const GraphDef& original_graph,
3619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower                       const GraphDef& optimized_graph, int line) {
3719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower  EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
3819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower  for (int i = 0; i < original_graph.node_size(); ++i) {
3919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower    const NodeDef& original = original_graph.node(i);
4019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower    const NodeDef& optimized = optimized_graph.node(i);
4119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower    EXPECT_EQ(original.name(), optimized.name()) << line;
4219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower    EXPECT_EQ(original.op(), optimized.op()) << line;
4319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower    EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
4419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower    for (int j = 0; j < original.input_size(); ++j) {
4519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      EXPECT_EQ(original.input(j), optimized.input(j)) << line;
4619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower    }
4719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower  }
4819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower}
4919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower
50c247826219dd2541c6aba4578a03a171375d9290Benoit Steinerclass ArithmeticOptimizerTest : public ::testing::Test {};
51c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
52c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerTEST_F(ArithmeticOptimizerTest, NoOp) {
53c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  // This trivial graph is so basic there's nothing to optimize.
54c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
55c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  GrapplerItem item;
56c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  CHECK(fake_input.NextItem(&item));
57c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
58c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  ArithmeticOptimizer optimizer;
59c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  GraphDef output;
60de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
61de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_EXPECT_OK(status);
6219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower  VerifyGraphsMatch(item.graph, output, __LINE__);
63c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner}
64c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
65c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerTEST_F(ArithmeticOptimizerTest, OpDedupping) {
66c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
67c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2});
68c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2});
6972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  Output div = ops::Div(s.WithOpName("div"), c1, c2);
70c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  GrapplerItem item;
71c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  TF_CHECK_OK(s.ToGraphDef(&item.graph));
7277b60c1ac63d0f188c4108ecb64bbe40004b2b8fA. Unique TensorFlower  item.fetch = {"div"};
73c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
74c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  ArithmeticOptimizer optimizer;
75c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  GraphDef output;
76c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  Status status = optimizer.Optimize(nullptr, item, &output);
77c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  TF_EXPECT_OK(status);
78de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
79de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  item.graph.Swap(&output);
80de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  status = optimizer.Optimize(nullptr, item, &output);
81de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_EXPECT_OK(status);
82c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
83c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  EXPECT_EQ(2, output.node_size());
84c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  const NodeDef& new_c1 = output.node(0);
85c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner  EXPECT_EQ("c1", new_c1.name());
8672d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  const NodeDef& new_div = output.node(1);
8772d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("div", new_div.name());
8872d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ(2, new_div.input_size());
8972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("c1", new_div.input(0));
9072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("c1", new_div.input(1));
91c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner}
92c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner
93cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
94cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
95cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
96cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
97cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
98cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
99cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
100cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
10172d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  Output div = ops::Div(s.WithOpName("div").WithControlDependencies(
10272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower                            {assert1.operation, assert2.operation}),
10372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower                        check1, check2);
104cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  GrapplerItem item;
105cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
10677b60c1ac63d0f188c4108ecb64bbe40004b2b8fA. Unique TensorFlower  item.fetch = {"div"};
107cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower
108cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  ArithmeticOptimizer optimizer;
109cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  GraphDef output;
110cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
111cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  TF_EXPECT_OK(status);
112cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
113cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  item.graph.Swap(&output);
114cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  status = optimizer.Optimize(nullptr, item, &output);
115cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  TF_EXPECT_OK(status);
116cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower
117cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower  EXPECT_EQ(5, output.node_size());
11872d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  const NodeDef& new_div = output.node(3);
11972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ(4, new_div.input_size());
12072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("check1", new_div.input(0));
12172d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("check1", new_div.input(1));
12272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("^assert1", new_div.input(2));
12372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("^assert1", new_div.input(3));
124cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower}
125cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower
126f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
127f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
128f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});
129f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2});
1306ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2);
1316ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1);
13272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2);
133f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  GrapplerItem item;
134f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
13577b60c1ac63d0f188c4108ecb64bbe40004b2b8fA. Unique TensorFlower  item.fetch = {"div"};
136f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower
137f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  ArithmeticOptimizer optimizer;
138f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  GraphDef output;
139f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
140f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  TF_EXPECT_OK(status);
141de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
142de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  item.graph.Swap(&output);
143de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  status = optimizer.Optimize(nullptr, item, &output);
144de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_EXPECT_OK(status);
145f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower
146f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  EXPECT_EQ(4, output.node_size());
147f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  const NodeDef& new_c1 = output.node(0);
148f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  EXPECT_EQ("c1", new_c1.name());
149f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  const NodeDef& new_c2 = output.node(1);
150f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower  EXPECT_EQ("c2", new_c2.name());
1516ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  const NodeDef& new_mul1 = output.node(2);
1526ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  EXPECT_EQ("mul1", new_mul1.name());
1536ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  EXPECT_EQ(2, new_mul1.input_size());
1546ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  EXPECT_EQ("c1", new_mul1.input(0));
1556ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  EXPECT_EQ("c2", new_mul1.input(1));
15672d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  const NodeDef& new_div1 = output.node(3);
15772d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("div1", new_div1.name());
15872d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ(2, new_div1.input_size());
15972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("mul1", new_div1.input(0));
16072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("mul1", new_div1.input(1));
16172d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower}
16272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower
16372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, MulToSquare) {
16472d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
16572d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
16672d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2});
16772d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c);
16872d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  Output id = ops::Identity(s.WithOpName("id"), mul);
16972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  GrapplerItem item;
17072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
17172d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower
17272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  ArithmeticOptimizer optimizer;
17372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  GraphDef output;
17472d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
17572d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  TF_EXPECT_OK(status);
17672d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower
17772d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ(5, output.node_size());
17898ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ("id", output.node(3).name());
17998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("mul_square"), output.node(3).input(0));
18072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("Square", output.node(4).op());
18198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("mul_square"), output.node(4).name());
18272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ(2, output.node(4).input_size());
18372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("c", output.node(4).input(0));
18472d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower  EXPECT_EQ("^d", output.node(4).input(1));
185f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower}
186f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower
187b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) {
188b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
189b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
190b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  Output neg1 = ops::Neg(s.WithOpName("neg1"), c);
191b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1);
192b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2);
193b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1);
194b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  Output id = ops::Identity(s.WithOpName("id"), recip2);
195b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  GrapplerItem item;
196b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
197b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower
198b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  ArithmeticOptimizer optimizer;
199b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  GraphDef output;
200b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
201b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  TF_EXPECT_OK(status);
202b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower
203b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  EXPECT_EQ(6, output.node_size());
204b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  EXPECT_EQ("c", output.node(1).input(0));
205b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  EXPECT_EQ("c", output.node(3).input(0));
206b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower  EXPECT_EQ("c", output.node(5).input(0));
207b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower}
208b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower
2096a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) {
2106a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2116a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
2126a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
2136a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  Output id1 = ops::Identity(s.WithOpName("id1"), recip1);
2146a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
2156a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze);
2166a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  Output id2 = ops::Identity(s.WithOpName("id2"), recip2);
2176a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  GrapplerItem item;
2186a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
2196a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower
2206a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  ArithmeticOptimizer optimizer;
2216a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  GraphDef output;
2226a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
2236a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  TF_EXPECT_OK(status);
224de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
225de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  item.graph.Swap(&output);
226de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  status = optimizer.Optimize(nullptr, item, &output);
227de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_EXPECT_OK(status);
2286a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower
2296a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  EXPECT_EQ(6, output.node_size());
2306a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  EXPECT_EQ("squeeze", output.node(5).input(0));
2316a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower  EXPECT_EQ("c", output.node(2).input(0));
2326a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower}
2336a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower
234f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) {
235f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
236f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
237f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c);
238f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output id1 = ops::Identity(s.WithOpName("id1"), recip1);
239f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1);
240f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output recip2 = ops::Reciprocal(
241f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      s.WithOpName("recip2").WithControlDependencies(squeeze), c);
242f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output id2 = ops::Identity(s.WithOpName("id2"), recip2);
243f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  GrapplerItem item;
244f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
245f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower
246f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  ArithmeticOptimizer optimizer;
247f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  GraphDef output;
248f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
249f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  TF_EXPECT_OK(status);
250f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower
251f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  // The optimizer should be a noop.
252f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ(item.graph.node_size(), output.node_size());
253f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  for (int i = 0; i < item.graph.node_size(); ++i) {
254f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    const NodeDef& original = item.graph.node(i);
255f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    const NodeDef& optimized = output.node(i);
256f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    EXPECT_EQ(original.name(), optimized.name());
257f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    EXPECT_EQ(original.op(), optimized.op());
258f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    EXPECT_EQ(original.input_size(), optimized.input_size());
259f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    for (int j = 0; j < original.input_size(); ++j) {
260f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ(original.input(j), optimized.input(j));
261f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    }
262f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  }
263f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower}
264f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower
265de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
2666ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
2676ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
2686ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  Output add = ops::Add(s.WithOpName("add"), x, x);
2696ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  Output id = ops::Identity(s.WithOpName("id"), add);
2706ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower
2716ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  GrapplerItem item;
2726ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
2736ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower
2746ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  ArithmeticOptimizer optimizer;
2756ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  GraphDef output;
2766ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
2776ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  TF_EXPECT_OK(status);
278de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
279de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  item.graph.Swap(&output);
280de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  status = optimizer.Optimize(nullptr, item, &output);
281de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_EXPECT_OK(status);
2826ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower
2836ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  EXPECT_EQ(5, output.node_size());
2846ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  const NodeDef& new_const = output.node(3);
28598ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("add_const"), new_const.name());
286f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ("^x", new_const.input(0));
287f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ(std::string("\0\0\0@", 4),
288f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower            new_const.attr().at("value").tensor().tensor_content());
2896ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  const NodeDef& new_mul = output.node(4);
29098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("add_mul"), new_mul.name());
29198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0));
2926ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  EXPECT_EQ("x", new_mul.input(1));
2936ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  const NodeDef& new_id = output.node(2);
2946ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower  EXPECT_EQ("id", new_id.name());
29598ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0));
2966ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower}
2976ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower
298f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
299f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
300f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2});
301f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2});
302f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x);
303f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Output id = ops::Identity(s.WithOpName("id"), add);
304f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower
305f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  GrapplerItem item;
306f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
307f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower
308f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  ArithmeticOptimizer optimizer;
309f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  GraphDef output;
310f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
311f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  TF_EXPECT_OK(status);
312f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
313f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  item.graph.Swap(&output);
314f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  status = optimizer.Optimize(nullptr, item, &output);
315f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  TF_EXPECT_OK(status);
316f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower
317f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ(6, output.node_size());
318f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  const NodeDef& new_const = output.node(4);
31998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("add_const"), new_const.name());
320f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ("^x", new_const.input(0));
321f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ(std::string("\0\0\0@", 4),
322f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower            new_const.attr().at("value").tensor().tensor_content());
323f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  const NodeDef& new_mul = output.node(5);
32498ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("add_mul"), new_mul.name());
32598ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0));
326f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ("x", new_mul.input(1));
327f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ("^y", new_mul.input(2));
328f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  const NodeDef& new_id = output.node(3);
329f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ("id", new_id.name());
33098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0));
331f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower}
332f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower
333de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
334de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  // Test case from b/69059093.
335de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
336de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10}));
337de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  Output add = ops::Add(s.WithOpName("Add"), p, p);
338de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  Output add1 = ops::Add(s.WithOpName("Add_1"), p, p);
339de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1);
340de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1);
341de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5);
342de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  Output id = ops::Identity(s.WithOpName("id"), add6);
343de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower
344de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  GrapplerItem item;
345de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
346f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  const std::vector<string> devices{
347f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1",
348f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      "/device:CPU:0", "/device:CPU:0", "/device:CPU:0",
349f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  };
350f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  for (int i = 0; i < item.graph.node_size(); ++i) {
351f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    item.graph.mutable_node(i)->set_device(devices[i]);
352f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  }
3534ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower  ArithmeticOptimizer optimizer;
354de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  GraphDef output;
355de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
356de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_EXPECT_OK(status);
357de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
358de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  item.graph.Swap(&output);
359de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  status = optimizer.Optimize(nullptr, item, &output);
360de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_EXPECT_OK(status);
361de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower
362f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  EXPECT_EQ(17, output.node_size());
363f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  // The graph gets optimized to
364f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  // Mul(p,
365f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  //     Add(Add(Const(2), Const(2)),
366f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  //         Add(Const(2), Const(2))))
36798ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(17, output.node_size());
368f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  for (const auto& node : output.node()) {
369f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    if ("id" == node.name()) {
370f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ(1, node.input_size());
37198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      EXPECT_EQ(OptimizedName("Add_6_hoist_mul"), node.input(0));
37298ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower    } else if (OptimizedName("Add_6_hoist_mul") == node.name()) {
373f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("Mul", node.op());
374f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ(2, node.input_size());
375f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("Placeholder", node.input(0));
37698ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      EXPECT_EQ(OptimizedName("Add_6_hoist_add"), node.input(1));
37798ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower    } else if (OptimizedName("Add_6_hoist_add") == node.name()) {
378f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("Add", node.op());
379f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ(3, node.input_size());
38098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      EXPECT_EQ(OptimizedName("Add_4_hoist_add"), node.input(0));
38198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      EXPECT_EQ(OptimizedName("Add_5_hoist_add"), node.input(1));
382f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("^Placeholder", node.input(2));
38398ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower    } else if (OptimizedName("Add_4_hoist_add") == node.name()) {
384f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("Add", node.op());
385f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ(3, node.input_size());
38698ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      EXPECT_EQ(OptimizedName("Add_const"), node.input(0));
38798ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1));
388f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("^Placeholder", node.input(2));
38998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower    } else if (OptimizedName("Add_5_hoist_add") == node.name()) {
390f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("Add", node.op());
391f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ(3, node.input_size());
39298ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      EXPECT_EQ(OptimizedName("Add_const"), node.input(0));
39398ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1));
394f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("^Placeholder", node.input(2));
39598ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower    } else if (OptimizedName("Add_const") == node.name()) {
396f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("Const", node.op());
397f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ(1, node.input_size());
398f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("^Placeholder", node.input(0));
39998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower    } else if (OptimizedName("Add_1_const") == node.name()) {
400f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("Const", node.op());
401f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ(1, node.input_size());
402f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower      EXPECT_EQ("^Placeholder", node.input(0));
403f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower    }
404f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower  }
405de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower}
406de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower
407de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, HoistFactor) {
40819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower  for (bool matching_shapes : {true, false}) {
40919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower    for (bool use_addn : {true, false}) {
41019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      tensorflow::Scope s = tensorflow::Scope::NewRootScope();
41119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
41219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
41319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      Output y2 = matching_shapes
41419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower                      ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
41519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower                      : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
41619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
41719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
41819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      Output id =
41919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower          use_addn ? ops::Identity(s.WithOpName("id"),
42019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower                                   ops::AddN(s.WithOpName("add"), {mul1, mul2}))
42119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower                   : ops::Identity(s.WithOpName("id"),
42219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower                                   ops::Add(s.WithOpName("add"), mul1, mul2));
42319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower
42419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      GrapplerItem item;
42519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      TF_CHECK_OK(s.ToGraphDef(&item.graph));
4264ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower      ArithmeticOptimizer optimizer;
42719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      GraphDef output;
42819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      Status status = optimizer.Optimize(nullptr, item, &output);
42919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      TF_EXPECT_OK(status);
43019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      // Run the optimizer twice to make sure the rewrite is idempotent.
43119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      item.graph.Swap(&output);
43219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      status = optimizer.Optimize(nullptr, item, &output);
43319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      TF_EXPECT_OK(status);
43419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower
43519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      if (use_addn && !matching_shapes) {
43619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        VerifyGraphsMatch(item.graph, output, __LINE__);
43719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      } else {
43819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        EXPECT_EQ(9, output.node_size());
43919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        const NodeDef& new_add = output.node(8);
44019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name());
44119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        EXPECT_EQ("y1", new_add.input(0));
44219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        EXPECT_EQ("y2", new_add.input(1));
44319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        const NodeDef& new_mul = output.node(7);
44419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name());
44519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        EXPECT_EQ("x", new_mul.input(0));
44619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1));
44719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        const NodeDef& new_id = output.node(6);
44819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        EXPECT_EQ("id", new_id.name());
44919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower        EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0));
45019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower      }
45119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower    }
45219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower  }
4536ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower}
4546ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower
45546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
45646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
45746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
45846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
45946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output z = ops::Complex(s.WithOpName("z"), re, im);
46046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
46146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output conj = ops::Conj(s.WithOpName("conj"), z);
46246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm);
46346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  GrapplerItem item;
46446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
46546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower
46646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  ArithmeticOptimizer optimizer;
46746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  GraphDef output;
46846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
46946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  TF_EXPECT_OK(status);
470de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
471de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  item.graph.Swap(&output);
472de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  status = optimizer.Optimize(nullptr, item, &output);
473de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_EXPECT_OK(status);
47446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower
47546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ(7, output.node_size());
47698ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("trans_fused"), output.node(6).name());
47746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ("ConjugateTranspose", output.node(6).op());
47846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ("z", output.node(6).input(0));
47946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ("perm", output.node(6).input(1));
48046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower}
48146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower
48246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) {
48346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
48446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
48546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
48646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output z = ops::Complex(s.WithOpName("z"), re, im);
48746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
48846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output conj = ops::Conj(s.WithOpName("conj"), z);
48946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output transp =
49046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower      ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm);
49146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  GrapplerItem item;
49246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
49346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower
49446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  ArithmeticOptimizer optimizer;
49546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  GraphDef output;
49646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
49746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  TF_EXPECT_OK(status);
49846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower
49946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ(7, output.node_size());
50098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("conjugate_trans_fused"), output.node(6).name());
50146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ("Transpose", output.node(6).op());
50246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ("z", output.node(6).input(0));
50346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ("perm", output.node(6).input(1));
50446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower}
50546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower
50646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) {
50746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
50846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2});
50946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2});
51046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output z = ops::Complex(s.WithOpName("z"), re, im);
51146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
51246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output trans = ops::Transpose(s.WithOpName("trans"), z, perm);
51346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Output conj = ops::Conj(s.WithOpName("conj"), trans);
51446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  GrapplerItem item;
51546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
51646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower
51746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  ArithmeticOptimizer optimizer;
51846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  GraphDef output;
51946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
52046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  TF_EXPECT_OK(status);
521de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
522de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  item.graph.Swap(&output);
523de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  status = optimizer.Optimize(nullptr, item, &output);
524de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower  TF_EXPECT_OK(status);
52546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower
52646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ(7, output.node_size());
52798ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("conj_fused"), output.node(6).name());
52846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ("ConjugateTranspose", output.node(6).op());
52946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ("z", output.node(6).input(0));
53046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower  EXPECT_EQ("perm", output.node(6).input(1));
53146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower}
53246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower
5339fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
5349fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  for (const string matmul_type : {"MatMul", "SparseMatMul", "BatchMatMul"}) {
5359fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    tensorflow::Scope s = tensorflow::Scope::NewRootScope();
5369fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
5379fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
5389fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
5399fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
5409fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
5419fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    if (matmul_type == "MatMul") {
5429fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      Output matmul = ops::MatMul(s.WithOpName("matmul"), trans_a, trans_b);
5439fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    } else if (matmul_type == "SparseMatMul") {
5449fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      Output matmul =
5459fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower          ops::SparseMatMul(s.WithOpName("matmul"), trans_a, trans_b);
5469fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    } else if (matmul_type == "BatchMatMul") {
5479fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      Output matmul =
5489fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower          ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
5499fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    }
5509fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    GrapplerItem item;
5519fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    TF_CHECK_OK(s.ToGraphDef(&item.graph));
5529fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower
5539fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    ArithmeticOptimizer optimizer;
5549fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    GraphDef output;
5559fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    Status status = optimizer.Optimize(nullptr, item, &output);
5569fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    TF_EXPECT_OK(status);
557de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower    // Run the optimizer twice to make sure the rewrite is idempotent.
558de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower    item.graph.Swap(&output);
559de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower    status = optimizer.Optimize(nullptr, item, &output);
560de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower    TF_EXPECT_OK(status);
5619fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower
5629fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    EXPECT_EQ(7, output.node_size());
56398ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower    EXPECT_EQ(OptimizedName("matmul_fused"), output.node(6).name());
5649fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    EXPECT_EQ("a", output.node(6).input(0));
5659fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    EXPECT_EQ("b", output.node(6).input(1));
5669fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    if (matmul_type == "BatchMatMul") {
5679fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      EXPECT_TRUE(output.node(6).attr().at("adj_x").b());
5689fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      EXPECT_TRUE(output.node(6).attr().at("adj_y").b());
5699fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    } else {
5709fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      EXPECT_TRUE(output.node(6).attr().at("transpose_a").b());
5719fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      EXPECT_TRUE(output.node(6).attr().at("transpose_b").b());
5729fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower    }
5739fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  }
5749fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower}
5759fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower
5769fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) {
5779fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
5789fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output re_a =
5799fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
5809fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output im_a =
5819fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
5829fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output re_b =
5839fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
5849fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output im_b =
5859fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower      ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2});
5869fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output a = ops::Complex(s.WithOpName("a"), re_a, im_a);
5879fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output b = ops::Complex(s.WithOpName("b"), re_b, im_b);
5889fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2});
5899fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm);
5909fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm);
5919fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b);
5929fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  GrapplerItem item;
5939fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
5949fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower
5959fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  ArithmeticOptimizer optimizer;
5969fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  GraphDef output;
5979fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  Status status = optimizer.Optimize(nullptr, item, &output);
5989fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  TF_EXPECT_OK(status);
5999fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower
6009fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  EXPECT_EQ(11, output.node_size());
60198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  EXPECT_EQ(OptimizedName("matmul_fused"), output.node(10).name());
6029fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  EXPECT_EQ("a", output.node(10).input(0));
6039fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  EXPECT_EQ("b", output.node(10).input(1));
6049fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  EXPECT_TRUE(output.node(10).attr().at("adj_x").b());
6059fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower  EXPECT_TRUE(output.node(10).attr().at("adj_y").b());
6069fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower}
6079fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower
608d871fdce70acc165e652c66638943b40ffcda7a3Jingyue WuTEST_F(ArithmeticOptimizerTest, IdentityReshape) {
609d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
610d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output inputs =
611d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu      ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
612d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output inputs_shape = ops::Shape(s, inputs);
613d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  // The target shape of the reshape is the concatenation of `batch_size` and
614d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  // [3,28,28].
615d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}),
616d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu                                 ops::Const(s, {1}, {1}));
617d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output target_shape = ops::Concat(
618d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu      s.WithOpName("target_shape"),
619d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu      {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {}));
620d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output reshape = ops::Reshape(s, inputs, target_shape);
621d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
622d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
623d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  GrapplerItem item;
624d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  item.fetch = {"outputs"};
625d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
626d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
627d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  GraphDef output;
6284ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
629d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
630f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
631d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
632d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
633d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  EXPECT_EQ(0, std::count_if(
634d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu                   output.node().begin(), output.node().end(),
635d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu                   [](const NodeDef& node) { return node.op() == "Reshape"; }));
636d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu}
637d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
638d871fdce70acc165e652c66638943b40ffcda7a3Jingyue WuTEST_F(ArithmeticOptimizerTest, NotIdentityReshape) {
639d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can
640d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  // be from [4,3,28,28] to [8,6,28,28].
641d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
642d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output inputs =
643d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu      ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28}));
644d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4}));
645d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
646d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
647d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  GrapplerItem item;
648d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  item.fetch = {"outputs"};
649d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
650d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
651d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  GraphDef output;
6524ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
653d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
654f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
655d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
656d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
657d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  EXPECT_EQ(1, std::count_if(
658d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu                   output.node().begin(), output.node().end(),
659d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu                   [](const NodeDef& node) { return node.op() == "Reshape"; }));
660d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu}
661d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
662d871fdce70acc165e652c66638943b40ffcda7a3Jingyue WuTEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) {
663d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
664d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output inputs =
665d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu      ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3}));
666d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2}));
667d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), reshape);
668d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
669d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  GrapplerItem item;
670d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  item.fetch = {"outputs"};
671d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
672d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
673d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  GraphDef output;
6744ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
675d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
676f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
677d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
678d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
679d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu  EXPECT_EQ(1, std::count_if(
680d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu                   output.node().begin(), output.node().end(),
681d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu                   [](const NodeDef& node) { return node.op() == "Reshape"; }));
682d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu}
683d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu
684b002c8b7d28f8327bac5db2efcd7924694beefafJingyue WuTEST_F(ArithmeticOptimizerTest, CombineReshapes) {
685b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two
686b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  // reshapes should be combined.
687b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
688b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  Output nchw_vect_c =
689b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu      ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_INT8,
690b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu                       ops::Placeholder::Shape({8, 3, 28, 28, 4}));
691b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  Output transpose =
692b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu      ops::Transpose(s.WithOpName("transpose"), nchw_vect_c,
693b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu                     ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5}));
694b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  Output nhwc = ops::Reshape(
695b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu      s.WithOpName("nhwc"), transpose,
696b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu      ops::Const(s.WithOpName("nhwc_shape"), {8, 28, 28, 12}, {4}));
697b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  Output flatten = ops::Reshape(
698b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu      s.WithOpName("flatten"), nhwc,
699b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu      ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2}));
700b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), flatten);
701b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu
702b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  GrapplerItem item;
703b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  item.fetch = {"outputs"};
704b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
705b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu
706b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  GraphDef output;
707b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
708b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu
709f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
710b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
711b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu
712b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu  EXPECT_EQ(1, std::count_if(
713b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu                   output.node().begin(), output.node().end(),
714b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu                   [](const NodeDef& node) { return node.op() == "Reshape"; }));
715b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu}
716b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu
7179d8346a1204d05b2ab16c169a6a6077167fe162aJingyue WuTEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) {
7189d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0");
7199d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  Output nhwc_uint8 =
7209d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
7219d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT);
7229d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  Output nchw_fp32 =
7239d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4}));
7249d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32);
7259d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
7269d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  GrapplerItem item;
7279d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  item.fetch = {"outputs"};
7289d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
7299d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
7309d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  GraphDef output;
7319d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
7329d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
733f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
7349d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
7359d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
7369d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  const NodeDef* transpose_node = nullptr;
7379d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  for (const NodeDef& node : output.node()) {
7389d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu    if (node.op() == "Transpose") {
7399d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      EXPECT_EQ(transpose_node, nullptr);
7409d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
7419d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      transpose_node = &node;
7429d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu    }
7439d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  }
7449d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  EXPECT_NE(transpose_node, nullptr);
7459d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
7469d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  for (const NodeDef& node : output.node()) {
7479d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu    if (node.op() == "Cast") {
7489d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      EXPECT_EQ(NodeName(node.input(0)), transpose_node->name());
7499d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu    }
7509d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  }
7519d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu}
7529d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
7539d8346a1204d05b2ab16c169a6a6077167fe162aJingyue WuTEST_F(ArithmeticOptimizerTest, NoReorderTransposeCast) {
7549d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0");
7559d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  Output nhwc_fp32 =
7569d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3}));
7579d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8);
7589d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  Output nchw_uint8 =
7599d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4}));
7609d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8);
7619d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
7629d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  GrapplerItem item;
7639d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  item.fetch = {"outputs"};
7649d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
7659d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
7669d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  GraphDef output;
7679d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
7689d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
769f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
7709d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
7719d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
7729d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  int num_transposes = 0;
7739d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  for (const NodeDef& node : output.node()) {
7749d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu    if (node.op() == "Transpose") {
7759d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      EXPECT_EQ(DT_UINT8, node.attr().at("T").type());
7769d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      EXPECT_EQ(node.input(0), "Cast");
7779d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu      ++num_transposes;
7789d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu    }
7799d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  }
7809d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu  EXPECT_EQ(1, num_transposes);
7819d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu}
7829d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu
783e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong LeeTEST_F(ArithmeticOptimizerTest, RemoveInverseTransposes) {
784e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
785e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output inputs_shape =
786e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee      ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
787e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output inputs =
788e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee      ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
789e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4});
790e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4});
791e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1);
792e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output transpose2 =
793e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee      ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2);
794e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
795e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
796e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  GrapplerItem item;
797e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  item.fetch = {"outputs"};
798e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  TF_CHECK_OK(s.ToGraphDef(&item.graph));
799e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
800e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  GraphDef output;
801e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
802e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
803f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
804e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
805e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
806e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  std::set<string> nodes_after_optimization;
807e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  for (const NodeDef& node : output.node()) {
808e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee    nodes_after_optimization.insert(node.name());
809e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  }
810e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  EXPECT_EQ(nodes_after_optimization,
811e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee            std::set<string>({"inputs_shape", "inputs", "outputs"}));
812e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee}
813e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
814e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue WuTEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) {
815e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
816e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  Output inputs_shape =
817e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu      ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4});
818e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
819e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu                                   ops::Placeholder::Shape({8, 12, 28, 28}));
820e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output;
821e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4});
822e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4});
823e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  Output branch0 = split[0];
824e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2);
825e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  Output branch2 = split[2];
826e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1));
827e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), concat);
828e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu
829e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  GrapplerItem item;
830e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  item.fetch = {"outputs"};
831e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
832e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu
833e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  GraphDef output;
834e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
835e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu
836f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
837e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
838e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu
839e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  for (const NodeDef& node : output.node()) {
840e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu    if (node.op() == "Concat") {
841e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu      EXPECT_EQ(node.input(0), "Split");
842e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu      EXPECT_EQ(node.input(1), "Split:1");
843e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu      EXPECT_EQ(node.input(2), "Split:2");
844e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu    }
845e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu  }
846e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu}
847e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu
84827df639673ae2bfe63b82862008da9bec488f0dbJingyue WuTEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) {
84927df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
85027df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  Output inputs =
85127df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu      ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3}));
85227df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0}));
85327df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0}));
85427df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  Output outputs =
85527df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu      ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2),
85627df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu                    ops::Const(s.WithOpName("outputs_const"), 1.0f));
85727df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu
85827df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  GrapplerItem item;
85927df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  item.fetch = {"outputs"};
86027df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
86127df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  GraphDef output;
86227df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
863f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
86427df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
86527df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu
86627df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  NodeMap node_map(&output);
86727df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  const NodeDef* outputs_node = node_map.GetNode("outputs");
86827df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  EXPECT_EQ(2, outputs_node->input_size());
86927df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  EXPECT_EQ(outputs_node->input(0), "outputs_const");
87027df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu  EXPECT_EQ(outputs_node->input(1), "^Placeholder");
87127df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu}
87227df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu
873e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong LeeTEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) {
874e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
875e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output inputs_shape =
876e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee      ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4});
877e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output inputs =
878e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee      ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT);
879e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4});
880e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm);
881e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output transpose2 =
882e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee      ops::Transpose(s.WithOpName("transpose2"), transpose1, perm);
883e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2);
884e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
885e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  GrapplerItem item;
886e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  item.fetch = {"outputs"};
887e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  TF_CHECK_OK(s.ToGraphDef(&item.graph));
888e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
889e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  GraphDef output;
890e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
891e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
892f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
893e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
894e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
895e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee  EXPECT_EQ(6, output.node_size());
896e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee}
897e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee
898f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue WuTEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) {
899f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
900f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
901f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                                   ops::Placeholder::Shape({8, 28, 28, 3}));
902f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
903f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output scaled_inputs =
904f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu      ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
905f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output perm_nhwc_to_nchw =
906f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu      ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
907f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
908f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                                      scaled_inputs, perm_nhwc_to_nchw);
909f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output weights = ops::Const(s.WithOpName("weights"),
910f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                              Input::Initializer(127.0f, {5, 5, 3, 16}));
911f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output conv =
912f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu      ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
913f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                  "VALID", ops::Conv2D::DataFormat("NCHW"));
914f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
915f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
916f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  GrapplerItem item;
917f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  item.fetch = {"outputs"};
918f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
919f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
920f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  GraphDef output;
921f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
922f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
923f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
924f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
925f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
926f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  NodeMap node_map(&output);
927f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  // `conv` is now a folded convolution with scaled weights.
928f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
929f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  CHECK_EQ(node_map.GetNode(NodeName(folded_conv->input(1)))->op(), "Mul");
930f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  // Its input should be a transpose of `inputs`.
931f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  const NodeDef* transpose = node_map.GetNode(NodeName(folded_conv->input(0)));
932f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  CHECK_EQ(NodeName(transpose->input(0)), inputs.node()->name());
933f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu}
934f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
935f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue WuTEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) {
936f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
937f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
938f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                                   ops::Placeholder::Shape({8, 28, 28, 3}));
939f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
940f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output scaled_inputs =
941f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu      ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
942f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output perm_nhwc_to_nchw =
943f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu      ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4});
944f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"),
945f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                                      scaled_inputs, perm_nhwc_to_nchw);
946f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output weights = ops::Const(s.WithOpName("weights"),
947f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                              Input::Initializer(127.0f, {5, 5, 3, 16}));
948f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output conv =
949f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu      ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1},
950f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                  "VALID", ops::Conv2D::DataFormat("NCHW"));
951f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
952f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
953f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28});
954f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0,
955f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu         inputs_nchw_tensor.tensor_data().size());
956f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
957f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  GrapplerItem item;
958f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  item.fetch = {"outputs"};
959f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  item.feed = {{"inputs_nchw", inputs_nchw_tensor}};
960f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
961f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
962f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  GraphDef output;
963f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
964f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
965f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
966f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
967f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
968f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  NodeMap node_map(&output);
969f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  const NodeDef* inputs_nchw_node_def =
970f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu      node_map.GetNode(inputs_nchw.node()->name());
971f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)),
972f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu            scaled_inputs.node()->name());
973f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu}
974f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
975f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue WuTEST_F(ArithmeticOptimizerTest, FoldMulToConv) {
976f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
977f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT,
978f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                                   ops::Placeholder::Shape({8, 28, 28, 28, 3}));
979f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {});
980f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output scaled_inputs =
981f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu      ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale);
982f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output weights = ops::Const(s.WithOpName("weights"),
983f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                              Input::Initializer(127.0f, {5, 5, 5, 3, 16}));
984f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights,
985f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu                            {1, 1, 1, 1, 1}, "VALID");
986f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
987f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
988f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  GrapplerItem item;
989f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  item.fetch = {"outputs"};
990f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
991f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
992f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  GraphDef output;
993f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
994f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
995f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
996f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
997f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
998f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  NodeMap node_map(&output);
999f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  // `conv` is now a folded convolution on `inputs` and scaled weights.
1000f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  const NodeDef* folded_conv = node_map.GetNode(conv.node()->name());
1001f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  CHECK_EQ(inputs.node()->name(), NodeName(folded_conv->input(0)));
1002f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu  CHECK_EQ(node_map.GetNode(NodeName(folded_conv->input(1)))->op(), "Mul");
1003f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu}
1004f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu
100533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue WuTEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) {
100633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  // This unit test exercises two optimizations, folding mul into conv, and
100733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  // reordering cast and transpose.
100833d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  //
100933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  //   Conv2D(Transpose(Mul(Cast(I), S)), W)
101033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  //     =>
101133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  //   Conv2D(Transpose(Cast(I)), W*S)
101233d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  //     =>
101333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  //   Conv2D(Cast(Transpose(I)), W*S)
101433d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0");
101533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  Output inputs =
101633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu      ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3}));
101733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  Output cast = ops::Cast(s, inputs, DT_FLOAT);
101833d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f));
101933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  Output transpose =
102033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu      ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2}));
102133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  Output weights = ops::Const(s.WithOpName("weights"),
102233d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu                              Input::Initializer(127.0f, {5, 5, 3, 16}));
102333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID",
102433d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu                            ops::Conv2D::DataFormat("NCHW"));
102533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), conv);
102633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu
102733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  GrapplerItem item;
102833d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  item.fetch = {"outputs"};
102933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
103033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu
103133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  GraphDef output;
103233d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
103333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu
1034f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  // Run the optimizer twice to make sure the rewrite is idempotent.
1035f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
1036f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1037f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower
1038f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
103933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  TF_EXPECT_OK(
104033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu      ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output));
104133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu
1042f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
104333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
104433d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu
104533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  NodeMap node_map(&output);
104633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  const NodeDef* inputs_node = CHECK_NOTNULL(node_map.GetNode("Placeholder"));
104733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  const NodeDef* transpose_node =
104898ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      CHECK_NOTNULL(node_map.GetNode(OptimizedName("Transpose_uint8")));
104998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower  const NodeDef* cast_node =
1050f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower      CHECK_NOTNULL(node_map.GetNode(OptimizedName("Cast_float")));
105133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  const NodeDef* weights_node =
105298ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D")));
105333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D"));
105433d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu
105533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  EXPECT_EQ(output.node_size(), 7);
105633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  EXPECT_EQ(transpose_node->input(0), inputs_node->name());
105733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  EXPECT_EQ(cast_node->input(0), transpose_node->name());
105833d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  EXPECT_EQ(conv_node->input(0), cast_node->name());
105933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu  EXPECT_EQ(conv_node->input(1), weights_node->name());
106033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu}
106133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu
106251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) {
106351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  // This unit test exercises optimization of folding mul into conv for
106451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  // multiple nodes in the graph.
106551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0");
106651889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower
106751889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  GrapplerItem item;
106851889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  Output conv[2];
106951889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower
107051889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  for (int i = 0; i < 2; ++i) {
107151889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower    Output inputs =
107251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower        ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28}));
107351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower    Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f));
107451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower    Output weights = ops::Const(s.WithOpName("weights"),
107551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower                                Input::Initializer(127.0f, {5, 5, 3, 16}));
107651889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower    conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID",
107751889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower                          ops::Conv2D::DataFormat("NCHW"));
107851889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  }
107951889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]);
108051889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower
108151889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  item.fetch = {"outputs"};
108251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  TF_CHECK_OK(s.ToGraphDef(&item.graph));
108351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower
108451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  GraphDef output;
108551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
108651889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower
1087f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
108851889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  TF_EXPECT_OK(
108951889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower      ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output));
109051889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower
1091f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
109251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
109351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower
109451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  NodeMap node_map(&output);
109551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  const NodeDef* weights_node =
109698ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D")));
109751889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D"));
109851889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower
109951889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  const NodeDef* weights_node_1 =
110098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower      CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D_1")));
110151889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  const NodeDef* conv_node_1 = CHECK_NOTNULL(node_map.GetNode("Conv2D_1"));
110251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  EXPECT_EQ(conv_node->input(1), weights_node->name());
110351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower  EXPECT_EQ(conv_node_1->input(1), weights_node_1->name());
110451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower}
110551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower
11068ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue WuTEST_F(ArithmeticOptimizerTest, CombineBitcasts) {
11078ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
11088ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output inputs =
11098ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu      ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({2, 3}));
11108ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
11118ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
11128ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
11138ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
11148ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  GrapplerItem item;
11158ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  item.fetch = {"outputs"};
11168ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
11178ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
11188ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  GraphDef output;
11198ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1120f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
11218ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
11228ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
11238ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  EXPECT_EQ(1, std::count_if(
11248ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu                   output.node().begin(), output.node().end(),
11258ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu                   [](const NodeDef& node) { return node.op() == "Bitcast"; }));
11268ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu}
11278ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
11288ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue WuTEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) {
11298ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
11308ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
11318ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output bc1 = ops::Bitcast(s, inputs, DT_QINT8);
11328ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output bc2 = ops::Bitcast(s, bc1, DT_INT8);
11338ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), bc2);
11348ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
11358ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  GrapplerItem item;
11368ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  item.fetch = {"outputs"};
11378ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
11388ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  GraphDef output;
11398ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1140f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
11418ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
11428ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
11438ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  EXPECT_EQ(0, std::count_if(
11448ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu                   output.node().begin(), output.node().end(),
11458ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu                   [](const NodeDef& node) { return node.op() == "Bitcast"; }));
11468ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu}
11478ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
11488ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue WuTEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
11498ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
11508ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3}));
11518ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output cast = ops::Cast(s, inputs, DT_INT8);
11528ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  Output outputs = ops::Identity(s.WithOpName("outputs"), cast);
11538ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
11548ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  GrapplerItem item;
11558ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  item.fetch = {"outputs"};
11568ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  TF_CHECK_OK(s.ToGraphDef(&item.graph));
11578ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  GraphDef output;
11588ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output));
1159f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower  item.graph.Swap(&output);
11608ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output));
11618ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
11628ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu  EXPECT_EQ(0, std::count_if(
11638ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu                   output.node().begin(), output.node().end(),
11648ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu                   [](const NodeDef& node) { return node.op() == "Cast"; }));
11658ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu}
11668ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu
1167c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner}  // namespace
1168c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner}  // namespace grappler
1169c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner}  // namespace tensorflow
1170