1c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 3c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerLicensed under the Apache License, Version 2.0 (the "License"); 4c247826219dd2541c6aba4578a03a171375d9290Benoit Steineryou may not use this file except in compliance with the License. 5c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerYou may obtain a copy of the License at 6c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 7c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner http://www.apache.org/licenses/LICENSE-2.0 8c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 9c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerUnless required by applicable law or agreed to in writing, software 10c247826219dd2541c6aba4578a03a171375d9290Benoit Steinerdistributed under the License is distributed on an "AS IS" BASIS, 11c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerSee the License for the specific language governing permissions and 13c247826219dd2541c6aba4578a03a171375d9290Benoit Steinerlimitations under the License. 14c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner==============================================================================*/ 15c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 16c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" 17c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/cc/ops/standard_ops.h" 18c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/framework/node_def.pb.h" 19c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/grappler/grappler_item.h" 20c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" 2133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu#include "tensorflow/core/grappler/optimizers/constant_folding.h" 22e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee#include "tensorflow/core/grappler/optimizers/model_pruner.h" 23c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/grappler/utils.h" 24c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/lib/core/status_test_util.h" 25c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner#include "tensorflow/core/platform/test.h" 26c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 27c247826219dd2541c6aba4578a03a171375d9290Benoit Steinernamespace tensorflow { 28c247826219dd2541c6aba4578a03a171375d9290Benoit Steinernamespace grappler { 29c247826219dd2541c6aba4578a03a171375d9290Benoit Steinernamespace { 30c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 3198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlowerstring OptimizedName(const string& name) { 3298ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower return AddPrefixToNodeName(name, kArithmeticOptimizer); 3398ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower} 3498ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower 3519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlowervoid VerifyGraphsMatch(const GraphDef& original_graph, 3619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower const GraphDef& optimized_graph, int line) { 3719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line; 3819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower for (int i = 0; i < original_graph.node_size(); ++i) { 3919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower const NodeDef& original = original_graph.node(i); 4019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower const NodeDef& optimized = optimized_graph.node(i); 4119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(original.name(), optimized.name()) << line; 4219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(original.op(), optimized.op()) << line; 4319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(original.input_size(), optimized.input_size()) << line; 4419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower for (int j = 0; j < original.input_size(); ++j) { 4519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(original.input(j), optimized.input(j)) << line; 4619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower } 4719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower } 4819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower} 4919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower 50c247826219dd2541c6aba4578a03a171375d9290Benoit Steinerclass ArithmeticOptimizerTest : public ::testing::Test {}; 51c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 52c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerTEST_F(ArithmeticOptimizerTest, NoOp) { 53c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner // This trivial graph is so basic there's nothing to optimize. 54c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); 55c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner GrapplerItem item; 56c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner CHECK(fake_input.NextItem(&item)); 57c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 58c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner ArithmeticOptimizer optimizer; 59c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner GraphDef output; 60de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 61de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 6219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower VerifyGraphsMatch(item.graph, output, __LINE__); 63c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner} 64c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 65c247826219dd2541c6aba4578a03a171375d9290Benoit SteinerTEST_F(ArithmeticOptimizerTest, OpDedupping) { 66c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 67c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2}); 68c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2}); 6972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower Output div = ops::Div(s.WithOpName("div"), c1, c2); 70c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner GrapplerItem item; 71c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner TF_CHECK_OK(s.ToGraphDef(&item.graph)); 7277b60c1ac63d0f188c4108ecb64bbe40004b2b8fA. Unique TensorFlower item.fetch = {"div"}; 73c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 74c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner ArithmeticOptimizer optimizer; 75c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner GraphDef output; 76c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner Status status = optimizer.Optimize(nullptr, item, &output); 77c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner TF_EXPECT_OK(status); 78de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 79de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower item.graph.Swap(&output); 80de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 81de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 82c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 83c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner EXPECT_EQ(2, output.node_size()); 84c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner const NodeDef& new_c1 = output.node(0); 85c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner EXPECT_EQ("c1", new_c1.name()); 8672d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower const NodeDef& new_div = output.node(1); 8772d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("div", new_div.name()); 8872d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ(2, new_div.input_size()); 8972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("c1", new_div.input(0)); 9072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("c1", new_div.input(1)); 91c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner} 92c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner 93cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) { 94cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 95cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({})); 96cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2}); 97cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo"); 98cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo"); 99cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c}); 100cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c}); 10172d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower Output div = ops::Div(s.WithOpName("div").WithControlDependencies( 10272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower {assert1.operation, assert2.operation}), 10372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower check1, check2); 104cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower GrapplerItem item; 105cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 10677b60c1ac63d0f188c4108ecb64bbe40004b2b8fA. Unique TensorFlower item.fetch = {"div"}; 107cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower 108cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower ArithmeticOptimizer optimizer; 109cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower GraphDef output; 110cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 111cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower TF_EXPECT_OK(status); 112cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 113cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower item.graph.Swap(&output); 114cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 115cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower TF_EXPECT_OK(status); 116cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower 117cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower EXPECT_EQ(5, output.node_size()); 11872d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower const NodeDef& new_div = output.node(3); 11972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ(4, new_div.input_size()); 12072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("check1", new_div.input(0)); 12172d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("check1", new_div.input(1)); 12272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("^assert1", new_div.input(2)); 12372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("^assert1", new_div.input(3)); 124cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower} 125cd8ced7a2d48574908d2c9b7127960078cf41690A. Unique TensorFlower 126f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { 127f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 128f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2}); 129f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2}); 1306ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2); 1316ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1); 13272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2); 133f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower GrapplerItem item; 134f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 13577b60c1ac63d0f188c4108ecb64bbe40004b2b8fA. Unique TensorFlower item.fetch = {"div"}; 136f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower 137f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower ArithmeticOptimizer optimizer; 138f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower GraphDef output; 139f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 140f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower TF_EXPECT_OK(status); 141de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 142de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower item.graph.Swap(&output); 143de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 144de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 145f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower 146f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower EXPECT_EQ(4, output.node_size()); 147f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower const NodeDef& new_c1 = output.node(0); 148f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower EXPECT_EQ("c1", new_c1.name()); 149f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower const NodeDef& new_c2 = output.node(1); 150f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower EXPECT_EQ("c2", new_c2.name()); 1516ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower const NodeDef& new_mul1 = output.node(2); 1526ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower EXPECT_EQ("mul1", new_mul1.name()); 1536ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower EXPECT_EQ(2, new_mul1.input_size()); 1546ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower EXPECT_EQ("c1", new_mul1.input(0)); 1556ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower EXPECT_EQ("c2", new_mul1.input(1)); 15672d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower const NodeDef& new_div1 = output.node(3); 15772d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("div1", new_div1.name()); 15872d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ(2, new_div1.input_size()); 15972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("mul1", new_div1.input(0)); 16072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("mul1", new_div1.input(1)); 16172d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower} 16272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower 16372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, MulToSquare) { 16472d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 16572d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); 16672d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower Output d = ops::Const(s.WithOpName("d"), {3.0f, 4.0f}, {1, 2}); 16772d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower Output mul = ops::Mul(s.WithControlDependencies(d).WithOpName("mul"), c, c); 16872d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower Output id = ops::Identity(s.WithOpName("id"), mul); 16972d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower GrapplerItem item; 17072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 17172d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower 17272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower ArithmeticOptimizer optimizer; 17372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower GraphDef output; 17472d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 17572d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower TF_EXPECT_OK(status); 17672d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower 17772d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ(5, output.node_size()); 17898ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ("id", output.node(3).name()); 17998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("mul_square"), output.node(3).input(0)); 18072d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("Square", output.node(4).op()); 18198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("mul_square"), output.node(4).name()); 18272d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ(2, output.node(4).input_size()); 18372d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("c", output.node(4).input(0)); 18472d72194c1d06e66f7893915a804932b56bef5dbA. Unique TensorFlower EXPECT_EQ("^d", output.node(4).input(1)); 185f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower} 186f17f389d88d2441302825e3afa5209fb3426002bA. Unique TensorFlower 187b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsReal) { 188b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 189b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); 190b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower Output neg1 = ops::Neg(s.WithOpName("neg1"), c); 191b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower Output neg2 = ops::Neg(s.WithOpName("neg2"), neg1); 192b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), neg2); 193b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), recip1); 194b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower Output id = ops::Identity(s.WithOpName("id"), recip2); 195b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower GrapplerItem item; 196b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 197b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower 198b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower ArithmeticOptimizer optimizer; 199b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower GraphDef output; 200b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 201b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower TF_EXPECT_OK(status); 202b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower 203b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower EXPECT_EQ(6, output.node_size()); 204b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower EXPECT_EQ("c", output.node(1).input(0)); 205b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower EXPECT_EQ("c", output.node(3).input(0)); 206b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower EXPECT_EQ("c", output.node(5).input(0)); 207b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower} 208b46c196e9d8fa58821e3e269babe1df58d5db050A. Unique TensorFlower 2096a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithChain) { 2106a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 2116a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); 2126a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); 2136a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower Output id1 = ops::Identity(s.WithOpName("id1"), recip1); 2146a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); 2156a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower Output recip2 = ops::Reciprocal(s.WithOpName("recip2"), squeeze); 2166a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower Output id2 = ops::Identity(s.WithOpName("id2"), recip2); 2176a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower GrapplerItem item; 2186a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 2196a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower 2206a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower ArithmeticOptimizer optimizer; 2216a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower GraphDef output; 2226a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 2236a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower TF_EXPECT_OK(status); 224de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 225de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower item.graph.Swap(&output); 226de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 227de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 2286a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower 2296a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower EXPECT_EQ(6, output.node_size()); 2306a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower EXPECT_EQ("squeeze", output.node(5).input(0)); 2316a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower EXPECT_EQ("c", output.node(2).input(0)); 2326a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower} 2336a9af6e58e8a903a5a837882d92d8079b88338dcA. Unique TensorFlower 234f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, SimplifyInvolutionsWithControlChain) { 235f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 236f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); 237f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output recip1 = ops::Reciprocal(s.WithOpName("recip1"), c); 238f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output id1 = ops::Identity(s.WithOpName("id1"), recip1); 239f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output squeeze = ops::Squeeze(s.WithOpName("squeeze"), id1); 240f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output recip2 = ops::Reciprocal( 241f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower s.WithOpName("recip2").WithControlDependencies(squeeze), c); 242f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output id2 = ops::Identity(s.WithOpName("id2"), recip2); 243f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower GrapplerItem item; 244f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 245f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower 246f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower ArithmeticOptimizer optimizer; 247f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower GraphDef output; 248f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 249f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower TF_EXPECT_OK(status); 250f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower 251f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower // The optimizer should be a noop. 252f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(item.graph.node_size(), output.node_size()); 253f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower for (int i = 0; i < item.graph.node_size(); ++i) { 254f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower const NodeDef& original = item.graph.node(i); 255f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower const NodeDef& optimized = output.node(i); 256f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(original.name(), optimized.name()); 257f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(original.op(), optimized.op()); 258f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(original.input_size(), optimized.input_size()); 259f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower for (int j = 0; j < original.input_size(); ++j) { 260f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(original.input(j), optimized.input(j)); 261f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower } 262f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower } 263f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower} 264f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower 265de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { 2666ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 2676ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); 2686ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower Output add = ops::Add(s.WithOpName("add"), x, x); 2696ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower Output id = ops::Identity(s.WithOpName("id"), add); 2706ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower 2716ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower GrapplerItem item; 2726ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 2736ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower 2746ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower ArithmeticOptimizer optimizer; 2756ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower GraphDef output; 2766ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 2776ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower TF_EXPECT_OK(status); 278de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 279de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower item.graph.Swap(&output); 280de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 281de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 2826ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower 2836ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower EXPECT_EQ(5, output.node_size()); 2846ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower const NodeDef& new_const = output.node(3); 28598ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_const"), new_const.name()); 286f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("^x", new_const.input(0)); 287f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(std::string("\0\0\0@", 4), 288f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower new_const.attr().at("value").tensor().tensor_content()); 2896ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower const NodeDef& new_mul = output.node(4); 29098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_mul"), new_mul.name()); 29198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0)); 2926ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower EXPECT_EQ("x", new_mul.input(1)); 2936ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower const NodeDef& new_id = output.node(2); 2946ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower EXPECT_EQ("id", new_id.name()); 29598ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0)); 2966ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower} 2976ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower 298f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { 299f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 300f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output y = ops::Const(s.WithOpName("y"), {1.0f, 2.0f}, {1, 2}); 301f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output x = ops::Const(s.WithOpName("x"), {3.0f, 4.0f}, {1, 2}); 302f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output add = ops::Add(s.WithOpName("add").WithControlDependencies(y), x, x); 303f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Output id = ops::Identity(s.WithOpName("id"), add); 304f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower 305f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower GrapplerItem item; 306f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 307f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower 308f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower ArithmeticOptimizer optimizer; 309f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower GraphDef output; 310f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 311f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower TF_EXPECT_OK(status); 312f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 313f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower item.graph.Swap(&output); 314f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 315f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower TF_EXPECT_OK(status); 316f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower 317f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(6, output.node_size()); 318f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower const NodeDef& new_const = output.node(4); 31998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_const"), new_const.name()); 320f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("^x", new_const.input(0)); 321f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(std::string("\0\0\0@", 4), 322f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower new_const.attr().at("value").tensor().tensor_content()); 323f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower const NodeDef& new_mul = output.node(5); 32498ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_mul"), new_mul.name()); 32598ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_const"), new_mul.input(0)); 326f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("x", new_mul.input(1)); 327f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("^y", new_mul.input(2)); 328f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower const NodeDef& new_id = output.node(3); 329f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("id", new_id.name()); 33098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_mul"), new_id.input(0)); 331f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower} 332f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower 333de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { 334de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower // Test case from b/69059093. 335de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 336de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower Output p = ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({10, 10})); 337de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower Output add = ops::Add(s.WithOpName("Add"), p, p); 338de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower Output add1 = ops::Add(s.WithOpName("Add_1"), p, p); 339de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower Output add4 = ops::Add(s.WithOpName("Add_4"), add, add1); 340de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower Output add5 = ops::Add(s.WithOpName("Add_5"), add, add1); 341de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower Output add6 = ops::Add(s.WithOpName("Add_6"), add4, add5); 342de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower Output id = ops::Identity(s.WithOpName("id"), add6); 343de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower 344de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower GrapplerItem item; 345de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 346f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower const std::vector<string> devices{ 347f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower "/device:CPU:0", "/device:GPU:0", "/device:CPU:0", "/device:GPU:1", 348f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower "/device:CPU:0", "/device:CPU:0", "/device:CPU:0", 349f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower }; 350f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower for (int i = 0; i < item.graph.node_size(); ++i) { 351f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower item.graph.mutable_node(i)->set_device(devices[i]); 352f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower } 3534ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower ArithmeticOptimizer optimizer; 354de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower GraphDef output; 355de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 356de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 357de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 358de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower item.graph.Swap(&output); 359de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 360de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 361de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower 362f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(17, output.node_size()); 363f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower // The graph gets optimized to 364f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower // Mul(p, 365f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower // Add(Add(Const(2), Const(2)), 366f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower // Add(Const(2), Const(2)))) 36798ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(17, output.node_size()); 368f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower for (const auto& node : output.node()) { 369f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower if ("id" == node.name()) { 370f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(1, node.input_size()); 37198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("Add_6_hoist_mul"), node.input(0)); 37298ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower } else if (OptimizedName("Add_6_hoist_mul") == node.name()) { 373f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("Mul", node.op()); 374f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(2, node.input_size()); 375f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("Placeholder", node.input(0)); 37698ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("Add_6_hoist_add"), node.input(1)); 37798ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower } else if (OptimizedName("Add_6_hoist_add") == node.name()) { 378f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("Add", node.op()); 379f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(3, node.input_size()); 38098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("Add_4_hoist_add"), node.input(0)); 38198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("Add_5_hoist_add"), node.input(1)); 382f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("^Placeholder", node.input(2)); 38398ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower } else if (OptimizedName("Add_4_hoist_add") == node.name()) { 384f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("Add", node.op()); 385f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(3, node.input_size()); 38698ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("Add_const"), node.input(0)); 38798ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1)); 388f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("^Placeholder", node.input(2)); 38998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower } else if (OptimizedName("Add_5_hoist_add") == node.name()) { 390f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("Add", node.op()); 391f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(3, node.input_size()); 39298ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("Add_const"), node.input(0)); 39398ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1)); 394f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("^Placeholder", node.input(2)); 39598ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower } else if (OptimizedName("Add_const") == node.name()) { 396f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("Const", node.op()); 397f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(1, node.input_size()); 398f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("^Placeholder", node.input(0)); 39998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower } else if (OptimizedName("Add_1_const") == node.name()) { 400f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("Const", node.op()); 401f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ(1, node.input_size()); 402f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower EXPECT_EQ("^Placeholder", node.input(0)); 403f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower } 404f5669d905a28893c71ff44245da6ed5e13d55d1cA. Unique TensorFlower } 405de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower} 406de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower 407de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, HoistFactor) { 40819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower for (bool matching_shapes : {true, false}) { 40919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower for (bool use_addn : {true, false}) { 41019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 41119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); 41219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2}); 41319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower Output y2 = matching_shapes 41419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2}) 41519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1}); 41619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1); 41719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x); 41819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower Output id = 41919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower use_addn ? ops::Identity(s.WithOpName("id"), 42019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower ops::AddN(s.WithOpName("add"), {mul1, mul2})) 42119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower : ops::Identity(s.WithOpName("id"), 42219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower ops::Add(s.WithOpName("add"), mul1, mul2)); 42319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower 42419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower GrapplerItem item; 42519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 4264ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower ArithmeticOptimizer optimizer; 42719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower GraphDef output; 42819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 42919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower TF_EXPECT_OK(status); 43019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 43119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower item.graph.Swap(&output); 43219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 43319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower TF_EXPECT_OK(status); 43419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower 43519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower if (use_addn && !matching_shapes) { 43619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower VerifyGraphsMatch(item.graph, output, __LINE__); 43719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower } else { 43819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(9, output.node_size()); 43919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower const NodeDef& new_add = output.node(8); 44019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name()); 44119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ("y1", new_add.input(0)); 44219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ("y2", new_add.input(1)); 44319f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower const NodeDef& new_mul = output.node(7); 44419f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name()); 44519f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ("x", new_mul.input(0)); 44619f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1)); 44719f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower const NodeDef& new_id = output.node(6); 44819f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ("id", new_id.name()); 44919f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0)); 45019f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower } 45119f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower } 45219f62f62e5dab41b62b60ac66e7d07c09d55e17aA. Unique TensorFlower } 4536ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower} 4546ace5e0494d8142dc67ca0714893afc716125917A. Unique TensorFlower 45546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) { 45646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 45746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2}); 45846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2}); 45946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output z = ops::Complex(s.WithOpName("z"), re, im); 46046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); 46146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output conj = ops::Conj(s.WithOpName("conj"), z); 46246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output transp = ops::Transpose(s.WithOpName("trans"), conj, perm); 46346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower GrapplerItem item; 46446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 46546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower 46646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower ArithmeticOptimizer optimizer; 46746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower GraphDef output; 46846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 46946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower TF_EXPECT_OK(status); 470de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 471de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower item.graph.Swap(&output); 472de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 473de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 47446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower 47546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ(7, output.node_size()); 47698ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("trans_fused"), output.node(6).name()); 47746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ("ConjugateTranspose", output.node(6).op()); 47846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ("z", output.node(6).input(0)); 47946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ("perm", output.node(6).input(1)); 48046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower} 48146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower 48246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FuseConjAndConjugateTranspose) { 48346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 48446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2}); 48546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2}); 48646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output z = ops::Complex(s.WithOpName("z"), re, im); 48746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); 48846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output conj = ops::Conj(s.WithOpName("conj"), z); 48946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output transp = 49046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower ops::ConjugateTranspose(s.WithOpName("conjugate_trans"), conj, perm); 49146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower GrapplerItem item; 49246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 49346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower 49446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower ArithmeticOptimizer optimizer; 49546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower GraphDef output; 49646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 49746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower TF_EXPECT_OK(status); 49846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower 49946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ(7, output.node_size()); 50098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("conjugate_trans_fused"), output.node(6).name()); 50146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ("Transpose", output.node(6).op()); 50246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ("z", output.node(6).input(0)); 50346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ("perm", output.node(6).input(1)); 50446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower} 50546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower 50646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FuseTransposeAndConj) { 50746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 50846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output re = ops::Const(s.WithOpName("re"), {1.0, 2.0, 3.0, 4.0}, {2, 2}); 50946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output im = ops::Const(s.WithOpName("im"), {5.0, 6.0, 7.0, 8.0}, {2, 2}); 51046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output z = ops::Complex(s.WithOpName("z"), re, im); 51146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); 51246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output trans = ops::Transpose(s.WithOpName("trans"), z, perm); 51346ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Output conj = ops::Conj(s.WithOpName("conj"), trans); 51446ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower GrapplerItem item; 51546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 51646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower 51746ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower ArithmeticOptimizer optimizer; 51846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower GraphDef output; 51946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 52046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower TF_EXPECT_OK(status); 521de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 522de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower item.graph.Swap(&output); 523de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 524de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 52546ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower 52646ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ(7, output.node_size()); 52798ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("conj_fused"), output.node(6).name()); 52846ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ("ConjugateTranspose", output.node(6).op()); 52946ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ("z", output.node(6).input(0)); 53046ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower EXPECT_EQ("perm", output.node(6).input(1)); 53146ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower} 53246ffa99df62b3ecdab65f9bbf202921205d59e68A. Unique TensorFlower 5339fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) { 5349fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower for (const string matmul_type : {"MatMul", "SparseMatMul", "BatchMatMul"}) { 5359fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 5369fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output a = ops::Const(s.WithOpName("a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); 5379fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output b = ops::Const(s.WithOpName("b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2}); 5389fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); 5399fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm); 5409fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm); 5419fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower if (matmul_type == "MatMul") { 5429fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output matmul = ops::MatMul(s.WithOpName("matmul"), trans_a, trans_b); 5439fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower } else if (matmul_type == "SparseMatMul") { 5449fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output matmul = 5459fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower ops::SparseMatMul(s.WithOpName("matmul"), trans_a, trans_b); 5469fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower } else if (matmul_type == "BatchMatMul") { 5479fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output matmul = 5489fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b); 5499fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower } 5509fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower GrapplerItem item; 5519fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 5529fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower 5539fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower ArithmeticOptimizer optimizer; 5549fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower GraphDef output; 5559fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 5569fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower TF_EXPECT_OK(status); 557de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 558de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower item.graph.Swap(&output); 559de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower status = optimizer.Optimize(nullptr, item, &output); 560de5d8eb503234a2bcce5141b564337feb26928efA. Unique TensorFlower TF_EXPECT_OK(status); 5619fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower 5629fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_EQ(7, output.node_size()); 56398ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("matmul_fused"), output.node(6).name()); 5649fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_EQ("a", output.node(6).input(0)); 5659fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_EQ("b", output.node(6).input(1)); 5669fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower if (matmul_type == "BatchMatMul") { 5679fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_TRUE(output.node(6).attr().at("adj_x").b()); 5689fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_TRUE(output.node(6).attr().at("adj_y").b()); 5699fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower } else { 5709fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_TRUE(output.node(6).attr().at("transpose_a").b()); 5719fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_TRUE(output.node(6).attr().at("transpose_b").b()); 5729fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower } 5739fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower } 5749fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower} 5759fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower 5769fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, FoldConjugateTransposeIntoBatchMatMul) { 5779fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 5789fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output re_a = 5799fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower ops::Const(s.WithOpName("re_a"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); 5809fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output im_a = 5819fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower ops::Const(s.WithOpName("im_a"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2}); 5829fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output re_b = 5839fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower ops::Const(s.WithOpName("re_b"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2}); 5849fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output im_b = 5859fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower ops::Const(s.WithOpName("im_b"), {-5.0f, -6.0f, -7.0f, -8.0f}, {2, 2}); 5869fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output a = ops::Complex(s.WithOpName("a"), re_a, im_a); 5879fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output b = ops::Complex(s.WithOpName("b"), re_b, im_b); 5889fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output perm = ops::Const(s.WithOpName("perm"), {1, 0}, {2}); 5899fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output trans_a = ops::ConjugateTranspose(s.WithOpName("trans_a"), a, perm); 5909fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output trans_b = ops::ConjugateTranspose(s.WithOpName("trans_b"), b, perm); 5919fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Output matmul = ops::BatchMatMul(s.WithOpName("matmul"), trans_a, trans_b); 5929fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower GrapplerItem item; 5939fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 5949fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower 5959fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower ArithmeticOptimizer optimizer; 5969fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower GraphDef output; 5979fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower Status status = optimizer.Optimize(nullptr, item, &output); 5989fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower TF_EXPECT_OK(status); 5999fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower 6009fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_EQ(11, output.node_size()); 60198ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower EXPECT_EQ(OptimizedName("matmul_fused"), output.node(10).name()); 6029fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_EQ("a", output.node(10).input(0)); 6039fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_EQ("b", output.node(10).input(1)); 6049fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_TRUE(output.node(10).attr().at("adj_x").b()); 6059fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower EXPECT_TRUE(output.node(10).attr().at("adj_y").b()); 6069fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower} 6079fd88af1a86249f1959fdd0279e4f5263bef678dA. Unique TensorFlower 608d871fdce70acc165e652c66638943b40ffcda7a3Jingyue WuTEST_F(ArithmeticOptimizerTest, IdentityReshape) { 609d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 610d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output inputs = 611d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28})); 612d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output inputs_shape = ops::Shape(s, inputs); 613d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu // The target shape of the reshape is the concatenation of `batch_size` and 614d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu // [3,28,28]. 615d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output batch_size = ops::Slice(s, inputs_shape, ops::Const(s, {0}, {1}), 616d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu ops::Const(s, {1}, {1})); 617d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output target_shape = ops::Concat( 618d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu s.WithOpName("target_shape"), 619d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu {batch_size, ops::Const(s, {3, 28, 28}, {3})}, ops::Const(s, {0}, {})); 620d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output reshape = ops::Reshape(s, inputs, target_shape); 621d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); 622d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 623d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu GrapplerItem item; 624d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu item.fetch = {"outputs"}; 625d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 626d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 627d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu GraphDef output; 6284ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 629d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 630f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 631d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 632d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 633d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu EXPECT_EQ(0, std::count_if( 634d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu output.node().begin(), output.node().end(), 635d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu [](const NodeDef& node) { return node.op() == "Reshape"; })); 636d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu} 637d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 638d871fdce70acc165e652c66638943b40ffcda7a3Jingyue WuTEST_F(ArithmeticOptimizerTest, NotIdentityReshape) { 639d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu // Reshape from [-1,3,28,28] to [8,-1,28,28] is not identity, because it can 640d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu // be from [4,3,28,28] to [8,6,28,28]. 641d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 642d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output inputs = 643d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({-1, 3, 28, 28})); 644d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output reshape = ops::Reshape(s, inputs, ops::Const(s, {8, -1, 28, 28}, {4})); 645d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); 646d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 647d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu GrapplerItem item; 648d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu item.fetch = {"outputs"}; 649d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 650d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 651d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu GraphDef output; 6524ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 653d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 654f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 655d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 656d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 657d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu EXPECT_EQ(1, std::count_if( 658d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu output.node().begin(), output.node().end(), 659d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu [](const NodeDef& node) { return node.op() == "Reshape"; })); 660d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu} 661d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 662d871fdce70acc165e652c66638943b40ffcda7a3Jingyue WuTEST_F(ArithmeticOptimizerTest, NotIdentityReshapeTooManyUnknownDimSizes) { 663d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 664d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output inputs = 665d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({4, 3})); 666d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output reshape = ops::Reshape(s, inputs, ops::Const(s, {-1, -1}, {2})); 667d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), reshape); 668d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 669d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu GrapplerItem item; 670d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu item.fetch = {"outputs"}; 671d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 672d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 673d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu GraphDef output; 6744ba3147461f2cd1b73029f986cf806b33d0ce290A. Unique TensorFlower TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 675d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 676f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 677d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 678d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 679d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu EXPECT_EQ(1, std::count_if( 680d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu output.node().begin(), output.node().end(), 681d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu [](const NodeDef& node) { return node.op() == "Reshape"; })); 682d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu} 683d871fdce70acc165e652c66638943b40ffcda7a3Jingyue Wu 684b002c8b7d28f8327bac5db2efcd7924694beefafJingyue WuTEST_F(ArithmeticOptimizerTest, CombineReshapes) { 685b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu // Converts an NCHW_VECT_C tensor to NHWC and then flattens it to 2D. The two 686b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu // reshapes should be combined. 687b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 688b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu Output nchw_vect_c = 689b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu ops::Placeholder(s.WithOpName("nchw_vect_c"), DT_INT8, 690b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu ops::Placeholder::Shape({8, 3, 28, 28, 4})); 691b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu Output transpose = 692b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu ops::Transpose(s.WithOpName("transpose"), nchw_vect_c, 693b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu ops::Const(s.WithOpName("perm"), {0, 2, 3, 1, 4}, {5})); 694b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu Output nhwc = ops::Reshape( 695b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu s.WithOpName("nhwc"), transpose, 696b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu ops::Const(s.WithOpName("nhwc_shape"), {8, 28, 28, 12}, {4})); 697b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu Output flatten = ops::Reshape( 698b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu s.WithOpName("flatten"), nhwc, 699b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu ops::Const(s.WithOpName("flatten_shape"), {8, 28 * 28 * 12}, {2})); 700b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), flatten); 701b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu 702b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu GrapplerItem item; 703b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu item.fetch = {"outputs"}; 704b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 705b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu 706b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu GraphDef output; 707b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 708b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu 709f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 710b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 711b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu 712b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu EXPECT_EQ(1, std::count_if( 713b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu output.node().begin(), output.node().end(), 714b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu [](const NodeDef& node) { return node.op() == "Reshape"; })); 715b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu} 716b002c8b7d28f8327bac5db2efcd7924694beefafJingyue Wu 7179d8346a1204d05b2ab16c169a6a6077167fe162aJingyue WuTEST_F(ArithmeticOptimizerTest, ReorderTransposeCast) { 7189d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); 7199d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu Output nhwc_uint8 = 7209d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3})); 7219d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu Output nhwc_fp32 = ops::Cast(s, nhwc_uint8, DT_FLOAT); 7229d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu Output nchw_fp32 = 7239d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu ops::Transpose(s, nhwc_fp32, ops::Const(s, {0, 3, 1, 2}, {4})); 7249d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_fp32); 7259d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 7269d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu GrapplerItem item; 7279d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu item.fetch = {"outputs"}; 7289d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 7299d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 7309d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu GraphDef output; 7319d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 7329d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 733f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 7349d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 7359d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 7369d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu const NodeDef* transpose_node = nullptr; 7379d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu for (const NodeDef& node : output.node()) { 7389d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu if (node.op() == "Transpose") { 7399d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu EXPECT_EQ(transpose_node, nullptr); 7409d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu EXPECT_EQ(DT_UINT8, node.attr().at("T").type()); 7419d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu transpose_node = &node; 7429d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu } 7439d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu } 7449d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu EXPECT_NE(transpose_node, nullptr); 7459d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 7469d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu for (const NodeDef& node : output.node()) { 7479d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu if (node.op() == "Cast") { 7489d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu EXPECT_EQ(NodeName(node.input(0)), transpose_node->name()); 7499d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu } 7509d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu } 7519d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu} 7529d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 7539d8346a1204d05b2ab16c169a6a6077167fe162aJingyue WuTEST_F(ArithmeticOptimizerTest, NoReorderTransposeCast) { 7549d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); 7559d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu Output nhwc_fp32 = 7569d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 28, 28, 3})); 7579d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu Output nhwc_uint8 = ops::Cast(s, nhwc_fp32, DT_UINT8); 7589d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu Output nchw_uint8 = 7599d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu ops::Transpose(s, nhwc_uint8, ops::Const(s, {0, 3, 1, 2}, {4})); 7609d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), nchw_uint8); 7619d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 7629d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu GrapplerItem item; 7639d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu item.fetch = {"outputs"}; 7649d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 7659d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 7669d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu GraphDef output; 7679d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 7689d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 769f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 7709d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 7719d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 7729d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu int num_transposes = 0; 7739d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu for (const NodeDef& node : output.node()) { 7749d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu if (node.op() == "Transpose") { 7759d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu EXPECT_EQ(DT_UINT8, node.attr().at("T").type()); 7769d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu EXPECT_EQ(node.input(0), "Cast"); 7779d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu ++num_transposes; 7789d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu } 7799d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu } 7809d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu EXPECT_EQ(1, num_transposes); 7819d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu} 7829d8346a1204d05b2ab16c169a6a6077167fe162aJingyue Wu 783e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong LeeTEST_F(ArithmeticOptimizerTest, RemoveInverseTransposes) { 784e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 785e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output inputs_shape = 786e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4}); 787e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output inputs = 788e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT); 789e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output perm1 = ops::Const(s.WithOpName("perm1"), {0, 2, 3, 1}, {4}); 790e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output perm2 = ops::Const(s.WithOpName("perm2"), {0, 3, 1, 2}, {4}); 791e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm1); 792e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output transpose2 = 793e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee ops::Transpose(s.WithOpName("transpose2"), transpose1, perm2); 794e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2); 795e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 796e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee GrapplerItem item; 797e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee item.fetch = {"outputs"}; 798e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee TF_CHECK_OK(s.ToGraphDef(&item.graph)); 799e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 800e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee GraphDef output; 801e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 802e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 803f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 804e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 805e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 806e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee std::set<string> nodes_after_optimization; 807e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee for (const NodeDef& node : output.node()) { 808e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee nodes_after_optimization.insert(node.name()); 809e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee } 810e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee EXPECT_EQ(nodes_after_optimization, 811e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee std::set<string>({"inputs_shape", "inputs", "outputs"})); 812e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee} 813e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 814e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue WuTEST_F(ArithmeticOptimizerTest, RemoveInverseTransposesMultipleOutputs) { 815e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 816e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu Output inputs_shape = 817e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu ops::Const(s.WithOpName("inputs_shape"), {8, 9, 28, 28}, {4}); 818e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT, 819e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu ops::Placeholder::Shape({8, 12, 28, 28})); 820e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu OutputList split = ops::Split(s, ops::Const(s, 1), inputs, 3).output; 821e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu Output perm1 = ops::Const(s, {0, 2, 3, 1}, {4}); 822e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu Output perm2 = ops::Const(s, {0, 3, 1, 2}, {4}); 823e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu Output branch0 = split[0]; 824e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu Output branch1 = ops::Transpose(s, ops::Transpose(s, split[1], perm1), perm2); 825e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu Output branch2 = split[2]; 826e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu Output concat = ops::Concat(s, {branch0, branch1, branch2}, ops::Const(s, 1)); 827e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), concat); 828e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu 829e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu GrapplerItem item; 830e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu item.fetch = {"outputs"}; 831e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 832e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu 833e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu GraphDef output; 834e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 835e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu 836f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 837e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 838e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu 839e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu for (const NodeDef& node : output.node()) { 840e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu if (node.op() == "Concat") { 841e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu EXPECT_EQ(node.input(0), "Split"); 842e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu EXPECT_EQ(node.input(1), "Split:1"); 843e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu EXPECT_EQ(node.input(2), "Split:2"); 844e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu } 845e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu } 846e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu} 847e11b9fd32eb5b8f1eb9b8a30dbb08fc1f83fc1ddJingyue Wu 84827df639673ae2bfe63b82862008da9bec488f0dbJingyue WuTEST_F(ArithmeticOptimizerTest, RemoveTransposesWithControlDependency) { 84927df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 85027df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu Output inputs = 85127df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({2, 3})); 85227df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu Output transpose1 = ops::Transpose(s, inputs, ops::Const(s, {1, 0})); 85327df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu Output transpose2 = ops::Transpose(s, transpose1, ops::Const(s, {1, 0})); 85427df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu Output outputs = 85527df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu ops::Identity(s.WithOpName("outputs").WithControlDependencies(transpose2), 85627df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu ops::Const(s.WithOpName("outputs_const"), 1.0f)); 85727df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu 85827df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu GrapplerItem item; 85927df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu item.fetch = {"outputs"}; 86027df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 86127df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu GraphDef output; 86227df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 863f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 86427df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 86527df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu 86627df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu NodeMap node_map(&output); 86727df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu const NodeDef* outputs_node = node_map.GetNode("outputs"); 86827df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu EXPECT_EQ(2, outputs_node->input_size()); 86927df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu EXPECT_EQ(outputs_node->input(0), "outputs_const"); 87027df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu EXPECT_EQ(outputs_node->input(1), "^Placeholder"); 87127df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu} 87227df639673ae2bfe63b82862008da9bec488f0dbJingyue Wu 873e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong LeeTEST_F(ArithmeticOptimizerTest, NotRemoveTransposes) { 874e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 875e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output inputs_shape = 876e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee ops::Const(s.WithOpName("inputs_shape"), {8, 3, 28, 28}, {4}); 877e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output inputs = 878e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee ops::RandomUniform(s.WithOpName("inputs"), inputs_shape, DT_FLOAT); 879e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output perm = ops::Const(s.WithOpName("perm"), {1, 2, 3, 0}, {4}); 880e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output transpose1 = ops::Transpose(s.WithOpName("transpose1"), inputs, perm); 881e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output transpose2 = 882e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee ops::Transpose(s.WithOpName("transpose2"), transpose1, perm); 883e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee Output outputs = ops::Identity(s.WithOpName("outputs"), transpose2); 884e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 885e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee GrapplerItem item; 886e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee item.fetch = {"outputs"}; 887e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee TF_CHECK_OK(s.ToGraphDef(&item.graph)); 888e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 889e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee GraphDef output; 890e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 891e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 892f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 893e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 894e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 895e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee EXPECT_EQ(6, output.node_size()); 896e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee} 897e4134ea1c920b3256c37004fd245a1f43f0254d7HyoukJoong Lee 898f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue WuTEST_F(ArithmeticOptimizerTest, FoldMulToTransposeConv) { 899f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 900f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT, 901f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Placeholder::Shape({8, 28, 28, 3})); 902f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {}); 903f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output scaled_inputs = 904f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale); 905f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output perm_nhwc_to_nchw = 906f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4}); 907f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"), 908f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu scaled_inputs, perm_nhwc_to_nchw); 909f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output weights = ops::Const(s.WithOpName("weights"), 910f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Input::Initializer(127.0f, {5, 5, 3, 16})); 911f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output conv = 912f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1}, 913f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu "VALID", ops::Conv2D::DataFormat("NCHW")); 914f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), conv); 915f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 916f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu GrapplerItem item; 917f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu item.fetch = {"outputs"}; 918f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 919f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 920f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu GraphDef output; 921f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 922f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 923f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 924f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 925f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 926f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu NodeMap node_map(&output); 927f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu // `conv` is now a folded convolution with scaled weights. 928f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu const NodeDef* folded_conv = node_map.GetNode(conv.node()->name()); 929f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu CHECK_EQ(node_map.GetNode(NodeName(folded_conv->input(1)))->op(), "Mul"); 930f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu // Its input should be a transpose of `inputs`. 931f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu const NodeDef* transpose = node_map.GetNode(NodeName(folded_conv->input(0))); 932f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu CHECK_EQ(NodeName(transpose->input(0)), inputs.node()->name()); 933f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu} 934f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 935f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue WuTEST_F(ArithmeticOptimizerTest, NotFoldMulAcrossPreservedTranspose) { 936f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 937f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT, 938f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Placeholder::Shape({8, 28, 28, 3})); 939f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {}); 940f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output scaled_inputs = 941f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale); 942f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output perm_nhwc_to_nchw = 943f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Const(s.WithOpName("perm_nhwc_to_nchw"), {0, 3, 1, 2}, {4}); 944f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output inputs_nchw = ops::Transpose(s.WithOpName("inputs_nchw"), 945f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu scaled_inputs, perm_nhwc_to_nchw); 946f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output weights = ops::Const(s.WithOpName("weights"), 947f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Input::Initializer(127.0f, {5, 5, 3, 16})); 948f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output conv = 949f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Conv2D(s.WithOpName("conv"), inputs_nchw, weights, {1, 1, 1, 1}, 950f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu "VALID", ops::Conv2D::DataFormat("NCHW")); 951f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), conv); 952f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 953f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Tensor inputs_nchw_tensor(DT_FLOAT, {8, 3, 28, 28}); 954f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu memset(const_cast<char*>(inputs_nchw_tensor.tensor_data().data()), 0, 955f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu inputs_nchw_tensor.tensor_data().size()); 956f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 957f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu GrapplerItem item; 958f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu item.fetch = {"outputs"}; 959f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu item.feed = {{"inputs_nchw", inputs_nchw_tensor}}; 960f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 961f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 962f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu GraphDef output; 963f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 964f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 965f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 966f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 967f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 968f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu NodeMap node_map(&output); 969f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu const NodeDef* inputs_nchw_node_def = 970f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu node_map.GetNode(inputs_nchw.node()->name()); 971f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu EXPECT_EQ(NodeName(inputs_nchw_node_def->input(0)), 972f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu scaled_inputs.node()->name()); 973f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu} 974f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 975f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue WuTEST_F(ArithmeticOptimizerTest, FoldMulToConv) { 976f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 977f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output inputs = ops::Placeholder(s.WithOpName("inputs"), DT_FLOAT, 978f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Placeholder::Shape({8, 28, 28, 28, 3})); 979f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output scale = ops::Const(s.WithOpName("scale"), 1.0f / 255.0f, {}); 980f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output scaled_inputs = 981f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu ops::Multiply(s.WithOpName("scaled_inputs"), inputs, scale); 982f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output weights = ops::Const(s.WithOpName("weights"), 983f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Input::Initializer(127.0f, {5, 5, 5, 3, 16})); 984f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output conv = ops::Conv3D(s.WithOpName("conv"), scaled_inputs, weights, 985f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu {1, 1, 1, 1, 1}, "VALID"); 986f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), conv); 987f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 988f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu GrapplerItem item; 989f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu item.fetch = {"outputs"}; 990f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 991f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 992f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu GraphDef output; 993f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 994f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 995f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 996f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 997f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 998f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu NodeMap node_map(&output); 999f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu // `conv` is now a folded convolution on `inputs` and scaled weights. 1000f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu const NodeDef* folded_conv = node_map.GetNode(conv.node()->name()); 1001f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu CHECK_EQ(inputs.node()->name(), NodeName(folded_conv->input(0))); 1002f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu CHECK_EQ(node_map.GetNode(NodeName(folded_conv->input(1)))->op(), "Mul"); 1003f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu} 1004f08c961c97c1ec6bb5ee7982b4cc14ba01f3f938Jingyue Wu 100533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue WuTEST_F(ArithmeticOptimizerTest, OptimizeCastMulTransposeConv) { 100633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu // This unit test exercises two optimizations, folding mul into conv, and 100733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu // reordering cast and transpose. 100833d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu // 100933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu // Conv2D(Transpose(Mul(Cast(I), S)), W) 101033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu // => 101133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu // Conv2D(Transpose(Cast(I)), W*S) 101233d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu // => 101333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu // Conv2D(Cast(Transpose(I)), W*S) 101433d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); 101533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu Output inputs = 101633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({8, 28, 28, 3})); 101733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu Output cast = ops::Cast(s, inputs, DT_FLOAT); 101833d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu Output mul = ops::Mul(s, cast, ops::Const(s, 1.0f / 255.0f)); 101933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu Output transpose = 102033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu ops::Transpose(s, mul, ops::Const(s.WithOpName("perm"), {0, 3, 1, 2})); 102133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu Output weights = ops::Const(s.WithOpName("weights"), 102233d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu Input::Initializer(127.0f, {5, 5, 3, 16})); 102333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu Output conv = ops::Conv2D(s, transpose, weights, {1, 1, 1, 1}, "VALID", 102433d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu ops::Conv2D::DataFormat("NCHW")); 102533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), conv); 102633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu 102733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu GrapplerItem item; 102833d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu item.fetch = {"outputs"}; 102933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 103033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu 103133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu GraphDef output; 103233d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 103333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu 1034f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower // Run the optimizer twice to make sure the rewrite is idempotent. 1035f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 1036f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 1037f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower 1038f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 103933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu TF_EXPECT_OK( 104033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output)); 104133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu 1042f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 104333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 104433d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu 104533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu NodeMap node_map(&output); 104633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu const NodeDef* inputs_node = CHECK_NOTNULL(node_map.GetNode("Placeholder")); 104733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu const NodeDef* transpose_node = 104898ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower CHECK_NOTNULL(node_map.GetNode(OptimizedName("Transpose_uint8"))); 104998ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower const NodeDef* cast_node = 1050f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower CHECK_NOTNULL(node_map.GetNode(OptimizedName("Cast_float"))); 105133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu const NodeDef* weights_node = 105298ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D"))); 105333d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); 105433d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu 105533d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu EXPECT_EQ(output.node_size(), 7); 105633d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu EXPECT_EQ(transpose_node->input(0), inputs_node->name()); 105733d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu EXPECT_EQ(cast_node->input(0), transpose_node->name()); 105833d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu EXPECT_EQ(conv_node->input(0), cast_node->name()); 105933d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu EXPECT_EQ(conv_node->input(1), weights_node->name()); 106033d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu} 106133d55122d994d12f2a066f9ec4f0f03094a59579Jingyue Wu 106251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlowerTEST_F(ArithmeticOptimizerTest, OptimizeMultipleMulTransposeConv) { 106351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower // This unit test exercises optimization of folding mul into conv for 106451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower // multiple nodes in the graph. 106551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice("/gpu:0"); 106651889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower 106751889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower GrapplerItem item; 106851889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower Output conv[2]; 106951889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower 107051889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower for (int i = 0; i < 2; ++i) { 107151889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower Output inputs = 107251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower ops::Placeholder(s, DT_FLOAT, ops::Placeholder::Shape({8, 3, 28, 28})); 107351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower Output mul = ops::Mul(s, inputs, ops::Const(s, 1.0f / 255.0f)); 107451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower Output weights = ops::Const(s.WithOpName("weights"), 107551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower Input::Initializer(127.0f, {5, 5, 3, 16})); 107651889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower conv[i] = ops::Conv2D(s, mul, weights, {1, 1, 1, 1}, "VALID", 107751889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower ops::Conv2D::DataFormat("NCHW")); 107851889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower } 107951889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower Output outputs = ops::Add(s.WithOpName("outputs"), conv[0], conv[1]); 108051889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower 108151889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower item.fetch = {"outputs"}; 108251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower TF_CHECK_OK(s.ToGraphDef(&item.graph)); 108351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower 108451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower GraphDef output; 108551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 108651889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower 1087f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 108851889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower TF_EXPECT_OK( 108951889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower ConstantFolding(/*cpu_device=*/nullptr).Optimize(nullptr, item, &output)); 109051889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower 1091f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 109251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 109351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower 109451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower NodeMap node_map(&output); 109551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower const NodeDef* weights_node = 109698ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D"))); 109751889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower const NodeDef* conv_node = CHECK_NOTNULL(node_map.GetNode("Conv2D")); 109851889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower 109951889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower const NodeDef* weights_node_1 = 110098ac3f5b7b3942eb0ede7cae1b1afab717b3090aA. Unique TensorFlower CHECK_NOTNULL(node_map.GetNode(OptimizedName("weights_scaled_Conv2D_1"))); 110151889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower const NodeDef* conv_node_1 = CHECK_NOTNULL(node_map.GetNode("Conv2D_1")); 110251889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower EXPECT_EQ(conv_node->input(1), weights_node->name()); 110351889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower EXPECT_EQ(conv_node_1->input(1), weights_node_1->name()); 110451889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower} 110551889acee1a266478b578afad3fbe7b3a90fc17aA. Unique TensorFlower 11068ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue WuTEST_F(ArithmeticOptimizerTest, CombineBitcasts) { 11078ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 11088ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output inputs = 11098ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu ops::Placeholder(s, DT_UINT8, ops::Placeholder::Shape({2, 3})); 11108ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output bc1 = ops::Bitcast(s, inputs, DT_QINT8); 11118ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output bc2 = ops::Bitcast(s, bc1, DT_INT8); 11128ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), bc2); 11138ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 11148ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu GrapplerItem item; 11158ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu item.fetch = {"outputs"}; 11168ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 11178ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 11188ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu GraphDef output; 11198ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 1120f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 11218ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 11228ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 11238ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu EXPECT_EQ(1, std::count_if( 11248ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu output.node().begin(), output.node().end(), 11258ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu [](const NodeDef& node) { return node.op() == "Bitcast"; })); 11268ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu} 11278ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 11288ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue WuTEST_F(ArithmeticOptimizerTest, CombineAndRemoveBitcasts) { 11298ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 11308ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3})); 11318ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output bc1 = ops::Bitcast(s, inputs, DT_QINT8); 11328ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output bc2 = ops::Bitcast(s, bc1, DT_INT8); 11338ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), bc2); 11348ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 11358ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu GrapplerItem item; 11368ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu item.fetch = {"outputs"}; 11378ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 11388ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu GraphDef output; 11398ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 1140f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 11418ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 11428ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 11438ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu EXPECT_EQ(0, std::count_if( 11448ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu output.node().begin(), output.node().end(), 11458ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu [](const NodeDef& node) { return node.op() == "Bitcast"; })); 11468ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu} 11478ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 11488ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue WuTEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) { 11498ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu tensorflow::Scope s = tensorflow::Scope::NewRootScope(); 11508ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output inputs = ops::Placeholder(s, DT_INT8, ops::Placeholder::Shape({2, 3})); 11518ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output cast = ops::Cast(s, inputs, DT_INT8); 11528ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu Output outputs = ops::Identity(s.WithOpName("outputs"), cast); 11538ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 11548ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu GrapplerItem item; 11558ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu item.fetch = {"outputs"}; 11568ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu TF_CHECK_OK(s.ToGraphDef(&item.graph)); 11578ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu GraphDef output; 11588ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu TF_EXPECT_OK(ArithmeticOptimizer().Optimize(nullptr, item, &output)); 1159f448189df9b62b6dd141ce14224dbfc0d8f0d11bA. Unique TensorFlower item.graph.Swap(&output); 11608ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu TF_EXPECT_OK(ModelPruner().Optimize(nullptr, item, &output)); 11618ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 11628ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu EXPECT_EQ(0, std::count_if( 11638ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu output.node().begin(), output.node().end(), 11648ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu [](const NodeDef& node) { return node.op() == "Cast"; })); 11658ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu} 11668ff5070392bd0066930d11e3e39d21d3fa84bb2eJingyue Wu 1167c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner} // namespace 1168c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner} // namespace grappler 1169c247826219dd2541c6aba4578a03a171375d9290Benoit Steiner} // namespace tensorflow 1170