1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" 16#include <cfloat> 17 18#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 19 20namespace tensorflow { 21namespace tensorforest { 22 23// When using smoothing but only tracking sum and squares, and we're adding 24// num_classes for smoothing each class, then Gini looks more like this: 25// Gini = 1 - \sum_i (c_i + 1)^2 / C^2 26// = 1 - (1 / C^2) ( (\sum_i c_i)^2 + 2 (\sum_i c_i) + (\sum_i 1)) 27// = 1 - (1 / C^2) ( stats.square() + 2 stats.sum() + #_classes) 28// = 1 - ( stats.square() + 2 stats.sum() + #_classes) / (smoothed_sum * 29// smoothed_sum) 30// 31// where 32// smoothed_sum = stats.sum() + #_classes 33float GiniImpurity(const LeafStat& stats, int32 num_classes) { 34 const float smoothed_sum = num_classes + stats.weight_sum(); 35 return 1.0 - ((stats.classification().gini().square() + 36 2 * stats.weight_sum() + num_classes) / 37 (smoothed_sum * smoothed_sum)); 38} 39 40float WeightedGiniImpurity(const LeafStat& stats, int32 num_classes) { 41 return stats.weight_sum() * GiniImpurity(stats, num_classes); 42} 43 44void UpdateGini(LeafStat* stats, float old_val, float weight) { 45 stats->set_weight_sum(stats->weight_sum() + weight); 46 // Equivalent to stats->square() - old_val * old_val + new_val * new_val, 47 // (for new_val = old_val + weight), but more numerically stable. 48 stats->mutable_classification()->mutable_gini()->set_square( 49 stats->classification().gini().square() + weight * weight + 50 2 * old_val * weight); 51} 52 53float Variance(const LeafStat& stats, int output) { 54 if (stats.weight_sum() == 0) { 55 return 0; 56 } 57 const float e_x = 58 stats.regression().mean_output().value(output).float_value() / 59 stats.weight_sum(); 60 const auto e_x2 = 61 stats.regression().mean_output_squares().value(output).float_value() / 62 stats.weight_sum(); 63 return e_x2 - e_x * e_x; 64} 65 66float TotalVariance(const LeafStat& stats) { 67 float sum = 0; 68 for (int i = 0; i < stats.regression().mean_output().value_size(); ++i) { 69 sum += Variance(stats, i); 70 } 71 return sum; 72} 73 74float SmoothedGini(float sum, float square, int num_classes) { 75 // See comments for GiniImpurity above. 76 const float smoothed_sum = num_classes + sum; 77 return 1.0 - (square + 2 * sum + num_classes) / (smoothed_sum * smoothed_sum); 78} 79 80float WeightedSmoothedGini(float sum, float square, int num_classes) { 81 return sum * SmoothedGini(sum, square, num_classes); 82} 83 84} // namespace tensorforest 85} // namespace tensorflow 86