11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License");
41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License.
51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at
61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//     http://www.apache.org/licenses/LICENSE-2.0
81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software
101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS,
111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and
131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License.
141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// =============================================================================
15f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_
16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_
171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/framework/tensor.h"
181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/framework/tensor_types.h"
191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow {
211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest {
221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
234de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlowertypedef TTypes<float, 1>::UnalignedConstTensor SingleDimStorageType;
241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Base class for classes that hold labels and weights. Mostly for testing
261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// purposes, because it's inconvenient to construct nasty Eigen::things.
271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass InputTarget {
281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public:
291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual ~InputTarget() {}
301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual int32 GetTargetAsClassIndex(int example_index,
311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                      int target_index) const = 0;
321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual float GetTargetWeight(int example_index) const = 0;
341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual float GetTargetAsContinuous(int example_index,
361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                      int target_index) const = 0;
371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowertemplate <typename T>
401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass StoredInputTarget : public InputTarget {
411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected:
4214101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower  // Takes ownership of t and w with a std::unique_ptr.
431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  StoredInputTarget(const T* t, const T* w, int num_targets)
441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      : target_(t), weight_(w), num_targets_(num_targets) {}
451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4614101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower  const std::unique_ptr<const T> target_;
4714101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower  const std::unique_ptr<const T> weight_;
481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int num_targets_;
491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Holds labels/targets and weights. Assumes that tensors are passed as
521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// t.unaligned_flat<float>(). For multi-output, specifying the number of
531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// outputs will correctly index the flattened data.
541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass TensorInputTarget : public StoredInputTarget<SingleDimStorageType> {
551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public:
5614101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower  TensorInputTarget(const Tensor& target, const Tensor& weight, int num_targets)
574de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlower      : StoredInputTarget(
584de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlower            new SingleDimStorageType(target.unaligned_flat<float>()),
594de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlower            new SingleDimStorageType(weight.unaligned_flat<float>()),
604de7361c43adc65c442359104a17c449a02eee7aA. Unique TensorFlower            num_targets),
6114101a8ba173b44179b2a9317781f140eb61b0a1A. Unique TensorFlower        original_tensor_(target) {}
621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 GetTargetAsClassIndex(int example_index,
641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                              int target_index) const override {
651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return static_cast<int32>(
661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        GetTargetAsContinuous(example_index, target_index));
671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float GetTargetWeight(int example_index) const override {
701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const size_t num_weights = weight_->size();
711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return num_weights > 0 && example_index < num_weights
721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower               ? (*weight_)(example_index)
731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower               : 1.0;
741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float GetTargetAsContinuous(int example_index,
771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                              int target_index) const override {
781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    QCHECK_LT(target_index, num_targets_);
791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return (*target_)(example_index * num_targets_ + target_index);
801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
824463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower  const Tensor& original_tensor() const { return original_tensor_; }
831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected:
851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  Tensor original_tensor_;
861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorforest
881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorflow
891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
90f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_
91