1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include <stdlib.h> 16#include <time.h> 17#include <algorithm> 18#include <cmath> 19#include <memory> 20#include <unordered_map> 21#include <unordered_set> 22#include <utility> 23#include <vector> 24 25#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h" 26#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 27#include "tensorflow/core/framework/op.h" 28#include "tensorflow/core/framework/op_kernel.h" 29#include "tensorflow/core/framework/tensor.h" 30#include "tensorflow/core/lib/gtl/top_n.h" 31#include "tensorflow/core/platform/types.h" 32#include "tensorflow/core/util/work_sharder.h" 33 34namespace tensorflow { 35 36using tensorforest::LeftProbabilityK; 37 38REGISTER_OP("KFeatureGradient") 39 .Attr("layer_num: int") 40 .Attr("random_seed: int") 41 .Input("input_data: float") 42 .Input("tree_parameters: float") 43 .Input("tree_biases: float") 44 .Input("routes: float") 45 .Output("routing_gradient: float") 46 .Output("data_gradient: float") 47 .Output("weight_gradient: float") 48 .Doc(R"doc( 49 Computes the derivative of the routing loss with respect to each decision 50 node. Each decision node is constrained to make a decision based on only 51 k features. 52 53 layer_num: The layer number of this tree. 54 random_seed: The base random seed. 55 56 input_data: The training batch's features as a 2-d tensor; 57 `input_data[i][j]` gives the j-th feature of the i-th input. 58 tree_parameters: `tree_parameters[i]` gives the weight of 59 the logistic regression model that translates from node features to 60 probabilities. 61 tree_biases: `tree_biases[i]` gives the bias of the logistic 62 regression model that translates from node features to 63 probabilities. 64 routes: The routes computed by routing_function_op. 65 66 routing_gradient: `routing_gradient` provides du / df, where u is the 67 routing function and f is the (vector of) decision functions. A decision 68 function f_i computes the routing decision at node i. 69 70 data_gradient: `data_gradient` provides df / dx, where f is the (vector 71 of) decision functions and x is a batch of data. 72 73 weights_gradient: `weights_gradient` provides df / dw, where f is the 74 (vector of) decision functions and w is the matrix of parameters that 75 determine how instances are routed through a tree. 76 77 f_i, the decision function at node i, is parameterized by t_i (parameters) 78 and b_i (bias) and takes data x as input. This op is called in 79 training_ops.py to compute du / df, and we use that to compute 80 81 du / dx = du / df * df / dx, 82 du / dt = du / df * df / dt, and 83 du / db = du / df * df / db. 84)doc"); 85 86class KFeatureGradient : public OpKernel { 87 public: 88 explicit KFeatureGradient(OpKernelConstruction* context) : OpKernel(context) { 89 OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_)); 90 OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_)); 91 } 92 93 void Compute(OpKernelContext* context) override { 94 // Gather input. 95 const Tensor& input_data_tensor = context->input(0); 96 const Tensor& tree_parameters_tensor = context->input(1); 97 const Tensor& tree_biases_tensor = context->input(2); 98 const Tensor& routing_tensor = context->input(3); 99 100 // Extract dimensions from input tensors. 101 const int32 num_data = 102 static_cast<int32>(input_data_tensor.shape().dim_size(0)); 103 const int32 num_features = 104 static_cast<int32>(input_data_tensor.shape().dim_size(1)); 105 const int32 num_nodes = 106 static_cast<int32>(tree_parameters_tensor.shape().dim_size(0)); 107 const int32 num_features_per_node = 108 static_cast<int32>(tree_parameters_tensor.shape().dim_size(1)); 109 110 // Construct output tensors. 111 Tensor* out_routes = nullptr; 112 TensorShape out_routes_shape; 113 out_routes_shape.AddDim(num_data); 114 out_routes_shape.AddDim(num_nodes); 115 116 Tensor* out_data = nullptr; 117 TensorShape out_data_shape; 118 out_data_shape.AddDim(num_nodes); 119 out_data_shape.AddDim(num_features); 120 121 Tensor* out_weights = nullptr; 122 TensorShape out_weights_shape; 123 out_weights_shape.AddDim(num_data); 124 out_weights_shape.AddDim(num_nodes); 125 out_weights_shape.AddDim(num_features_per_node); 126 127 OP_REQUIRES_OK(context, 128 context->allocate_output(0, out_routes_shape, &out_routes)); 129 OP_REQUIRES_OK(context, 130 context->allocate_output(1, out_data_shape, &out_data)); 131 OP_REQUIRES_OK( 132 context, context->allocate_output(2, out_weights_shape, &out_weights)); 133 134 tensorforest::Initialize(*out_data, 0.0f); 135 136 // Compute output. 137 const auto input_data = input_data_tensor.tensor<float, 2>(); 138 const auto tree_parameters = tree_parameters_tensor.tensor<float, 2>(); 139 const auto tree_biases = tree_biases_tensor.tensor<float, 1>(); 140 const auto routes = routing_tensor.tensor<float, 2>(); 141 142 auto routes_grad = out_routes->tensor<float, 2>(); 143 auto data_grad = out_data->tensor<float, 2>(); 144 auto weights_grad = out_weights->tensor<float, 3>(); 145 146 std::vector<int32> feature_set; 147 for (int i = 0; i < num_data; i++) { 148 const Tensor point = input_data_tensor.Slice(i, i + 1); 149 feature_set.clear(); 150 151 // Traverse the tree from the bottom up. 152 for (int j = num_nodes - 1; j >= 0; j--) { 153 tensorforest::GetFeatureSet(layer_num_, j, random_seed_, num_features, 154 num_features_per_node, &feature_set); 155 156 // Compute routing gradient. 157 // j is a leaf node. 158 if (j >= num_nodes / 2) { 159 routes_grad(i, j) = routes(i, j); 160 } else { // j is not a leaf node 161 int32 left_child = 2 * j + 1; 162 int32 right_child = left_child + 1; 163 164 float left_prob = LeftProbabilityK( 165 point, feature_set, tree_parameters_tensor.Slice(j, j + 1), 166 tree_biases(j), num_features, num_features_per_node); 167 168 float right_prob = 1.0f - left_prob; 169 170 routes_grad(i, j) = (right_prob * routes(i, left_child) + 171 left_prob * routes(i, right_child)); 172 } 173 // Compute data and weight gradient. 174 for (int k = 0; k < num_features_per_node; k++) { 175 CHECK_LT(feature_set[k], num_features); 176 data_grad(j, feature_set[k]) = tree_parameters(j, k); 177 weights_grad(i, j, k) = input_data(i, feature_set[k]); 178 } 179 } 180 } 181 } 182 183 private: 184 int32 layer_num_; 185 int32 random_seed_; 186}; 187 188REGISTER_KERNEL_BUILDER(Name("KFeatureGradient").Device(DEVICE_CPU), 189 KFeatureGradient); 190} // namespace tensorflow 191