remote_fused_graph_rewriter_transform_test.cc revision d5f4d9bbac520ad9eae6614fe678e9d1568435a4
1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/const_op.h"
17#include "tensorflow/cc/ops/image_ops.h"
18#include "tensorflow/cc/ops/nn_ops.h"
19#include "tensorflow/cc/ops/standard_ops.h"
20#include "tensorflow/core/common_runtime/function.h"
21#include "tensorflow/core/framework/tensor_shape.h"
22#include "tensorflow/core/framework/tensor_testutil.h"
23#include "tensorflow/core/graph/default_device.h"
24#include "tensorflow/core/graph/node_builder.h"
25#include "tensorflow/core/graph/testlib.h"
26#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
27#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
28#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
29#include "tensorflow/core/lib/core/status_test_util.h"
30#include "tensorflow/core/platform/test.h"
31#include "tensorflow/core/public/session.h"
32#include "tensorflow/tools/graph_transforms/transform_utils.h"
33
34namespace tensorflow {
35namespace graph_transforms {
36
37// Declared here so we don't have to put it in a public header.
38Status FuseRemoteGraph(const GraphDef& input_graph_def,
39                       const TransformFuncContext& context,
40                       GraphDef* output_graph_def);
41
42Status PlaceRemoteGraphArguments(const GraphDef& input_graph_def,
43                                 const TransformFuncContext& context,
44                                 GraphDef* output_graph_def);
45
46namespace {
47constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
48    "remote_fused_graph_executor_name";
49constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME =
50    "remote_fused_graph_node_name";
51constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME0 =
52    "fuse_test_remote_fused_graph_executor0";
53constexpr const char* const REMOTE_FUSED_EXECUTOR_NAME1 =
54    "fuse_test_remote_fused_graph_executor1";
55
56Status BuildRemoteFusedGraphExecutor0(
57    std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
58  executor->reset(
59      new TestRemoteFusedGraphExecutor({"Mul"}, REMOTE_FUSED_EXECUTOR_NAME0));
60  return Status::OK();
61}
62
63Status BuildRemoteFusedGraphExecutor1(
64    std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
65  executor->reset(new TestRemoteFusedGraphExecutor(
66      {"Const", "Mul"}, REMOTE_FUSED_EXECUTOR_NAME1));
67  return Status::OK();
68}
69
70class FuseRemoteGraphMultipleAddOpsRewriterTest : public ::testing::Test {
71 protected:
72  void SetUp() final {
73    TF_ASSERT_OK(RemoteFusedGraphExecuteOpTestUtils::BuildMultipleAddGraph(
74        &input_graph_def_));
75    RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
76        hexagon_remote_fused_graph_executor_build(
77            REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
78            [](std::unique_ptr<IRemoteFusedGraphExecutor>* executor) -> Status {
79              return Status::OK();
80            });
81    RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
82        test_remote_fused_graph_executor_build0(REMOTE_FUSED_EXECUTOR_NAME0,
83                                                BuildRemoteFusedGraphExecutor0);
84
85    RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
86        test_remote_fused_graph_executor_build1(REMOTE_FUSED_EXECUTOR_NAME1,
87                                                BuildRemoteFusedGraphExecutor1);
88  }
89
90  void TearDown() final {}
91
92  Status Fuse() { return FuseInternal(/*only_place_args=*/false); }
93
94  Status PlaceFuseArgs() { return FuseInternal(/*only_place_args*/ true); }
95
96  Status FuseWithPlacedArgs() {
97    const std::vector<std::pair<string, Tensor>> input_tensors{
98        {"A", {DT_FLOAT, {1, 1, 1, 1}}}};
99    return RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
100        input_graph_def_with_fuse_args_, input_tensors, &output_graph_def_);
101  }
102
103  Status FuseInternal(bool only_place_args) {
104    TransformFuncContext context;
105    context.input_names = inputs_;
106    context.output_names = outputs_;
107
108    if (!input_types_.empty()) {
109      context.params.insert(std::pair<string, std::vector<string>>(
110          {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES,
111           {input_types_}}));
112    }
113    if (!input_shapes_.empty()) {
114      context.params.insert(std::pair<string, std::vector<string>>(
115          {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES,
116           {input_shapes_}}));
117    }
118    if (!fused_node_names_str_.empty()) {
119      context.params.insert(std::pair<string, std::vector<string>>(
120          {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES,
121           {fused_node_names_str_}}));
122    }
123
124    if (!border_inputs_str_.empty()) {
125      context.params.insert(std::pair<string, std::vector<string>>(
126          {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS,
127           {border_inputs_str_}}));
128    }
129    if (!border_outputs_str_.empty()) {
130      context.params.insert(std::pair<string, std::vector<string>>(
131          {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS,
132           {border_outputs_str_}}));
133    }
134
135    if (!fused_op_types_str_.empty()) {
136      context.params.insert(std::pair<string, std::vector<string>>(
137          {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES,
138           {fused_op_types_str_}}));
139    }
140
141    if (fuse_by_executor_) {
142      context.params.insert(std::pair<string, std::vector<string>>(
143          {RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR,
144           {"true"}}));
145    }
146
147    context.params.insert(std::pair<string, std::vector<string>>(
148        {RemoteFusedGraphExecuteUtils::
149             TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
150         {remote_fused_graph_executor_name_}}));
151    context.params.insert(std::pair<string, std::vector<string>>(
152        {RemoteFusedGraphExecuteUtils::
153             TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME,
154         {REMOTE_FUSED_GRAPH_NODE_NAME}}));
155
156    if (only_place_args) {
157      return PlaceRemoteGraphArguments(input_graph_def_, context,
158                                       &input_graph_def_with_fuse_args_);
159    } else {
160      return FuseRemoteGraph(input_graph_def_, context, &output_graph_def_);
161    }
162  }
163
164  void SetInputShapeType() {
165    input_types_ = "float";
166    input_shapes_ = "1,1,1,1";
167  }
168
169  void ReplaceOpType(const std::unordered_set<string>& op_name,
170                     const string& new_op_type) {
171    for (NodeDef& node_def : *input_graph_def_.mutable_node()) {
172      if (op_name.count(node_def.name()) > 0) {
173        node_def.set_op(new_op_type);
174      }
175    }
176  }
177
178  void CheckGraph(int expected_node_count, int expected_cluster_count) {
179    EXPECT_EQ(expected_node_count, output_graph_def_.node_size());
180
181    int cluster_count = 0;
182    for (const NodeDef& node_def : output_graph_def_.node()) {
183      const string& name = node_def.name();
184      if (StringPiece(name).starts_with(REMOTE_FUSED_GRAPH_NODE_NAME)) {
185        ++cluster_count;
186        RemoteFusedGraphExecuteInfo info;
187        string serialized_proto;
188        TF_ASSERT_OK(
189            GetNodeAttr(node_def,
190                        RemoteFusedGraphExecuteUtils::
191                            ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO,
192                        &serialized_proto));
193        info.ParseFromString(serialized_proto);
194        CHECK_EQ(remote_fused_graph_executor_name_, info.executor_name());
195      }
196    }
197    EXPECT_EQ(expected_cluster_count, cluster_count);
198  }
199
200 public:
201  const std::vector<string> inputs_{"A"};
202  const std::vector<string> outputs_{"K"};
203  GraphDef input_graph_def_;
204  string input_types_;
205  string input_shapes_;
206  GraphDef input_graph_def_with_fuse_args_;
207  GraphDef output_graph_def_;
208  string fused_node_names_str_;
209  string border_inputs_str_;
210  string border_outputs_str_;
211  string fused_op_types_str_;
212  string remote_fused_graph_executor_name_{REMOTE_FUSED_GRAPH_EXECUTOR_NAME};
213  bool fuse_by_executor_{false};
214};
215
216TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
217       FuseRemoteGraphByNodesWithShapeType_HIJ) {
218  SetInputShapeType();
219  fused_node_names_str_ = "H,I,J";
220  TF_ASSERT_OK(Fuse());
221  CheckGraph(9, 1);
222}
223
224TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
225       FuseRemoteGraphByNodesWithoutShapeType_HIJ) {
226  fused_node_names_str_ = "H,I,J";
227  TF_ASSERT_OK(Fuse());
228  CheckGraph(9, 1);
229}
230
231TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
232       FuseRemoteGraphByNodesWithShapeType_ABCDEFGHIJK) {
233  SetInputShapeType();
234  fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
235  TF_ASSERT_OK(Fuse());
236  CheckGraph(3, 1);
237}
238
239TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
240       FuseRemoteGraphByNodesWithoutShapeType_ABCDEFGHIJK) {
241  fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
242  TF_ASSERT_OK(Fuse());
243  CheckGraph(3, 1);
244}
245
246TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
247       FuseRemoteGraphByBorderWithShapeType_FCG_J) {
248  SetInputShapeType();
249  border_inputs_str_ = "F:0,C:0,G";
250  border_outputs_str_ = "J:0";
251  TF_ASSERT_OK(Fuse());
252  CheckGraph(9, 1);
253}
254
255TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
256       FuseRemoteGraphByBorderWithoutShapeType_FCG_J) {
257  border_inputs_str_ = "F:0,C:0,G";
258  border_outputs_str_ = "J:0";
259  TF_ASSERT_OK(Fuse());
260  CheckGraph(9, 1);
261}
262
263TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
264       FuseRemoteGraphByBorderWithShapeType_ABCDE_K) {
265  SetInputShapeType();
266  border_inputs_str_ = "A,B,C,D,E";
267  border_outputs_str_ = "K";
268  TF_ASSERT_OK(Fuse());
269  CheckGraph(7, 1);
270}
271
272TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
273       FuseRemoteGraphByBorderWithoutShapeType_ABCDE_K) {
274  border_inputs_str_ = "A,B,C,D,E";
275  border_outputs_str_ = "K";
276  TF_ASSERT_OK(Fuse());
277  CheckGraph(7, 1);
278}
279
280TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
281       FuseRemoteGraphByOpTypes_HIJ) {
282  ReplaceOpType({"H", "I", "J"}, "Mul");
283  fused_op_types_str_ = "Mul";
284  TF_ASSERT_OK(Fuse());
285  CheckGraph(9, 1);
286}
287
288TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
289       FuseRemoteGraphByOpTypes_FGHIJ) {
290  ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
291  fused_op_types_str_ = "Const,Mul";
292  TF_ASSERT_OK(Fuse());
293  CheckGraph(3, 1);
294}
295
296TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
297       FuseRemoteGraphByExecutor_HIJ) {
298  ReplaceOpType({"H", "I", "J"}, "Mul");
299  remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME0;
300  fuse_by_executor_ = true;
301  TF_ASSERT_OK(Fuse());
302  CheckGraph(9, 1);
303}
304
305TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
306       FuseRemoteGraphByExecutor_FGHIJ) {
307  ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
308  remote_fused_graph_executor_name_ = REMOTE_FUSED_EXECUTOR_NAME1;
309  fuse_by_executor_ = true;
310  TF_ASSERT_OK(Fuse());
311  CheckGraph(3, 1);
312}
313
314TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_HIJ) {
315  fused_node_names_str_ = "H,I,J";
316  TF_ASSERT_OK(PlaceFuseArgs());
317  TF_ASSERT_OK(FuseWithPlacedArgs());
318  CheckGraph(9, 1);
319}
320
321TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDEFGHIJK) {
322  fused_node_names_str_ = "A,B,C,D,E,F,G,H,I,J,K";
323  TF_ASSERT_OK(PlaceFuseArgs());
324  TF_ASSERT_OK(FuseWithPlacedArgs());
325  CheckGraph(3, 1);
326}
327
328TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_FCG_J) {
329  border_inputs_str_ = "F:0,C:0,G";
330  border_outputs_str_ = "J:0";
331  TF_ASSERT_OK(PlaceFuseArgs());
332  TF_ASSERT_OK(FuseWithPlacedArgs());
333  CheckGraph(9, 1);
334}
335
336TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_ABCDE_K) {
337  SetInputShapeType();
338  border_inputs_str_ = "A,B,C,D,E";
339  border_outputs_str_ = "K";
340  TF_ASSERT_OK(PlaceFuseArgs());
341  TF_ASSERT_OK(FuseWithPlacedArgs());
342  CheckGraph(7, 1);
343}
344
345TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest, PlaceAndFuse_MUL_HIJ) {
346  SetInputShapeType();
347  ReplaceOpType({"H", "I", "J"}, "Mul");
348  fused_op_types_str_ = "Mul";
349
350  TF_ASSERT_OK(PlaceFuseArgs());
351  TF_ASSERT_OK(FuseWithPlacedArgs());
352  CheckGraph(9, 1);
353}
354
355TEST_F(FuseRemoteGraphMultipleAddOpsRewriterTest,
356       PlaceAndFuse_CONST_MUL_FGHIJ) {
357  SetInputShapeType();
358  ReplaceOpType({"F", "G", "H", "I", "J"}, "Mul");
359  fused_op_types_str_ = "Const,Mul";
360
361  TF_ASSERT_OK(PlaceFuseArgs());
362  TF_ASSERT_OK(FuseWithPlacedArgs());
363  CheckGraph(3, 1);
364}
365
366}  // namespace
367}  // namespace graph_transforms
368}  // namespace tensorflow
369