1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15// RoutingFunction returns the probability of reaching each leaf node
16// in a soft decision tree.
17
18#include <stdlib.h>
19#include <time.h>
20#include <algorithm>
21#include <cmath>
22#include <memory>
23#include <unordered_map>
24#include <unordered_set>
25#include <utility>
26#include <vector>
27
28#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
29#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
30#include "tensorflow/core/framework/op.h"
31#include "tensorflow/core/framework/op_kernel.h"
32#include "tensorflow/core/framework/shape_inference.h"
33#include "tensorflow/core/framework/tensor.h"
34#include "tensorflow/core/lib/gtl/top_n.h"
35#include "tensorflow/core/platform/types.h"
36#include "tensorflow/core/util/work_sharder.h"
37
38namespace tensorflow {
39
40using shape_inference::InferenceContext;
41using shape_inference::ShapeHandle;
42
43using tensorforest::CheckTensorBounds;
44using tensorforest::LeftProbability;
45
46// The term 'routing function' is synonymous with 'the probability
47// that an instance is routed to each leaf node.'  It is defined in
48// 'Deep Neural Decision Forests' by Kontschieder et al.
49REGISTER_OP("HardRoutingFunction")
50    .Attr("max_nodes: int")
51    .Attr("tree_depth: int")
52    .Input("input_data: float")
53    .Input("tree_parameters: float")
54    .Input("tree_biases: float")
55    .Output("path_probability: float")
56    .Output("path: int32")
57    .SetShapeFn([](InferenceContext* c) {
58      ShapeHandle input;
59      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
60      int64 tree_depth;
61      TF_RETURN_IF_ERROR(c->GetAttr("tree_depth", &tree_depth));
62
63      auto out = c->Matrix(c->Dim(input, 0), tree_depth);
64      c->set_output(0, out);
65      c->set_output(1, out);
66      return Status::OK();
67    })
68    .Doc(R"doc(
69  Chooses a single path for each instance in `input_data` and returns the leaf
70  the probability of the path and the path taken.
71
72  tree_depth: The depth of the decision tree.
73
74  input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
75   gives the j-th feature of the i-th input.
76  tree_parameters: `tree_parameters[i]` gives the weight of
77   the logistic regression model that translates from node features to
78   probabilities.
79  tree_biases: `tree_biases[i]` gives the bias of the logistic
80   regression model that translates from node features to
81   probabilities.
82
83  path_probility: `path_probability[i]` gives the probability of reaching each
84   node in `path[i]`.
85  path: `path[i][j]` gives the jth node in the path taken by the ith data
86   instance.
87)doc");
88
89class HardRoutingFunction : public OpKernel {
90 public:
91  explicit HardRoutingFunction(OpKernelConstruction* context)
92      : OpKernel(context) {
93    OP_REQUIRES_OK(context, context->GetAttr("tree_depth", &tree_depth_));
94  }
95
96  void Compute(OpKernelContext* context) override {
97    const Tensor& input_data = context->input(0);
98    const Tensor& tree_parameters_tensor = context->input(1);
99    const Tensor& tree_biases_tensor = context->input(2);
100
101    if (input_data.shape().dim_size(0) > 0) {
102      OP_REQUIRES(
103          context, input_data.shape().dims() == 2,
104          errors::InvalidArgument("input_data should be two-dimensional"));
105    }
106
107    // Check tensor bounds.
108    if (!CheckTensorBounds(context, input_data)) return;
109
110    const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
111    const int32 num_features =
112        static_cast<int32>(input_data.shape().dim_size(1));
113
114    Tensor* output_probability = nullptr;
115    TensorShape output_probability_shape;
116    output_probability_shape.AddDim(num_data);
117    output_probability_shape.AddDim(tree_depth_);
118
119    Tensor* output_path = nullptr;
120    TensorShape output_path_shape;
121    output_path_shape.AddDim(num_data);
122    output_path_shape.AddDim(tree_depth_);
123
124    OP_REQUIRES_OK(context,
125                   context->allocate_output(0, output_probability_shape,
126                                            &output_probability));
127    OP_REQUIRES_OK(
128        context, context->allocate_output(1, output_path_shape, &output_path));
129
130    auto out_probability = output_probability->tensor<float, 2>();
131    auto out_path = output_path->tensor<int32, 2>();
132
133    const auto data = input_data.tensor<float, 2>();
134    const auto tree_parameters = tree_parameters_tensor.tensor<float, 2>();
135    const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
136
137    // Deterministically traverse the tree to a leaf.
138    for (int i = 0; i < num_data; i++) {
139      const Tensor point = input_data.Slice(i, i + 1);
140      int32 node = 0;
141
142      out_probability(i, 0) = 1.0;
143      out_path(i, 0) = 0;
144      for (int j = 0; j < tree_depth_ - 1; j++) {
145        float left_prob =
146            LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1),
147                            tree_biases(j), num_features);
148
149        int32 left_child = 2 * node + 1;
150        int32 right_child = left_child + 1;
151
152        float dot_product = 0.0;
153        for (int k = 0; k < num_features; k++) {
154          dot_product += tree_parameters(j, k) * data(i, k);
155        }
156        if (dot_product < tree_biases(j)) {
157          out_probability(i, j + 1) = left_prob * out_probability(i, j);
158          out_path(i, j + 1) = left_child;
159          node = left_child;
160        } else {
161          out_probability(i, j + 1) = (1.0 - left_prob) * out_probability(i, j);
162          out_path(i, j + 1) = right_child;
163          node = right_child;
164        }
165      }
166    }
167  }
168
169 private:
170  int32 tree_depth_;
171};
172
173REGISTER_KERNEL_BUILDER(Name("HardRoutingFunction").Device(DEVICE_CPU),
174                        HardRoutingFunction);
175}  // namespace tensorflow
176