1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" 16 17namespace tensorflow { 18namespace tensorforest { 19 20using decision_trees::Leaf; 21 22std::unique_ptr<LeafModelOperator> 23LeafModelOperatorFactory::CreateLeafModelOperator( 24 const TensorForestParams& params) { 25 switch (params.leaf_type()) { 26 case MODEL_DENSE_CLASSIFICATION: 27 return std::unique_ptr<LeafModelOperator>( 28 new DenseClassificationLeafModelOperator(params)); 29 30 case MODEL_SPARSE_CLASSIFICATION: 31 return std::unique_ptr<LeafModelOperator>( 32 new SparseClassificationLeafModelOperator(params)); 33 34 case MODEL_SPARSE_OR_DENSE_CLASSIFICATION: 35 return std::unique_ptr<LeafModelOperator>( 36 new SparseOrDenseClassificationLeafModelOperator(params)); 37 38 case MODEL_REGRESSION: 39 return std::unique_ptr<LeafModelOperator>( 40 new RegressionLeafModelOperator(params)); 41 42 default: 43 LOG(ERROR) << "Unknown model operator: " << params.leaf_type(); 44 return nullptr; 45 } 46} 47 48// ------------------------ Dense ----------------------------- // 49float DenseClassificationLeafModelOperator::GetOutputValue( 50 const decision_trees::Leaf& leaf, int32 o) const { 51 return leaf.vector().value(o).float_value(); 52} 53 54void DenseClassificationLeafModelOperator::UpdateModel( 55 Leaf* leaf, const InputTarget* target, int example) const { 56 const int32 int_label = target->GetTargetAsClassIndex(example, 0); 57 QCHECK_LT(int_label, params_.num_outputs()) 58 << "Got label greater than indicated number of classes. Is " 59 "params.num_classes set correctly?"; 60 QCHECK_GE(int_label, 0); 61 auto* val = leaf->mutable_vector()->mutable_value(int_label); 62 63 float weight = target->GetTargetWeight(example); 64 val->set_float_value(val->float_value() + weight); 65} 66 67void DenseClassificationLeafModelOperator::InitModel(Leaf* leaf) const { 68 for (int i = 0; i < params_.num_outputs(); ++i) { 69 leaf->mutable_vector()->add_value(); 70 } 71} 72 73void DenseClassificationLeafModelOperator::ExportModel( 74 const LeafStat& stat, decision_trees::Leaf* leaf) const { 75 *leaf->mutable_vector() = stat.classification().dense_counts(); 76} 77 78// ------------------------- Sparse -------------------------- // 79float SparseClassificationLeafModelOperator::GetOutputValue( 80 const decision_trees::Leaf& leaf, int32 o) const { 81 const auto it = leaf.sparse_vector().sparse_value().find(o); 82 if (it == leaf.sparse_vector().sparse_value().end()) { 83 return 0; // default value 84 } else { 85 return it->second.float_value(); 86 } 87} 88 89void SparseClassificationLeafModelOperator::UpdateModel( 90 Leaf* leaf, const InputTarget* target, int example) const { 91 const int32 int_label = target->GetTargetAsClassIndex(example, 0); 92 QCHECK_LT(int_label, params_.num_outputs()) 93 << "Got label greater than indicated number of classes. Is " 94 "params.num_classes set correctly?"; 95 QCHECK_GE(int_label, 0); 96 const float weight = target->GetTargetWeight(example); 97 98 auto value_map = leaf->mutable_sparse_vector()->mutable_sparse_value(); 99 auto it = value_map->find(int_label); 100 if (it == value_map->end()) { 101 (*value_map)[int_label].set_float_value(weight); 102 } else { 103 it->second.set_float_value(it->second.float_value() + weight); 104 } 105} 106 107void SparseClassificationLeafModelOperator::ExportModel( 108 const LeafStat& stat, decision_trees::Leaf* leaf) const { 109 *leaf->mutable_sparse_vector() = stat.classification().sparse_counts(); 110} 111 112// ------------------------- SparseOrDense -------------------------- // 113float SparseOrDenseClassificationLeafModelOperator::GetOutputValue( 114 const decision_trees::Leaf& leaf, int32 o) const { 115 if (leaf.has_vector()) { 116 return dense_->GetOutputValue(leaf, o); 117 } else { 118 return sparse_->GetOutputValue(leaf, o); 119 } 120} 121 122void SparseOrDenseClassificationLeafModelOperator::UpdateModel( 123 Leaf* leaf, const InputTarget* target, int example) const { 124 if (leaf->has_vector()) { 125 return dense_->UpdateModel(leaf, target, example); 126 } else { 127 return sparse_->UpdateModel(leaf, target, example); 128 } 129} 130 131void SparseOrDenseClassificationLeafModelOperator::ExportModel( 132 const LeafStat& stat, decision_trees::Leaf* leaf) const { 133 if (stat.classification().has_dense_counts()) { 134 return dense_->ExportModel(stat, leaf); 135 } else { 136 return sparse_->ExportModel(stat, leaf); 137 } 138} 139 140// ------------------------ Regression ----------------------------- // 141float RegressionLeafModelOperator::GetOutputValue( 142 const decision_trees::Leaf& leaf, int32 o) const { 143 return leaf.vector().value(o).float_value(); 144} 145 146void RegressionLeafModelOperator::InitModel(Leaf* leaf) const { 147 for (int i = 0; i < params_.num_outputs(); ++i) { 148 leaf->mutable_vector()->add_value(); 149 } 150} 151 152void RegressionLeafModelOperator::ExportModel( 153 const LeafStat& stat, decision_trees::Leaf* leaf) const { 154 leaf->clear_vector(); 155 for (int i = 0; i < params_.num_outputs(); ++i) { 156 const float new_val = 157 stat.regression().mean_output().value(i).float_value() / 158 stat.weight_sum(); 159 leaf->mutable_vector()->add_value()->set_float_value(new_val); 160 } 161} 162 163} // namespace tensorforest 164} // namespace tensorflow 165