1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h"
16
17#include "tensorflow/core/framework/graph.pb.h"
18#include "tensorflow/core/lib/io/path.h"
19#include "tensorflow/core/platform/env.h"
20
21namespace tensorflow {
22namespace tensorforest {
23
24// Names of ops in the graph to run.
25constexpr char kInitializeOp[] = "init";
26constexpr char kAddExampleOp[] = "add_example";
27constexpr char kSplitScoreName[] = "split_score";
28constexpr char kGetSplitName[] = "get_split";
29constexpr char kGetLeftStatsName[] = "get_left_stats";
30constexpr char kGetRightStatsName[] = "get_right_stats";
31
32// Names of files written by python graph builder.
33constexpr char kGraphFilename[] = "graph";
34constexpr char kSaverDefFilename[] = "saver";
35constexpr char kMetaDefFilename[] = "meta";
36
37// Names of Tensor inputs.
38constexpr char kFeaturesName[] = "features";
39constexpr char kInputDataName[] = "input_data";
40constexpr char kTargetsName[] = "targets";
41constexpr char kExamplesName[] = "examples";
42
43constexpr char kNoOp[] = "none";
44
45CandidateGraphRunner::CandidateGraphRunner(
46    const string& graph_dir, const decision_trees::BinaryNode& split)
47    : split_(split) {
48  // read graph from file.
49  GraphDef graph_def;
50  TF_CHECK_OK(ReadBinaryProto(
51      Env::Default(), io::JoinPath(graph_dir, kGraphFilename), &graph_def))
52      << "Could not read graph def.";
53
54  // create session.
55  session_.reset(::tensorflow::NewSession(SessionOptions()));
56  TF_CHECK_OK(session_->Create(graph_def)) << "Failed to create session";
57
58  // Features don't change, store them in a tensor.
59  const auto& oblique = split.inequality_left_child_test().oblique();
60  const int32 feat_size = oblique.features_size();
61  features_.reset(new Tensor(tensorflow::DT_INT32, TensorShape({feat_size})));
62  auto feat = features_->flat<int32>();
63  int i = 0;
64  for (const auto& id : oblique.features()) {
65    safe_strto32(id.id().value(), &feat(i++));
66  }
67}
68
69void CandidateGraphRunner::RunOp(const string& name,
70                                 const TensorNameValueList& inputs,
71                                 const std::vector<string>& output_tensor_names,
72                                 std::vector<Tensor>* outputs) {
73  std::vector<string> op_name;
74  if (name != kNoOp) {
75    op_name.push_back(name);
76  }
77  TF_CHECK_OK(session_->Run(inputs, output_tensor_names, op_name, outputs))
78      << "Failed to run: " << name;
79}
80
81void CandidateGraphRunner::Init() {
82  RunOp(kInitializeOp, TensorNameValueList(), std::vector<string>(), nullptr);
83}
84
85void CandidateGraphRunner::AddExample(const Tensor& input_data,
86                                      const Tensor& target,
87                                      const Tensor& examples) {
88  TensorNameValueList inputs;
89  inputs.emplace_back(kFeaturesName, *features_);
90  inputs.emplace_back(kExamplesName, examples);
91  inputs.emplace_back(kInputDataName, input_data);
92  inputs.emplace_back(kTargetsName, target);
93
94  RunOp(kAddExampleOp, inputs, std::vector<string>(), nullptr);
95}
96
97float CandidateGraphRunner::SplitScore() {
98  std::vector<Tensor> outputs;
99  RunOp(kNoOp, TensorNameValueList(), {kSplitScoreName}, &outputs);
100  return outputs[0].unaligned_flat<float>()(0);
101}
102
103void CandidateGraphRunner::GetSplit(decision_trees::BinaryNode* node) {
104  std::vector<Tensor> outputs;
105  RunOp(kNoOp, TensorNameValueList(), {kGetSplitName}, &outputs);
106  ParseProtoUnlimited(node, outputs[0].unaligned_flat<string>()(0));
107  const auto& oblique = split_.inequality_left_child_test().oblique();
108  auto* new_split =
109      node->mutable_inequality_left_child_test()->mutable_oblique();
110  for (const auto& id : oblique.features()) {
111    *new_split->add_features() = id;
112  }
113}
114
115void CandidateGraphRunner::GetLeftStats(LeafStat* stats) {
116  std::vector<Tensor> outputs;
117  RunOp(kNoOp, TensorNameValueList(), {kGetLeftStatsName}, &outputs);
118  const auto& counts = outputs[0].unaligned_flat<float>();
119  auto* dense = stats->mutable_classification()->mutable_dense_counts();
120  for (int i = 0; i < counts.size(); ++i) {
121    dense->add_value()->set_float_value(counts(i));
122  }
123}
124
125void CandidateGraphRunner::GetRightStats(LeafStat* stats) {
126  std::vector<Tensor> outputs;
127  RunOp(kNoOp, TensorNameValueList(), {kGetRightStatsName}, &outputs);
128  const auto& counts = outputs[0].unaligned_flat<float>();
129  auto* dense = stats->mutable_classification()->mutable_dense_counts();
130  for (int i = 0; i < counts.size(); ++i) {
131    dense->add_value()->set_float_value(counts(i));
132  }
133}
134
135}  // namespace tensorforest
136}  // namespace tensorflow
137