1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15
16#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
17#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/shape_inference.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/lib/gtl/top_n.h"
23#include "tensorflow/core/lib/math/math_util.h"
24#include "tensorflow/core/platform/types.h"
25#include "tensorflow/core/util/work_sharder.h"
26
27namespace tensorflow {
28
29using shape_inference::InferenceContext;
30using shape_inference::ShapeHandle;
31
32REGISTER_OP("UnpackPath")
33    .Input("path: int32")
34    .Input("path_values: float")
35    .Output("unpacked_path: float")
36    .SetShapeFn([](InferenceContext* c) {
37      ShapeHandle input, params;
38      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
39      TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &params));
40
41      auto num_points = c->Dim(input, 0);
42
43      auto tree_depth = c->Dim(params, 1);
44      int64 num_nodes = InferenceContext::kUnknownDim;
45      if (c->ValueKnown(tree_depth)) {
46        num_nodes = (static_cast<int64>(1) << c->Value(tree_depth)) - 1;
47      }
48
49      c->set_output(0, c->Matrix(num_points, num_nodes));
50      return Status::OK();
51    })
52    .Doc(R"doc(
53  Takes a batch of paths through a tree and a batch of values along those paths
54  and returns a batch_size by num_nodes encoding of the path values.
55
56  path: `path[i][j]` gives the jth node in the path taken by the ith data
57   instance.
58  path_values: `path_values[i][j]` gives the value associated with node j in the
59   path defined by the ith instance
60
61  unpacked_paths: `unpacked_paths[i][path[i][k]]` is path_values[i][k] for k in
62   [0, tree_depth).  All other elements of unpacked_paths are zero.
63)doc");
64
65class UnpackPath : public OpKernel {
66 public:
67  explicit UnpackPath(OpKernelConstruction* context) : OpKernel(context) {}
68
69  void Compute(OpKernelContext* context) override {
70    VLOG(1) << "unpack start";
71    const Tensor& path_tensor = context->input(0);
72    const Tensor& path_values_tensor = context->input(1);
73
74    const int32 num_data = static_cast<int32>(path_tensor.shape().dim_size(0));
75    const int32 tree_depth =
76        static_cast<int32>(path_tensor.shape().dim_size(1));
77
78    const int32 num_nodes = MathUtil::IPow(2, tree_depth) - 1;
79
80    VLOG(1) << "num_data: " << num_data;
81    VLOG(1) << "tree_depth: " << tree_depth;
82    VLOG(1) << "num_nodes: " << num_nodes;
83
84    Tensor* output = nullptr;
85    TensorShape output_shape;
86    output_shape.AddDim(num_data);
87    output_shape.AddDim(num_nodes);
88
89    OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
90    VLOG(1) << "unpack before init";
91    tensorforest::Initialize(*output, 0.0f);
92    VLOG(1) << "unpack after init";
93
94    auto out = output->tensor<float, 2>();
95
96    const auto path = path_tensor.tensor<int32, 2>();
97    const auto path_values = path_values_tensor.tensor<float, 2>();
98
99    for (int i = 0; i < num_data; i++) {
100      for (int j = 0; j < tree_depth; j++) {
101        CHECK_LT(path(i, j), num_nodes);
102        out(i, path(i, j)) = path_values(i, j);
103      }
104    }
105    VLOG(1) << "unpack end";
106  }
107};
108
109REGISTER_KERNEL_BUILDER(Name("UnpackPath").Device(DEVICE_CPU), UnpackPath);
110
111}  // namespace tensorflow
112