1227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 3227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerLicensed under the Apache License, Version 2.0 (the "License"); 4227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFloweryou may not use this file except in compliance with the License. 5227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerYou may obtain a copy of the License at 6227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 7227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower http://www.apache.org/licenses/LICENSE-2.0 8227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 9227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerUnless required by applicable law or agreed to in writing, software 10227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerdistributed under the License is distributed on an "AS IS" BASIS, 11227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerSee the License for the specific language governing permissions and 13227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerlimitations under the License. 14227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower==============================================================================*/ 15227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 16227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h" 17227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 187b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower#include <algorithm> 19f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower#include <queue> 20227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower#include <utility> 21227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 227b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower#include "tensorflow/core/common_runtime/shape_refiner.h" 23f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower#include "tensorflow/core/framework/node_def_util.h" 24e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/framework/tensor.pb.h" 25e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving#include "tensorflow/core/framework/tensor_shape.pb.h" 267b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower#include "tensorflow/core/graph/algorithm.h" 27bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower#include "tensorflow/core/graph/node_builder.h" 283b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower#include "tensorflow/core/public/session.h" 293b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower#include "tensorflow/core/public/session_options.h" 303b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 31227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowernamespace tensorflow { 32bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowernamespace { 33bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerconst Node* FindNodeByName(const string& name, const Graph& graph) { 34bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const Node* node : graph.nodes()) { 35bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(node); 36bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (node->name() == name) { 37bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return node; 38bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 39bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 40bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return nullptr; 41bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 42bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 43bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerstd::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts( 44bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& node_names_and_ports) { 45bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::unordered_set<string> retval; 46bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& node_name_and_port : node_names_and_ports) { 47bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId tid = ParseTensorName(node_name_and_port); 48bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower retval.emplace(tid.first.ToString()); 49bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 50bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return retval; 51bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 52bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 53bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerNode* FindMutableNodeByName(const string& name, Graph* graph) { 54bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (Node* node : graph->nodes()) { 55bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (node != nullptr && node->name() == name) { 56bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return node; 57bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 58bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 59bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return nullptr; 60bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 61bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 62bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerconst NodeDef* FindNodeDefByName(const string& input, 63bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const GraphDef& graph_def) { 64bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId tid = ParseTensorName(input); 65bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string name = tid.first.ToString(); 66bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const NodeDef& node_def : graph_def.node()) { 67bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (node_def.name() == name) { 68bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return &node_def; 69bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 70bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 71bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return nullptr; 72bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 73bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 741c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowerbool IsSameNodeName(const NodeDef& node_def, const string& node_name_and_port, 751c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower TensorId* tid) { 761c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK_NOTNULL(tid); 771c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower *tid = ParseTensorName(node_name_and_port); 781c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (node_def.name() == tid->first.ToString()) { 791c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return true; 801c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 811c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return false; 821c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 831c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 841c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowerbool ContainsSameTensorId(const string& tensor_name, 851c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const std::vector<string>& tensor_names) { 861c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const TensorId tid0 = ParseTensorName(tensor_name); 871c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower for (const string& name : tensor_names) { 881c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const TensorId tid1 = ParseTensorName(name); 891c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (tid0.first == tid1.first && tid0.second == tid1.second) { 901c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return true; 911c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 921c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 931c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return false; 941c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 951c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 961c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowervoid AppendDeliminator(string* str) { 971c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK_NOTNULL(str); 981c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (!str->empty()) { 991c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower *str += ":"; 1001c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 1011c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 1021c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 1031c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowervoid ConvertMapToVector(const std::unordered_map<int, string>& in, 1041c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::vector<string>* out) { 1051c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK_NOTNULL(out); 1061c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower out->resize(in.size()); 10790d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai for (size_t i = 0; i < in.size(); ++i) { 1081c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(in.count(i) > 0); 1091c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower out->at(i) = in.at(i); 1101c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 1111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 1121c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 113bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerstring DumpGraphDef(const GraphDef& graph_def) { 114bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower string out; 115bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const NodeDef& node : graph_def.node()) { 116bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower out += strings::StrCat("node: ", node.name(), "\n input: "); 117bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& input : node.input()) { 118bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower out += strings::StrCat(input, ", "); 119bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 120bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower out += "\n"; 121bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 122bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return out; 123bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 124bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 125bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerstring DumpCluster(const RemoteFusedGraphExecuteUtils::ClusterInfo& cluster) { 126bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower string out; 127bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower out += "Nodes:\n"; 128bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& str : std::get<0>(cluster)) { 129bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower out += str + ", "; 130bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 131bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower out += "\nInput border:\n"; 132bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& str : std::get<1>(cluster)) { 133bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower out += str + ", "; 134bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 135bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower out += "\nOutput border:\n"; 136bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& str : std::get<2>(cluster)) { 137bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower out += str + ", "; 138bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 139bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return out; 140bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 141bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 142bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} // namespace 143227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 144f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower/* static */ constexpr const char* const 145f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES; 146f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower/* static */ constexpr const char* const 147f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES; 1489ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils:: 1499ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower ATTR_SERIALIZED_REMOTE_FUSED_GRAPH_EXECUTE_INFO; 1501c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ constexpr const char* const 1511c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower RemoteFusedGraphExecuteUtils::ATTR_NODE_TYPE; 1529ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const RemoteFusedGraphExecuteUtils:: 1539ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower TRANSFORM_ARG_REMOTE_FUSED_GRAPH_EXECUTOR_NAME; 1549ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const 1559ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_REMOTE_FUSED_GRAPH_NODE_NAME; 1569ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const 1579ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_NODES; 1589ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const 1599ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_INPUTS; 1609ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const 1619ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_BORDER_OUTPUTS; 1629ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const 163f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSED_OP_TYPES; 164f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower/* static */ constexpr const char* const 165d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_FUSE_BY_EXECUTOR; 166d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower/* static */ constexpr const char* const 1679ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_TYPES; 1689ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower/* static */ constexpr const char* const 1699ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower RemoteFusedGraphExecuteUtils::TRANSFORM_ARG_INPUT_SHAPES; 170f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower 171227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar::ExecutorBuildRegistrar( 172227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower const string& name, ExecutorBuildFunc executor_build_func) { 173227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry(); 174227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower executor_build_registry[name] = std::move(executor_build_func); 175227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower} 176227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 177227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower/* static */ const RemoteFusedGraphExecuteUtils::ExecutorBuildFunc* 178227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(const string& name) { 179227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower ExecutorBuildRegistry& executor_build_registry = *GetExecutorBuildRegistry(); 180227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower if (executor_build_registry.count(name) <= 0) { 181227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower return nullptr; 182227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower } 183227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower return &executor_build_registry.at(name); 184227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower} 185227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 186227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower/* static */ RemoteFusedGraphExecuteUtils::ExecutorBuildRegistry* 187227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::GetExecutorBuildRegistry() { 188227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower static ExecutorBuildRegistry executor_builder_registry; 189227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower return &executor_builder_registry; 190227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower} 191227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower 1923b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/** 1933b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * - DryRunInference 1943b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * To determine shapes of output tensors of all nodes, dryrun the graph. 1953b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * This function supplies memory allocation information when loading 1963b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * the graph. This function is used to verify shape inference and actual 1973b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower * output shape. 1983b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower */ 1993b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::DryRunInference( 2003b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const GraphDef& graph_def, 2013b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const std::vector<std::pair<string, Tensor>>& input_node_info_list, 2023b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const std::vector<string>& output_node_names, const bool initialize_by_zero, 2033b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower std::vector<tensorflow::Tensor>* output_tensors) { 2043b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower // Create input tensor vector. If "initialize_by_zero" is true, 2053b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower // input tensor fields are initialized by 0. 2063b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower std::vector<std::pair<string, tensorflow::Tensor>> input_tensors; 2073b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower for (const std::pair<string, Tensor>& input : input_node_info_list) { 2083b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower CHECK(input.second.IsInitialized()); 2093b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower if (!initialize_by_zero) { 2103b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower input_tensors.push_back({input.first, input.second}); 2113b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower continue; 2123b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2133b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower // If input tensor is not initialized, initialize by 0-filling 2143b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const DataType data_type = input.second.dtype(); 2153b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const TensorShape& shape = input.second.shape(); 2163b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower Tensor input_tensor(data_type, shape); 2173b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower switch (data_type) { 2183b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower case DT_INT32: { 2193b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower auto int_tensor = input_tensor.flat<int32>(); 2203b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower int_tensor = int_tensor.constant(0); 2213b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower break; 2223b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2233b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower case DT_FLOAT: { 2243b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower auto float_tensor = input_tensor.flat<float>(); 2253b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower float_tensor = float_tensor.constant(0.0f); 2263b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower break; 2273b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2283b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower case DT_QUINT8: { 2293b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower auto int_tensor = input_tensor.flat<quint8>(); 2303b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower int_tensor = int_tensor.constant(0); 2313b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower break; 2323b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2333b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower default: 2343b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower LOG(FATAL) << "Unsupported input type: " << data_type; 2353b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2363b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower input_tensors.push_back({input.first, input_tensor}); 2373b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2383b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 2393b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower // Setup session 2403b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower CHECK(output_tensors != nullptr); 2413b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower SessionOptions session_options; 2423b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower session_options.env = Env::Default(); 2433b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower std::unique_ptr<Session> session = 2443b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower std::unique_ptr<Session>(NewSession(session_options)); 2453b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower Status status = session->Create(graph_def); 2463b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower if (!status.ok()) { 2473b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower return status; 2483b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2493b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 2503b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower // Setup session arguments 2513b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower RunOptions run_options; 2523b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower run_options.set_trace_level(RunOptions::FULL_TRACE); 2533b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower RunMetadata run_metadata; 2543b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 2553b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower // Run inference with all node as output 2563b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower status = session->Run(run_options, input_tensors, output_node_names, {}, 2573b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower output_tensors, &run_metadata); 2583b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower if (!status.ok()) { 2593b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower LOG(ERROR) << "Error during inference: " << status; 2603b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower return status; 2613b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2623b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower return Status(); 2633b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower} 2643b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 2653b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode( 2663b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const GraphDef& graph_def, 2673b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const std::vector<std::pair<string, Tensor>>& input_node_info_list, 2683b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const bool initialize_by_zero, 2693b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower RemoteFusedGraphExecuteUtils::TensorShapeMap* tensor_shape_map) { 2703b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower CHECK(tensor_shape_map != nullptr); 2713b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower std::vector<Tensor> output_tensors; 2723b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower output_tensors.reserve(graph_def.node_size()); 2733b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower std::vector<string> output_node_names; 2740a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower 2750a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower Graph graph(OpRegistry::Global()); 2760a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower Status status = ImportGraphDef({}, graph_def, &graph, nullptr); 2770a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower if (!status.ok()) { 2780a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower return status; 2790a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower } 2800a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower 2810a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower for (const Node* node : graph.nodes()) { 2820a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower if (IsInputNode(input_node_info_list, node->name())) { 2830a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower continue; 2840a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower } 2850a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower for (int i = 0; i < node->num_outputs(); ++i) { 2860a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower output_node_names.emplace_back(strings::StrCat(node->name(), ":", i)); 2873b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2883b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2890a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower 2900a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower status = DryRunInference(graph_def, input_node_info_list, output_node_names, 2910a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower initialize_by_zero, &output_tensors); 2923b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower if (!status.ok()) { 2933b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower VLOG(1) << "Failed to dryrun " << status; 2943b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower return status; 2953b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 2960a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower 2973b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower CHECK_EQ(output_node_names.size(), output_tensors.size()) 2983b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower << output_node_names.size() << ", " << output_tensors.size(); 2993b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 3003b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower // Append output tensor of input node in advance to create a map 3013b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower // to avoid memory reallocation inside vector 3023b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower for (const std::pair<string, Tensor>& input_node_info : 3033b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower input_node_info_list) { 3043b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower output_tensors.push_back(input_node_info.second); 3053b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 3063b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 307e18c7a5332b97ddcaad7f5485c46ad8dc8fcd274Suharsh Sivakumar for (int i = 0; static_cast<size_t>(i) < output_node_names.size(); ++i) { 3083b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const string& name = output_node_names.at(i); 3093b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const Tensor& tensor = output_tensors.at(i); 310f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower EmplaceTensorShapeType(name, tensor, tensor_shape_map); 3113b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 312e18c7a5332b97ddcaad7f5485c46ad8dc8fcd274Suharsh Sivakumar for (int i = 0; static_cast<size_t>(i) < input_node_info_list.size(); ++i) { 3133b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const string& name = input_node_info_list.at(i).first; 3143b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const Tensor& tensor = output_tensors.at(output_node_names.size() + i); 315f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower EmplaceTensorShapeType(name, tensor, tensor_shape_map); 3163b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 3170a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower CHECK_EQ(output_node_names.size() + input_node_info_list.size(), 3180a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower output_tensors.size()); 3193b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower return status; 3203b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower} 3213b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 3223b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/* static */ bool RemoteFusedGraphExecuteUtils::IsInputNode( 3233b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const std::vector<std::pair<string, Tensor>>& input_tensor_vector, 3243b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const string& node_name) { 3253b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower for (const std::pair<string, Tensor>& pair : input_tensor_vector) { 326a11e669c5a4c855b0b507da97378bc7e03a08f86A. Unique TensorFlower const TensorId tid = ParseTensorName(pair.first); 327a11e669c5a4c855b0b507da97378bc7e03a08f86A. Unique TensorFlower if (node_name == tid.first.ToString()) { 3283b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower return true; 3293b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 3303b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 3313b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower return false; 3323b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower} 3333b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 3343b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower/* static */ void RemoteFusedGraphExecuteUtils::ConvertToTensorShapeMap( 3353b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const std::vector<std::pair<string, Tensor>>& input_node_info_list, 3363b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const std::vector<string>& output_node_names, 3373b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const std::vector<tensorflow::Tensor>& output_tensors, 3383b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower TensorShapeMap* tensor_shape_map) { 3393b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower CHECK_NE(tensor_shape_map, nullptr); 3403b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower tensor_shape_map->clear(); 3413b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower tensor_shape_map->reserve(input_node_info_list.size() + 3423b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower output_node_names.size()); 3433b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const int output_node_count = output_node_names.size(); 3443b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower CHECK_EQ(output_node_count, output_tensors.size()); 3453b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower for (int i = 0; i < output_node_count; ++i) { 3463b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const string& node_name = output_node_names.at(i); 3473b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower const Tensor& tensor = output_tensors.at(i); 348f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower EmplaceTensorShapeType(node_name, tensor, tensor_shape_map); 3493b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower } 3503b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower} 3513b596922684f6ad444da14075fa94b60618fbb35A. Unique TensorFlower 352f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::MakeTensorFromProto( 353f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower const TensorProto& tensor_proto, Tensor* tensor) { 354f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower if (tensor_proto.dtype() > 0 && tensor_proto.dtype() <= DataType_MAX) { 355f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower Tensor parsed(tensor_proto.dtype()); 356f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower if (parsed.FromProto(cpu_allocator(), tensor_proto)) { 357f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower *tensor = parsed; 358f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower return Status::OK(); 359f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower } 360f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower } 361f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower return errors::InvalidArgument("Cannot parse tensor from proto"); 362f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower} 363f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower 364f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower/* static */ bool RemoteFusedGraphExecuteUtils::AddOutputTensorShapeType( 365f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower const std::vector<DataType>& data_types, 366f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower const std::vector<TensorShape>& shapes, NodeDef* node_def) { 367f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower AddNodeAttr(ATTR_OUTPUT_DATA_TYPES, data_types, node_def); 368f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower AddNodeAttr(ATTR_OUTPUT_SHAPES, shapes, node_def); 369f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower return true; 370f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower} 371f51d8c75687060313f2a06a7d6713d6e4326f37fA. Unique TensorFlower 372f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower/* static */ Status 373f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( 374f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower const TensorShapeMap& tensor_shape_map, NodeDef* node_def) { 375f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower CHECK_NE(node_def, nullptr); 376f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower std::priority_queue<std::tuple<int, const TensorShapeType*>> queue; 377f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower auto its = tensor_shape_map.equal_range(node_def->name()); 378f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower for (auto it = its.first; it != its.second; ++it) { 379f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower queue.emplace(std::make_tuple(it->second.first, &it->second.second)); 380f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower } 381f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower int last_port = queue.size(); 382f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower std::vector<DataType> data_types; 383f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower std::vector<TensorShape> shapes; 384f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower while (!queue.empty()) { 385f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower const int port = std::get<0>(queue.top()); 386f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower const TensorShapeType* tst = std::get<1>(queue.top()); 387f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower CHECK_NE(tst, nullptr); 388f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower data_types.emplace(data_types.begin(), tst->first); 389f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower shapes.emplace(shapes.begin(), tst->second); 390f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower CHECK_EQ(last_port - 1, port); 391f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower last_port = port; 392f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower queue.pop(); 393f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower } 394f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower AddOutputTensorShapeType(data_types, shapes, node_def); 395f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower return Status::OK(); 396f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower} 397f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower 3980a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( 39973882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving AttrSlice attrs, std::vector<DataType>* data_types, 4000a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower std::vector<TensorShape>* shapes) { 4010a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower Status status; 4020a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower if (data_types != nullptr) { 40373882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving status = GetNodeAttr(attrs, ATTR_OUTPUT_DATA_TYPES, data_types); 4040a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower } 4050a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower if (!status.ok()) { 4060a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower return status; 4070a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower } 4080a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower if (shapes != nullptr) { 40973882f257ffb1bc9e1a828571c085d080b1d9266Geoffrey Irving status = GetNodeAttr(attrs, ATTR_OUTPUT_SHAPES, shapes); 4100a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower if (status.ok() && data_types != nullptr) { 4110a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower CHECK_EQ(data_types->size(), shapes->size()); 4120a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower } 4130a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower } 4140a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower 4150a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower return status; 4160a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower} 4170a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower 418bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ bool RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( 419bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const GraphDef& graph_def, const string& name_and_port, DataType* data_type, 420bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TensorShape* shape) { 421bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::vector<DataType> data_types; 422bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::vector<TensorShape> shapes; 423bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId tid = ParseTensorName(name_and_port); 424bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string node_name = tid.first.ToString(); 425bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const int port = tid.second; 426bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const NodeDef* node_def = FindNodeDefByName(node_name, graph_def); 427bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(node_def); 428bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GetOutputTensorShapeType(*node_def, &data_types, &shapes).IgnoreError(); 429bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (data_types.empty()) { 430bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return false; 431bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 432bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK(data_types.size() > port); 433bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower *data_type = data_types.at(port); 434bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower *shape = shapes.at(port); 435bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return true; 436bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 437bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 4387b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::PropagateShapeInference( 4397b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower const GraphDef& graph_def, 4407b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower const std::vector<std::pair<string, Tensor>>& input_node_info_list, 4417b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower Graph* graph, ShapeRefiner* shape_refiner) { 4427b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower Status status; 4437b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower auto visit = [&shape_refiner, &input_node_info_list, &status](Node* node) { 4447b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower if (!status.ok()) { 4457b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower return; 4467b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } 4477b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower CHECK_NE(node, nullptr); 4487b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower // If we visit an input node, we use the shape provided and set the 4497b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower // shape accordingly. 4507b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower bool is_input_node = false; 4517b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower for (const std::pair<string, Tensor>& input_node_info : 4527b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower input_node_info_list) { 4537b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower if (node->name() == input_node_info.first) { 4547b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower shape_inference::InferenceContext* context = 4557b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower shape_refiner->GetContext(node); 4567b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower shape_inference::ShapeHandle handle; 4577b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower status = context->MakeShapeFromTensorShape( 4587b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower input_node_info.second.shape(), &handle); 4590a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower if (!status.ok()) { 4600a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower break; 4610a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower } 4620a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower status = shape_refiner->SetShape(node, 0, handle); 4630a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower if (!status.ok()) { 4640a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower break; 4650a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower } 4667b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower is_input_node = true; 4677b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } 4687b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower if (!status.ok()) { 4697b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower break; 4707b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } 4717b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } 4727b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower // If not an input node call AddNode() that recomputes the shape. 4737b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower if (!is_input_node && status.ok()) { 4747b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower status = shape_refiner->AddNode(node); 4750a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower } 4760a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower if (!status.ok()) { 4770a6bbfaf2f3c4c9184cae9c239b99b7b855638a4A. Unique TensorFlower VLOG(1) << "Shape inference failed for node: " << node->name(); 4787b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } 4797b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower }; 4807b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower 4817b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower ReverseDFS(*graph, {}, visit); 4827b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower 4837b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower return status; 4847b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower} 4857b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower 4867b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildTensorShapeMapFromGraph( 4877b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower const Graph& graph, const ShapeRefiner& shape_refiner, 4887b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower TensorShapeMap* tensor_shape_map) { 4897b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower for (int i = 0; i < graph.num_node_ids(); ++i) { 4907b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower const Node* node = graph.FindNodeId(i); 4917b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower CHECK_NE(node, nullptr); 4927b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower for (int j = 0; j < node->num_outputs(); ++j) { 4937b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower const int output_index = j; 4947b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower const DataType dt = node->output_type(output_index); 4957b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower shape_inference::InferenceContext* context = 4967b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower shape_refiner.GetContext(node); 4977b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower CHECK_NE(context, nullptr); 4987b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower shape_inference::ShapeHandle shape_handle = context->output(output_index); 4997b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower if (context->RankKnown(shape_handle)) { 5007b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower TensorShape ts; 5017b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower for (int k = 0; k < context->Rank(shape_handle); ++k) { 5027b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower shape_inference::DimensionHandle dh = context->Dim(shape_handle, k); 5037b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower CHECK(context->ValueKnown(dh)); 5047b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower ts.AddDim(context->Value(dh)); 5057b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } 5067b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower const string& node_name = node->name(); 5077b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower CHECK(tensor_shape_map->count(node_name) == 0); 508f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower tensor_shape_map->emplace(node_name, 509f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower std::make_pair(j, std::make_pair(dt, ts))); 5107b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } else { 5117b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower return errors::InvalidArgument("Graph contains unknow shapes"); 5127b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } 5137b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } 5147b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower } 5157b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower return Status::OK(); 5167b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower} 5177b8e31c58140fe6c6bdd3a0d946b978c2a216702A. Unique TensorFlower 518f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower/* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType* 519f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::GetTensorShapeType( 520f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower const TensorShapeMap& tensor_shape_map, const string& node_name) { 521f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower if (node_name.find(':') != string::npos) { 522f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower const TensorId tid = ParseTensorName(node_name); 523f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower return GetTensorShapeType(tensor_shape_map, tid.first.ToString(), 524f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower tid.second); 525f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower } else { 526f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower return GetTensorShapeType(tensor_shape_map, node_name, 0); 527f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower } 528f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower} 529f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower 530f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower/* static */ const RemoteFusedGraphExecuteUtils::TensorShapeType* 531f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::GetTensorShapeType( 532f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower const TensorShapeMap& tensor_shape_map, const string& node_name, 533f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower const int port) { 534f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower CHECK_EQ(node_name.find(':'), string::npos); 535f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower if (tensor_shape_map.count(node_name) <= 0) { 536f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower return nullptr; 537f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower } 538f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower auto its = tensor_shape_map.equal_range(node_name); 539f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower for (auto it = its.first; it != its.second; ++it) { 540f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower if (it->second.first == port) { 541f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower return &it->second.second; 542f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower } 543f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower } 544f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower return nullptr; 545f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower} 546f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower 54739af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower/* static */ void 54839af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto( 54939af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower const RemoteFusedGraphExecuteInfo& proto, 55039af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower std::vector<std::pair<string, Tensor>>* inputs, 55139af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower std::vector<string>* outputs) { 55239af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower CHECK_EQ(proto.graph_input_node_name_size(), 55339af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower proto.default_graph_input_tensor_shape_size()); 55439af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower for (int i = 0; i < proto.graph_input_node_name_size(); ++i) { 55539af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower inputs->emplace_back( 55639af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower proto.graph_input_node_name(i), 55739af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower Tensor(proto.default_graph_input_tensor_shape(i).dtype(), 55839af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower TensorShape(proto.default_graph_input_tensor_shape(i).shape()))); 55939af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower } 56039af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower for (const string& output_node_name : proto.graph_output_node_name()) { 56139af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower outputs->emplace_back(output_node_name); 56239af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower } 56339af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower} 56439af8956057f7f44f3d2fea0845c69225cabc195A. Unique TensorFlower 565f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower/* static */ void RemoteFusedGraphExecuteUtils::EmplaceTensorShapeType( 566f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower const string& name, const Tensor& tensor, 567f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower TensorShapeMap* tensor_shape_map) { 568f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower const TensorId tid = ParseTensorName(name); 569f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower CHECK_EQ(tensor_shape_map->count(name), 0); 570f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower tensor_shape_map->emplace( 571f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower tid.first.ToString(), 572f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower std::make_pair(tid.second, 573f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower std::make_pair(tensor.dtype(), tensor.shape()))); 574f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower} 575f1f8e72b10645910f7002068d4f8c4edb7a90d96A. Unique TensorFlower 576bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildAndAddTensorShapes( 577bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<std::pair<string, Tensor>>& input_tensors, 578bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const bool dry_run_inference, GraphDef* graph_def) { 579bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TensorShapeMap tensor_shape_map; 580bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (dry_run_inference) { 581bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(DryRunInferenceForAllNode(*graph_def, input_tensors, 582bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower /*initialize_by_zero=*/true, 583bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower &tensor_shape_map)); 584bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } else { 585bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower ImportGraphDefOptions opts; 586bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Graph graph(OpRegistry::Global()); 587e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 588bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR( 589bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower ImportGraphDef(opts, *graph_def, &graph, &shape_refiner)); 590bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(PropagateShapeInference(*graph_def, input_tensors, 591bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower &graph, &shape_refiner)); 592bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR( 593bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower BuildTensorShapeMapFromGraph(graph, shape_refiner, &tensor_shape_map)); 594bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 595bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 596bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (NodeDef& node_def : *graph_def->mutable_node()) { 597bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR( 598bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def)); 599bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 600bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 601bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 602bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 603bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 604bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status 605bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo( 606bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& executor_name, const GraphDef& subgraph_def, 607bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& inputs, const std::vector<string>& outputs, 608bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const bool require_shape_type, RemoteFusedGraphExecuteInfo* execute_info, 609bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower DataTypeVector* input_types, DataTypeVector* output_types) { 610bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(execute_info); 611bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(input_types); 612bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(output_types); 613bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 614bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower execute_info->Clear(); 615bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower execute_info->set_executor_name(executor_name); 616bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 617bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // copy graph 618bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower *execute_info->mutable_remote_graph() = subgraph_def; 619bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 620bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& input : inputs) { 621bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower DataType dt; 622bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TensorShape shape; 623bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const bool has_shapetype = 624bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GetOutputTensorShapeType(subgraph_def, input, &dt, &shape); 625bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 626bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower execute_info->add_graph_input_node_name(input); 627bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (has_shapetype) { 628bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type = 629bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower *execute_info->add_default_graph_input_tensor_shape(); 630bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower tensor_shape_type.set_dtype(dt); 631bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape(); 632bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const int64 dim : shape.dim_sizes()) { 633bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower tensor_shape_proto.add_dim()->set_size(dim); 634bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 635bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower input_types->push_back(dt); 636bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } else { 637bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK(!require_shape_type) 638bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << "No shape type found for " << input << DumpGraphDef(subgraph_def); 639bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // Assuming input type is float if no data provided. 640bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower input_types->push_back(DT_FLOAT); 641bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 642bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 643bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 644bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& output : outputs) { 645bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower DataType dt; 646bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TensorShape shape; 647bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const bool has_shapetype = 648bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GetOutputTensorShapeType(subgraph_def, output, &dt, &shape); 649bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 650bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower execute_info->add_graph_output_node_name(output); 651bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (has_shapetype) { 652bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& 653bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower tensor_shape_type_proto = 654bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower *execute_info->add_default_graph_output_tensor_shape(); 655bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower tensor_shape_type_proto.set_dtype(dt); 656bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TensorShapeProto& tensor_shape_proto = 657bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower *tensor_shape_type_proto.mutable_shape(); 658bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const int64 dim : shape.dim_sizes()) { 659bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower tensor_shape_proto.add_dim()->set_size(dim); 660bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 661bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower output_types->push_back(dt); 662bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } else { 663bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK(!require_shape_type) 664bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << "No shape type found for " << output << DumpGraphDef(subgraph_def); 665bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // Assuming output type is float if no data provided. 666bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower output_types->push_back(DT_FLOAT); 667bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 668bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 669bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 670bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 671bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 672bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 673bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status 674bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( 675bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& node_name, const string& executor_name, 676bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const GraphDef& subgraph_def, const std::vector<string>& inputs, 677bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& outputs, const bool require_shape_type, 678bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Graph* graph, Node** created_node) { 679bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(graph); 680bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(created_node); 681bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 682bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower RemoteFusedGraphExecuteInfo execute_info; 683bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower DataTypeVector input_types; 684bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower DataTypeVector output_types; 685bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 686bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_CHECK_OK(RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteInfo( 687bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower executor_name, subgraph_def, inputs, outputs, require_shape_type, 688bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower &execute_info, &input_types, &output_types)); 689bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 690bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::vector<NodeBuilder::NodeOut> node_out_list; 691bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& input : inputs) { 692bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId tid = ParseTensorName(input); 693bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Node* node = FindMutableNodeByName(tid.first.ToString(), graph); 694bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(node); 695bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower node_out_list.emplace_back(node, tid.second); 696bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 697bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 698bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string execute_info_str = execute_info.SerializeAsString(); 699bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 700bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower auto builder = 701bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower NodeBuilder(node_name, "RemoteFusedGraphExecute") 702bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower .Input(node_out_list) 703bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower .Attr("Tinputs", input_types) 704bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower .Attr("Toutputs", output_types) 705bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower .Attr("serialized_remote_fused_graph_execute_info", execute_info_str); 706bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 707bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node)); 708bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 709bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 710bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 711bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildIdentityOpNode( 712bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& node_name, const string& input_node_name, 713bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const int input_node_port, const DataType dt, Graph* graph, 714bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Node** created_node) { 715bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Node* node = FindMutableNodeByName(input_node_name, graph); 716bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(node); 717bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower NodeBuilder::NodeOut node_out(node, input_node_port); 718bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 719bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower auto builder = 720bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower NodeBuilder(node_name, "Identity").Input(node_out).Attr("T", dt); 721bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 722bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(builder.Finalize(graph, created_node)); 723bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 724bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 725bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 726bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::ClusterizeNodes( 727bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::unordered_set<string>& node_names, const GraphDef& graph_def, 728bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::vector<ClusterInfo>* cluster_infos) { 729bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Graph graph(OpRegistry::Global()); 730e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 731bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner)); 732bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::unordered_set<string> remaining_nodes = node_names; 733bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 734bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower while (!remaining_nodes.empty()) { 735bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower ClusterInfo ci; 736bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 737bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // Determine one cluster nodes 738bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::unordered_set<const Node*> visited; 739bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::deque<const Node*> queue; 740bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower queue.emplace_back(FindNodeByName(*remaining_nodes.begin(), graph)); 741bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower while (!queue.empty()) { 742bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const Node* node = queue.front(); 743bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(node); 744bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower queue.pop_front(); 745bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& node_name = node->name(); 746bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (node_names.count(node_name) > 0) { 747bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::get<0>(ci).emplace(node_name); 748bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower remaining_nodes.erase(node_name); 749bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } else { 750bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // Edge of subgraph. Do nothing. 751bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower continue; 752bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 753bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const Node* in : node->in_nodes()) { 754bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (visited.insert(in).second) { 755bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower queue.push_back(in); 756bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 757bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 758bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const Node* out : node->out_nodes()) { 759bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (visited.insert(out).second) { 760bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower queue.push_back(out); 761bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 762bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 763bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 764bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 765bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // Determine one cluster border 766bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::vector<string>& border_inputs = std::get<1>(ci); 767bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::vector<string>& border_outputs = std::get<2>(ci); 768bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& node_name : node_names) { 769bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Node* node = FindMutableNodeByName(node_name, &graph); 770bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(node); 771bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower int input_count = 0; 772bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const Edge* in_edge : node->in_edges()) { 773bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const Node* src_node = in_edge->src(); 774bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const bool src_is_outside = 775bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower node_names.count(src_node->name()) <= 0 && !src_node->IsSource(); 776bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (src_is_outside) { 777bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string src_name = 778bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower strings::StrCat(src_node->name(), ":", in_edge->src_output()); 779bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_EQ(1, src_node->num_outputs()) 780bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << "output count of input border node must be one." 781bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << src_node->name(); 782bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (std::find(border_inputs.begin(), border_inputs.end(), src_name) == 783bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower border_inputs.end()) { 784bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower border_inputs.emplace_back(src_name); 785bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 786bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } else { 787bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower ++input_count; 788bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 789bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 790f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower CHECK(input_count == 0 || input_count == node->in_edges().size()) 791f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower << "Invalid input_count(" << input_count << ", " 792f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower << node->in_edges().size() << ") " << node_name; 793bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 794bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const Edge* out_edge : node->out_edges()) { 795bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const Node* dst_node = out_edge->dst(); 796bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(dst_node); 797bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const bool dst_is_outside = node_names.count(dst_node->name()) <= 0; 798bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string dst_name = 799bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower strings::StrCat(node->name(), ":", out_edge->src_output()); 800bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (dst_is_outside) { 801bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (dst_node->IsSink()) { 802bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_EQ(1, node->num_outputs()) 803bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << "If you want to specify output node as subgraph output node " 804bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << "the output count of the node must be 1 " 805bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << "because that node is replaced by identity node."; 806bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string identity_dst_name = 807bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower strings::StrCat(node->name(), ":", 0); 808bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (std::find(border_outputs.begin(), border_outputs.end(), 809bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower identity_dst_name) == border_outputs.end()) { 810bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower border_outputs.emplace_back(identity_dst_name); 811bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 812bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } else { 813bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (std::find(border_outputs.begin(), border_outputs.end(), 814bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower dst_name) == border_outputs.end()) { 815bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower border_outputs.emplace_back(dst_name); 816bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 817bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 818bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 819bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 820bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 821bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower cluster_infos->emplace_back(ci); 822bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower VLOG(1) << DumpCluster(ci); 823bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 824bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 825bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 826bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 827bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterSubgraphDef( 828bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const ClusterInfo& cluster, const GraphDef& graph_def, 829bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GraphDef* subgraph_def) { 830bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::unordered_set<string>& node_names = std::get<0>(cluster); 831bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::unordered_set<string>& border_input_names = 832bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower BuildNodeSetFromNodeNamesAndPorts(std::get<1>(cluster)); 833bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 834bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Graph graph(OpRegistry::Global()); 835e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 836bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner)); 837bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 838bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (Node* node : graph.nodes()) { 839bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (node != nullptr && node_names.count(node->name()) <= 0 && 840bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower border_input_names.count(node->name()) <= 0 && !node->IsSource() && 841bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower !node->IsSink()) { 842bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower graph.RemoveNode(node); 843bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 844bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 845bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower graph.ToGraphDef(subgraph_def); 846bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 847bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& subgraph_input : std::get<1>(cluster)) { 848bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId tid = ParseTensorName(subgraph_input); 849bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string subgraph_input_name = tid.first.ToString(); 850bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const int subgraph_input_port = tid.second; 851bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def); 852bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(node_def); 853bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::vector<DataType> dt_vec; 854bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::vector<TensorShape> shape_vec; 855bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GetOutputTensorShapeType(*node_def, &dt_vec, &shape_vec).IgnoreError(); 856bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const DataType& dt = 857bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower dt_vec.empty() ? DT_FLOAT : dt_vec.at(subgraph_input_port); 858bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorShape& shape = 859bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower shape_vec.empty() ? TensorShape({}) : shape_vec.at(subgraph_input_port); 860bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 861bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(ReplaceInputNodeByPlaceHolder(subgraph_input_name, dt, 862bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower shape, subgraph_def)); 863bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 8649ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower 8659ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower // sort subgraph_def to align order in graph_def 8669ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower std::unordered_map<string, int> name_to_id_map; 8679ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower for (int i = 0; i < graph_def.node_size(); ++i) { 8689ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower name_to_id_map.emplace(graph_def.node(i).name(), i); 8699ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower } 8709ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower std::sort(subgraph_def->mutable_node()->begin(), 8719ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower subgraph_def->mutable_node()->end(), 8729ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower [&name_to_id_map](const NodeDef& node0, const NodeDef& node1) { 8739ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower CHECK(name_to_id_map.count(node0.name()) > 0); 8749ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower CHECK(name_to_id_map.count(node1.name()) > 0); 8759ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower const int id0 = name_to_id_map.at(node0.name()); 8769ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower const int id1 = name_to_id_map.at(node1.name()); 8779ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower return id0 < id1; 8789ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower }); 8799ad851e54d014532dd3b3c8308396769f9a7aeeeA. Unique TensorFlower 880bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower VLOG(1) << DumpGraphDef(*subgraph_def); 881bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 882bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 883bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 884bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 885bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& border_inputs, 886bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& border_outputs, const GraphDef& graph_def, 887bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower ClusterInfo* cluster) { 888bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Graph graph(OpRegistry::Global()); 889e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 890bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, &shape_refiner)); 891bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 892bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::unordered_set<const Node*> visited; 893bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::deque<const Node*> queue; 894bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& output : border_outputs) { 895bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId tid = ParseTensorName(output); 896bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& output_node_name = tid.first.ToString(); 897bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const Node* node : graph.nodes()) { 898bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (output_node_name == node->name()) { 899bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower queue.push_back(node); 900bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower visited.insert(node); 901bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 902bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 903bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 904bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 905bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::unordered_set<const Node*> border_input_nodes; 906bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // propagate visit to parent nodes until input nodes 907bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower while (!queue.empty()) { 908bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const Node* node = queue.front(); 909bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower queue.pop_front(); 910bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const Edge* edge : node->in_edges()) { 911bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const Node* src_node = edge->src(); 912bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(src_node); 913bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const int src_port = edge->src_output(); 914bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower bool input_found = false; 915bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& input : border_inputs) { 916bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId tid = ParseTensorName(input); 917bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (tid.first.ToString() == src_node->name() && 918bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower tid.second == src_port) { 919bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower input_found = true; 920bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower border_input_nodes.insert(src_node); 921bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 922bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 923bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (visited.insert(src_node).second) { 924bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (!input_found) { 925bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower queue.push_back(src_node); 926bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 927bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 928bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 929bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 930bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 931bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const Node* node : visited) { 932bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (node != nullptr && !node->IsSource() && !node->IsSink() && 933bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower border_input_nodes.count(node) <= 0) { 934bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::get<0>(*cluster).insert(node->name()); 935bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 936bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 937bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::get<1>(*cluster) = border_inputs; 938bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::get<2>(*cluster) = border_outputs; 939bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 940bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 941bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 942bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseCluster( 943bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const GraphDef& input_graph_def, const std::vector<string>& inputs, 944bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& outputs, 945bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& remote_fused_graph_node_name, const ClusterInfo& cluster, 946bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& remote_graph_executor_name, const bool require_shape_type, 947bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GraphDef* output_graph_def) { 948bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower LOG(INFO) << "Transforming quantized stripped model to a remote fused " 949bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower "graph execute op by fusing a specified subgraph..."; 950bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 951bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK(!remote_graph_executor_name.empty()); 952bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 953bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& border_inputs = std::get<1>(cluster); 954bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& border_outputs = std::get<2>(cluster); 955bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 956bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GraphDef subgraph_def; 957bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR( 958bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower BuildClusterSubgraphDef(cluster, input_graph_def, &subgraph_def)); 959bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 960bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Graph graph(OpRegistry::Global()); 961e85d3df92deb9d717befdf173966a2913ac2aea0Geoffrey Irving ShapeRefiner shape_refiner(graph.versions(), graph.op_registry()); 962bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR( 963bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower ImportGraphDef({}, input_graph_def, &graph, &shape_refiner)); 964bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 965bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Node* fused_node; 966bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(BuildRemoteFusedGraphExecuteOpNode( 967bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower remote_fused_graph_node_name, remote_graph_executor_name, subgraph_def, 968bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower border_inputs, border_outputs, require_shape_type, &graph, &fused_node)); 969bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 970bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const Node* node : graph.nodes()) { 971bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (int i = 0; i < node->num_inputs(); ++i) { 972bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const Edge* edge = nullptr; 973bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(node->input_edge(i, &edge)); 974bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (int j = 0; j < border_outputs.size(); ++j) { 975bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& output = border_outputs.at(j); 976bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId tid = ParseTensorName(output); 977bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string output_name = tid.first.ToString(); 978bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Node* src_node = edge->src(); 979bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (src_node != nullptr && src_node->name() == output_name && 980bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower edge->src_output() == tid.second) { 981bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // Source node is replaced by new fused node. 982bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Node* dst_node = edge->dst(); 983bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const int dst_input = edge->dst_input(); 984bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower LOG(INFO) << "Removing existing edge to " << edge->dst()->name() 985bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << " from " << edge->src()->name(); 986bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower graph.RemoveEdge(edge); 987bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower graph.AddEdge(fused_node, j, dst_node, dst_input); 988bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 989bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 990bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 991bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 992bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 993bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // Replace output nodes by identity nodes which forward outputs from 994bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // RemoteFusedGraphExecuteOpNode 995bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (const string& output : outputs) { 996bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId output_tid = ParseTensorName(output); 997bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string output_name = output_tid.first.ToString(); 99890d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai for (size_t i = 0; i < border_outputs.size(); ++i) { 999bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId subgraph_output_tid = 1000bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower ParseTensorName(border_outputs.at(i)); 1001bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& subgraph_output_name = subgraph_output_tid.first.ToString(); 1002bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (output_name == subgraph_output_name) { 1003bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower LOG(INFO) << "As graph output and subgraph output are same, " 1004bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << "the graph output node is replaced by identity node"; 10051c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower Node* original_output_node = FindMutableNodeByName(output_name, &graph); 1006bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(original_output_node); 1007bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_EQ(1, original_output_node->num_outputs()) 1008bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower << "Num outputs should be 1 for " << output << "."; 1009bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower graph.RemoveNode(original_output_node); 1010bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower Node* new_node; 10111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower TF_RETURN_IF_ERROR(BuildIdentityOpNode(output_name, 1012bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower remote_fused_graph_node_name, i, 1013bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower DT_FLOAT, &graph, &new_node)); 1014bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_NOTNULL(new_node); 1015bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 1016bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 1017bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 1018bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 1019bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GraphDef result_graph_def; 1020bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 1021bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower graph.ToGraphDef(&result_graph_def); 1022bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 1023bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower ClusterInfo graph_cluster; 1024bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR( 1025bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower BuildClusterByBorder(inputs, outputs, result_graph_def, &graph_cluster)); 1026bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 1027bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // Remove unvisited nodes 1028bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(BuildClusterSubgraphDef(graph_cluster, result_graph_def, 1029bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower output_graph_def)); 1030bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 1031bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 1032bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 1033bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 1034bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames( 1035bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const GraphDef& input_graph_def, const std::vector<string>& inputs, 1036bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& outputs, 1037bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& remote_fused_graph_node_name_prefix, 1038bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::unordered_set<string>& subgraph_nodes, 1039bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& remote_fused_graph_executor_name, 1040bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const bool require_shape_type, GraphDef* output_graph_def) { 1041bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower std::vector<ClusterInfo> ci_vec; 1042bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::ClusterizeNodes( 1043bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower subgraph_nodes, input_graph_def, &ci_vec)); 1044bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 104590d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai for (size_t i = 0; i < ci_vec.size(); ++i) { 1046bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string remote_fused_graph_node_name = 1047bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower strings::StrCat(remote_fused_graph_node_name_prefix, "/", i); 1048bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(FuseCluster(input_graph_def, inputs, outputs, 1049bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower remote_fused_graph_node_name, ci_vec.at(i), 1050bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower remote_fused_graph_executor_name, 1051bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower require_shape_type, output_graph_def)); 1052bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 1053bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 1054bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 1055bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 1056bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByBorder( 1057bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const GraphDef& input_graph_def, const std::vector<string>& inputs, 1058bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& outputs, 1059bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& remote_fused_graph_node_name, 1060bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& border_inputs, 1061bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const std::vector<string>& border_outputs, 1062bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& remote_graph_executor_name, const bool require_shape_type, 1063bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GraphDef* output_graph_def) { 1064bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower ClusterInfo cluster; 1065bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildClusterByBorder( 1066bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower border_inputs, border_outputs, input_graph_def, &cluster)); 1067bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 1068bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return FuseCluster( 1069bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower input_graph_def, inputs, outputs, remote_fused_graph_node_name, cluster, 1070bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower remote_graph_executor_name, require_shape_type, output_graph_def); 1071bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 1072bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 1073f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByOpTypes( 1074f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const GraphDef& input_graph_def, const std::vector<string>& inputs, 1075f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const std::vector<string>& outputs, 1076f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const string& remote_fused_graph_node_name_prefix, 1077f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const std::unordered_set<string>& fused_op_types, 1078f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const string& remote_fused_graph_executor_name, 1079f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const bool require_shape_type, GraphDef* output_graph_def) { 1080f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const std::unordered_set<string> fused_nodes_filtered_by_op_types = 1081f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower BuildNodeMapFromOpTypes(input_graph_def, fused_op_types); 1082f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower 1083f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower return FuseRemoteGraphByNodeNames( 1084f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower input_graph_def, inputs, outputs, remote_fused_graph_node_name_prefix, 1085f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower fused_nodes_filtered_by_op_types, remote_fused_graph_executor_name, 1086f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower require_shape_type, output_graph_def); 1087f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower} 1088f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower 1089d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::FuseRemoteGraphByExecutor( 1090d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower const GraphDef& input_graph_def, const std::vector<string>& inputs, 1091d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower const std::vector<string>& outputs, const string& executor_name, 1092d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower GraphDef* output_graph_def) { 1093d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower const ExecutorBuildFunc* build_func = GetExecutorBuildFunc(executor_name); 1094d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower if (build_func == nullptr) { 1095d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower return errors::InvalidArgument("Unknown executor name: " + executor_name); 1096d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower } 1097d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower std::unique_ptr<IRemoteFusedGraphExecutor> executor; 1098d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower TF_RETURN_IF_ERROR((*build_func)(&executor)); 1099d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower CHECK_NOTNULL(executor.get()); 1100d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower if (!executor->IsEnabled()) { 1101d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower // As this executor is not enabled, just return original graph as is. 1102d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower *output_graph_def = input_graph_def; 1103d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower return Status::OK(); 1104d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower } 1105d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower return executor->FuseRemoteGraph(input_graph_def, inputs, outputs, 1106d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower output_graph_def); 1107d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower} 1108d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower 11091c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::PlaceRemoteGraphArguments( 11101c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const std::vector<string>& inputs, const std::vector<string>& outputs, 11111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const std::unordered_set<string>& fused_node_names, 11121c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const std::vector<string>& border_inputs, 11131c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const std::vector<string>& border_outputs, 1114f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const std::unordered_set<string>& fused_op_types, 11151c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const string& remote_fused_graph_node_name, 11161c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const string& remote_graph_executor_name, GraphDef* graph_def) { 11171c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK_NOTNULL(graph_def); 1118f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower 1119f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const std::unordered_set<string> fused_nodes_filtered_by_op_types = 1120f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower BuildNodeMapFromOpTypes(*graph_def, fused_op_types); 1121f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower 11221c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower for (NodeDef& node_def : *graph_def->mutable_node()) { 11231c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower string attr_str; 11241c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower TensorId tid; 112590d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai for (size_t i = 0; i < inputs.size(); ++i) { 11261c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (IsSameNodeName(node_def, inputs.at(i), &tid)) { 11271c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower AppendDeliminator(&attr_str); 11281c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_INPUT, 11291c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower tid.second, i, remote_graph_executor_name, 11301c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower remote_fused_graph_node_name); 11311c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11321c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 113390d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai for (size_t i = 0; i < outputs.size(); ++i) { 11341c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (IsSameNodeName(node_def, outputs.at(i), &tid)) { 11351c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower AppendDeliminator(&attr_str); 11361c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT, 11371c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower tid.second, i); 11381c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11391c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11401c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower for (const string& fused_node_name : fused_node_names) { 11411c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (fused_node_name == node_def.name()) { 11421c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower AppendDeliminator(&attr_str); 11431c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE); 11441c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11451c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 1146f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower for (const string& fused_node_name : fused_nodes_filtered_by_op_types) { 1147f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower if (fused_node_name == node_def.name()) { 1148f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower AppendDeliminator(&attr_str); 1149f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE); 1150f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower } 1151f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower } 115290d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai for (size_t i = 0; i < border_inputs.size(); ++i) { 11531c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) { 11541c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower AppendDeliminator(&attr_str); 11551c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::BORDER_INPUT, 11561c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower tid.second, i); 11571c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11581c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 115990d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai for (size_t i = 0; i < border_outputs.size(); ++i) { 11601c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) { 11611c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower AppendDeliminator(&attr_str); 11621c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower attr_str += BuildNodeTypeAttr( 11631c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower RemoteFusedGraphExecuteInfo::BORDER_OUTPUT, tid.second, i); 11641c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11651c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11661c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (attr_str.empty()) { 11671c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::UNUSED); 11681c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11691c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def); 11701c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11711c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return Status::OK(); 11721c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 11731c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 11741c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ Status 11751c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments( 11761c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const GraphDef& input_graph_def, 11771c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const std::vector<std::pair<string, Tensor>>& input_tensors, 11781c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower GraphDef* output_graph_def) { 11791c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::unordered_map<int, string> input_map; 11801c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::unordered_map<int, string> output_map; 11811c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::unordered_set<string> fused_node_names; 11821c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::unordered_map<int, string> border_input_map; 11831c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::unordered_map<int, string> border_output_map; 11841c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower string remote_graph_executor_name; 11851c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower string remote_fused_graph_node_name; 11861c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 11871c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower for (const NodeDef& node_def : input_graph_def.node()) { 11881c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower string attr_str; 11891c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower TF_RETURN_IF_ERROR(GetNodeAttr(node_def, ATTR_NODE_TYPE, &attr_str)); 11901c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::vector<std::vector<string>> attr_strs; 11911c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower for (const string& str : str_util::Split(attr_str, ":")) { 11921c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower attr_strs.emplace_back(str_util::Split(str, ",")); 11931c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11941c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (attr_strs.empty()) { 11951c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return errors::InvalidArgument("Remote graph node type not found."); 11961c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 11971c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower for (const std::vector<string>& attr : attr_strs) { 11981c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (attr.empty()) { 11991c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return errors::InvalidArgument("Empty remote graph node type attr."); 12001c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12011c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower int node_type_int; 12021c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0); 12031c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const RemoteFusedGraphExecuteInfo::NodeType node_type = 12041c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower static_cast<RemoteFusedGraphExecuteInfo::NodeType>(node_type_int); 12051c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const string& name = node_def.name(); 12061c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower int port; 12071c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower int index; 12081c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 12091c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower switch (node_type) { 12101c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower case RemoteFusedGraphExecuteInfo::GRAPH_INPUT: 12111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower VLOG(2) << "Graph input: " << name; 12121c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK_EQ(5, attr.size()); 12131c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(strings::safe_strto32(attr.at(1), &port)); 12141c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(strings::safe_strto32(attr.at(2), &index)); 12151c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(!attr.at(3).empty()); 12161c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower remote_graph_executor_name = attr.at(3); 12171c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(!attr.at(4).empty()); 12181c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower remote_fused_graph_node_name = attr.at(4); 12191c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower input_map.emplace(index, strings::StrCat(name, ":", port)); 12201c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (GetExecutorBuildFunc(remote_graph_executor_name) == nullptr) { 12211c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower LOG(INFO) << "Executor for " << remote_graph_executor_name 12221c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower << " not registered. Do not fuse."; 12231c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower *output_graph_def = input_graph_def; 12241c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return Status::OK(); 12251c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12261c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower break; 12271c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower case RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT: 12281c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower VLOG(2) << "Graph output: " << name; 12291c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK_EQ(3, attr.size()); 12301c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(strings::safe_strto32(attr.at(1), &port)); 12311c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(strings::safe_strto32(attr.at(2), &index)); 12321c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower output_map.emplace(index, strings::StrCat(name, ":", port)); 12331c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower break; 12341c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower case RemoteFusedGraphExecuteInfo::FUSED_NODE: 12351c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower VLOG(2) << "Fused node: " << name; 12361c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK_EQ(1, attr.size()); 12371c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower fused_node_names.emplace(name); 12381c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower break; 12391c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower case RemoteFusedGraphExecuteInfo::BORDER_INPUT: 12401c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower VLOG(2) << "Border input: " << name; 12411c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK_EQ(3, attr.size()); 12421c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(strings::safe_strto32(attr.at(1), &port)); 12431c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(strings::safe_strto32(attr.at(2), &index)); 12441c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower border_input_map.emplace(index, strings::StrCat(name, ":", port)); 12451c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower break; 12461c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower case RemoteFusedGraphExecuteInfo::BORDER_OUTPUT: 12471c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower VLOG(2) << "Border output: " << name; 12481c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK_EQ(3, attr.size()); 12491c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(strings::safe_strto32(attr.at(1), &port)); 12501c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower CHECK(strings::safe_strto32(attr.at(2), &index)); 12511c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower border_output_map.emplace(index, strings::StrCat(name, ":", port)); 12521c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower break; 12531c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower case RemoteFusedGraphExecuteInfo::UNUSED: 12541c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower // do nothing 12551c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower break; 12561c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower default: 12571c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower // unsupported value 125888cdf1f81fa1938c5bb81c5d293fc0ed0758cadcA. Unique TensorFlower LOG(FATAL); 12591c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12601c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12611c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12621c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower bool require_shape_type = false; 12631c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::vector<string> inputs; 12641c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::vector<string> outputs; 12651c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::vector<string> border_inputs; 12661c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower std::vector<string> border_outputs; 12671c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower ConvertMapToVector(input_map, &inputs); 12681c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower ConvertMapToVector(output_map, &outputs); 12691c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower ConvertMapToVector(border_input_map, &border_inputs); 12701c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower ConvertMapToVector(border_output_map, &border_outputs); 12711c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 12721c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (!input_tensors.empty()) { 12731c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower bool input_match = false; 12741c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (inputs.size() == input_tensors.size()) { 12751c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower for (const std::pair<string, Tensor>& input_tensor : input_tensors) { 12761c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (!ContainsSameTensorId(input_tensor.first, inputs)) { 12771c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower break; 12781c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12791c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower DataType data_type; 12801c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower TensorShape shape; 12811c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (GetOutputTensorShapeType(input_graph_def, input_tensor.first, 12821c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower &data_type, &shape)) { 12831c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (data_type == input_tensor.second.dtype() && 12841c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower shape == input_tensor.second.shape()) { 12851c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower VLOG(2) << "Input matched!"; 12861c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower // Shape type matched. 12871c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower input_match = true; 12881c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower require_shape_type = true; 12891c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12901c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } else { 12911c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower // Shape type not required. 12921c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower input_match = true; 12931c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12941c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12951c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 12961c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (!input_match) { 12971c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower // Input mismatch. Just copy original graph 12981c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower *output_graph_def = input_graph_def; 12991c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return Status::OK(); 13001c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 13011c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 13021c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 13031c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (!fused_node_names.empty()) { 13041c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower TF_RETURN_IF_ERROR(FuseRemoteGraphByNodeNames( 13051c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower input_graph_def, inputs, outputs, remote_fused_graph_node_name, 13061c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower fused_node_names, remote_graph_executor_name, require_shape_type, 13071c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower output_graph_def)); 13081c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } else if (!border_inputs.empty() || !border_outputs.empty()) { 13091c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower TF_RETURN_IF_ERROR(FuseRemoteGraphByBorder( 13101c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower input_graph_def, inputs, outputs, remote_fused_graph_node_name, 13111c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower border_inputs, border_outputs, remote_graph_executor_name, 13121c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower require_shape_type, output_graph_def)); 13131c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } else { 13141c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower *output_graph_def = input_graph_def; 13151c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 13161c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 13171c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return Status::OK(); 13181c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 13191c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 13201c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ bool RemoteFusedGraphExecuteUtils::IsFuseReady( 13211c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const GraphDef& graph_def, 13221c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const std::vector<std::pair<string, Tensor>>& input_tensors) { 13231c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower for (const std::pair<string, Tensor>& input_tensor : input_tensors) { 13241c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const NodeDef* node_def = FindNodeDefByName(input_tensor.first, graph_def); 13251c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (node_def == nullptr) { 13261c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return false; 13271c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 13281c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower string attr; 13291c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const Status status = GetNodeAttr(*node_def, ATTR_NODE_TYPE, &attr); 13301c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower if (!status.ok() || attr.empty()) { 13311c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return false; 13321c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 13331c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower } 13341c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return true; 13351c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 13361c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 1337351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor( 1338351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower const void* src_ptr, const int src_size, Tensor* tensor) { 1339351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower CHECK(tensor->TotalBytes() >= src_size) 1340351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower << tensor->TotalBytes() << ", " << src_size; 1341351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower void* dst_ptr; 1342351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower switch (tensor->dtype()) { 1343351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_FLOAT: 1344351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<float>().data(); 1345351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1346351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_DOUBLE: 1347351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<double>().data(); 1348351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1349351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_INT32: 1350351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<int32>().data(); 1351351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1352351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_UINT8: 1353351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<uint8>().data(); 1354351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1355351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_INT16: 1356351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<int16>().data(); 1357351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1358351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_INT8: 1359351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<int8>().data(); 1360351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1361351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_STRING: 1362351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<string>().data(); 1363351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1364351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_INT64: 1365351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<int64>().data(); 1366351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1367351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_BOOL: 1368351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<bool>().data(); 1369351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1370351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_QINT8: 1371351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<qint8>().data(); 1372351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1373351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_QUINT8: 1374351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<quint8>().data(); 1375351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1376351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_QINT32: 1377351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<qint32>().data(); 1378351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1379351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_BFLOAT16: 1380351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<bfloat16>().data(); 1381351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1382351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_QINT16: 1383351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<qint16>().data(); 1384351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1385351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_QUINT16: 1386351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<quint16>().data(); 1387351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1388351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower case DT_UINT16: 1389351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower dst_ptr = tensor->flat<uint16>().data(); 1390351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1391351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower default: 139288cdf1f81fa1938c5bb81c5d293fc0ed0758cadcA. Unique TensorFlower LOG(FATAL) << "type " << tensor->dtype() << " is not supported."; 1393351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower break; 1394351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower } 1395351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower CHECK_NOTNULL(dst_ptr); 1396351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower std::memcpy(dst_ptr, src_ptr, src_size); 1397351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower return Status::OK(); 1398351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower} 1399351e1673beffa8583ff75046eb516893b9e5c79dA. Unique TensorFlower 1400f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower/* static */ std::unordered_set<string> 1401f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildNodeMapFromOpTypes( 1402f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower const GraphDef& graph_def, const std::unordered_set<string>& op_types) { 1403f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower std::unordered_set<string> retval; 1404f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower for (const NodeDef& node_def : graph_def.node()) { 1405f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower if (op_types.count(node_def.op()) > 0) { 1406f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower retval.emplace(node_def.name()); 1407f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower } 1408f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower } 1409f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower return retval; 1410f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower} 1411f39dd926259bee979915df2edaeb2369eebeefcbA. Unique TensorFlower 1412d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower/* static */ std::unordered_set<string> 1413d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlowerRemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions( 1414d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower const GraphDef& graph_def, 1415d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower const IRemoteFusedGraphOpsDefinitions& ops_definitions) { 1416d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower std::unordered_set<string> retval; 1417d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower for (const NodeDef& node_def : graph_def.node()) { 1418d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower std::vector<DataType> dt_vec; 1419d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower std::vector<TensorShape> shape_vec; 1420d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower const Status status = 1421d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower GetOutputTensorShapeType(node_def, &dt_vec, &shape_vec); 1422d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower if (!status.ok()) { 1423d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower shape_vec.clear(); 1424d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower } 1425d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower if (ops_definitions.GetOpIdFor( 1426d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower node_def.op(), DataTypeVector(dt_vec.begin(), dt_vec.end())) != 1427d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID) { 1428d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower retval.emplace(node_def.name()); 1429d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower } 1430d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower } 1431d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower return retval; 1432d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower} 1433d5f4d9bbac520ad9eae6614fe678e9d1568435a4A. Unique TensorFlower 1434bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower/* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder( 1435bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string& input, const DataType type, const TensorShape& shape, 1436bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower GraphDef* graph_def) { 1437bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const TensorId tid = ParseTensorName(input); 1438bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower CHECK_EQ(0, tid.second); 1439bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower const string node_name = tid.first.ToString(); 1440bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower for (NodeDef& node : *graph_def->mutable_node()) { 1441bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (node.name() != node_name) { 1442bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower continue; 1443bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 1444bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower if (node.op() == "Placeholder") { 1445bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 1446bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } else { 1447bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower NodeDef placeholder_node; 1448bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower placeholder_node.set_op("Placeholder"); 1449bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower placeholder_node.set_name(node_name); 1450bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower AddNodeAttr("dtype", type, &placeholder_node); 1451bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower AddNodeAttr("shape", shape, &placeholder_node); 1452bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower // TODO(satok): Remove once we merge attributes 1453bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower AddOutputTensorShapeType({type}, {shape}, &placeholder_node); 1454bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower node.Clear(); 1455bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower node = placeholder_node; 1456bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return Status::OK(); 1457bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 1458bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower } 1459bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower return errors::InvalidArgument( 1460bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower strings::StrCat(node_name, " not found for replacement.")); 1461bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower} 1462bf3e5296f7e2e98a95f260ab797ce2302092fd35A. Unique TensorFlower 14631c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr( 14641c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, 14651c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const int index, const string& executor_name, const string& node_name) { 14661c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index, 14671c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower ",", executor_name, ",", node_name); 14681c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 14691c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 14701c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr( 14711c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port, 14721c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const int index) { 14731c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index); 14741c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 14751c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 14761c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr( 14771c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower const RemoteFusedGraphExecuteInfo::NodeType node_type) { 14781c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower return strings::StrCat(static_cast<int>(node_type)); 14791c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower} 14801c5715078cf726411fc3f5667c503ab33c9f1612A. Unique TensorFlower 1481227877f2990f69bf2db56b2b5d545d778bcfeee8A. Unique TensorFlower} // namespace tensorflow 1482