1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15// RoutingFunction returns the probability of reaching each leaf node
16// in a soft decision tree.
17
18#include <stdlib.h>
19#include <time.h>
20#include <algorithm>
21#include <cmath>
22#include <memory>
23#include <unordered_map>
24#include <unordered_set>
25#include <utility>
26#include <vector>
27
28#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
29#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
30#include "tensorflow/core/framework/op.h"
31#include "tensorflow/core/framework/op_kernel.h"
32#include "tensorflow/core/framework/shape_inference.h"
33#include "tensorflow/core/framework/tensor.h"
34#include "tensorflow/core/lib/gtl/top_n.h"
35#include "tensorflow/core/platform/types.h"
36#include "tensorflow/core/util/work_sharder.h"
37
38namespace tensorflow {
39
40using shape_inference::InferenceContext;
41using shape_inference::ShapeHandle;
42
43using tensorforest::CheckTensorBounds;
44using tensorforest::LeftProbability;
45
46// The term 'routing function' is synonymous with 'the probability
47// that an instance is routed to each leaf node.'  It is defined in
48// 'Deep Neural Decision Forests' by Kontschieder et al.
49REGISTER_OP("RoutingFunction")
50    .Attr("max_nodes: int")
51    .Input("input_data: float")
52    .Input("tree_parameters: float")
53    .Input("tree_biases: float")
54    .Output("probabilities: float")
55    .SetShapeFn([](InferenceContext* c) {
56      ShapeHandle input, params;
57      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
58      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &params));
59
60      c->set_output(0, c->Matrix(c->Dim(input, 0), c->Dim(params, 0)));
61      return Status::OK();
62    })
63    .Doc(R"doc(
64  Returns the probability that each input will reach each leaf node.
65
66  max_nodes: The number of nodes in the tree.
67
68  input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
69   gives the j-th feature of the i-th input.
70  tree_parameters: `tree_parameters[i]` gives the weight of
71   the logistic regression model that translates from node features to
72   probabilities.
73  tree_biases: `tree_biases[i]` gives the bias of the logistic
74   regression model that translates from node features to
75   probabilities.
76
77  probabilities: `probabilities[i][j]` is the probability that input i
78   will reach node j.
79)doc");
80
81class RoutingFunction : public OpKernel {
82 public:
83  explicit RoutingFunction(OpKernelConstruction* context) : OpKernel(context) {
84    OP_REQUIRES_OK(context, context->GetAttr("max_nodes", &max_nodes_));
85  }
86
87  void Compute(OpKernelContext* context) override {
88    const Tensor& input_data = context->input(0);
89    const Tensor& tree_parameters_tensor = context->input(1);
90    const Tensor& tree_biases_tensor = context->input(2);
91
92    if (input_data.shape().dim_size(0) > 0) {
93      OP_REQUIRES(
94          context, input_data.shape().dims() == 2,
95          errors::InvalidArgument("input_data should be two-dimensional"));
96    }
97
98    // Check tensor bounds.
99    if (!CheckTensorBounds(context, input_data)) return;
100
101    const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
102    const int32 num_features =
103        static_cast<int32>(input_data.shape().dim_size(1));
104
105    Tensor* output_probabilities = nullptr;
106    TensorShape output_shape;
107    output_shape.AddDim(num_data);
108    output_shape.AddDim(max_nodes_);
109
110    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape,
111                                                     &output_probabilities));
112
113    auto out_probs = output_probabilities->tensor<float, 2>();
114    const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
115
116    // Iteratively compute the probability of reaching each leaf.
117    for (int i = 0; i < num_data; i++) {
118      const Tensor point = input_data.Slice(i, i + 1);
119
120      out_probs(i, 0) = 1.0;
121
122      for (int j = 0; j < max_nodes_ / 2; j++) {
123        int32 left_child = 2 * j + 1;
124        int32 right_child = left_child + 1;
125
126        float prob = out_probs(i, j);
127        float left_prob =
128            LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1),
129                            tree_biases(j), num_features);
130
131        out_probs(i, left_child) = prob * left_prob;
132        out_probs(i, right_child) = prob * (1.0 - left_prob);
133      }
134    }
135  }
136
137 private:
138  int32 max_nodes_;
139};
140
141REGISTER_KERNEL_BUILDER(Name("RoutingFunction").Device(DEVICE_CPU),
142                        RoutingFunction);
143}  // namespace tensorflow
144