1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h" 16 17#include "tensorflow/core/framework/graph.pb.h" 18#include "tensorflow/core/lib/io/path.h" 19#include "tensorflow/core/platform/env.h" 20 21namespace tensorflow { 22namespace tensorforest { 23 24// Names of ops in the graph to run. 25constexpr char kInitializeOp[] = "init"; 26constexpr char kAddExampleOp[] = "add_example"; 27constexpr char kSplitScoreName[] = "split_score"; 28constexpr char kGetSplitName[] = "get_split"; 29constexpr char kGetLeftStatsName[] = "get_left_stats"; 30constexpr char kGetRightStatsName[] = "get_right_stats"; 31 32// Names of files written by python graph builder. 33constexpr char kGraphFilename[] = "graph"; 34constexpr char kSaverDefFilename[] = "saver"; 35constexpr char kMetaDefFilename[] = "meta"; 36 37// Names of Tensor inputs. 38constexpr char kFeaturesName[] = "features"; 39constexpr char kInputDataName[] = "input_data"; 40constexpr char kTargetsName[] = "targets"; 41constexpr char kExamplesName[] = "examples"; 42 43constexpr char kNoOp[] = "none"; 44 45CandidateGraphRunner::CandidateGraphRunner( 46 const string& graph_dir, const decision_trees::BinaryNode& split) 47 : split_(split) { 48 // read graph from file. 49 GraphDef graph_def; 50 TF_CHECK_OK(ReadBinaryProto( 51 Env::Default(), io::JoinPath(graph_dir, kGraphFilename), &graph_def)) 52 << "Could not read graph def."; 53 54 // create session. 55 session_.reset(::tensorflow::NewSession(SessionOptions())); 56 TF_CHECK_OK(session_->Create(graph_def)) << "Failed to create session"; 57 58 // Features don't change, store them in a tensor. 59 const auto& oblique = split.inequality_left_child_test().oblique(); 60 const int32 feat_size = oblique.features_size(); 61 features_.reset(new Tensor(tensorflow::DT_INT32, TensorShape({feat_size}))); 62 auto feat = features_->flat<int32>(); 63 int i = 0; 64 for (const auto& id : oblique.features()) { 65 safe_strto32(id.id().value(), &feat(i++)); 66 } 67} 68 69void CandidateGraphRunner::RunOp(const string& name, 70 const TensorNameValueList& inputs, 71 const std::vector<string>& output_tensor_names, 72 std::vector<Tensor>* outputs) { 73 std::vector<string> op_name; 74 if (name != kNoOp) { 75 op_name.push_back(name); 76 } 77 TF_CHECK_OK(session_->Run(inputs, output_tensor_names, op_name, outputs)) 78 << "Failed to run: " << name; 79} 80 81void CandidateGraphRunner::Init() { 82 RunOp(kInitializeOp, TensorNameValueList(), std::vector<string>(), nullptr); 83} 84 85void CandidateGraphRunner::AddExample(const Tensor& input_data, 86 const Tensor& target, 87 const Tensor& examples) { 88 TensorNameValueList inputs; 89 inputs.emplace_back(kFeaturesName, *features_); 90 inputs.emplace_back(kExamplesName, examples); 91 inputs.emplace_back(kInputDataName, input_data); 92 inputs.emplace_back(kTargetsName, target); 93 94 RunOp(kAddExampleOp, inputs, std::vector<string>(), nullptr); 95} 96 97float CandidateGraphRunner::SplitScore() { 98 std::vector<Tensor> outputs; 99 RunOp(kNoOp, TensorNameValueList(), {kSplitScoreName}, &outputs); 100 return outputs[0].unaligned_flat<float>()(0); 101} 102 103void CandidateGraphRunner::GetSplit(decision_trees::BinaryNode* node) { 104 std::vector<Tensor> outputs; 105 RunOp(kNoOp, TensorNameValueList(), {kGetSplitName}, &outputs); 106 ParseProtoUnlimited(node, outputs[0].unaligned_flat<string>()(0)); 107 const auto& oblique = split_.inequality_left_child_test().oblique(); 108 auto* new_split = 109 node->mutable_inequality_left_child_test()->mutable_oblique(); 110 for (const auto& id : oblique.features()) { 111 *new_split->add_features() = id; 112 } 113} 114 115void CandidateGraphRunner::GetLeftStats(LeafStat* stats) { 116 std::vector<Tensor> outputs; 117 RunOp(kNoOp, TensorNameValueList(), {kGetLeftStatsName}, &outputs); 118 const auto& counts = outputs[0].unaligned_flat<float>(); 119 auto* dense = stats->mutable_classification()->mutable_dense_counts(); 120 for (int i = 0; i < counts.size(); ++i) { 121 dense->add_value()->set_float_value(counts(i)); 122 } 123} 124 125void CandidateGraphRunner::GetRightStats(LeafStat* stats) { 126 std::vector<Tensor> outputs; 127 RunOp(kNoOp, TensorNameValueList(), {kGetRightStatsName}, &outputs); 128 const auto& counts = outputs[0].unaligned_flat<float>(); 129 auto* dense = stats->mutable_classification()->mutable_dense_counts(); 130 for (int i = 0; i < counts.size(); ++i) { 131 dense->add_value()->set_float_value(counts(i)); 132 } 133} 134 135} // namespace tensorforest 136} // namespace tensorflow 137