1227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
3227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License");
4227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFloweryou may not use this file except in compliance with the License.
5227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerYou may obtain a copy of the License at
6227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
7227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower    http://www.apache.org/licenses/LICENSE-2.0
8227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
9227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software
10227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS,
11227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerSee the License for the specific language governing permissions and
13227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerlimitations under the License.
14227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower==============================================================================*/
15227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
16227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
17227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
187b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower#include <algorithm>
19f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower#include <queue>
20227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower#include <utility>
21227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
227b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower#include "tensorflow/core/common_runtime/shape_refiner.h"
23f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower#include "tensorflow/core/framework/node_def_util.h"
24e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/framework/tensor.pb.h"
25e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/framework/tensor_shape.pb.h"
267b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower#include "tensorflow/core/graph/algorithm.h"
27bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower#include "tensorflow/core/graph/node_builder.h"
283b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower#include "tensorflow/core/public/session.h"
293b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower#include "tensorflow/core/public/session_options.h"
303b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
31227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowernamespace tensorflow {
32bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowernamespace {
33bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerconst Node* FindNodeByName(const string& name, const Graph& graph) {
34bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const Node* node : graph.nodes()) {
35bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    CHECK_NOTNULL(node);
36bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    if (node->name() == name) {
37bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      return node;
38bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
39bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
40bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return nullptr;
41bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
42bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
43bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerstd::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
44bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& node_names_and_ports) {
45bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::unordered_set<string> retval;
46bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& node_name_and_port : node_names_and_ports) {
47bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const TensorId tid = ParseTensorName(node_name_and_port);
48bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    retval.emplace(tid.first.ToString());
49bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
50bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return retval;
51bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
52bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
53bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerNode* FindMutableNodeByName(const string& name, Graph* graph) {
54bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (Node* node : graph->nodes()) {
55bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    if (node != nullptr && node->name() == name) {
56bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      return node;
57bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
58bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
59bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return nullptr;
60bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
61bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
62bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerconst NodeDef* FindNodeDefByName(const string& input,
63bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                 const GraphDef& graph_def) {
64bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const TensorId tid = ParseTensorName(input);
65bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const string name = tid.first.ToString();
66bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const NodeDef& node_def : graph_def.node()) {
67bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    if (node_def.name() == name) {
68bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      return &node_def;
69bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
70bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
71bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return nullptr;
72bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
73bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
741c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowerbool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port,
751c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                    TensorId* tid) {
761c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  CHECK_NOTNULL(tid);
771c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  *tid = ParseTensorName(node_name_and_port);
781c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  if (node_def.name() == tid->first.ToString()) {
791c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    return true;
801c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  }
811c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  return false;
821c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
831c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
841c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowerbool ContainsSameTensorId(const string& tensor_name,
851c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                          const std::vector<string>& tensor_names) {
861c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  const TensorId tid0 = ParseTensorName(tensor_name);
871c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  for (const string& name : tensor_names) {
881c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const TensorId tid1 = ParseTensorName(name);
891c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    if (tid0.first == tid1.first && tid0.second == tid1.second) {
901c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      return true;
911c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
921c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  }
931c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  return false;
941c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
951c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
961c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowervoid AppendDeliminator(string* str) {
971c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  CHECK_NOTNULL(str);
981c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  if (!str->empty()) {
991c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    *str += ":";
1001c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  }
1011c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
1021c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
1031c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowervoid ConvertMapToVector(const std::unordered_map<int, string>& in,
1041c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                        std::vector<string>* out) {
1051c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  CHECK_NOTNULL(out);
1061c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  out->resize(in.size());
10790d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai  for (size_t i = 0; i < in.size(); ++i) {
1081c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    CHECK(in.count(i) > 0);
1091c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    out->at(i) = in.at(i);
1101c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  }
1111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
1121c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
113bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerstring DumpGraphDef(const GraphDef& graph_def) {
114bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  string out;
115bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const NodeDef& node : graph_def.node()) {
116bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    out += strings::StrCat("node: ", node.name(), "\n    input: ");
117bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    for (const string& input : node.input()) {
118bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      out += strings::StrCat(input, ", ");
119bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
120bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    out += "\n";
121bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
122bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return out;
123bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
124bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
125bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerstring DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) {
126bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  string out;
127bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  out += "Nodes:\n";
128bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& str : std::get<0>(cluster)) {
129bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    out += str + ", ";
130bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
131bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  out += "\nInput border:\n";
132bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& str : std::get<1>(cluster)) {
133bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    out += str + ", ";
134bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
135bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  out += "\nOutput border:\n";
136bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& str : std::get<2>(cluster)) {
137bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    out += str + ", ";
138bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
139bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return out;
140bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
141bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
142bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}  // namespace
143227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
144f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower/* static */ constexpr const char* const
145f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower    RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES;
146f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower/* static */ constexpr const char* const
147f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower    RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES;
1489ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
1499ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower    ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO;
1501c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ constexpr const char* const
1511c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    RemoteFusedGraphExecuteUtils::ATTR_NODE_TYPE;
1529ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils::
1539ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower    TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
1549ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const
1559ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower    RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME;
1569ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const
1579ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower    RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES;
1589ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const
1599ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower    RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS;
1609ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const
1619ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower    RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS;
1629ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const
163f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES;
164f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower/* static */ constexpr const char* const
165d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR;
166d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower/* static */ constexpr const char* const
1679ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower    RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES;
1689ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const
1699ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower    RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES;
170f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower
171227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar(
172227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower    const string& name, ExecutorBuildFunc executor_build_func) {
173227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower  ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
174227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower  executor_build_registry[name] = std::move(executor_build_func);
175227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower}
176227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
177227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower/* static */ const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc*
178227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(const string& name) {
179227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower  ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry();
180227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower  if (executor_build_registry.count(name) <= 0) {
181227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower    return nullptr;
182227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower  }
183227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower  return &executor_build_registry.at(name);
184227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower}
185227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
186227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower/* static */ RemoteFusedGraphExecuteUtils::ExecutorBuildRegistry*
187227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() {
188227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower  static ExecutorBuildRegistry executor_builder_registry;
189227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower  return &executor_builder_registry;
190227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower}
191227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower
1923b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/**
1933b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * - DryRunInference
1943b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * To determine shapes of output tensors of all nodes, dryrun the graph.
1953b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * This function supplies memory allocation information when loading
1963b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * the graph. This function is used to verify shape inference and actual
1973b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * output shape.
1983b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower */
1993b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::DryRunInference(
2003b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const GraphDef& graph_def,
2013b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const std::vector<std::pair<string, Tensor>>& input_node_info_list,
2023b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const std::vector<string>& output_node_names, const bool initialize_by_zero,
2033b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    std::vector<tensorflow::Tensor>* output_tensors) {
2043b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  // Create input tensor vector.  If "initialize_by_zero" is true,
2053b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  // input tensor fields are initialized by 0.
2063b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
2073b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  for (const std::pair<string, Tensor>& input : input_node_info_list) {
2083b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    CHECK(input.second.IsInitialized());
2093b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    if (!initialize_by_zero) {
2103b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      input_tensors.push_back({input.first, input.second});
2113b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      continue;
2123b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    }
2133b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    // If input tensor is not initialized, initialize by 0-filling
2143b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const DataType data_type = input.second.dtype();
2153b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const TensorShape& shape = input.second.shape();
2163b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    Tensor input_tensor(data_type, shape);
2173b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    switch (data_type) {
2183b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      case DT_INT32: {
2193b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        auto int_tensor = input_tensor.flat<int32>();
2203b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        int_tensor = int_tensor.constant(0);
2213b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        break;
2223b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      }
2233b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      case DT_FLOAT: {
2243b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        auto float_tensor = input_tensor.flat<float>();
2253b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        float_tensor = float_tensor.constant(0.0f);
2263b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        break;
2273b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      }
2283b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      case DT_QUINT8: {
2293b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        auto int_tensor = input_tensor.flat<quint8>();
2303b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        int_tensor = int_tensor.constant(0);
2313b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        break;
2323b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      }
2333b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      default:
2343b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower        LOG(FATAL) << "Unsupported input type: " << data_type;
2353b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    }
2363b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    input_tensors.push_back({input.first, input_tensor});
2373b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
2383b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
2393b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  // Setup session
2403b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  CHECK(output_tensors != nullptr);
2413b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  SessionOptions session_options;
2423b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  session_options.env = Env::Default();
2433b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  std::unique_ptr<Session> session =
2443b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      std::unique_ptr<Session>(NewSession(session_options));
2453b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  Status status = session->Create(graph_def);
2463b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  if (!status.ok()) {
2473b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    return status;
2483b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
2493b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
2503b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  // Setup session arguments
2513b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  RunOptions run_options;
2523b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  run_options.set_trace_level(RunOptions::FULL_TRACE);
2533b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  RunMetadata run_metadata;
2543b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
2553b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  // Run inference with all node as output
2563b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  status = session->Run(run_options, input_tensors, output_node_names, {},
2573b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower                        output_tensors, &run_metadata);
2583b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  if (!status.ok()) {
2593b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    LOG(ERROR) << "Error during inference: " << status;
2603b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    return status;
2613b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
2623b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  return Status();
2633b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower}
2643b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
2653b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
2663b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const GraphDef& graph_def,
2673b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const std::vector<std::pair<string, Tensor>>& input_node_info_list,
2683b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const bool initialize_by_zero,
2693b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) {
2703b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  CHECK(tensor_shape_map != nullptr);
2713b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  std::vector<Tensor> output_tensors;
2723b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  output_tensors.reserve(graph_def.node_size());
2733b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  std::vector<string> output_node_names;
2740a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower
2750a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  Graph graph(OpRegistry::Global());
2760a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  Status status = ImportGraphDef({}, graph_def, &graph, nullptr);
2770a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  if (!status.ok()) {
2780a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    return status;
2790a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  }
2800a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower
2810a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  for (const Node* node : graph.nodes()) {
2820a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    if (IsInputNode(input_node_info_list, node->name())) {
2830a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower      continue;
2840a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    }
2850a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    for (int i = 0; i < node->num_outputs(); ++i) {
2860a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower      output_node_names.emplace_back(strings::StrCat(node->name(), ":", i));
2873b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    }
2883b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
2890a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower
2900a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  status = DryRunInference(graph_def, input_node_info_list, output_node_names,
2910a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower                           initialize_by_zero, &output_tensors);
2923b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  if (!status.ok()) {
2933b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    VLOG(1) << "Failed to dryrun " << status;
2943b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    return status;
2953b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
2960a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower
2973b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  CHECK_EQ(output_node_names.size(), output_tensors.size())
2983b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      << output_node_names.size() << ", " << output_tensors.size();
2993b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
3003b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  // Append output tensor of input node in advance to create a map
3013b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  // to avoid memory reallocation inside vector
3023b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  for (const std::pair<string, Tensor>& input_node_info :
3033b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower       input_node_info_list) {
3043b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    output_tensors.push_back(input_node_info.second);
3053b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
3063b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
307e18c7a5332b97ddcaad7f5485c46ad8dc8fcd274Suharsh Sivakumar  for (int i = 0; static_cast<size_t>(i) < output_node_names.size(); ++i) {
3083b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const string& name = output_node_names.at(i);
3093b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const Tensor& tensor = output_tensors.at(i);
310f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    EmplaceTensorShapeType(name, tensor, tensor_shape_map);
3113b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
312e18c7a5332b97ddcaad7f5485c46ad8dc8fcd274Suharsh Sivakumar  for (int i = 0; static_cast<size_t>(i) < input_node_info_list.size(); ++i) {
3133b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const string& name = input_node_info_list.at(i).first;
3143b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const Tensor& tensor = output_tensors.at(output_node_names.size() + i);
315f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    EmplaceTensorShapeType(name, tensor, tensor_shape_map);
3163b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
3170a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  CHECK_EQ(output_node_names.size() + input_node_info_list.size(),
3180a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower           output_tensors.size());
3193b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  return status;
3203b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower}
3213b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
3223b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/* static */ bool RemoteFusedGraphExecuteUtils::IsInputNode(
3233b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const std::vector<std::pair<string, Tensor>>& input_tensor_vector,
3243b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const string& node_name) {
3253b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  for (const std::pair<string, Tensor>& pair : input_tensor_vector) {
326a11e669c5a4c855b0b507da97378bc7e03a08f86A. Unique TensorFlower    const TensorId tid = ParseTensorName(pair.first);
327a11e669c5a4c855b0b507da97378bc7e03a08f86A. Unique TensorFlower    if (node_name == tid.first.ToString()) {
3283b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower      return true;
3293b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    }
3303b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
3313b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  return false;
3323b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower}
3333b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
3343b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/* static */ void RemoteFusedGraphExecuteUtils::ConvertToTensorShapeMap(
3353b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const std::vector<std::pair<string, Tensor>>& input_node_info_list,
3363b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const std::vector<string>& output_node_names,
3373b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const std::vector<tensorflow::Tensor>& output_tensors,
3383b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    TensorShapeMap* tensor_shape_map) {
3393b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  CHECK_NE(tensor_shape_map, nullptr);
3403b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  tensor_shape_map->clear();
3413b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  tensor_shape_map->reserve(input_node_info_list.size() +
3423b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower                            output_node_names.size());
3433b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  const int output_node_count = output_node_names.size();
3443b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  CHECK_EQ(output_node_count, output_tensors.size());
3453b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  for (int i = 0; i < output_node_count; ++i) {
3463b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const string& node_name = output_node_names.at(i);
3473b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower    const Tensor& tensor = output_tensors.at(i);
348f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    EmplaceTensorShapeType(node_name, tensor, tensor_shape_map);
3493b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower  }
3503b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower}
3513b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower
352f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto(
353f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower    const TensorProto& tensor_proto, Tensor* tensor) {
354f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower  if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) {
355f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower    Tensor parsed(tensor_proto.dtype());
356f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower    if (parsed.FromProto(cpu_allocator(), tensor_proto)) {
357f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower      *tensor = parsed;
358f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower      return Status::OK();
359f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower    }
360f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower  }
361f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower  return errors::InvalidArgument("Cannot parse tensor from proto");
362f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower}
363f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower
364f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower/* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType(
365f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower    const std::vector<DataType>& data_types,
366f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower    const std::vector<TensorShape>& shapes, NodeDef* node_def) {
367f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower  AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def);
368f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower  AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def);
369f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower  return true;
370f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower}
371f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower
372f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower/* static */ Status
373f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
374f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    const TensorShapeMap& tensor_shape_map, NodeDef* node_def) {
375f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  CHECK_NE(node_def, nullptr);
376f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  std::priority_queue<std::tuple<int, const TensorShapeType*>> queue;
377f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  auto its = tensor_shape_map.equal_range(node_def->name());
378f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  for (auto it = its.first; it != its.second; ++it) {
379f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    queue.emplace(std::make_tuple(it->second.first, &it->second.second));
380f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  }
381f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  int last_port = queue.size();
382f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  std::vector<DataType> data_types;
383f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  std::vector<TensorShape> shapes;
384f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  while (!queue.empty()) {
385f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    const int port = std::get<0>(queue.top());
386f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    const TensorShapeType* tst = std::get<1>(queue.top());
387f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    CHECK_NE(tst, nullptr);
388f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    data_types.emplace(data_types.begin(), tst->first);
389f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    shapes.emplace(shapes.begin(), tst->second);
390f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    CHECK_EQ(last_port - 1, port);
391f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    last_port = port;
392f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    queue.pop();
393f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  }
394f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  AddOutputTensorShapeType(data_types, shapes, node_def);
395f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  return Status::OK();
396f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower}
397f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower
3980a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
39973882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving    AttrSlice attrs, std::vector<DataType>* data_types,
4000a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    std::vector<TensorShape>* shapes) {
4010a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  Status status;
4020a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  if (data_types != nullptr) {
40373882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving    status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types);
4040a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  }
4050a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  if (!status.ok()) {
4060a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    return status;
4070a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  }
4080a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  if (shapes != nullptr) {
40973882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving    status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes);
4100a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    if (status.ok() && data_types != nullptr) {
4110a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower      CHECK_EQ(data_types->size(), shapes->size());
4120a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    }
4130a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  }
4140a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower
4150a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower  return status;
4160a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower}
4170a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower
418bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ bool RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType(
419bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const GraphDef& graph_def, const string& name_and_port, DataType* data_type,
420bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TensorShape* shape) {
421bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::vector<DataType> data_types;
422bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::vector<TensorShape> shapes;
423bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const TensorId tid = ParseTensorName(name_and_port);
424bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const string node_name = tid.first.ToString();
425bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const int port = tid.second;
426bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
427bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK_NOTNULL(node_def);
428bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  GetOutputTensorShapeType(*node_def, &data_types, &shapes).IgnoreError();
429bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  if (data_types.empty()) {
430bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    return false;
431bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
432bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK(data_types.size() > port);
433bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  *data_type = data_types.at(port);
434bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  *shape = shapes.at(port);
435bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return true;
436bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
437bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
4387b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference(
4397b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    const GraphDef& graph_def,
4407b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    const std::vector<std::pair<string, Tensor>>& input_node_info_list,
4417b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    Graph* graph, ShapeRefiner* shape_refiner) {
4427b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower  Status status;
4437b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower  auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) {
4447b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    if (!status.ok()) {
4457b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      return;
4467b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    }
4477b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    CHECK_NE(node, nullptr);
4487b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    // If we visit an input node, we use the shape provided and set the
4497b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    // shape accordingly.
4507b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    bool is_input_node = false;
4517b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    for (const std::pair<string, Tensor>& input_node_info :
4527b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower         input_node_info_list) {
4537b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      if (node->name() == input_node_info.first) {
4547b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        shape_inference::InferenceContext* context =
4557b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower            shape_refiner->GetContext(node);
4567b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        shape_inference::ShapeHandle handle;
4577b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        status = context->MakeShapeFromTensorShape(
4587b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower            input_node_info.second.shape(), &handle);
4590a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower        if (!status.ok()) {
4600a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower          break;
4610a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower        }
4620a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower        status = shape_refiner->SetShape(node, 0, handle);
4630a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower        if (!status.ok()) {
4640a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower          break;
4650a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower        }
4667b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        is_input_node = true;
4677b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      }
4687b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      if (!status.ok()) {
4697b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        break;
4707b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      }
4717b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    }
4727b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    // If not an input node call AddNode() that recomputes the shape.
4737b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    if (!is_input_node && status.ok()) {
4747b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      status = shape_refiner->AddNode(node);
4750a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    }
4760a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower    if (!status.ok()) {
4770a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower      VLOG(1) << "Shape inference failed for node: " << node->name();
4787b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    }
4797b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower  };
4807b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower
4817b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower  ReverseDFS(*graph, {}, visit);
4827b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower
4837b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower  return status;
4847b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower}
4857b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower
4867b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph(
4877b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    const Graph& graph, const ShapeRefiner& shape_refiner,
4887b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    TensorShapeMap* tensor_shape_map) {
4897b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower  for (int i = 0; i < graph.num_node_ids(); ++i) {
4907b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    const Node* node = graph.FindNodeId(i);
4917b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    CHECK_NE(node, nullptr);
4927b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    for (int j = 0; j < node->num_outputs(); ++j) {
4937b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      const int output_index = j;
4947b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      const DataType dt = node->output_type(output_index);
4957b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      shape_inference::InferenceContext* context =
4967b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower          shape_refiner.GetContext(node);
4977b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      CHECK_NE(context, nullptr);
4987b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      shape_inference::ShapeHandle shape_handle = context->output(output_index);
4997b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      if (context->RankKnown(shape_handle)) {
5007b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        TensorShape ts;
5017b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        for (int k = 0; k < context->Rank(shape_handle); ++k) {
5027b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower          shape_inference::DimensionHandle dh = context->Dim(shape_handle, k);
5037b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower          CHECK(context->ValueKnown(dh));
5047b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower          ts.AddDim(context->Value(dh));
5057b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        }
5067b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        const string& node_name = node->name();
5077b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        CHECK(tensor_shape_map->count(node_name) == 0);
508f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower        tensor_shape_map->emplace(node_name,
509f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower                                  std::make_pair(j, std::make_pair(dt, ts)));
5107b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      } else {
5117b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower        return errors::InvalidArgument("Graph contains unknow shapes");
5127b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower      }
5137b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower    }
5147b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower  }
5157b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower  return Status::OK();
5167b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower}
5177b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower
518f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower/* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
519f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::GetTensorShapeType(
520f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    const TensorShapeMap& tensor_shape_map, const string& node_name) {
521f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  if (node_name.find(':') != string::npos) {
522f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    const TensorId tid = ParseTensorName(node_name);
523f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    return GetTensorShapeType(tensor_shape_map, tid.first.ToString(),
524f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower                              tid.second);
525f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  } else {
526f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    return GetTensorShapeType(tensor_shape_map, node_name, 0);
527f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  }
528f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower}
529f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower
530f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower/* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType*
531f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::GetTensorShapeType(
532f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    const TensorShapeMap& tensor_shape_map, const string& node_name,
533f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    const int port) {
534f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  CHECK_EQ(node_name.find(':'), string::npos);
535f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  if (tensor_shape_map.count(node_name) <= 0) {
536f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    return nullptr;
537f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  }
538f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  auto its = tensor_shape_map.equal_range(node_name);
539f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  for (auto it = its.first; it != its.second; ++it) {
540f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    if (it->second.first == port) {
541f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower      return &it->second.second;
542f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    }
543f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  }
544f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  return nullptr;
545f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower}
546f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower
54739af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower/* static */ void
54839af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
54939af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower    const RemoteFusedGraphExecuteInfo& proto,
55039af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower    std::vector<std::pair<string, Tensor>>* inputs,
55139af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower    std::vector<string>* outputs) {
55239af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower  CHECK_EQ(proto.graph_input_node_name_size(),
55339af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower           proto.default_graph_input_tensor_shape_size());
55439af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower  for (int i = 0; i < proto.graph_input_node_name_size(); ++i) {
55539af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower    inputs->emplace_back(
55639af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower        proto.graph_input_node_name(i),
55739af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower        Tensor(proto.default_graph_input_tensor_shape(i).dtype(),
55839af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower               TensorShape(proto.default_graph_input_tensor_shape(i).shape())));
55939af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower  }
56039af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower  for (const string& output_node_name : proto.graph_output_node_name()) {
56139af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower    outputs->emplace_back(output_node_name);
56239af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower  }
56339af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower}
56439af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower
565f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower/* static */ void RemoteFusedGraphExecuteUtils::EmplaceTensorShapeType(
566f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    const string& name, const Tensor& tensor,
567f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower    TensorShapeMap* tensor_shape_map) {
568f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  const TensorId tid = ParseTensorName(name);
569f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  CHECK_EQ(tensor_shape_map->count(name), 0);
570f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower  tensor_shape_map->emplace(
571f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower      tid.first.ToString(),
572f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower      std::make_pair(tid.second,
573f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower                     std::make_pair(tensor.dtype(), tensor.shape())));
574f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower}
575f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower
576bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes(
577bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<std::pair<string, Tensor>>& input_tensors,
578bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const bool dry_run_inference, GraphDef* graph_def) {
579bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TensorShapeMap tensor_shape_map;
580bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  if (dry_run_inference) {
581bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TF_RETURN_IF_ERROR(DryRunInferenceForAllNode(*graph_def, input_tensors,
582bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                                 /*initialize_by_zero=*/true,
583bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                                 &tensor_shape_map));
584bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  } else {
585bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    ImportGraphDefOptions opts;
586bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    Graph graph(OpRegistry::Global());
587e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving    ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
588bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TF_RETURN_IF_ERROR(
589bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        ImportGraphDef(opts, *graph_def, &graph, &shape_refiner));
590bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors,
591bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                               &graph, &shape_refiner));
592bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TF_RETURN_IF_ERROR(
593bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        BuildTensorShapeMapFromGraph(graph, shape_refiner, &tensor_shape_map));
594bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
595bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
596bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (NodeDef& node_def : *graph_def->mutable_node()) {
597bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TF_RETURN_IF_ERROR(
598bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
599bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
600bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
601bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return Status::OK();
602bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
603bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
604bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status
605bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
606bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& executor_name, const GraphDef& subgraph_def,
607bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& inputs, const std::vector<string>& outputs,
608bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info,
609bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    DataTypeVector* input_types, DataTypeVector* output_types) {
610bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK_NOTNULL(execute_info);
611bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK_NOTNULL(input_types);
612bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK_NOTNULL(output_types);
613bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
614bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  execute_info->Clear();
615bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  execute_info->set_executor_name(executor_name);
616bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
617bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  // copy graph
618bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  *execute_info->mutable_remote_graph() = subgraph_def;
619bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
620bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& input : inputs) {
621bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    DataType dt;
622bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TensorShape shape;
623bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const bool has_shapetype =
624bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        GetOutputTensorShapeType(subgraph_def, input, &dt, &shape);
625bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
626bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    execute_info->add_graph_input_node_name(input);
627bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    if (has_shapetype) {
628bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
629bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          *execute_info->add_default_graph_input_tensor_shape();
630bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      tensor_shape_type.set_dtype(dt);
631bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
632bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      for (const int64 dim : shape.dim_sizes()) {
633bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        tensor_shape_proto.add_dim()->set_size(dim);
634bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
635bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      input_types->push_back(dt);
636bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    } else {
637bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      CHECK(!require_shape_type)
638bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          << "No shape type found for " << input << DumpGraphDef(subgraph_def);
639bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      // Assuming input type is float if no data provided.
640bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      input_types->push_back(DT_FLOAT);
641bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
642bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
643bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
644bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& output : outputs) {
645bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    DataType dt;
646bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TensorShape shape;
647bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const bool has_shapetype =
648bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        GetOutputTensorShapeType(subgraph_def, output, &dt, &shape);
649bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
650bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    execute_info->add_graph_output_node_name(output);
651bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    if (has_shapetype) {
652bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      RemoteFusedGraphExecuteInfo::TensorShapeTypeProto&
653bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          tensor_shape_type_proto =
654bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower              *execute_info->add_default_graph_output_tensor_shape();
655bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      tensor_shape_type_proto.set_dtype(dt);
656bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      TensorShapeProto& tensor_shape_proto =
657bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          *tensor_shape_type_proto.mutable_shape();
658bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      for (const int64 dim : shape.dim_sizes()) {
659bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        tensor_shape_proto.add_dim()->set_size(dim);
660bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
661bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      output_types->push_back(dt);
662bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    } else {
663bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      CHECK(!require_shape_type)
664bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          << "No shape type found for " << output << DumpGraphDef(subgraph_def);
665bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      // Assuming output type is float if no data provided.
666bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      output_types->push_back(DT_FLOAT);
667bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
668bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
669bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
670bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return Status::OK();
671bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
672bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
673bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status
674bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
675bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& node_name, const string& executor_name,
676bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const GraphDef& subgraph_def, const std::vector<string>& inputs,
677bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& outputs, const bool require_shape_type,
678bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    Graph* graph, Node** created_node) {
679bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK_NOTNULL(graph);
680bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK_NOTNULL(created_node);
681bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
682bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  RemoteFusedGraphExecuteInfo execute_info;
683bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  DataTypeVector input_types;
684bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  DataTypeVector output_types;
685bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
686bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_CHECK_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo(
687bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      executor_name, subgraph_def, inputs, outputs, require_shape_type,
688bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      &execute_info, &input_types, &output_types));
689bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
690bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::vector<NodeBuilder::NodeOut> node_out_list;
691bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& input : inputs) {
692bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const TensorId tid = ParseTensorName(input);
693bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    Node* node = FindMutableNodeByName(tid.first.ToString(), graph);
694bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    CHECK_NOTNULL(node);
695bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    node_out_list.emplace_back(node, tid.second);
696bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
697bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
698bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const string execute_info_str = execute_info.SerializeAsString();
699bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
700bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  auto builder =
701bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      NodeBuilder(node_name, "RemoteFusedGraphExecute")
702bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          .Input(node_out_list)
703bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          .Attr("Tinputs", input_types)
704bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          .Attr("Toutputs", output_types)
705bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          .Attr("serialized_remote_fused_graph_execute_info", execute_info_str);
706bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
707bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
708bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return Status::OK();
709bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
710bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
711bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildIdentityOpNode(
712bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& node_name, const string& input_node_name,
713bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const int input_node_port, const DataType dt, Graph* graph,
714bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    Node** created_node) {
715bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  Node* node = FindMutableNodeByName(input_node_name, graph);
716bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK_NOTNULL(node);
717bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  NodeBuilder::NodeOut node_out(node, input_node_port);
718bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
719bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  auto builder =
720bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      NodeBuilder(node_name, "Identity").Input(node_out).Attr("T", dt);
721bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
722bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node));
723bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return Status::OK();
724bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
725bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
726bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::ClusterizeNodes(
727bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::unordered_set<string>& node_names, const GraphDef& graph_def,
728bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    std::vector<ClusterInfo>* cluster_infos) {
729bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  Graph graph(OpRegistry::Global());
730e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving  ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
731bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
732bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::unordered_set<string> remaining_nodes = node_names;
733bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
734bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  while (!remaining_nodes.empty()) {
735bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    ClusterInfo ci;
736bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
737bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    // Determine one cluster nodes
738bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    std::unordered_set<const Node*> visited;
739bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    std::deque<const Node*> queue;
740bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    queue.emplace_back(FindNodeByName(*remaining_nodes.begin(), graph));
741bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    while (!queue.empty()) {
742bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      const Node* node = queue.front();
743bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      CHECK_NOTNULL(node);
744bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      queue.pop_front();
745bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      const string& node_name = node->name();
746bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      if (node_names.count(node_name) > 0) {
747bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        std::get<0>(ci).emplace(node_name);
748bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        remaining_nodes.erase(node_name);
749bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      } else {
750bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        // Edge of subgraph.  Do nothing.
751bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        continue;
752bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
753bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      for (const Node* in : node->in_nodes()) {
754bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        if (visited.insert(in).second) {
755bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          queue.push_back(in);
756bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        }
757bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
758bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      for (const Node* out : node->out_nodes()) {
759bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        if (visited.insert(out).second) {
760bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          queue.push_back(out);
761bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        }
762bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
763bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
764bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
765bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    // Determine one cluster border
766bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    std::vector<string>& border_inputs = std::get<1>(ci);
767bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    std::vector<string>& border_outputs = std::get<2>(ci);
768bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    for (const string& node_name : node_names) {
769bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      Node* node = FindMutableNodeByName(node_name, &graph);
770bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      CHECK_NOTNULL(node);
771bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      int input_count = 0;
772bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      for (const Edge* in_edge : node->in_edges()) {
773bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        const Node* src_node = in_edge->src();
774bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        const bool src_is_outside =
775bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            node_names.count(src_node->name()) <= 0 && !src_node->IsSource();
776bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        if (src_is_outside) {
777bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          const string src_name =
778bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower              strings::StrCat(src_node->name(), ":", in_edge->src_output());
779bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          CHECK_EQ(1, src_node->num_outputs())
780bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower              << "output count of input border node must be one."
781bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower              << src_node->name();
782bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          if (std::find(border_inputs.begin(), border_inputs.end(), src_name) ==
783bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower              border_inputs.end()) {
784bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            border_inputs.emplace_back(src_name);
785bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          }
786bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        } else {
787bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          ++input_count;
788bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        }
789bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
790f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower      CHECK(input_count == 0 || input_count == node->in_edges().size())
791f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower          << "Invalid input_count(" << input_count << ", "
792f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower          << node->in_edges().size() << ") " << node_name;
793bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
794bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      for (const Edge* out_edge : node->out_edges()) {
795bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        const Node* dst_node = out_edge->dst();
796bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        CHECK_NOTNULL(dst_node);
797bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        const bool dst_is_outside = node_names.count(dst_node->name()) <= 0;
798bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        const string dst_name =
799bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            strings::StrCat(node->name(), ":", out_edge->src_output());
800bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        if (dst_is_outside) {
801bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          if (dst_node->IsSink()) {
802bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            CHECK_EQ(1, node->num_outputs())
803bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                << "If you want to specify output node as subgraph output node "
804bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                << "the output count of the node must be 1 "
805bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                << "because that node is replaced by identity node.";
806bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            const string identity_dst_name =
807bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                strings::StrCat(node->name(), ":", 0);
808bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            if (std::find(border_outputs.begin(), border_outputs.end(),
809bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                          identity_dst_name) == border_outputs.end()) {
810bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower              border_outputs.emplace_back(identity_dst_name);
811bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            }
812bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          } else {
813bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            if (std::find(border_outputs.begin(), border_outputs.end(),
814bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                          dst_name) == border_outputs.end()) {
815bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower              border_outputs.emplace_back(dst_name);
816bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            }
817bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          }
818bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        }
819bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
820bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
821bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    cluster_infos->emplace_back(ci);
822bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    VLOG(1) << DumpCluster(ci);
823bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
824bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return Status::OK();
825bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
826bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
827bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef(
828bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const ClusterInfo& cluster, const GraphDef& graph_def,
829bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    GraphDef* subgraph_def) {
830bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const std::unordered_set<string>& node_names = std::get<0>(cluster);
831bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const std::unordered_set<string>& border_input_names =
832bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster));
833bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
834bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  Graph graph(OpRegistry::Global());
835e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving  ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
836bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
837bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
838bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (Node* node : graph.nodes()) {
839bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    if (node != nullptr && node_names.count(node->name()) <= 0 &&
840bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        border_input_names.count(node->name()) <= 0 && !node->IsSource() &&
841bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        !node->IsSink()) {
842bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      graph.RemoveNode(node);
843bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
844bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
845bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  graph.ToGraphDef(subgraph_def);
846bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
847bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& subgraph_input : std::get<1>(cluster)) {
848bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const TensorId tid = ParseTensorName(subgraph_input);
849bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string subgraph_input_name = tid.first.ToString();
850bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const int subgraph_input_port = tid.second;
851bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
852bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    CHECK_NOTNULL(node_def);
853bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    std::vector<DataType> dt_vec;
854bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    std::vector<TensorShape> shape_vec;
855bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    GetOutputTensorShapeType(*node_def, &dt_vec, &shape_vec).IgnoreError();
856bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const DataType& dt =
857bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        dt_vec.empty() ? DT_FLOAT : dt_vec.at(subgraph_input_port);
858bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const TensorShape& shape =
859bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        shape_vec.empty() ? TensorShape({}) : shape_vec.at(subgraph_input_port);
860bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
861bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt,
862bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                                     shape, subgraph_def));
863bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
8649ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower
8659ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower  // sort subgraph_def to align order in graph_def
8669ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower  std::unordered_map<string, int> name_to_id_map;
8679ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower  for (int i = 0; i < graph_def.node_size(); ++i) {
8689ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower    name_to_id_map.emplace(graph_def.node(i).name(), i);
8699ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower  }
8709ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower  std::sort(subgraph_def->mutable_node()->begin(),
8719ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower            subgraph_def->mutable_node()->end(),
8729ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower            [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) {
8739ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower              CHECK(name_to_id_map.count(node0.name()) > 0);
8749ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower              CHECK(name_to_id_map.count(node1.name()) > 0);
8759ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower              const int id0 = name_to_id_map.at(node0.name());
8769ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower              const int id1 = name_to_id_map.at(node1.name());
8779ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower              return id0 < id1;
8789ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower            });
8799ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower
880bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  VLOG(1) << DumpGraphDef(*subgraph_def);
881bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return Status::OK();
882bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
883bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
884bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
885bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& border_inputs,
886bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& border_outputs, const GraphDef& graph_def,
887bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    ClusterInfo* cluster) {
888bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  Graph graph(OpRegistry::Global());
889e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving  ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
890bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner));
891bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
892bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::unordered_set<const Node*> visited;
893bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::deque<const Node*> queue;
894bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& output : border_outputs) {
895bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const TensorId tid = ParseTensorName(output);
896bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& output_node_name = tid.first.ToString();
897bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    for (const Node* node : graph.nodes()) {
898bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      if (output_node_name == node->name()) {
899bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        queue.push_back(node);
900bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        visited.insert(node);
901bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
902bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
903bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
904bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
905bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::unordered_set<const Node*> border_input_nodes;
906bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  // propagate visit to parent nodes until input nodes
907bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  while (!queue.empty()) {
908bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const Node* node = queue.front();
909bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    queue.pop_front();
910bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    for (const Edge* edge : node->in_edges()) {
911bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      const Node* src_node = edge->src();
912bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      CHECK_NOTNULL(src_node);
913bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      const int src_port = edge->src_output();
914bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      bool input_found = false;
915bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      for (const string& input : border_inputs) {
916bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        const TensorId tid = ParseTensorName(input);
917bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        if (tid.first.ToString() == src_node->name() &&
918bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            tid.second == src_port) {
919bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          input_found = true;
920bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          border_input_nodes.insert(src_node);
921bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        }
922bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
923bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      if (visited.insert(src_node).second) {
924bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        if (!input_found) {
925bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          queue.push_back(src_node);
926bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        }
927bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
928bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
929bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
930bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
931bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const Node* node : visited) {
932bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    if (node != nullptr && !node->IsSource() && !node->IsSink() &&
933bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        border_input_nodes.count(node) <= 0) {
934bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      std::get<0>(*cluster).insert(node->name());
935bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
936bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
937bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::get<1>(*cluster) = border_inputs;
938bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::get<2>(*cluster) = border_outputs;
939bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return Status::OK();
940bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
941bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
942bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseCluster(
943bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const GraphDef& input_graph_def, const std::vector<string>& inputs,
944bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& outputs,
945bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& remote_fused_graph_node_name, const ClusterInfo& cluster,
946bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& remote_graph_executor_name, const bool require_shape_type,
947bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    GraphDef* output_graph_def) {
948bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  LOG(INFO) << "Transforming quantized stripped model to a remote fused "
949bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower               "graph execute op by fusing a specified subgraph...";
950bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
951bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK(!remote_graph_executor_name.empty());
952bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
953bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const std::vector<string>& border_inputs = std::get<1>(cluster);
954bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const std::vector<string>& border_outputs = std::get<2>(cluster);
955bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
956bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  GraphDef subgraph_def;
957bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(
958bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def));
959bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
960bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  Graph graph(OpRegistry::Global());
961e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving  ShapeRefiner shape_refiner(graph.versions(), graph.op_registry());
962bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(
963bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      ImportGraphDef({}, input_graph_def, &graph, &shape_refiner));
964bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
965bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  Node* fused_node;
966bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(BuildRemoteFusedGraphExecuteOpNode(
967bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      remote_fused_graph_node_name, remote_graph_executor_name, subgraph_def,
968bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      border_inputs, border_outputs, require_shape_type, &graph, &fused_node));
969bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
970bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const Node* node : graph.nodes()) {
971bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    for (int i = 0; i < node->num_inputs(); ++i) {
972bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      const Edge* edge = nullptr;
973bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      TF_RETURN_IF_ERROR(node->input_edge(i, &edge));
974bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      for (int j = 0; j < border_outputs.size(); ++j) {
975bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        const string& output = border_outputs.at(j);
976bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        const TensorId tid = ParseTensorName(output);
977bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        const string output_name = tid.first.ToString();
978bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        Node* src_node = edge->src();
979bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        if (src_node != nullptr && src_node->name() == output_name &&
980bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            edge->src_output() == tid.second) {
981bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          // Source node is replaced by new fused node.
982bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          Node* dst_node = edge->dst();
983bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          const int dst_input = edge->dst_input();
984bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          LOG(INFO) << "Removing existing edge to " << edge->dst()->name()
985bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                    << " from " << edge->src()->name();
986bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          graph.RemoveEdge(edge);
987bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          graph.AddEdge(fused_node, j, dst_node, dst_input);
988bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        }
989bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
990bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
991bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
992bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
993bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  // Replace output nodes by identity nodes which forward outputs from
994bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  // RemoteFusedGraphExecuteOpNode
995bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (const string& output : outputs) {
996bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const TensorId output_tid = ParseTensorName(output);
997bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string output_name = output_tid.first.ToString();
99890d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai    for (size_t i = 0; i < border_outputs.size(); ++i) {
999bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      const TensorId subgraph_output_tid =
1000bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower          ParseTensorName(border_outputs.at(i));
1001bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      const string& subgraph_output_name = subgraph_output_tid.first.ToString();
1002bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      if (output_name == subgraph_output_name) {
1003bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        LOG(INFO) << "As graph output and subgraph output are same, "
1004bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                  << "the graph output node is replaced by identity node";
10051c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        Node* original_output_node = FindMutableNodeByName(output_name, &graph);
1006bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        CHECK_NOTNULL(original_output_node);
1007bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        CHECK_EQ(1, original_output_node->num_outputs())
1008bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower            << "Num outputs should be 1 for " << output << ".";
1009bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        graph.RemoveNode(original_output_node);
1010bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        Node* new_node;
10111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        TF_RETURN_IF_ERROR(BuildIdentityOpNode(output_name,
1012bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                               remote_fused_graph_node_name, i,
1013bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                               DT_FLOAT, &graph, &new_node));
1014bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        CHECK_NOTNULL(new_node);
1015bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      }
1016bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
1017bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
1018bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
1019bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  GraphDef result_graph_def;
1020bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
1021bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  graph.ToGraphDef(&result_graph_def);
1022bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
1023bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  ClusterInfo graph_cluster;
1024bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(
1025bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      BuildClusterByBorder(inputs, outputs, result_graph_def, &graph_cluster));
1026bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
1027bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  // Remove unvisited nodes
1028bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(BuildClusterSubgraphDef(graph_cluster, result_graph_def,
1029bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                             output_graph_def));
1030bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
1031bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return Status::OK();
1032bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
1033bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
1034bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
1035bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const GraphDef& input_graph_def, const std::vector<string>& inputs,
1036bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& outputs,
1037bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& remote_fused_graph_node_name_prefix,
1038bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::unordered_set<string>& subgraph_nodes,
1039bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& remote_fused_graph_executor_name,
1040bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const bool require_shape_type, GraphDef* output_graph_def) {
1041bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  std::vector<ClusterInfo> ci_vec;
1042bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::ClusterizeNodes(
1043bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      subgraph_nodes, input_graph_def, &ci_vec));
1044bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
104590d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai  for (size_t i = 0; i < ci_vec.size(); ++i) {
1046bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string remote_fused_graph_node_name =
1047bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower        strings::StrCat(remote_fused_graph_node_name_prefix, "/", i);
1048bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    TF_RETURN_IF_ERROR(FuseCluster(input_graph_def, inputs, outputs,
1049bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                   remote_fused_graph_node_name, ci_vec.at(i),
1050bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                   remote_fused_graph_executor_name,
1051bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower                                   require_shape_type, output_graph_def));
1052bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
1053bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return Status::OK();
1054bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
1055bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
1056bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder(
1057bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const GraphDef& input_graph_def, const std::vector<string>& inputs,
1058bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& outputs,
1059bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& remote_fused_graph_node_name,
1060bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& border_inputs,
1061bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const std::vector<string>& border_outputs,
1062bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& remote_graph_executor_name, const bool require_shape_type,
1063bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    GraphDef* output_graph_def) {
1064bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  ClusterInfo cluster;
1065bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildClusterByBorder(
1066bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      border_inputs, border_outputs, input_graph_def, &cluster));
1067bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
1068bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return FuseCluster(
1069bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      input_graph_def, inputs, outputs, remote_fused_graph_node_name, cluster,
1070bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      remote_graph_executor_name, require_shape_type, output_graph_def);
1071bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
1072bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
1073f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes(
1074f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    const GraphDef& input_graph_def, const std::vector<string>& inputs,
1075f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    const std::vector<string>& outputs,
1076f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    const string& remote_fused_graph_node_name_prefix,
1077f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    const std::unordered_set<string>& fused_op_types,
1078f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    const string& remote_fused_graph_executor_name,
1079f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    const bool require_shape_type, GraphDef* output_graph_def) {
1080f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower  const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1081f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower      BuildNodeMapFromOpTypes(input_graph_def, fused_op_types);
1082f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower
1083f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower  return FuseRemoteGraphByNodeNames(
1084f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower      input_graph_def, inputs, outputs, remote_fused_graph_node_name_prefix,
1085f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower      fused_nodes_filtered_by_op_types, remote_fused_graph_executor_name,
1086f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower      require_shape_type, output_graph_def);
1087f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower}
1088f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower
1089d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor(
1090d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    const GraphDef& input_graph_def, const std::vector<string>& inputs,
1091d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    const std::vector<string>& outputs, const string& executor_name,
1092d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    GraphDef* output_graph_def) {
1093d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name);
1094d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  if (build_func == nullptr) {
1095d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    return errors::InvalidArgument("Unknown executor name: " + executor_name);
1096d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  }
1097d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  std::unique_ptr<IRemoteFusedGraphExecutor> executor;
1098d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  TF_RETURN_IF_ERROR((*build_func)(&executor));
1099d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  CHECK_NOTNULL(executor.get());
1100d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  if (!executor->IsEnabled()) {
1101d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    // As this executor is not enabled, just return original graph as is.
1102d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    *output_graph_def = input_graph_def;
1103d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    return Status::OK();
1104d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  }
1105d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  return executor->FuseRemoteGraph(input_graph_def, inputs, outputs,
1106d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower                                   output_graph_def);
1107d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower}
1108d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower
11091c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments(
11101c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const std::vector<string>& inputs, const std::vector<string>& outputs,
11111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const std::unordered_set<string>& fused_node_names,
11121c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const std::vector<string>& border_inputs,
11131c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const std::vector<string>& border_outputs,
1114f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    const std::unordered_set<string>& fused_op_types,
11151c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const string& remote_fused_graph_node_name,
11161c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const string& remote_graph_executor_name, GraphDef* graph_def) {
11171c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  CHECK_NOTNULL(graph_def);
1118f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower
1119f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower  const std::unordered_set<string> fused_nodes_filtered_by_op_types =
1120f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower      BuildNodeMapFromOpTypes(*graph_def, fused_op_types);
1121f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower
11221c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  for (NodeDef& node_def : *graph_def->mutable_node()) {
11231c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    string attr_str;
11241c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    TensorId tid;
112590d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai    for (size_t i = 0; i < inputs.size(); ++i) {
11261c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      if (IsSameNodeName(node_def, inputs.at(i), &tid)) {
11271c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        AppendDeliminator(&attr_str);
11281c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_INPUT,
11291c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                                      tid.second, i, remote_graph_executor_name,
11301c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                                      remote_fused_graph_node_name);
11311c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      }
11321c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
113390d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai    for (size_t i = 0; i < outputs.size(); ++i) {
11341c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      if (IsSameNodeName(node_def, outputs.at(i), &tid)) {
11351c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        AppendDeliminator(&attr_str);
11361c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT,
11371c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                                      tid.second, i);
11381c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      }
11391c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
11401c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    for (const string& fused_node_name : fused_node_names) {
11411c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      if (fused_node_name == node_def.name()) {
11421c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        AppendDeliminator(&attr_str);
11431c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE);
11441c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      }
11451c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
1146f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    for (const string& fused_node_name : fused_nodes_filtered_by_op_types) {
1147f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower      if (fused_node_name == node_def.name()) {
1148f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower        AppendDeliminator(&attr_str);
1149f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower        attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE);
1150f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower      }
1151f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    }
115290d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai    for (size_t i = 0; i < border_inputs.size(); ++i) {
11531c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) {
11541c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        AppendDeliminator(&attr_str);
11551c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::BORDER_INPUT,
11561c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                                      tid.second, i);
11571c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      }
11581c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
115990d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai    for (size_t i = 0; i < border_outputs.size(); ++i) {
11601c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) {
11611c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        AppendDeliminator(&attr_str);
11621c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        attr_str += BuildNodeTypeAttr(
11631c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower            RemoteFusedGraphExecuteInfo::BORDER_OUTPUT, tid.second, i);
11641c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      }
11651c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
11661c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    if (attr_str.empty()) {
11671c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::UNUSED);
11681c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
11691c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def);
11701c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  }
11711c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  return Status::OK();
11721c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
11731c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
11741c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ Status
11751c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
11761c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const GraphDef& input_graph_def,
11771c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const std::vector<std::pair<string, Tensor>>& input_tensors,
11781c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    GraphDef* output_graph_def) {
11791c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  std::unordered_map<int, string> input_map;
11801c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  std::unordered_map<int, string> output_map;
11811c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  std::unordered_set<string> fused_node_names;
11821c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  std::unordered_map<int, string> border_input_map;
11831c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  std::unordered_map<int, string> border_output_map;
11841c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  string remote_graph_executor_name;
11851c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  string remote_fused_graph_node_name;
11861c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
11871c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  for (const NodeDef& node_def : input_graph_def.node()) {
11881c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    string attr_str;
11891c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    TF_RETURN_IF_ERROR(GetNodeAttr(node_def, ATTR_NODE_TYPE, &attr_str));
11901c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    std::vector<std::vector<string>> attr_strs;
11911c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    for (const string& str : str_util::Split(attr_str, ":")) {
11921c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      attr_strs.emplace_back(str_util::Split(str, ","));
11931c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
11941c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    if (attr_strs.empty()) {
11951c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      return errors::InvalidArgument("Remote graph node type not found.");
11961c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
11971c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    for (const std::vector<string>& attr : attr_strs) {
11981c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      if (attr.empty()) {
11991c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        return errors::InvalidArgument("Empty remote graph node type attr.");
12001c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      }
12011c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      int node_type_int;
12021c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0);
12031c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      const RemoteFusedGraphExecuteInfo::NodeType node_type =
12041c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          static_cast<RemoteFusedGraphExecuteInfo::NodeType>(node_type_int);
12051c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      const string& name = node_def.name();
12061c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      int port;
12071c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      int index;
12081c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
12091c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      switch (node_type) {
12101c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        case RemoteFusedGraphExecuteInfo::GRAPH_INPUT:
12111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          VLOG(2) << "Graph input: " << name;
12121c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK_EQ(5, attr.size());
12131c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(strings::safe_strto32(attr.at(1), &port));
12141c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(strings::safe_strto32(attr.at(2), &index));
12151c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(!attr.at(3).empty());
12161c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          remote_graph_executor_name = attr.at(3);
12171c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(!attr.at(4).empty());
12181c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          remote_fused_graph_node_name = attr.at(4);
12191c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          input_map.emplace(index, strings::StrCat(name, ":", port));
12201c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          if (GetExecutorBuildFunc(remote_graph_executor_name) == nullptr) {
12211c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower            LOG(INFO) << "Executor for " << remote_graph_executor_name
12221c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                      << " not registered.  Do not fuse.";
12231c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower            *output_graph_def = input_graph_def;
12241c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower            return Status::OK();
12251c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          }
12261c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          break;
12271c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        case RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT:
12281c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          VLOG(2) << "Graph output: " << name;
12291c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK_EQ(3, attr.size());
12301c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(strings::safe_strto32(attr.at(1), &port));
12311c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(strings::safe_strto32(attr.at(2), &index));
12321c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          output_map.emplace(index, strings::StrCat(name, ":", port));
12331c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          break;
12341c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        case RemoteFusedGraphExecuteInfo::FUSED_NODE:
12351c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          VLOG(2) << "Fused node: " << name;
12361c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK_EQ(1, attr.size());
12371c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          fused_node_names.emplace(name);
12381c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          break;
12391c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        case RemoteFusedGraphExecuteInfo::BORDER_INPUT:
12401c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          VLOG(2) << "Border input: " << name;
12411c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK_EQ(3, attr.size());
12421c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(strings::safe_strto32(attr.at(1), &port));
12431c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(strings::safe_strto32(attr.at(2), &index));
12441c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          border_input_map.emplace(index, strings::StrCat(name, ":", port));
12451c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          break;
12461c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        case RemoteFusedGraphExecuteInfo::BORDER_OUTPUT:
12471c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          VLOG(2) << "Border output: " << name;
12481c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK_EQ(3, attr.size());
12491c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(strings::safe_strto32(attr.at(1), &port));
12501c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          CHECK(strings::safe_strto32(attr.at(2), &index));
12511c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          border_output_map.emplace(index, strings::StrCat(name, ":", port));
12521c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          break;
12531c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        case RemoteFusedGraphExecuteInfo::UNUSED:
12541c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          // do nothing
12551c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          break;
12561c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        default:
12571c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          // unsupported value
125888cdf1f81fa1938c5bb81c5d293fc0ed0758cadcA. Unique TensorFlower          LOG(FATAL);
12591c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      }
12601c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
12611c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  }
12621c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  bool require_shape_type = false;
12631c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  std::vector<string> inputs;
12641c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  std::vector<string> outputs;
12651c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  std::vector<string> border_inputs;
12661c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  std::vector<string> border_outputs;
12671c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  ConvertMapToVector(input_map, &inputs);
12681c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  ConvertMapToVector(output_map, &outputs);
12691c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  ConvertMapToVector(border_input_map, &border_inputs);
12701c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  ConvertMapToVector(border_output_map, &border_outputs);
12711c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
12721c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  if (!input_tensors.empty()) {
12731c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    bool input_match = false;
12741c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    if (inputs.size() == input_tensors.size()) {
12751c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
12761c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        if (!ContainsSameTensorId(input_tensor.first, inputs)) {
12771c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          break;
12781c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        }
12791c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        DataType data_type;
12801c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        TensorShape shape;
12811c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        if (GetOutputTensorShapeType(input_graph_def, input_tensor.first,
12821c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                                     &data_type, &shape)) {
12831c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          if (data_type == input_tensor.second.dtype() &&
12841c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower              shape == input_tensor.second.shape()) {
12851c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower            VLOG(2) << "Input matched!";
12861c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower            // Shape type matched.
12871c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower            input_match = true;
12881c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower            require_shape_type = true;
12891c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          }
12901c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        } else {
12911c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          // Shape type not required.
12921c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower          input_match = true;
12931c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        }
12941c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      }
12951c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
12961c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    if (!input_match) {
12971c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      // Input mismatch.  Just copy original graph
12981c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      *output_graph_def = input_graph_def;
12991c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      return Status::OK();
13001c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
13011c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  }
13021c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
13031c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  if (!fused_node_names.empty()) {
13041c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    TF_RETURN_IF_ERROR(FuseRemoteGraphByNodeNames(
13051c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        input_graph_def, inputs, outputs, remote_fused_graph_node_name,
13061c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        fused_node_names, remote_graph_executor_name, require_shape_type,
13071c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        output_graph_def));
13081c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  } else if (!border_inputs.empty() || !border_outputs.empty()) {
13091c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    TF_RETURN_IF_ERROR(FuseRemoteGraphByBorder(
13101c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        input_graph_def, inputs, outputs, remote_fused_graph_node_name,
13111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        border_inputs, border_outputs, remote_graph_executor_name,
13121c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower        require_shape_type, output_graph_def));
13131c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  } else {
13141c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    *output_graph_def = input_graph_def;
13151c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  }
13161c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
13171c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  return Status::OK();
13181c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
13191c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
13201c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ bool RemoteFusedGraphExecuteUtils::IsFuseReady(
13211c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const GraphDef& graph_def,
13221c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const std::vector<std::pair<string, Tensor>>& input_tensors) {
13231c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  for (const std::pair<string, Tensor>& input_tensor : input_tensors) {
13241c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const NodeDef* node_def = FindNodeDefByName(input_tensor.first, graph_def);
13251c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    if (node_def == nullptr) {
13261c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      return false;
13271c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
13281c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    string attr;
13291c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const Status status = GetNodeAttr(*node_def, ATTR_NODE_TYPE, &attr);
13301c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    if (!status.ok() || attr.empty()) {
13311c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower      return false;
13321c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    }
13331c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  }
13341c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  return true;
13351c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
13361c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
1337351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
1338351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    const void* src_ptr, const int src_size, Tensor* tensor) {
1339351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower  CHECK(tensor->TotalBytes() >= src_size)
1340351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      << tensor->TotalBytes() << ", " << src_size;
1341351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower  void* dst_ptr;
1342351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower  switch (tensor->dtype()) {
1343351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_FLOAT:
1344351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<float>().data();
1345351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1346351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_DOUBLE:
1347351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<double>().data();
1348351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1349351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_INT32:
1350351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<int32>().data();
1351351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1352351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_UINT8:
1353351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<uint8>().data();
1354351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1355351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_INT16:
1356351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<int16>().data();
1357351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1358351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_INT8:
1359351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<int8>().data();
1360351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1361351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_STRING:
1362351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<string>().data();
1363351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1364351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_INT64:
1365351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<int64>().data();
1366351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1367351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_BOOL:
1368351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<bool>().data();
1369351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1370351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_QINT8:
1371351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<qint8>().data();
1372351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1373351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_QUINT8:
1374351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<quint8>().data();
1375351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1376351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_QINT32:
1377351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<qint32>().data();
1378351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1379351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_BFLOAT16:
1380351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<bfloat16>().data();
1381351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1382351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_QINT16:
1383351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<qint16>().data();
1384351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1385351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_QUINT16:
1386351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<quint16>().data();
1387351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1388351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    case DT_UINT16:
1389351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      dst_ptr = tensor->flat<uint16>().data();
1390351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1391351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower    default:
139288cdf1f81fa1938c5bb81c5d293fc0ed0758cadcA. Unique TensorFlower      LOG(FATAL) << "type " << tensor->dtype() << " is not supported.";
1393351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower      break;
1394351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower  }
1395351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower  CHECK_NOTNULL(dst_ptr);
1396351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower  std::memcpy(dst_ptr, src_ptr, src_size);
1397351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower  return Status::OK();
1398351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower}
1399351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower
1400f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower/* static */ std::unordered_set<string>
1401f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes(
1402f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    const GraphDef& graph_def, const std::unordered_set<string>& op_types) {
1403f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower  std::unordered_set<string> retval;
1404f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower  for (const NodeDef& node_def : graph_def.node()) {
1405f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    if (op_types.count(node_def.op()) > 0) {
1406f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower      retval.emplace(node_def.name());
1407f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower    }
1408f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower  }
1409f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower  return retval;
1410f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower}
1411f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower
1412d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower/* static */ std::unordered_set<string>
1413d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
1414d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    const GraphDef& graph_def,
1415d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    const IRemoteFusedGraphOpsDefinitions& ops_definitions) {
1416d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  std::unordered_set<string> retval;
1417d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  for (const NodeDef& node_def : graph_def.node()) {
1418d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    std::vector<DataType> dt_vec;
1419d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    std::vector<TensorShape> shape_vec;
1420d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    const Status status =
1421d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower        GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec);
1422d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    if (!status.ok()) {
1423d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower      shape_vec.clear();
1424d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    }
1425d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    if (ops_definitions.GetOpIdFor(
1426d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower            node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) !=
1427d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower        IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) {
1428d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower      retval.emplace(node_def.name());
1429d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower    }
1430d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  }
1431d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower  return retval;
1432d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower}
1433d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower
1434bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder(
1435bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    const string& input, const DataType type, const TensorShape& shape,
1436bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    GraphDef* graph_def) {
1437bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const TensorId tid = ParseTensorName(input);
1438bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  CHECK_EQ(0, tid.second);
1439bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  const string node_name = tid.first.ToString();
1440bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  for (NodeDef& node : *graph_def->mutable_node()) {
1441bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    if (node.name() != node_name) {
1442bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      continue;
1443bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
1444bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    if (node.op() == "Placeholder") {
1445bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      return Status::OK();
1446bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    } else {
1447bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      NodeDef placeholder_node;
1448bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      placeholder_node.set_op("Placeholder");
1449bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      placeholder_node.set_name(node_name);
1450bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      AddNodeAttr("dtype", type, &placeholder_node);
1451bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      AddNodeAttr("shape", shape, &placeholder_node);
1452bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      // TODO(satok): Remove once we merge attributes
1453bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      AddOutputTensorShapeType({type}, {shape}, &placeholder_node);
1454bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      node.Clear();
1455bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      node = placeholder_node;
1456bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      return Status::OK();
1457bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower    }
1458bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  }
1459bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower  return errors::InvalidArgument(
1460bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower      strings::StrCat(node_name, " not found for replacement."));
1461bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower}
1462bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower
14631c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
14641c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
14651c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const int index, const string& executor_name, const string& node_name) {
14661c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index,
14671c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower                         ",", executor_name, ",", node_name);
14681c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
14691c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
14701c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
14711c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
14721c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const int index) {
14731c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index);
14741c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
14751c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
14761c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
14771c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower    const RemoteFusedGraphExecuteInfo::NodeType node_type) {
14781c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower  return strings::StrCat(static_cast<int>(node_type));
14791c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower}
14801c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower
1481227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower}  // namespace tensorflow
1482