grow_stats.cc revision 75f03e2d509d016021f8508555f9ab96af2c7cfe
11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License"); 41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License. 51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at 61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// http://www.apache.org/licenses/LICENSE-2.0 81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software 101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS, 111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and 131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License. 141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ============================================================================= 151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" 161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <cfloat> 181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <queue> 191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" 211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/lib/random/distribution_sampler.h" 221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow { 251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest { 261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// When creating evaluators for the split candidates, use these 281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// for the left and right return values. 291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerstatic const int32 LEFT_INDEX = 0; 301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerstatic const int32 RIGHT_INDEX = 1; 311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerGrowStats::GrowStats(const TensorForestParams& params, int32 depth) 3308a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower : weight_sum_(0), 3408a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower depth_(depth), 351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower params_(params), 361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower split_after_samples_(ResolveParam(params.split_after_samples(), depth)), 371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower num_splits_to_consider_( 381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ResolveParam(params.num_splits_to_consider(), depth)), 391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower num_outputs_(params.num_outputs()) {} 401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlowervoid GrowStats::AddSplit(const decision_trees::BinaryNode& split, 4275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, 4375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const InputTarget* target, int example) { 4475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower // It's possible that the split collection calls AddSplit, but we actually 4575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower // have all the splits we need and are just waiting for them to be fully 4675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower // initialized. 4775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (splits_.size() < num_splits_to_consider_) { 4875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower splits_.push_back(split); 4975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower evaluators_.emplace_back( 5075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX)); 5175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AddSplitStats(target, example); 5275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 5375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower 5475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (input_data != nullptr && target != nullptr && 5575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower params_.initialize_average_splits()) { 5675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AdditionalInitializationExample(input_data, target, example); 5775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid GrowStats::RemoveSplit(int split_num) { 611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower splits_.erase(splits_.begin() + split_num); 621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower evaluators_.erase(evaluators_.begin() + split_num); 631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower RemoveSplitStats(split_num); 641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Classification --------------------------- // 671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerClassificationStats::ClassificationStats(const TensorForestParams& params, 691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 depth) 701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower : GrowStats(params, depth), finish_early_(false) { 711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Early splitting params. 721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params.finish_type().type() == SPLIT_FINISH_BASIC) { 731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower min_split_samples_ = split_after_samples_; 7408a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower finish_sample_epoch_ = 1; 7508a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower finish_check_every_ = split_after_samples_ * 2; 761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (!params.has_dominate_fraction() || !params.has_min_split_samples()) { 781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower LOG(FATAL) << "dominate_fraction and min_split_samples " 791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower << "required for early-finish strategy."; 801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower min_split_samples_ = ResolveParam(params.min_split_samples(), depth); 821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower finish_check_every_ = 831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ResolveParam(params.finish_type().check_every_steps(), depth); 841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower finish_sample_epoch_ = min_split_samples_ / finish_check_every_; 851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_); 871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) { 881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_; 891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Pruning params. 941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params.pruning_type().type() != SPLIT_PRUNE_NONE) { 951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_check_every_ = 961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ResolveParam(params.pruning_type().prune_every_samples(), depth); 971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_sample_epoch_ = 1; 981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_fraction_ = 0.0; 991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower switch (params_.pruning_type().type()) { 1001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower case SPLIT_PRUNE_HALF: 1011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_fraction_ = 0.5; 1021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower break; 1031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower case SPLIT_PRUNE_QUARTER: 1041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_fraction_ = 0.25; 1051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower break; 1061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower case SPLIT_PRUNE_10_PERCENT: 1071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_fraction_ = 0.10; 1081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower break; 1091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower case SPLIT_PRUNE_HOEFFDING: 1101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_); 1111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_)); 1121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower break; 1131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower default: 1141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower LOG(WARNING) << "Unknown pruning type"; 1151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 11608a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower } else { 11708a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower prune_check_every_ = split_after_samples_ * 2; 11808a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower prune_sample_epoch_ = 1; 1191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params.use_running_stats_method()) { 1221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_gini_.reset(new RunningGiniScores()); 1231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_gini_.reset(new RunningGiniScores()); 1241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower uint64 time_seed = static_cast<uint64>(std::clock()); 1271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower single_rand_ = std::unique_ptr<random::PhiloxRandom>( 1281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower new random::PhiloxRandom(time_seed)); 1291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower rng_ = std::unique_ptr<random::SimplePhilox>( 1301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower new random::SimplePhilox(single_rand_.get())); 1311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 1321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 13375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlowervoid ClassificationStats::AdditionalInitializationExample( 13475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, 13575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower int example) { 13675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const int32 new_target = target->GetTargetAsClassIndex(example, 0); 13775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower std::unordered_set<int> to_erase; 13875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower for (auto it = half_initialized_splits_.begin(); 13975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower it != half_initialized_splits_.end(); ++it) { 14075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (it->second != new_target) { 14175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower auto& split = splits_[it->first]; 14275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (split.has_inequality_left_child_test()) { 14375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower auto& test = split.inequality_left_child_test(); 14475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower auto* thresh = 14575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower split.mutable_inequality_left_child_test()->mutable_threshold(); 14675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (test.has_feature_id()) { 14775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const float val = 14875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower input_data->GetExampleValue(example, test.feature_id()); 14975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower thresh->set_float_value((thresh->float_value() + val) / 2); 15075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower to_erase.insert(it->first); 15375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower 15675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower for (const int split_id : to_erase) { 15775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower half_initialized_splits_.erase(split_id); 15875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower} 16075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower 1611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool ClassificationStats::IsFinished() const { 1621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1; 1631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return basic || finish_early_; 1641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 1651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum, 1671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float* right_sum) const { 1681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_gini_ == nullptr) { 1691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return GiniScore(split, left_sum, right_sum); 1701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 1711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum = left_gini_->sum(split); 1721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left = WeightedSmoothedGini( 1731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum, left_gini_->square(split), num_outputs_); 1741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum = right_gini_->sum(split); 1761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right = WeightedSmoothedGini( 1771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum, right_gini_->square(split), num_outputs_); 1781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left + right; 1801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 1821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::AddExample( 1841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, 1851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int example) { 1861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int64 int_label = target->GetTargetAsClassIndex(example, 0); 1871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float weight = target->GetTargetWeight(example); 1881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 1901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto& eval = evaluators_[i]; 1911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (eval->Decide(input_data, example) == LEFT_INDEX) { 1921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_gini_ != nullptr) { 1931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_gini_->update(i, left_count(i, int_label), weight); 1941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ClassificationAddLeftExample(i, int_label, weight); 1961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else if (right_gini_ != nullptr) { 1971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_gini_->update(i, right_count(i, int_label), weight); 1981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ClassificationAddTotalExample(int_label, weight); 2021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ += weight; 2041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckFinishEarly(); 2061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckPrune(); 2071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 2081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckPrune() { 2101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (IsFinished() || weight_sum_ < prune_sample_epoch_ * prune_check_every_) { 2111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 2121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++prune_sample_epoch_; 2141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) { 2161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckPruneHoeffding(); 2171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 2181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int to_remove = num_splits() * prune_fraction_; 2211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (to_remove <= 0) { 2221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 2231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // pair ordering is first-then-second by default, no need for custom 2261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // comparison. Use std::greater to make it a min-heap. 2271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>, 2281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::greater<std::pair<float, int>>> 2291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst; 2301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Track indices that are in the heap so we can iterate over them 2321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // by largest-first later. 2331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::set<int> indices; 2341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 2361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left, right; 2371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float split_score = MaybeCachedGiniScore(i, &left, &right); 2381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (worst.size() < to_remove) { 2391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst.push(std::pair<float, int>(split_score, i)); 2401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower indices.insert(i); 2411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else if (worst.top().first < split_score) { 2421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower indices.erase(worst.top().second); 2431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst.pop(); 2441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst.push(std::pair<float, int>(split_score, i)); 2451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower indices.insert(i); 2461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // traverse indices from the back so that they are removed correctly. 2501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (auto it = indices.rbegin(); it != indices.rend(); ++it) { 2511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower RemoveSplit(*it); 2521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 2541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckPruneHoeffding() { 2561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> split_scores(num_splits()); 2571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Find best split score 2581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float best_split_score = FLT_MAX; 2591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 2601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left, right; 2611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower split_scores[i] = MaybeCachedGiniScore(i, &left, &right); 2621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (split_scores[i] < best_split_score) { 2631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_split_score = split_scores[i]; 2641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // We apply the Hoeffding bound to the difference between the best split 2681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // score and the i-th split score. 2691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted. 2701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float num_classes = params_.num_outputs(); 2711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes); 2721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_); 2731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = num_splits() - 1; i >= 0; i--) { 2741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (split_scores[i] - best_split_score > epsilon) { 2751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower RemoveSplit(i); 2761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 2791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarly() { 2811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (weight_sum_ < min_split_samples_ || 2821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ < finish_sample_epoch_ * finish_check_every_) { 2831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 2841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++finish_sample_epoch_; 2861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) { 2881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckFinishEarlyHoeffding(); 2891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) { 2901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckFinishEarlyBootstrap(); 2911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 2931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarlyHoeffding() { 2951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Each term in the Gini impurity can range from 0 to 0.5 * 0.5. 2961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float range = 0.25 * static_cast<float>(params_.num_outputs()) * weight_sum_; 2971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float hoeffding_bound = 2991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_)); 3001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float unused_left_sum, unused_right_sum; 3021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::function<float(int)> score_fn = 3031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::bind(&ClassificationStats::MaybeCachedGiniScore, this, 3041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::placeholders::_1, &unused_left_sum, &unused_right_sum); 3051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float best_score; 3071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 best_index; 3081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float second_best_score; 3091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 second_best_index; 3101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower GetTwoBest(num_splits(), score_fn, &best_score, &best_index, 3111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower &second_best_score, &second_best_index); 3121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower finish_early_ = (second_best_score - best_score) > hoeffding_bound; 3141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 3151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::MakeBootstrapWeights(int index, 3171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float>* weights) { 3181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int n = weight_sum_; 3191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float denom = static_cast<float>(n) + static_cast<float>(num_outputs_); 3201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs_; ++i) { 3211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Use the Laplace smoothed per-class probabilities when generating the 3221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // bootstrap samples. 3231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*weights)[i] = (left_count(index, i) + 1.0) / denom; 3241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom; 3251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 3271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerint ClassificationStats::NumBootstrapSamples() const { 3291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float p = 1.0 - dominate_fraction_; 3301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int bootstrap_samples = 1; 3311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower while (p < 1.0) { 3321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++bootstrap_samples; 3331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower p = p * 2; 3341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return bootstrap_samples; 3361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 3371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarlyBootstrap() { 3391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float unused_left_sum, unused_right_sum; 3401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::function<float(int)> score_fn = 3411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::bind(&ClassificationStats::MaybeCachedGiniScore, this, 3421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::placeholders::_1, &unused_left_sum, &unused_right_sum); 3431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float best_score; 3451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 best_index; 3461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float second_best_score; 3471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 second_best_index; 3481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower GetTwoBest(num_splits(), score_fn, &best_score, &best_index, 3491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower &second_best_score, &second_best_index); 3501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> weights1(num_outputs_ * 2); 3521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower MakeBootstrapWeights(best_index, &weights1); 3531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower random::DistributionSampler ds1(weights1); 3541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> weights2(num_outputs_ * 2); 3561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower MakeBootstrapWeights(second_best_index, &weights2); 3571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower random::DistributionSampler ds2(weights2); 3581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int bootstrap_samples = NumBootstrapSamples(); 3601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int worst_g1 = 0; 3621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < bootstrap_samples; i++) { 3631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get()); 3641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst_g1 = std::max(worst_g1, g1); 3651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int best_g2 = 99; 3681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < bootstrap_samples; i++) { 3691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get()); 3701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_g2 = std::min(best_g2, g2); 3711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower finish_early_ = worst_g1 < best_g2; 3741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 3751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Dense Classification --------------------------- // 3771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { 3781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower Initialize(); 3791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (!slot.has_post_init_leaf_stats()) { 3801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 3811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_classes = params_.num_outputs(); 3831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 3841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& class_stats = 3851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot.post_init_leaf_stats().classification().dense_counts(); 3861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Total counts. 3881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_classes; ++i) { 3891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_[i] = class_stats.value(i).float_value(); 3901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower num_outputs_seen_ += total_counts_[i] != 0; 3911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Candidate counts and splits. 3941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int split_num = 0; 3951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& cand : slot.candidates()) { 39675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AddSplit(cand.split(), nullptr, nullptr, -1); 3971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& left_stats = cand.left_stats().classification().dense_counts(); 3981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_classes; ++i) { 3991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float val = left_stats.value(i).float_value(); 4001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower mutable_left_count(split_num, i) = val; 4011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower MaybeInitializeRunningCount(split_num, val); 4021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++split_num; 4041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 4061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const { 4081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* slot_stats = slot->mutable_post_init_leaf_stats(); 4091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot_stats->set_weight_sum(weight_sum_); 4101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* class_stats = slot->mutable_post_init_leaf_stats() 4121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_classification() 4131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_dense_counts(); 4141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs_; ++i) { 4151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower class_stats->add_value()->set_float_value(total_counts_[i]); 4161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int split_num = 0; split_num < num_splits(); ++split_num) { 4191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* cand = slot->add_candidates(); 4201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *cand->mutable_split() = splits_[split_num]; 4211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_stats = cand->mutable_left_stats() 4221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_classification() 4231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_dense_counts(); 4241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs_; ++i) { 4251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_stats->add_value()->set_float_value(left_count(split_num, i)); 4261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 4291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat DenseClassificationGrowStats::GiniScore(int split, float* left_sum, 4311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float* right_sum) const { 4321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left_square = 0, right_square = 0; 4331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum = 0; 4341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum = 0; 4351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int j = 0; j < num_outputs_; ++j) { 4361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left = left_count(split, j); 4371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum += left; 4381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_square += left * left; 4391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right = right_count(split, j); 4401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum += right; 4411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_square += right * right; 4421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left_score = 4451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower WeightedSmoothedGini(*left_sum, left_square, num_outputs_); 4461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right_score = 4471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower WeightedSmoothedGini(*right_sum, right_square, num_outputs_); 4481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_score + right_score; 4491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 4501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool DenseClassificationGrowStats::BestSplit(SplitCandidate* best) const { 4521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float min_score = FLT_MAX; 4531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int best_index = -1; 4541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float best_left_sum, best_right_sum; 4551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Calculate sums. 4571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 4581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left_sum, right_sum; 4591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum); 4601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Find the lowest gini. 4611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_sum > 0 && right_sum > 0 && 4621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower split_score < min_score) { // useless check 4631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower min_score = split_score; 4641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_index = i; 4651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_left_sum = left_sum; 4661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_right_sum = right_sum; 4671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // This could happen if all the splits are useless. 4711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (best_index < 0) { 4721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return false; 4731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Fill in stats to be used for leaf model. 4761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *best->mutable_split() = splits_[best_index]; 4771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Left 4781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left = best->mutable_left_stats(); 4791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_class_stats = left->mutable_classification(); 4801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left->set_weight_sum(best_left_sum); 4811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_counts = left_class_stats->mutable_dense_counts(); 4821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < params_.num_outputs(); ++i) { 4831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts->add_value()->set_float_value( 4841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_count(best_index, i)); 4851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Right 4881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right = best->mutable_right_stats(); 4891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_class_stats = right->mutable_classification(); 4901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right->set_weight_sum(best_right_sum); 4911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_counts = right_class_stats->mutable_dense_counts(); 4921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < params_.num_outputs(); ++i) { 4931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_counts->add_value()->set_float_value( 4941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_[i] - left_count(best_index, i)); 4951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return true; 4971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 4981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Sparse Classification --------------------------- // 5001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { 5011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower Initialize(); 5021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (!slot.has_post_init_leaf_stats()) { 5031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 5041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 5061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& class_stats = 5071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot.post_init_leaf_stats().classification().sparse_counts(); 5081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Total counts. 5101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (auto const& entry : class_stats.sparse_value()) { 5111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_[entry.first] = entry.second.float_value(); 5121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Candidate counts and splits. 5151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int split_num = 0; 5161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& cand : slot.candidates()) { 51775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AddSplit(cand.split(), nullptr, nullptr, -1); 5181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& left_stats = cand.left_stats().classification().sparse_counts(); 5191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (auto const& entry : left_stats.sparse_value()) { 5201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float val = entry.second.float_value(); 5211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_[split_num][entry.first] = val; 5221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower MaybeInitializeRunningCount(split_num, val); 5231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++split_num; 5251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 5271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const { 5291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* slot_stats = slot->mutable_post_init_leaf_stats(); 5301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot_stats->set_weight_sum(weight_sum_); 5311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* class_stats = slot->mutable_post_init_leaf_stats() 5331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_classification() 5341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_sparse_counts() 5351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_sparse_value(); 5361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& entry : total_counts_) { 5371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower decision_trees::Value val; 5381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower val.set_float_value(entry.second); 5391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*class_stats)[entry.first] = val; 5401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int split_num = 0; split_num < num_splits(); ++split_num) { 5431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* cand = slot->add_candidates(); 5441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *cand->mutable_split() = splits_[split_num]; 5451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_stats = cand->mutable_left_stats() 5461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_classification() 5471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_sparse_counts() 5481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_sparse_value(); 5491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& entry : left_counts_[split_num]) { 5501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower decision_trees::Value val; 5511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower val.set_float_value(entry.second); 5521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*left_stats)[entry.first] = val; 5531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 5561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat SparseClassificationGrowStats::GiniScore( 5581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int split, float* left_sum, float* right_sum) const { 5591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left_square = 0, right_square = 0; 5601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum = 0; 5611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum = 0; 5621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& entry : total_counts_) { 5631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int label = entry.first; 5641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left = 0; 5651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float right = 0; 5661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto it = left_counts_[split].find(label); 5671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (it == left_counts_[split].end()) { 5681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right = entry.second; 5691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 5701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left = it->second; 5711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right = entry.second - it->second; 5721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum += left; 5741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_square += left * left; 5751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum += right; 5761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_square += right * right; 5771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_classes = params_.num_outputs(); 5791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left_score = 5801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower WeightedSmoothedGini(*left_sum, left_square, num_classes); 5811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right_score = 5821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower WeightedSmoothedGini(*right_sum, right_square, num_classes); 5831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_score + right_score; 5841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 5851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const { 5871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float min_score = FLT_MAX; 5881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int best_index = -1; 5891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float best_left_sum = -1; 5901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float best_right_sum = -1; 5911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Find the lowest gini. 5931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 5941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left_sum, right_sum; 5951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum); 5961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_sum > 0 && right_sum > 0 && 5971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower split_score < min_score) { // useless check 5981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower min_score = split_score; 5991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_index = i; 6001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_left_sum = left_sum; 6011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_right_sum = right_sum; 6021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // This could happen if all the splits are useless. 6061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (best_index < 0) { 6071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return false; 6081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Fill in stats to be used for leaf model. 6111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *best->mutable_split() = splits_[best_index]; 6121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Left 6131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left = best->mutable_left_stats(); 6141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_class_stats = left->mutable_classification(); 6151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left->set_weight_sum(best_left_sum); 6161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_counts = 6171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_class_stats->mutable_sparse_counts()->mutable_sparse_value(); 6181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Right 6201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right = best->mutable_right_stats(); 6211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_class_stats = right->mutable_classification(); 6221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right->set_weight_sum(best_right_sum); 6231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_counts = 6241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_class_stats->mutable_sparse_counts()->mutable_sparse_value(); 6251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& entry : total_counts_) { 6271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto it = left_counts_[best_index].find(entry.first); 6281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (it == left_counts_[best_index].end()) { 6291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*right_counts)[entry.first].set_float_value(entry.second); 6301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 6311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left = it->second; 6321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right = entry.second - it->second; 6331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*left_counts)[entry.first].set_float_value(left); 6341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (right > 0) { 6351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*right_counts)[entry.first].set_float_value(right); 6361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return true; 6401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 6411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// --------------------- Least Squares Regression --------------------------- // 6431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::ExtractFromProto( 6441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const FertileSlot& slot) { 6451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_outputs = params_.num_outputs(); 6461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower Initialize(); 6471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (!slot.has_post_init_leaf_stats()) { 6481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 6491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 6511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& total_sums = 6521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot.post_init_leaf_stats().regression().mean_output(); 6531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& total_squares = 6541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot.post_init_leaf_stats().regression().mean_output_squares(); 6551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Total counts. 6571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 6581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_[i] = total_sums.value(i).float_value(); 6591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_squares_[i] = total_squares.value(i).float_value(); 6601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Candidate counts and splits. 6631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int split_num = 0; 6641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& cand : slot.candidates()) { 66575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AddSplit(cand.split(), nullptr, nullptr, -1); 6661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& sums = cand.left_stats().regression().mean_output(); 6671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& squares = cand.left_stats().regression().mean_output_squares(); 6681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 6691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_sum(split_num, i) = sums.value(i).float_value(); 6701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_square(split_num, i) = squares.value(i).float_value(); 6711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_[split_num] = cand.left_stats().weight_sum(); 6731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++split_num; 6741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 6761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const { 6781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_outputs = params_.num_outputs(); 6791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* slot_stats = slot->mutable_post_init_leaf_stats(); 6801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot_stats->set_weight_sum(weight_sum_); 6811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* total_sums = slot->mutable_post_init_leaf_stats() 6831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_regression() 6841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_mean_output(); 6851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* total_squares = slot->mutable_post_init_leaf_stats() 6861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_regression() 6871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_mean_output_squares(); 6881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < total_sum_.size(); ++i) { 6901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sums->add_value()->set_float_value(total_sum_[i]); 6911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_squares->add_value()->set_float_value(total_sum_squares_[i]); 6921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int split_num = 0; split_num < num_splits(); ++split_num) { 6951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* cand = slot->add_candidates(); 6961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *cand->mutable_split() = splits_[split_num]; 6971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* sums = cand->mutable_left_stats() 6981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_regression() 6991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_mean_output(); 7001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* squares = cand->mutable_left_stats() 7011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_regression() 7021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_mean_output_squares(); 7031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 7041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower sums->add_value()->set_float_value(left_sum(split_num, i)); 7051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower squares->add_value()->set_float_value(left_square(split_num, i)); 7061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]); 7081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 7101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::AddExample( 7121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, 7131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int example) { 7141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_outputs = params_.num_outputs(); 7151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Update splits. 7161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 7171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto& eval = evaluators_[i]; 7181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (eval->Decide(input_data, example) == LEFT_INDEX) { 7191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int j = 0; j < num_outputs; ++j) { 7201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float output = target->GetTargetAsContinuous(example, j); 7211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_sum(i, j) += output; 7221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_square(i, j) += output * output; 7231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++left_counts_[i]; 7251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Update totals. 7291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 7301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float output = target->GetTargetAsContinuous(example, i); 7311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_[i] += output; 7321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_squares_[i] += output * output; 7331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ += 1.0; 7351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 7361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat LeastSquaresRegressionGrowStats::SplitVariance(int split) const { 7381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float total_variance = 0; 7391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < params_.num_outputs(); ++i) { 7401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Left side 7411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float le_x = 7421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_sum(split, i) / left_counts_[split]; 7431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float le_x2 = 7451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_square(split, i) / left_counts_[split]; 7461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_variance += le_x2 - le_x * le_x; 7471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Right side 7491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float re_x = (total_sum_[i] - left_sum(split, i)) / 7501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (weight_sum_ - left_counts_[split]); 7511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float re_x2 = 7531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (total_sum_squares_[i] - left_square(split, i)) / 7541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (weight_sum_ - left_counts_[split]); 7551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_variance += re_x2 - re_x * re_x; 7561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return total_variance; 7581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 7591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const { 7611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float min_score = FLT_MAX; 7621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int best_index = -1; 7631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_outputs = params_.num_outputs(); 7641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 7651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) { 7661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float split_score = SplitVariance(i); 7671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (split_score < min_score) { 7681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower min_score = split_score; 7691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_index = i; 7701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // This could happen if all the splits are useless. 7751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (best_index < 0) { 7761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return false; 7771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Fill in right stats to be used for leaf model. 7801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *best->mutable_split() = splits_[best_index]; 7811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Left 7821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left = best->mutable_left_stats(); 7831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_reg_stats = left->mutable_regression(); 7841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left->set_weight_sum(left_counts_[best_index]); 7851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_output_sum = left_reg_stats->mutable_mean_output(); 7861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 7871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_output_sum->add_value()->set_float_value( 7881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_sum(best_index, i)); 7891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 7901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Right 7921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right = best->mutable_right_stats(); 7931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_reg_stats = right->mutable_regression(); 7941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right->set_weight_sum(weight_sum_ - left_counts_[best_index]); 7951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_output_sum = right_reg_stats->mutable_mean_output(); 7961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 7971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_output_sum->add_value()->set_float_value( 7981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_[i] - left_sum(best_index, i)); 7991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return true; 8011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 8021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool LeastSquaresRegressionGrowStats::IsFinished() const { 8041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return weight_sum_ >= split_after_samples_; 8051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 8061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorforest 8081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorflow 809