1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h"
16
17namespace tensorflow {
18namespace tensorforest {
19
20using decision_trees::Leaf;
21
22std::unique_ptr<LeafModelOperator>
23LeafModelOperatorFactory::CreateLeafModelOperator(
24    const TensorForestParams& params) {
25  switch (params.leaf_type()) {
26    case MODEL_DENSE_CLASSIFICATION:
27      return std::unique_ptr<LeafModelOperator>(
28          new DenseClassificationLeafModelOperator(params));
29
30    case MODEL_SPARSE_CLASSIFICATION:
31      return std::unique_ptr<LeafModelOperator>(
32          new SparseClassificationLeafModelOperator(params));
33
34    case MODEL_SPARSE_OR_DENSE_CLASSIFICATION:
35      return std::unique_ptr<LeafModelOperator>(
36          new SparseOrDenseClassificationLeafModelOperator(params));
37
38    case MODEL_REGRESSION:
39      return std::unique_ptr<LeafModelOperator>(
40          new RegressionLeafModelOperator(params));
41
42    default:
43      LOG(ERROR) << "Unknown model operator: " << params.leaf_type();
44      return nullptr;
45  }
46}
47
48// ------------------------ Dense ----------------------------- //
49float DenseClassificationLeafModelOperator::GetOutputValue(
50    const decision_trees::Leaf& leaf, int32 o) const {
51  return leaf.vector().value(o).float_value();
52}
53
54void DenseClassificationLeafModelOperator::UpdateModel(
55    Leaf* leaf, const InputTarget* target, int example) const {
56  const int32 int_label = target->GetTargetAsClassIndex(example, 0);
57  QCHECK_LT(int_label, params_.num_outputs())
58      << "Got label greater than indicated number of classes. Is "
59         "params.num_classes set correctly?";
60  QCHECK_GE(int_label, 0);
61  auto* val = leaf->mutable_vector()->mutable_value(int_label);
62
63  float weight = target->GetTargetWeight(example);
64  val->set_float_value(val->float_value() + weight);
65}
66
67void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const {
68  for (int i = 0; i < params_.num_outputs(); ++i) {
69    leaf->mutable_vector()->add_value();
70  }
71}
72
73void DenseClassificationLeafModelOperator::ExportModel(
74    const LeafStat& stat, decision_trees::Leaf* leaf) const {
75  *leaf->mutable_vector() = stat.classification().dense_counts();
76}
77
78// ------------------------- Sparse -------------------------- //
79float SparseClassificationLeafModelOperator::GetOutputValue(
80    const decision_trees::Leaf& leaf, int32 o) const {
81  const auto it = leaf.sparse_vector().sparse_value().find(o);
82  if (it == leaf.sparse_vector().sparse_value().end()) {
83    return 0;  // default value
84  } else {
85    return it->second.float_value();
86  }
87}
88
89void SparseClassificationLeafModelOperator::UpdateModel(
90    Leaf* leaf, const InputTarget* target, int example) const {
91  const int32 int_label = target->GetTargetAsClassIndex(example, 0);
92  QCHECK_LT(int_label, params_.num_outputs())
93      << "Got label greater than indicated number of classes. Is "
94         "params.num_classes set correctly?";
95  QCHECK_GE(int_label, 0);
96  const float weight = target->GetTargetWeight(example);
97
98  auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value();
99  auto it = value_map->find(int_label);
100  if (it == value_map->end()) {
101    (*value_map)[int_label].set_float_value(weight);
102  } else {
103    it->second.set_float_value(it->second.float_value() + weight);
104  }
105}
106
107void SparseClassificationLeafModelOperator::ExportModel(
108    const LeafStat& stat, decision_trees::Leaf* leaf) const {
109  *leaf->mutable_sparse_vector() = stat.classification().sparse_counts();
110}
111
112// ------------------------- SparseOrDense -------------------------- //
113float SparseOrDenseClassificationLeafModelOperator::GetOutputValue(
114    const decision_trees::Leaf& leaf, int32 o) const {
115  if (leaf.has_vector()) {
116    return dense_->GetOutputValue(leaf, o);
117  } else {
118    return sparse_->GetOutputValue(leaf, o);
119  }
120}
121
122void SparseOrDenseClassificationLeafModelOperator::UpdateModel(
123    Leaf* leaf, const InputTarget* target, int example) const {
124  if (leaf->has_vector()) {
125    return dense_->UpdateModel(leaf, target, example);
126  } else {
127    return sparse_->UpdateModel(leaf, target, example);
128  }
129}
130
131void SparseOrDenseClassificationLeafModelOperator::ExportModel(
132    const LeafStat& stat, decision_trees::Leaf* leaf) const {
133  if (stat.classification().has_dense_counts()) {
134    return dense_->ExportModel(stat, leaf);
135  } else {
136    return sparse_->ExportModel(stat, leaf);
137  }
138}
139
140// ------------------------ Regression ----------------------------- //
141float RegressionLeafModelOperator::GetOutputValue(
142    const decision_trees::Leaf& leaf, int32 o) const {
143  return leaf.vector().value(o).float_value();
144}
145
146void RegressionLeafModelOperator::InitModel(Leaf* leaf) const {
147  for (int i = 0; i < params_.num_outputs(); ++i) {
148    leaf->mutable_vector()->add_value();
149  }
150}
151
152void RegressionLeafModelOperator::ExportModel(
153    const LeafStat& stat, decision_trees::Leaf* leaf) const {
154  leaf->clear_vector();
155  for (int i = 0; i < params_.num_outputs(); ++i) {
156    const float new_val =
157        stat.regression().mean_output().value(i).float_value() /
158        stat.weight_sum();
159    leaf->mutable_vector()->add_value()->set_float_value(new_val);
160  }
161}
162
163}  // namespace tensorforest
164}  // namespace tensorflow
165