1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include <stdlib.h>
16#include <time.h>
17#include <algorithm>
18#include <cmath>
19#include <memory>
20#include <unordered_map>
21#include <unordered_set>
22#include <utility>
23#include <vector>
24
25#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
26#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
27#include "tensorflow/core/framework/op.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/tensor.h"
30#include "tensorflow/core/lib/gtl/top_n.h"
31#include "tensorflow/core/platform/types.h"
32#include "tensorflow/core/util/work_sharder.h"
33
34namespace tensorflow {
35
36using tensorforest::LeftProbabilityK;
37
38REGISTER_OP("KFeatureGradient")
39    .Attr("layer_num: int")
40    .Attr("random_seed: int")
41    .Input("input_data: float")
42    .Input("tree_parameters: float")
43    .Input("tree_biases: float")
44    .Input("routes: float")
45    .Output("routing_gradient: float")
46    .Output("data_gradient: float")
47    .Output("weight_gradient: float")
48    .Doc(R"doc(
49    Computes the derivative of the routing loss with respect to each decision
50    node.  Each decision node is constrained to make a decision based on only
51    k features.
52
53    layer_num: The layer number of this tree.
54    random_seed: The base random seed.
55
56    input_data: The training batch's features as a 2-d tensor;
57     `input_data[i][j]` gives the j-th feature of the i-th input.
58    tree_parameters: `tree_parameters[i]` gives the weight of
59     the logistic regression model that translates from node features to
60     probabilities.
61    tree_biases: `tree_biases[i]` gives the bias of the logistic
62     regression model that translates from node features to
63     probabilities.
64    routes: The routes computed by routing_function_op.
65
66    routing_gradient: `routing_gradient` provides du / df, where u is the
67     routing function and f is the (vector of) decision functions.  A decision
68     function f_i computes the routing decision at node i.
69
70    data_gradient: `data_gradient` provides df / dx, where f is the (vector
71     of) decision functions and x is a batch of data.
72
73    weights_gradient: `weights_gradient` provides df / dw, where f is the
74     (vector of) decision functions and w is the matrix of parameters that
75     determine how instances are routed through a tree.
76
77    f_i, the decision function at node i, is parameterized by t_i (parameters)
78    and b_i (bias) and takes data x as input.  This op is called in
79    training_ops.py to compute du / df, and we use that to compute
80
81    du / dx = du / df * df / dx,
82    du / dt = du / df * df / dt, and
83    du / db = du / df * df / db.
84)doc");
85
86class KFeatureGradient : public OpKernel {
87 public:
88  explicit KFeatureGradient(OpKernelConstruction* context) : OpKernel(context) {
89    OP_REQUIRES_OK(context, context->GetAttr("layer_num", &layer_num_));
90    OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
91  }
92
93  void Compute(OpKernelContext* context) override {
94    // Gather input.
95    const Tensor& input_data_tensor = context->input(0);
96    const Tensor& tree_parameters_tensor = context->input(1);
97    const Tensor& tree_biases_tensor = context->input(2);
98    const Tensor& routing_tensor = context->input(3);
99
100    // Extract dimensions from input tensors.
101    const int32 num_data =
102        static_cast<int32>(input_data_tensor.shape().dim_size(0));
103    const int32 num_features =
104        static_cast<int32>(input_data_tensor.shape().dim_size(1));
105    const int32 num_nodes =
106        static_cast<int32>(tree_parameters_tensor.shape().dim_size(0));
107    const int32 num_features_per_node =
108        static_cast<int32>(tree_parameters_tensor.shape().dim_size(1));
109
110    // Construct output tensors.
111    Tensor* out_routes = nullptr;
112    TensorShape out_routes_shape;
113    out_routes_shape.AddDim(num_data);
114    out_routes_shape.AddDim(num_nodes);
115
116    Tensor* out_data = nullptr;
117    TensorShape out_data_shape;
118    out_data_shape.AddDim(num_nodes);
119    out_data_shape.AddDim(num_features);
120
121    Tensor* out_weights = nullptr;
122    TensorShape out_weights_shape;
123    out_weights_shape.AddDim(num_data);
124    out_weights_shape.AddDim(num_nodes);
125    out_weights_shape.AddDim(num_features_per_node);
126
127    OP_REQUIRES_OK(context,
128                   context->allocate_output(0, out_routes_shape, &out_routes));
129    OP_REQUIRES_OK(context,
130                   context->allocate_output(1, out_data_shape, &out_data));
131    OP_REQUIRES_OK(
132        context, context->allocate_output(2, out_weights_shape, &out_weights));
133
134    tensorforest::Initialize(*out_data, 0.0f);
135
136    // Compute output.
137    const auto input_data = input_data_tensor.tensor<float, 2>();
138    const auto tree_parameters = tree_parameters_tensor.tensor<float, 2>();
139    const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
140    const auto routes = routing_tensor.tensor<float, 2>();
141
142    auto routes_grad = out_routes->tensor<float, 2>();
143    auto data_grad = out_data->tensor<float, 2>();
144    auto weights_grad = out_weights->tensor<float, 3>();
145
146    std::vector<int32> feature_set;
147    for (int i = 0; i < num_data; i++) {
148      const Tensor point = input_data_tensor.Slice(i, i + 1);
149      feature_set.clear();
150
151      // Traverse the tree from the bottom up.
152      for (int j = num_nodes - 1; j >= 0; j--) {
153        tensorforest::GetFeatureSet(layer_num_, j, random_seed_, num_features,
154                                    num_features_per_node, &feature_set);
155
156        // Compute routing gradient.
157        // j is a leaf node.
158        if (j >= num_nodes / 2) {
159          routes_grad(i, j) = routes(i, j);
160        } else {  // j is not a leaf node
161          int32 left_child = 2 * j + 1;
162          int32 right_child = left_child + 1;
163
164          float left_prob = LeftProbabilityK(
165              point, feature_set, tree_parameters_tensor.Slice(j, j + 1),
166              tree_biases(j), num_features, num_features_per_node);
167
168          float right_prob = 1.0f - left_prob;
169
170          routes_grad(i, j) = (right_prob * routes(i, left_child) +
171                               left_prob * routes(i, right_child));
172        }
173        // Compute data and weight gradient.
174        for (int k = 0; k < num_features_per_node; k++) {
175          CHECK_LT(feature_set[k], num_features);
176          data_grad(j, feature_set[k]) = tree_parameters(j, k);
177          weights_grad(i, j, k) = input_data(i, feature_set[k]);
178        }
179      }
180    }
181  }
182
183 private:
184  int32 layer_num_;
185  int32 random_seed_;
186};
187
188REGISTER_KERNEL_BUILDER(Name("KFeatureGradient").Device(DEVICE_CPU),
189                        KFeatureGradient);
190}  // namespace tensorflow
191