1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include <stdlib.h>
16#include <time.h>
17#include <algorithm>
18#include <cmath>
19#include <memory>
20#include <unordered_map>
21#include <unordered_set>
22#include <utility>
23#include <vector>
24
25#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
26#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
27#include "tensorflow/core/framework/op.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/shape_inference.h"
30#include "tensorflow/core/framework/tensor.h"
31#include "tensorflow/core/lib/gtl/top_n.h"
32#include "tensorflow/core/platform/types.h"
33#include "tensorflow/core/util/work_sharder.h"
34
35namespace tensorflow {
36
37using shape_inference::InferenceContext;
38using shape_inference::ShapeHandle;
39
40using tensorforest::LeftProbability;
41
42// This op computes the derivative of the routing loss with respect to each
43// decision node.
44REGISTER_OP("StochasticHardRoutingGradient")
45    .Attr("tree_depth: int")
46    .Input("input_data: float")
47    .Input("tree_parameters: float")
48    .Input("tree_biases: float")
49    .Input("path_probability: float")
50    .Input("path: int32")
51    .Output("routing_gradient: float")
52    .Output("data_gradient: float")
53    .Output("parameter_gradient: float")
54    .Output("bias_gradient: float")
55    .SetShapeFn([](InferenceContext* c) {
56      ShapeHandle input, params;
57      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
58      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &params));
59
60      auto num_points = c->Dim(input, 0);
61      auto num_features = c->Dim(input, 1);
62      auto num_nodes = c->Dim(params, 0);
63
64      c->set_output(0, c->Matrix(num_points, num_nodes));
65      c->set_output(1, c->Matrix(num_nodes, num_features));
66      c->set_output(2, c->MakeShape({num_points, num_nodes, num_features}));
67      c->set_output(3, c->Vector(num_nodes));
68      return Status::OK();
69    })
70    .Doc(R"doc(
71  Computes the derivative of the routing loss with respect to each decision
72  node.
73
74  tree_depth: The depth of the decision tree.
75
76  input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
77   gives the j-th feature of the i-th input
78  tree_parameters: `tree_parameters[i]` gives the weight of
79   the logistic regression model that translates from node features to
80   probabilities.
81  tree_biases: `tree_biases[i]` gives the bias of the logistic
82   regression model that translates from node features to
83   probabilities.
84  path_probility: `path_probability[i]` gives the probability of reaching each
85   node in `path[i]`.
86  path: `path[i][j]` gives the jth node in the path taken by the ith data
87   instance.
88
89  routing_gradient: `routing_gradient` provides du / df, where u is the routing
90   function and f is the (vector of) decision functions.  A decision function
91   f_i computes the routing decision at node i.
92  data_gradient: `data_gradient` provides df / dx, where f is the (vector
93   of) decision functions and x is a batch of data.
94  parameter_gradient: `parameter_gradient` provides df / dw, where f is the
95   (vector of) decision functions and w is the matrix of parameters that
96   determine how instances are routed through a tree.
97  bias_gradient: `bias_gradient` provides df / db, where f is the
98   (vector of) decision functions and b is the vector of bias parameters that
99   determine how instances are routed through a tree.
100
101  f_i is parameterized by t_i (parameters) and b_i (bias) and takes data x as
102  input.  This op is called in training_ops.py to compute du / df, and we use
103  that to compute
104
105     du / dx = du / df * df / dx,
106     du / dt = du / df * df / dt, and
107     du / db = du / df * df / db.
108)doc");
109
110class StochasticHardRoutingGradient : public OpKernel {
111 public:
112  explicit StochasticHardRoutingGradient(OpKernelConstruction* context)
113      : OpKernel(context) {
114    OP_REQUIRES_OK(context, context->GetAttr("tree_depth", &tree_depth_));
115  }
116
117  void Compute(OpKernelContext* context) override {
118    VLOG(1) << "stochastic gradient start";
119    const Tensor& input_data = context->input(0);
120    const Tensor& tree_parameters_tensor = context->input(1);
121    const Tensor& tree_biases_tensor = context->input(2);
122
123    const Tensor& path_probability_tensor = context->input(3);
124    const Tensor& path_tensor = context->input(4);
125
126    const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
127    const int32 num_features =
128        static_cast<int32>(input_data.shape().dim_size(1));
129    const int32 num_nodes =
130        static_cast<int32>(tree_parameters_tensor.shape().dim_size(0));
131
132    Tensor* output_routing = nullptr;
133    TensorShape output_routing_shape;
134    output_routing_shape.AddDim(num_data);
135    output_routing_shape.AddDim(num_nodes);
136
137    Tensor* output_data = nullptr;
138    TensorShape output_data_shape;
139    output_data_shape.AddDim(num_nodes);
140    output_data_shape.AddDim(num_features);
141
142    Tensor* output_parameters = nullptr;
143    TensorShape output_parameters_shape;
144    output_parameters_shape.AddDim(num_data);
145    output_parameters_shape.AddDim(num_nodes);
146    output_parameters_shape.AddDim(num_features);
147
148    Tensor* output_bias = nullptr;
149    TensorShape output_bias_shape;
150    output_bias_shape.AddDim(num_data);
151
152    OP_REQUIRES_OK(context, context->allocate_output(0, output_routing_shape,
153                                                     &output_routing));
154    OP_REQUIRES_OK(
155        context, context->allocate_output(1, output_data_shape, &output_data));
156    OP_REQUIRES_OK(context, context->allocate_output(2, output_parameters_shape,
157                                                     &output_parameters));
158    OP_REQUIRES_OK(
159        context, context->allocate_output(3, output_bias_shape, &output_bias));
160
161    tensorforest::Initialize(*output_routing, 0.0);
162    tensorforest::Initialize(*output_data, 0.0);
163    tensorforest::Initialize(*output_parameters, 0.0);
164    tensorforest::Initialize(*output_bias, 0.0);
165
166    auto out_routing = output_routing->tensor<float, 2>();
167    auto out_data = output_data->tensor<float, 2>();
168    auto out_parameters = output_parameters->tensor<float, 3>();
169    auto out_bias = output_bias->tensor<float, 1>();
170
171    const auto data = input_data.tensor<float, 2>();
172    const auto tree_parameters = tree_parameters_tensor.tensor<float, 2>();
173    const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
174    const auto path_probability = path_probability_tensor.tensor<float, 2>();
175    const auto path = path_tensor.tensor<int32, 2>();
176
177    for (int i = 0; i < num_data; i++) {
178      const Tensor point = input_data.Slice(i, i + 1);
179
180      // Traverses the tree from the bottom up.
181      for (int j = tree_depth_ - 1; j > -1; j--) {
182        int32 node = path(i, j);
183
184        CHECK_LT(node, num_nodes);
185        CHECK_GT(node, -1);
186
187        // Compute data, parameter, and bias gradients.
188        // TODO(atwoodj): Should these be normalized?  Loss looks pretty large.
189        for (int k = 0; k < num_features; k++) {
190          out_data(node, k) = tree_parameters(node, k);
191          out_parameters(i, node, k) = out_parameters(i, node, k) + data(i, k);
192        }
193        out_bias(node) = out_bias(node) + 1.0;
194
195        // Compute decision gradient.
196        // node is a leaf
197        if (node >= num_nodes / 2) {
198          CHECK_LT(node, num_nodes);
199          out_routing(i, node) = path_probability(i, j);
200        } else {  // node is not a leaf
201          int32 left_child = 2 * j + 1;
202
203          float left_prob =
204              LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1),
205                              tree_biases(j), num_features);
206
207          float right_prob = 1 - left_prob;
208
209          CHECK_GT(j - 1, -1);
210          if (path(i, j - 1) == left_child) {
211            CHECK_LT(node, num_nodes);
212            out_routing(i, node) = right_prob * path_probability(i, j - 1);
213          } else {
214            CHECK_LT(node, num_nodes);
215            out_routing(i, node) = left_prob * path_probability(i, j - 1);
216          }
217        }
218      }
219    }
220    VLOG(1) << "stochastic gradient end";
221  }
222
223 private:
224  int32 tree_depth_;
225};
226
227REGISTER_KERNEL_BUILDER(
228    Name("StochasticHardRoutingGradient").Device(DEVICE_CPU),
229    StochasticHardRoutingGradient);
230}  // namespace tensorflow
231