1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15// RoutingFunction returns the probability of reaching each leaf node 16// in a soft decision tree. 17 18#include <stdlib.h> 19#include <time.h> 20#include <algorithm> 21#include <cmath> 22#include <memory> 23#include <unordered_map> 24#include <unordered_set> 25#include <utility> 26#include <vector> 27 28#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h" 29#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 30#include "tensorflow/core/framework/op.h" 31#include "tensorflow/core/framework/op_kernel.h" 32#include "tensorflow/core/framework/shape_inference.h" 33#include "tensorflow/core/framework/tensor.h" 34#include "tensorflow/core/lib/gtl/top_n.h" 35#include "tensorflow/core/platform/types.h" 36#include "tensorflow/core/util/work_sharder.h" 37 38namespace tensorflow { 39 40using shape_inference::InferenceContext; 41using shape_inference::ShapeHandle; 42 43using tensorforest::CheckTensorBounds; 44using tensorforest::LeftProbability; 45 46// The term 'routing function' is synonymous with 'the probability 47// that an instance is routed to each leaf node.' It is defined in 48// 'Deep Neural Decision Forests' by Kontschieder et al. 49REGISTER_OP("HardRoutingFunction") 50 .Attr("max_nodes: int") 51 .Attr("tree_depth: int") 52 .Input("input_data: float") 53 .Input("tree_parameters: float") 54 .Input("tree_biases: float") 55 .Output("path_probability: float") 56 .Output("path: int32") 57 .SetShapeFn([](InferenceContext* c) { 58 ShapeHandle input; 59 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 60 int64 tree_depth; 61 TF_RETURN_IF_ERROR(c->GetAttr("tree_depth", &tree_depth)); 62 63 auto out = c->Matrix(c->Dim(input, 0), tree_depth); 64 c->set_output(0, out); 65 c->set_output(1, out); 66 return Status::OK(); 67 }) 68 .Doc(R"doc( 69 Chooses a single path for each instance in `input_data` and returns the leaf 70 the probability of the path and the path taken. 71 72 tree_depth: The depth of the decision tree. 73 74 input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` 75 gives the j-th feature of the i-th input. 76 tree_parameters: `tree_parameters[i]` gives the weight of 77 the logistic regression model that translates from node features to 78 probabilities. 79 tree_biases: `tree_biases[i]` gives the bias of the logistic 80 regression model that translates from node features to 81 probabilities. 82 83 path_probility: `path_probability[i]` gives the probability of reaching each 84 node in `path[i]`. 85 path: `path[i][j]` gives the jth node in the path taken by the ith data 86 instance. 87)doc"); 88 89class HardRoutingFunction : public OpKernel { 90 public: 91 explicit HardRoutingFunction(OpKernelConstruction* context) 92 : OpKernel(context) { 93 OP_REQUIRES_OK(context, context->GetAttr("tree_depth", &tree_depth_)); 94 } 95 96 void Compute(OpKernelContext* context) override { 97 const Tensor& input_data = context->input(0); 98 const Tensor& tree_parameters_tensor = context->input(1); 99 const Tensor& tree_biases_tensor = context->input(2); 100 101 if (input_data.shape().dim_size(0) > 0) { 102 OP_REQUIRES( 103 context, input_data.shape().dims() == 2, 104 errors::InvalidArgument("input_data should be two-dimensional")); 105 } 106 107 // Check tensor bounds. 108 if (!CheckTensorBounds(context, input_data)) return; 109 110 const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0)); 111 const int32 num_features = 112 static_cast<int32>(input_data.shape().dim_size(1)); 113 114 Tensor* output_probability = nullptr; 115 TensorShape output_probability_shape; 116 output_probability_shape.AddDim(num_data); 117 output_probability_shape.AddDim(tree_depth_); 118 119 Tensor* output_path = nullptr; 120 TensorShape output_path_shape; 121 output_path_shape.AddDim(num_data); 122 output_path_shape.AddDim(tree_depth_); 123 124 OP_REQUIRES_OK(context, 125 context->allocate_output(0, output_probability_shape, 126 &output_probability)); 127 OP_REQUIRES_OK( 128 context, context->allocate_output(1, output_path_shape, &output_path)); 129 130 auto out_probability = output_probability->tensor<float, 2>(); 131 auto out_path = output_path->tensor<int32, 2>(); 132 133 const auto data = input_data.tensor<float, 2>(); 134 const auto tree_parameters = tree_parameters_tensor.tensor<float, 2>(); 135 const auto tree_biases = tree_biases_tensor.tensor<float, 1>(); 136 137 // Deterministically traverse the tree to a leaf. 138 for (int i = 0; i < num_data; i++) { 139 const Tensor point = input_data.Slice(i, i + 1); 140 int32 node = 0; 141 142 out_probability(i, 0) = 1.0; 143 out_path(i, 0) = 0; 144 for (int j = 0; j < tree_depth_ - 1; j++) { 145 float left_prob = 146 LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), 147 tree_biases(j), num_features); 148 149 int32 left_child = 2 * node + 1; 150 int32 right_child = left_child + 1; 151 152 float dot_product = 0.0; 153 for (int k = 0; k < num_features; k++) { 154 dot_product += tree_parameters(j, k) * data(i, k); 155 } 156 if (dot_product < tree_biases(j)) { 157 out_probability(i, j + 1) = left_prob * out_probability(i, j); 158 out_path(i, j + 1) = left_child; 159 node = left_child; 160 } else { 161 out_probability(i, j + 1) = (1.0 - left_prob) * out_probability(i, j); 162 out_path(i, j + 1) = right_child; 163 node = right_child; 164 } 165 } 166 } 167 } 168 169 private: 170 int32 tree_depth_; 171}; 172 173REGISTER_KERNEL_BUILDER(Name("HardRoutingFunction").Device(DEVICE_CPU), 174 HardRoutingFunction); 175} // namespace tensorflow 176