1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include <stdlib.h> 16#include <time.h> 17#include <algorithm> 18#include <cmath> 19#include <memory> 20#include <unordered_map> 21#include <unordered_set> 22#include <utility> 23#include <vector> 24 25#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h" 26#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 27#include "tensorflow/core/framework/op.h" 28#include "tensorflow/core/framework/op_kernel.h" 29#include "tensorflow/core/framework/shape_inference.h" 30#include "tensorflow/core/framework/tensor.h" 31#include "tensorflow/core/lib/gtl/top_n.h" 32#include "tensorflow/core/platform/types.h" 33#include "tensorflow/core/util/work_sharder.h" 34 35namespace tensorflow { 36 37using shape_inference::InferenceContext; 38using shape_inference::ShapeHandle; 39 40using tensorforest::LeftProbability; 41 42// This op computes the derivative of the routing loss with respect to each 43// decision node. 44REGISTER_OP("RoutingGradient") 45 .Attr("max_nodes: int") 46 .Input("input_data: float") 47 .Input("tree_parameters: float") 48 .Input("tree_biases: float") 49 .Input("routes: float") 50 .Output("routing_gradient: float") 51 .SetShapeFn([](InferenceContext* c) { 52 ShapeHandle input, params; 53 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 54 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, ¶ms)); 55 56 c->set_output(0, c->Matrix(c->Dim(input, 0), c->Dim(params, 0))); 57 return Status::OK(); 58 }) 59 .Doc(R"doc( 60 Computes the derivative of the routing loss with respect to each decision 61 node. 62 63 max_nodes: The number of nodes in the tree. 64 65 tree_parameters: `tree_parameters[i]` gives the weight of 66 the logistic regression model that translates from node features to 67 probabilities. 68 tree_biases: `tree_biases[i]` gives the bias of the logistic 69 regression model that translates from node features to 70 probabilities. 71 routes: The routes computed by routing_function_op. 72 73 routing_gradient: `routing_gradient` provides du / df, where u is the routing 74 function and f is the (vector of) decision functions. A decision function 75 f_i computes the routing decision at node i. 76 77 f_i is parameterized by t_i (parameters) and b_i (bias) and takes data x as 78 input. This op is called in training_ops.py to compute du / df, and we use 79 that to compute 80 81 du / dx = du / df * df / dx, 82 du / dt = du / df * df / dt, and 83 du / db = du / df * df / db. 84)doc"); 85 86class RoutingGradient : public OpKernel { 87 public: 88 explicit RoutingGradient(OpKernelConstruction* context) : OpKernel(context) { 89 OP_REQUIRES_OK(context, context->GetAttr("max_nodes", &max_nodes_)); 90 } 91 92 void Compute(OpKernelContext* context) override { 93 const Tensor& input_data = context->input(0); 94 const Tensor& tree_parameters_tensor = context->input(1); 95 const Tensor& tree_biases_tensor = context->input(2); 96 const Tensor& routing_tensor = context->input(3); 97 98 // TODO(atwoodj): Add dimension checks. 99 100 const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0)); 101 const int32 num_features = 102 static_cast<int32>(input_data.shape().dim_size(1)); 103 104 Tensor* output = nullptr; 105 TensorShape output_shape; 106 output_shape.AddDim(num_data); 107 output_shape.AddDim(max_nodes_); 108 109 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 110 111 auto out = output->tensor<float, 2>(); 112 const auto tree_biases = tree_biases_tensor.tensor<float, 1>(); 113 const auto routes = routing_tensor.tensor<float, 2>(); 114 115 // A derivation of the gradient can be found at go/routingderivation. 116 for (int i = 0; i < num_data; i++) { 117 const Tensor point = input_data.Slice(i, i + 1); 118 119 // Traverses the tree from the bottom up. 120 for (int j = max_nodes_ - 1; j >= 0; j--) { 121 // j is a leaf node 122 if (j >= max_nodes_ / 2) { 123 out(i, j) = routes(i, j); 124 } else { // j is not a leaf node 125 int32 left_child = 2 * j + 1; 126 int32 right_child = left_child + 1; 127 float left_prob = 128 LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), 129 tree_biases(j), num_features); 130 131 float right_prob = 1 - left_prob; 132 133 out(i, j) = (right_prob * routes(i, left_child) + 134 left_prob * routes(i, right_child)); 135 } 136 } 137 } 138 } 139 140 private: 141 int32 max_nodes_; 142}; 143 144REGISTER_KERNEL_BUILDER(Name("RoutingGradient").Device(DEVICE_CPU), 145 RoutingGradient); 146} // namespace tensorflow 147