1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15// RoutingFunction returns the probability of reaching each leaf node 16// in a soft decision tree. 17 18#include <stdlib.h> 19#include <time.h> 20#include <algorithm> 21#include <cmath> 22#include <memory> 23#include <unordered_map> 24#include <unordered_set> 25#include <utility> 26#include <vector> 27 28#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h" 29#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 30#include "tensorflow/core/framework/op.h" 31#include "tensorflow/core/framework/op_kernel.h" 32#include "tensorflow/core/framework/shape_inference.h" 33#include "tensorflow/core/framework/tensor.h" 34#include "tensorflow/core/lib/gtl/top_n.h" 35#include "tensorflow/core/platform/types.h" 36#include "tensorflow/core/util/work_sharder.h" 37 38namespace tensorflow { 39 40using shape_inference::InferenceContext; 41using shape_inference::ShapeHandle; 42 43using tensorforest::CheckTensorBounds; 44using tensorforest::LeftProbability; 45 46// The term 'routing function' is synonymous with 'the probability 47// that an instance is routed to each leaf node.' It is defined in 48// 'Deep Neural Decision Forests' by Kontschieder et al. 49REGISTER_OP("RoutingFunction") 50 .Attr("max_nodes: int") 51 .Input("input_data: float") 52 .Input("tree_parameters: float") 53 .Input("tree_biases: float") 54 .Output("probabilities: float") 55 .SetShapeFn([](InferenceContext* c) { 56 ShapeHandle input, params; 57 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 58 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, ¶ms)); 59 60 c->set_output(0, c->Matrix(c->Dim(input, 0), c->Dim(params, 0))); 61 return Status::OK(); 62 }) 63 .Doc(R"doc( 64 Returns the probability that each input will reach each leaf node. 65 66 max_nodes: The number of nodes in the tree. 67 68 input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` 69 gives the j-th feature of the i-th input. 70 tree_parameters: `tree_parameters[i]` gives the weight of 71 the logistic regression model that translates from node features to 72 probabilities. 73 tree_biases: `tree_biases[i]` gives the bias of the logistic 74 regression model that translates from node features to 75 probabilities. 76 77 probabilities: `probabilities[i][j]` is the probability that input i 78 will reach node j. 79)doc"); 80 81class RoutingFunction : public OpKernel { 82 public: 83 explicit RoutingFunction(OpKernelConstruction* context) : OpKernel(context) { 84 OP_REQUIRES_OK(context, context->GetAttr("max_nodes", &max_nodes_)); 85 } 86 87 void Compute(OpKernelContext* context) override { 88 const Tensor& input_data = context->input(0); 89 const Tensor& tree_parameters_tensor = context->input(1); 90 const Tensor& tree_biases_tensor = context->input(2); 91 92 if (input_data.shape().dim_size(0) > 0) { 93 OP_REQUIRES( 94 context, input_data.shape().dims() == 2, 95 errors::InvalidArgument("input_data should be two-dimensional")); 96 } 97 98 // Check tensor bounds. 99 if (!CheckTensorBounds(context, input_data)) return; 100 101 const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0)); 102 const int32 num_features = 103 static_cast<int32>(input_data.shape().dim_size(1)); 104 105 Tensor* output_probabilities = nullptr; 106 TensorShape output_shape; 107 output_shape.AddDim(num_data); 108 output_shape.AddDim(max_nodes_); 109 110 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, 111 &output_probabilities)); 112 113 auto out_probs = output_probabilities->tensor<float, 2>(); 114 const auto tree_biases = tree_biases_tensor.tensor<float, 1>(); 115 116 // Iteratively compute the probability of reaching each leaf. 117 for (int i = 0; i < num_data; i++) { 118 const Tensor point = input_data.Slice(i, i + 1); 119 120 out_probs(i, 0) = 1.0; 121 122 for (int j = 0; j < max_nodes_ / 2; j++) { 123 int32 left_child = 2 * j + 1; 124 int32 right_child = left_child + 1; 125 126 float prob = out_probs(i, j); 127 float left_prob = 128 LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), 129 tree_biases(j), num_features); 130 131 out_probs(i, left_child) = prob * left_prob; 132 out_probs(i, right_child) = prob * (1.0 - left_prob); 133 } 134 } 135 } 136 137 private: 138 int32 max_nodes_; 139}; 140 141REGISTER_KERNEL_BUILDER(Name("RoutingFunction").Device(DEVICE_CPU), 142 RoutingFunction); 143} // namespace tensorflow 144