1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15 16#include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h" 17#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 18#include "tensorflow/core/framework/op.h" 19#include "tensorflow/core/framework/op_kernel.h" 20#include "tensorflow/core/framework/shape_inference.h" 21#include "tensorflow/core/framework/tensor.h" 22#include "tensorflow/core/lib/gtl/top_n.h" 23#include "tensorflow/core/lib/math/math_util.h" 24#include "tensorflow/core/platform/types.h" 25#include "tensorflow/core/util/work_sharder.h" 26 27namespace tensorflow { 28 29using shape_inference::InferenceContext; 30using shape_inference::ShapeHandle; 31 32REGISTER_OP("UnpackPath") 33 .Input("path: int32") 34 .Input("path_values: float") 35 .Output("unpacked_path: float") 36 .SetShapeFn([](InferenceContext* c) { 37 ShapeHandle input, params; 38 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); 39 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, ¶ms)); 40 41 auto num_points = c->Dim(input, 0); 42 43 auto tree_depth = c->Dim(params, 1); 44 int64 num_nodes = InferenceContext::kUnknownDim; 45 if (c->ValueKnown(tree_depth)) { 46 num_nodes = (static_cast<int64>(1) << c->Value(tree_depth)) - 1; 47 } 48 49 c->set_output(0, c->Matrix(num_points, num_nodes)); 50 return Status::OK(); 51 }) 52 .Doc(R"doc( 53 Takes a batch of paths through a tree and a batch of values along those paths 54 and returns a batch_size by num_nodes encoding of the path values. 55 56 path: `path[i][j]` gives the jth node in the path taken by the ith data 57 instance. 58 path_values: `path_values[i][j]` gives the value associated with node j in the 59 path defined by the ith instance 60 61 unpacked_paths: `unpacked_paths[i][path[i][k]]` is path_values[i][k] for k in 62 [0, tree_depth). All other elements of unpacked_paths are zero. 63)doc"); 64 65class UnpackPath : public OpKernel { 66 public: 67 explicit UnpackPath(OpKernelConstruction* context) : OpKernel(context) {} 68 69 void Compute(OpKernelContext* context) override { 70 VLOG(1) << "unpack start"; 71 const Tensor& path_tensor = context->input(0); 72 const Tensor& path_values_tensor = context->input(1); 73 74 const int32 num_data = static_cast<int32>(path_tensor.shape().dim_size(0)); 75 const int32 tree_depth = 76 static_cast<int32>(path_tensor.shape().dim_size(1)); 77 78 const int32 num_nodes = MathUtil::IPow(2, tree_depth) - 1; 79 80 VLOG(1) << "num_data: " << num_data; 81 VLOG(1) << "tree_depth: " << tree_depth; 82 VLOG(1) << "num_nodes: " << num_nodes; 83 84 Tensor* output = nullptr; 85 TensorShape output_shape; 86 output_shape.AddDim(num_data); 87 output_shape.AddDim(num_nodes); 88 89 OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); 90 VLOG(1) << "unpack before init"; 91 tensorforest::Initialize(*output, 0.0f); 92 VLOG(1) << "unpack after init"; 93 94 auto out = output->tensor<float, 2>(); 95 96 const auto path = path_tensor.tensor<int32, 2>(); 97 const auto path_values = path_values_tensor.tensor<float, 2>(); 98 99 for (int i = 0; i < num_data; i++) { 100 for (int j = 0; j < tree_depth; j++) { 101 CHECK_LT(path(i, j), num_nodes); 102 out(i, path(i, j)) = path_values(i, j); 103 } 104 } 105 VLOG(1) << "unpack end"; 106 } 107}; 108 109REGISTER_KERNEL_BUILDER(Name("UnpackPath").Device(DEVICE_CPU), UnpackPath); 110 111} // namespace tensorflow 112