1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include <stdlib.h>
16#include <time.h>
17#include <algorithm>
18#include <cmath>
19#include <memory>
20#include <unordered_map>
21#include <unordered_set>
22#include <utility>
23#include <vector>
24
25#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
26#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
27#include "tensorflow/core/framework/op.h"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/shape_inference.h"
30#include "tensorflow/core/framework/tensor.h"
31#include "tensorflow/core/lib/gtl/top_n.h"
32#include "tensorflow/core/platform/types.h"
33#include "tensorflow/core/util/work_sharder.h"
34
35namespace tensorflow {
36
37using shape_inference::InferenceContext;
38using shape_inference::ShapeHandle;
39
40using tensorforest::LeftProbability;
41
42// This op computes the derivative of the routing loss with respect to each
43// decision node.
44REGISTER_OP("RoutingGradient")
45    .Attr("max_nodes: int")
46    .Input("input_data: float")
47    .Input("tree_parameters: float")
48    .Input("tree_biases: float")
49    .Input("routes: float")
50    .Output("routing_gradient: float")
51    .SetShapeFn([](InferenceContext* c) {
52      ShapeHandle input, params;
53      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
54      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &params));
55
56      c->set_output(0, c->Matrix(c->Dim(input, 0), c->Dim(params, 0)));
57      return Status::OK();
58    })
59    .Doc(R"doc(
60  Computes the derivative of the routing loss with respect to each decision
61  node.
62
63  max_nodes: The number of nodes in the tree.
64
65  tree_parameters: `tree_parameters[i]` gives the weight of
66   the logistic regression model that translates from node features to
67   probabilities.
68  tree_biases: `tree_biases[i]` gives the bias of the logistic
69   regression model that translates from node features to
70   probabilities.
71  routes: The routes computed by routing_function_op.
72
73  routing_gradient: `routing_gradient` provides du / df, where u is the routing
74   function and f is the (vector of) decision functions.  A decision function
75   f_i computes the routing decision at node i.
76
77   f_i is parameterized by t_i (parameters) and b_i (bias) and takes data x as
78   input.  This op is called in training_ops.py to compute du / df, and we use
79   that to compute
80
81     du / dx = du / df * df / dx,
82     du / dt = du / df * df / dt, and
83     du / db = du / df * df / db.
84)doc");
85
86class RoutingGradient : public OpKernel {
87 public:
88  explicit RoutingGradient(OpKernelConstruction* context) : OpKernel(context) {
89    OP_REQUIRES_OK(context, context->GetAttr("max_nodes", &max_nodes_));
90  }
91
92  void Compute(OpKernelContext* context) override {
93    const Tensor& input_data = context->input(0);
94    const Tensor& tree_parameters_tensor = context->input(1);
95    const Tensor& tree_biases_tensor = context->input(2);
96    const Tensor& routing_tensor = context->input(3);
97
98    // TODO(atwoodj): Add dimension checks.
99
100    const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
101    const int32 num_features =
102        static_cast<int32>(input_data.shape().dim_size(1));
103
104    Tensor* output = nullptr;
105    TensorShape output_shape;
106    output_shape.AddDim(num_data);
107    output_shape.AddDim(max_nodes_);
108
109    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
110
111    auto out = output->tensor<float, 2>();
112    const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
113    const auto routes = routing_tensor.tensor<float, 2>();
114
115    // A derivation of the gradient can be found at go/routingderivation.
116    for (int i = 0; i < num_data; i++) {
117      const Tensor point = input_data.Slice(i, i + 1);
118
119      // Traverses the tree from the bottom up.
120      for (int j = max_nodes_ - 1; j >= 0; j--) {
121        // j is a leaf node
122        if (j >= max_nodes_ / 2) {
123          out(i, j) = routes(i, j);
124        } else {  // j is not a leaf node
125          int32 left_child = 2 * j + 1;
126          int32 right_child = left_child + 1;
127          float left_prob =
128              LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1),
129                              tree_biases(j), num_features);
130
131          float right_prob = 1 - left_prob;
132
133          out(i, j) = (right_prob * routes(i, left_child) +
134                       left_prob * routes(i, right_child));
135        }
136      }
137    }
138  }
139
140 private:
141  int32 max_nodes_;
142};
143
144REGISTER_KERNEL_BUILDER(Name("RoutingGradient").Device(DEVICE_CPU),
145                        RoutingGradient);
146}  // namespace tensorflow
147