1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16 17%include "tensorflow/python/platform/base.i" 18%include "cluster.i" 19 20%typemap(in) const tensorflow::MetaGraphDef& (tensorflow::MetaGraphDef temp) { 21 char* c_string; 22 Py_ssize_t py_size; 23 if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { 24 // Python has raised an error (likely TypeError or UnicodeEncodeError). 25 SWIG_fail; 26 } 27 28 if (!temp.ParseFromString(string(c_string, py_size))) { 29 PyErr_SetString( 30 PyExc_TypeError, 31 "The MetaGraphDef could not be parsed as a valid protocol buffer"); 32 SWIG_fail; 33 } 34 $1 = &temp; 35} 36 37%typemap(in) const tensorflow::RewriterConfig& ( 38 tensorflow::RewriterConfig temp) { 39 char* c_string; 40 Py_ssize_t py_size; 41 if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) { 42 // Python has raised an error (likely TypeError or UnicodeEncodeError). 43 SWIG_fail; 44 } 45 46 if (!temp.ParseFromString(string(c_string, py_size))) { 47 PyErr_SetString( 48 PyExc_TypeError, 49 "The RewriterConfig could not be parsed as a valid protocol buffer"); 50 SWIG_fail; 51 } 52 $1 = &temp; 53} 54 55%{ 56 #include <memory> 57 #include "tensorflow/c/tf_status_helper.h" 58 #include "tensorflow/core/lib/core/status.h" 59 #include "tensorflow/core/common_runtime/device.h" 60 #include "tensorflow/core/framework/device_base.h" 61 #include "tensorflow/core/common_runtime/device_factory.h" 62 #include "tensorflow/core/framework/device_attributes.pb.h" 63 #include "tensorflow/core/framework/graph.pb.h" 64 #include "tensorflow/core/grappler/grappler_item.h" 65 #include "tensorflow/core/grappler/grappler_item_builder.h" 66 #include "tensorflow/core/grappler/clusters/cluster.h" 67 #include "tensorflow/core/grappler/clusters/utils.h" 68 #include "tensorflow/core/grappler/clusters/virtual_cluster.h" 69 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" 70 #include "tensorflow/core/protobuf/meta_graph.pb.h" 71 #include "tensorflow/core/protobuf/rewriter_config.pb.h" 72 #include "tensorflow/core/public/session_options.h" 73 74 75void DetectDevices(std::unordered_map<string, tensorflow::DeviceProperties>* device_map) { 76 tensorflow::SessionOptions options; 77 std::vector<tensorflow::Device*> devices; 78 tensorflow::Status status = tensorflow::DeviceFactory::AddDevices(options, "", &devices); 79 if (!status.ok()) { 80 return; 81 } 82 83 for (const tensorflow::Device* device : devices) { 84 tensorflow::DeviceProperties& prop = (*device_map)[device->name()]; 85 prop = tensorflow::grappler::GetDeviceInfo(device->parsed_name()); 86 87 // Overwrite the memory limit since users might have requested to use only a fraction of the 88 // available device memory. 89 const tensorflow::DeviceAttributes& attr = device->attributes(); 90 prop.set_memory_size(attr.memory_limit()); 91 delete device; 92 } 93} 94 95PyObject* TF_OptimizeGraph( 96 GCluster cluster, 97 const tensorflow::RewriterConfig& rewriter_config, 98 const tensorflow::MetaGraphDef& metagraph, 99 bool verbose, const string& graph_id, TF_Status* out_status) { 100 tensorflow::grappler::ItemConfig item_config; 101 item_config.inline_functions = false; 102 item_config.apply_optimizations = false; 103 std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item = 104 tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); 105 106 if (!grappler_item) { 107 TF_SetStatus(out_status, TF_INVALID_ARGUMENT, "Failed to import metagraph, check error log for more info."); 108 return nullptr; 109 } 110 111 tensorflow::DeviceBase* cpu_device = nullptr; 112 tensorflow::GraphDef out_graph; 113 tensorflow::grappler::MetaOptimizer optimizer(cpu_device, rewriter_config); 114 tensorflow::Status status = optimizer.Optimize(cluster.get(), *grappler_item, &out_graph); 115 if (verbose) { 116 optimizer.PrintResult(); 117 } 118 tensorflow::Set_TF_Status_from_Status(out_status, status); 119 string out_graph_str = out_graph.SerializeAsString(); 120 PyObject* ret = PyBytes_FromStringAndSize(out_graph_str.data(), 121 out_graph_str.size()); 122 return ret; 123 } 124%} 125 126 127// Wrap this function 128PyObject* TF_OptimizeGraph( 129 GCluster cluster, 130 const tensorflow::RewriterConfig& rewriter_config, 131 const tensorflow::MetaGraphDef& metagraph, bool verbose, 132 const string& graph_id, TF_Status* out_status); 133 134 135 136