11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License"); 41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License. 51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at 61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// http://www.apache.org/licenses/LICENSE-2.0 81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software 101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS, 111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and 131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License. 141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ============================================================================= 15f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ 16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ 171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h" 181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/framework/tensor_types.h" 191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow { 211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest { 221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 234de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlowertypedef TTypes<float, 1>::UnalignedConstTensor SingleDimStorageType; 241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Base class for classes that hold labels and weights. Mostly for testing 261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// purposes, because it's inconvenient to construct nasty Eigen::things. 271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass InputTarget { 281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public: 291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual ~InputTarget() {} 301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual int32 GetTargetAsClassIndex(int example_index, 311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int target_index) const = 0; 321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual float GetTargetWeight(int example_index) const = 0; 341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual float GetTargetAsContinuous(int example_index, 361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int target_index) const = 0; 371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowertemplate <typename T> 401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass StoredInputTarget : public InputTarget { 411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected: 4214101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower // Takes ownership of t and w with a std::unique_ptr. 431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower StoredInputTarget(const T* t, const T* w, int num_targets) 441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower : target_(t), weight_(w), num_targets_(num_targets) {} 451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4614101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower const std::unique_ptr<const T> target_; 4714101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower const std::unique_ptr<const T> weight_; 481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int num_targets_; 491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Holds labels/targets and weights. Assumes that tensors are passed as 521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// t.unaligned_flat<float>(). For multi-output, specifying the number of 531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// outputs will correctly index the flattened data. 541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass TensorInputTarget : public StoredInputTarget<SingleDimStorageType> { 551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public: 5614101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower TensorInputTarget(const Tensor& target, const Tensor& weight, int num_targets) 574de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlower : StoredInputTarget( 584de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlower new SingleDimStorageType(target.unaligned_flat<float>()), 594de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlower new SingleDimStorageType(weight.unaligned_flat<float>()), 604de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlower num_targets), 6114101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower original_tensor_(target) {} 621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 GetTargetAsClassIndex(int example_index, 641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int target_index) const override { 651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return static_cast<int32>( 661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower GetTargetAsContinuous(example_index, target_index)); 671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float GetTargetWeight(int example_index) const override { 701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const size_t num_weights = weight_->size(); 711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return num_weights > 0 && example_index < num_weights 721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ? (*weight_)(example_index) 731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower : 1.0; 741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float GetTargetAsContinuous(int example_index, 771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int target_index) const override { 781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower QCHECK_LT(target_index, num_targets_); 791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return (*target_)(example_index * num_targets_ + target_index); 801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 824463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower const Tensor& original_tensor() const { return original_tensor_; } 831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected: 851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower Tensor original_tensor_; 861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorforest 881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorflow 891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 90f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ 91