11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License"); 41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License. 51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at 61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// http://www.apache.org/licenses/LICENSE-2.0 81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software 101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS, 111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and 131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License. 141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ============================================================================= 151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" 161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <cfloat> 181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <queue> 191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" 211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/lib/random/distribution_sampler.h" 221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow { 241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest { 251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// When creating evaluators for the split candidates, use these 271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// for the left and right return values. 281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerstatic const int32 LEFT_INDEX = 0; 291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerstatic const int32 RIGHT_INDEX = 1; 301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerGrowStats::GrowStats(const TensorForestParams& params, int32 depth) 3208a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower : weight_sum_(0), 3308a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower depth_(depth), 341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower params_(params), 351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower split_after_samples_(ResolveParam(params.split_after_samples(), depth)), 361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower num_splits_to_consider_( 371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ResolveParam(params.num_splits_to_consider(), depth)), 381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower num_outputs_(params.num_outputs()) {} 391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlowervoid GrowStats::AddSplit(const decision_trees::BinaryNode& split, 4175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, 4275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const InputTarget* target, int example) { 4375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower // It's possible that the split collection calls AddSplit, but we actually 4475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower // have all the splits we need and are just waiting for them to be fully 4575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower // initialized. 4675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (splits_.size() < num_splits_to_consider_) { 4775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower splits_.push_back(split); 4875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower evaluators_.emplace_back( 4975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX)); 5075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AddSplitStats(target, example); 5175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 5275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower 5375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (input_data != nullptr && target != nullptr && 5475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower params_.initialize_average_splits()) { 5575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AdditionalInitializationExample(input_data, target, example); 5675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid GrowStats::RemoveSplit(int split_num) { 601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower splits_.erase(splits_.begin() + split_num); 611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower evaluators_.erase(evaluators_.begin() + split_num); 621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower RemoveSplitStats(split_num); 631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Classification --------------------------- // 661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerClassificationStats::ClassificationStats(const TensorForestParams& params, 681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 depth) 691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower : GrowStats(params, depth), finish_early_(false) { 701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Early splitting params. 711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params.finish_type().type() == SPLIT_FINISH_BASIC) { 721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower min_split_samples_ = split_after_samples_; 7308a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower finish_sample_epoch_ = 1; 7408a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower finish_check_every_ = split_after_samples_ * 2; 751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (!params.has_dominate_fraction() || !params.has_min_split_samples()) { 771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower LOG(FATAL) << "dominate_fraction and min_split_samples " 781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower << "required for early-finish strategy."; 791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower min_split_samples_ = ResolveParam(params.min_split_samples(), depth); 811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower finish_check_every_ = 821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ResolveParam(params.finish_type().check_every_steps(), depth); 831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower finish_sample_epoch_ = min_split_samples_ / finish_check_every_; 841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_); 861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) { 871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_; 881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Pruning params. 931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params.pruning_type().type() != SPLIT_PRUNE_NONE) { 941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_check_every_ = 951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ResolveParam(params.pruning_type().prune_every_samples(), depth); 961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_sample_epoch_ = 1; 971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_fraction_ = 0.0; 981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower switch (params_.pruning_type().type()) { 991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower case SPLIT_PRUNE_HALF: 1001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_fraction_ = 0.5; 1011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower break; 1021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower case SPLIT_PRUNE_QUARTER: 1031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_fraction_ = 0.25; 1041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower break; 1051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower case SPLIT_PRUNE_10_PERCENT: 1061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower prune_fraction_ = 0.10; 1071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower break; 1081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower case SPLIT_PRUNE_HOEFFDING: 1091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_); 1101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_)); 1111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower break; 1121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower default: 1131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower LOG(WARNING) << "Unknown pruning type"; 1141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 11508a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower } else { 11608a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower prune_check_every_ = split_after_samples_ * 2; 11708a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower prune_sample_epoch_ = 1; 1181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params.use_running_stats_method()) { 1211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_gini_.reset(new RunningGiniScores()); 1221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_gini_.reset(new RunningGiniScores()); 1231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower uint64 time_seed = static_cast<uint64>(std::clock()); 1261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower single_rand_ = std::unique_ptr<random::PhiloxRandom>( 1271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower new random::PhiloxRandom(time_seed)); 1281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower rng_ = std::unique_ptr<random::SimplePhilox>( 1291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower new random::SimplePhilox(single_rand_.get())); 1301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 1311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 13275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlowervoid ClassificationStats::AdditionalInitializationExample( 13375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, 13475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower int example) { 13575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const int32 new_target = target->GetTargetAsClassIndex(example, 0); 13675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower std::unordered_set<int> to_erase; 13775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower for (auto it = half_initialized_splits_.begin(); 13875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower it != half_initialized_splits_.end(); ++it) { 13975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (it->second != new_target) { 14075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower auto& split = splits_[it->first]; 14175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (split.has_inequality_left_child_test()) { 14275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower auto& test = split.inequality_left_child_test(); 14375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower auto* thresh = 14475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower split.mutable_inequality_left_child_test()->mutable_threshold(); 14575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (test.has_feature_id()) { 14675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const float val = 14775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower input_data->GetExampleValue(example, test.feature_id()); 14875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower thresh->set_float_value((thresh->float_value() + val) / 2); 14975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower to_erase.insert(it->first); 15275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower 15575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower for (const int split_id : to_erase) { 15675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower half_initialized_splits_.erase(split_id); 15775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 15875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower} 15975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower 1601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool ClassificationStats::IsFinished() const { 161054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower bool basic = (weight_sum_ >= split_after_samples_) && !is_pure(); 1621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return basic || finish_early_; 1631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 1641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum, 1661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float* right_sum) const { 1671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_gini_ == nullptr) { 1681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return GiniScore(split, left_sum, right_sum); 1691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 1701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum = left_gini_->sum(split); 1711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left = WeightedSmoothedGini( 1721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum, left_gini_->square(split), num_outputs_); 1731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum = right_gini_->sum(split); 1751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right = WeightedSmoothedGini( 1761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum, right_gini_->square(split), num_outputs_); 1771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left + right; 1791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 1811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::AddExample( 1831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, 1841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int example) { 1851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int64 int_label = target->GetTargetAsClassIndex(example, 0); 1861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float weight = target->GetTargetWeight(example); 1871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 1891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto& eval = evaluators_[i]; 1901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (eval->Decide(input_data, example) == LEFT_INDEX) { 1911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_gini_ != nullptr) { 1921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_gini_->update(i, left_count(i, int_label), weight); 1931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ClassificationAddLeftExample(i, int_label, weight); 195054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } else { 196054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (right_gini_ != nullptr) { 197054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_gini_->update(i, right_count(i, int_label), weight); 198054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 199054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower ClassificationAddRightExample(i, int_label, weight); 2001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ClassificationAddTotalExample(int_label, weight); 2041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ += weight; 2061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckFinishEarly(); 2081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckPrune(); 2091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 2101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckPrune() { 2123d72bc69ee838e8c6b0b801e274aac4c31647b22A. Unique TensorFlower if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() || 2133d72bc69ee838e8c6b0b801e274aac4c31647b22A. Unique TensorFlower weight_sum_ < prune_sample_epoch_ * prune_check_every_) { 2141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 2151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++prune_sample_epoch_; 2171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) { 2191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckPruneHoeffding(); 2201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 2211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int to_remove = num_splits() * prune_fraction_; 2241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (to_remove <= 0) { 2251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 2261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // pair ordering is first-then-second by default, no need for custom 2291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // comparison. Use std::greater to make it a min-heap. 2301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>, 2311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::greater<std::pair<float, int>>> 2321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst; 2331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Track indices that are in the heap so we can iterate over them 2351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // by largest-first later. 2361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::set<int> indices; 2371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 2391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left, right; 2401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float split_score = MaybeCachedGiniScore(i, &left, &right); 2411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (worst.size() < to_remove) { 2421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst.push(std::pair<float, int>(split_score, i)); 2431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower indices.insert(i); 2441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else if (worst.top().first < split_score) { 2451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower indices.erase(worst.top().second); 2461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst.pop(); 2471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst.push(std::pair<float, int>(split_score, i)); 2481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower indices.insert(i); 2491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // traverse indices from the back so that they are removed correctly. 2531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (auto it = indices.rbegin(); it != indices.rend(); ++it) { 2541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower RemoveSplit(*it); 2551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 2571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckPruneHoeffding() { 2591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> split_scores(num_splits()); 2601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Find best split score 2611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float best_split_score = FLT_MAX; 2621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 2631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left, right; 2641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower split_scores[i] = MaybeCachedGiniScore(i, &left, &right); 2651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (split_scores[i] < best_split_score) { 2661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_split_score = split_scores[i]; 2671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // We apply the Hoeffding bound to the difference between the best split 2711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // score and the i-th split score. 2721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted. 2731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float num_classes = params_.num_outputs(); 2741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes); 2751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_); 2761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = num_splits() - 1; i >= 0; i--) { 2771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (split_scores[i] - best_split_score > epsilon) { 2781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower RemoveSplit(i); 2791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 2821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarly() { 2841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (weight_sum_ < min_split_samples_ || 2851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ < finish_sample_epoch_ * finish_check_every_) { 2861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 2871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++finish_sample_epoch_; 2891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) { 2911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckFinishEarlyHoeffding(); 2921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) { 2931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower CheckFinishEarlyBootstrap(); 2941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 2961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarlyHoeffding() { 2981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Each term in the Gini impurity can range from 0 to 0.5 * 0.5. 2991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float range = 0.25 * static_cast<float>(params_.num_outputs()) * weight_sum_; 3001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float hoeffding_bound = 3021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_)); 3031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float unused_left_sum, unused_right_sum; 3051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::function<float(int)> score_fn = 3061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::bind(&ClassificationStats::MaybeCachedGiniScore, this, 3071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::placeholders::_1, &unused_left_sum, &unused_right_sum); 3081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float best_score; 3101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 best_index; 3111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float second_best_score; 3121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 second_best_index; 3131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower GetTwoBest(num_splits(), score_fn, &best_score, &best_index, 3141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower &second_best_score, &second_best_index); 3151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower finish_early_ = (second_best_score - best_score) > hoeffding_bound; 3171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 3181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::MakeBootstrapWeights(int index, 3201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float>* weights) { 3211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int n = weight_sum_; 3221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float denom = static_cast<float>(n) + static_cast<float>(num_outputs_); 3231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs_; ++i) { 3241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Use the Laplace smoothed per-class probabilities when generating the 3251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // bootstrap samples. 3261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*weights)[i] = (left_count(index, i) + 1.0) / denom; 3271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom; 3281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 3301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerint ClassificationStats::NumBootstrapSamples() const { 3321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float p = 1.0 - dominate_fraction_; 3331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int bootstrap_samples = 1; 3341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower while (p < 1.0) { 3351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++bootstrap_samples; 3361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower p = p * 2; 3371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return bootstrap_samples; 3391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 3401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarlyBootstrap() { 3421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float unused_left_sum, unused_right_sum; 3431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::function<float(int)> score_fn = 3441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::bind(&ClassificationStats::MaybeCachedGiniScore, this, 3451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::placeholders::_1, &unused_left_sum, &unused_right_sum); 3461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float best_score; 3481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 best_index; 3491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float second_best_score; 3501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 second_best_index; 3511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower GetTwoBest(num_splits(), score_fn, &best_score, &best_index, 3521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower &second_best_score, &second_best_index); 3531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> weights1(num_outputs_ * 2); 3551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower MakeBootstrapWeights(best_index, &weights1); 3561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower random::DistributionSampler ds1(weights1); 3571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> weights2(num_outputs_ * 2); 3591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower MakeBootstrapWeights(second_best_index, &weights2); 3601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower random::DistributionSampler ds2(weights2); 3611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int bootstrap_samples = NumBootstrapSamples(); 3631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int worst_g1 = 0; 3651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < bootstrap_samples; i++) { 3661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get()); 3671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower worst_g1 = std::max(worst_g1, g1); 3681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int best_g2 = 99; 3711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < bootstrap_samples; i++) { 3721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get()); 3731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_g2 = std::min(best_g2, g2); 3741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower finish_early_ = worst_g1 < best_g2; 3771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 3781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 379054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerbool ClassificationStats::BestSplit(SplitCandidate* best) const { 380054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float min_score = FLT_MAX; 381054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int best_index = -1; 382054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float best_left_sum, best_right_sum; 383054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 384054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Calculate sums. 385054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 386054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float left_sum, right_sum; 387054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum); 388054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Find the lowest gini. 389054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (left_sum > 0 && right_sum > 0 && 390054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower split_score < min_score) { // useless check 391054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower min_score = split_score; 392054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower best_index = i; 393054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower best_left_sum = left_sum; 394054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower best_right_sum = right_sum; 395054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 396054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 397054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 398054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // This could happen if all the splits are useless. 399054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (best_index < 0) { 400054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return false; 401054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 402054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 403054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Fill in stats to be used for leaf model. 404054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower *best->mutable_split() = splits_[best_index]; 405054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* left = best->mutable_left_stats(); 406054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left->set_weight_sum(best_left_sum); 407054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* right = best->mutable_right_stats(); 408054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right->set_weight_sum(best_right_sum); 409054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower InitLeafClassStats(best_index, left, right); 410054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 411054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return true; 412054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 413054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 4141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Dense Classification --------------------------- // 4151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { 4161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower Initialize(); 4171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (!slot.has_post_init_leaf_stats()) { 4181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 4191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_classes = params_.num_outputs(); 4211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 4221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& class_stats = 4231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot.post_init_leaf_stats().classification().dense_counts(); 4241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Total counts. 4261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_classes; ++i) { 4271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_[i] = class_stats.value(i).float_value(); 4281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower num_outputs_seen_ += total_counts_[i] != 0; 4291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Candidate counts and splits. 4321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int split_num = 0; 4331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& cand : slot.candidates()) { 43475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AddSplit(cand.split(), nullptr, nullptr, -1); 4351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& left_stats = cand.left_stats().classification().dense_counts(); 4361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_classes; ++i) { 4371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float val = left_stats.value(i).float_value(); 4381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower mutable_left_count(split_num, i) = val; 4391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower MaybeInitializeRunningCount(split_num, val); 4401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++split_num; 4421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 4441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const { 4461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* slot_stats = slot->mutable_post_init_leaf_stats(); 4471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot_stats->set_weight_sum(weight_sum_); 4481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* class_stats = slot->mutable_post_init_leaf_stats() 4501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_classification() 4511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_dense_counts(); 4521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs_; ++i) { 4531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower class_stats->add_value()->set_float_value(total_counts_[i]); 4541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4564463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower for (int split_num = 0; split_num < num_splits(); ++split_num) { 4571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* cand = slot->add_candidates(); 4581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *cand->mutable_split() = splits_[split_num]; 4591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_stats = cand->mutable_left_stats() 4601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_classification() 4611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_dense_counts(); 4621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs_; ++i) { 4634463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower left_stats->add_value()->set_float_value(left_count(split_num, i)); 4641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 4671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat DenseClassificationGrowStats::GiniScore(int split, float* left_sum, 4691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float* right_sum) const { 4701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left_square = 0, right_square = 0; 4711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum = 0; 4721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum = 0; 4731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int j = 0; j < num_outputs_; ++j) { 4741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left = left_count(split, j); 4751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum += left; 4761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_square += left * left; 4771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right = right_count(split, j); 4781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum += right; 4791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_square += right * right; 4801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left_score = 4831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower WeightedSmoothedGini(*left_sum, left_square, num_outputs_); 4841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right_score = 4851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower WeightedSmoothedGini(*right_sum, right_square, num_outputs_); 4861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_score + right_score; 4871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 4881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 489054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid DenseClassificationGrowStats::InitLeafClassStats( 490054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const { 491054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* left_class_stats = left_stats->mutable_classification(); 4921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_counts = left_class_stats->mutable_dense_counts(); 4931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < params_.num_outputs(); ++i) { 494054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts->add_value()->set_float_value(left_count(best_split_index, i)); 4951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 497054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* right_class_stats = right_stats->mutable_classification(); 4981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_counts = right_class_stats->mutable_dense_counts(); 4991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < params_.num_outputs(); ++i) { 500054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts->add_value()->set_float_value(total_counts_[i] - 501054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_count(best_split_index, i)); 5021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 5041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Sparse Classification --------------------------- // 5061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { 5071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower Initialize(); 5081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (!slot.has_post_init_leaf_stats()) { 5091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 5101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 5121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& class_stats = 5131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot.post_init_leaf_stats().classification().sparse_counts(); 5141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Total counts. 5161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (auto const& entry : class_stats.sparse_value()) { 5171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_[entry.first] = entry.second.float_value(); 5181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Candidate counts and splits. 5211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int split_num = 0; 5221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& cand : slot.candidates()) { 52375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AddSplit(cand.split(), nullptr, nullptr, -1); 5241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& left_stats = cand.left_stats().classification().sparse_counts(); 5251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (auto const& entry : left_stats.sparse_value()) { 5261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float val = entry.second.float_value(); 5271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_[split_num][entry.first] = val; 5281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower MaybeInitializeRunningCount(split_num, val); 5291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++split_num; 5311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 5331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const { 5351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* slot_stats = slot->mutable_post_init_leaf_stats(); 5361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot_stats->set_weight_sum(weight_sum_); 5371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* class_stats = slot->mutable_post_init_leaf_stats() 5391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_classification() 5401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_sparse_counts() 5411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_sparse_value(); 5421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& entry : total_counts_) { 5431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower decision_trees::Value val; 5441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower val.set_float_value(entry.second); 5451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*class_stats)[entry.first] = val; 5461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5484463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower for (int split_num = 0; split_num < num_splits(); ++split_num) { 5491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* cand = slot->add_candidates(); 5501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *cand->mutable_split() = splits_[split_num]; 5511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_stats = cand->mutable_left_stats() 5521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_classification() 5531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_sparse_counts() 5541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_sparse_value(); 5551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& entry : left_counts_[split_num]) { 5561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower decision_trees::Value val; 5571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower val.set_float_value(entry.second); 5581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*left_stats)[entry.first] = val; 5591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 5621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5634463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlowerfloat SparseClassificationGrowStats::GiniScore(int split, float* left_sum, 5644463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower float* right_sum) const { 5651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left_square = 0, right_square = 0; 5661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum = 0; 5671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum = 0; 5681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& entry : total_counts_) { 5691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int label = entry.first; 5701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left = 0; 5711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float right = 0; 5721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto it = left_counts_[split].find(label); 5731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (it == left_counts_[split].end()) { 5741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right = entry.second; 5751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 5761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left = it->second; 5771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right = entry.second - it->second; 5781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *left_sum += left; 5801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_square += left * left; 5811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *right_sum += right; 5821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_square += right * right; 5831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_classes = params_.num_outputs(); 5851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left_score = 5861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower WeightedSmoothedGini(*left_sum, left_square, num_classes); 5871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right_score = 5881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower WeightedSmoothedGini(*right_sum, right_square, num_classes); 5891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_score + right_score; 5901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 5911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 592054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid SparseClassificationGrowStats::InitLeafClassStats( 593054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const { 594054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* left_class_stats = left_stats->mutable_classification(); 5951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_counts = 5961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_class_stats->mutable_sparse_counts()->mutable_sparse_value(); 597054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* right_class_stats = right_stats->mutable_classification(); 5981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_counts = 5991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_class_stats->mutable_sparse_counts()->mutable_sparse_value(); 6001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& entry : total_counts_) { 602054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto it = left_counts_[best_split_index].find(entry.first); 603054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (it == left_counts_[best_split_index].end()) { 6041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*right_counts)[entry.first].set_float_value(entry.second); 6051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } else { 6061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float left = it->second; 6071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float right = entry.second - it->second; 6081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*left_counts)[entry.first].set_float_value(left); 6091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (right > 0) { 6101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (*right_counts)[entry.first].set_float_value(right); 6111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 6131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 614054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 615054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 616054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// -------------------- FixedSizeClassStats --------------------------------- // 617054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 618054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// FixedSizeClassStats implements the "SpaceSaving" algorithm by 619054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi. See for example 620054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf 621054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 622054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerint argmin(const std::unordered_map<int, float>& m) { 623054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int c = -1; 624054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float f = FLT_MAX; 625054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower for (const auto it : m) { 626054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (it.second < f) { 627054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower f = it.second; 628054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower c = it.first; 629054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 630054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 631054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return c; 632054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 633054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 634054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeClassStats::accumulate(int c, float w) { 635054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto it = class_weights_.find(c); 636054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (it != class_weights_.end()) { 637054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower it->second += w; 638054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (c == smallest_weight_class_) { 639054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower smallest_weight_class_ = argmin(class_weights_); 640054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 641054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return; 642054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 643054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 644054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (class_weights_.size() < n_) { 645054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower class_weights_.insert(it, std::pair<int, float>(c, w)); 646054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (class_weights_.size() == n_) { 647054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Can't assume last added has the smallest weight, because the 648054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // w's might be all different. 649054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower smallest_weight_class_ = argmin(class_weights_); 650054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 651054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return; 652054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 653054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 654054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // This is the slightly unintuitive heart of the SpaceSaving algorithm: 655054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // if the map is full and we see a new class, we find the entry with the 656054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // smallest weight and "take it over": we add our weight to its weight, 657054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // and assign it all to the new seen class. 658054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower it = class_weights_.find(smallest_weight_class_); 659054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float new_weight = it->second + w; 660054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower class_weights_.erase(it); 661054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower class_weights_[c] = new_weight; 662054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower smallest_weight_class_ = argmin(class_weights_); 663054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 664054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 665054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerfloat FixedSizeClassStats::get_weight(int c) const { 666054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Every entry in class_weights_ might be overstated by as much as the 667054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // smallest_weight. We therefore assume that each has been overstated 668054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // by smallest_weight / 2.0, and we re-distribute that mass over all 669054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // num_classes_ classes. 670054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float smallest_weight = 0.0; 671054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto it = class_weights_.find(smallest_weight_class_); 672054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (it != class_weights_.end()) { 673054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower smallest_weight = it->second; 674054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 675054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_); 676054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower it = class_weights_.find(c); 677054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (it != class_weights_.end()) { 678054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower w += it->second - smallest_weight / 2.0; 679054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 680054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return w; 681054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 682054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 683054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const { 684054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower *sum = 0.0; 685054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower *square = 0.0; 686054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 687054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float smallest_weight = 0.0; 688054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto it = class_weights_.find(smallest_weight_class_); 689054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (it != class_weights_.end()) { 690054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower smallest_weight = it->second; 691054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 692054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 693054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float w; 694054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower for (const auto it : class_weights_) { 695054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower *sum += it.second; 696054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower w = get_weight(it.first); 697054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower *square += w * w; 698054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 699054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 700054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_); 701054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower *square += (num_classes_ - n_) * w * w; 702054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 703054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 704054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeClassStats::ExtractFromProto( 705054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower const decision_trees::SparseVector& sparse_vector) { 706054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower for (const auto& it : sparse_vector.sparse_value()) { 707054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower class_weights_[it.first] = it.second.float_value(); 708054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 709054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (class_weights_.size() == n_) { 710054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower smallest_weight_class_ = argmin(class_weights_); 711054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 712054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 713054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 714054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeClassStats::PackToProto( 715054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower decision_trees::SparseVector* sparse_vector) const { 716054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower for (const auto it : class_weights_) { 717054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower (*sparse_vector->mutable_sparse_value())[it.first].set_float_value( 718054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower it.second); 719054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 720054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 721054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 722054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// --------------------- FixedSizeSparseClassificationGrowStats ------------- // 723054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 724054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeSparseClassificationGrowStats::ExtractFromProto( 725054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower const FertileSlot& slot) { 726054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower Initialize(); 727054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (!slot.has_post_init_leaf_stats()) { 728054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return; 729054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 730054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 731054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 732054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Candidate counts and splits. 733054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int split_num = 0; 734054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_.clear(); 735054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_.clear(); 736054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower for (const auto& cand : slot.candidates()) { 737054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower AddSplit(cand.split(), nullptr, nullptr, -1); 738054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower const auto& left_stats = cand.left_stats().classification().sparse_counts(); 739054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_.emplace_back(params_.num_classes_to_track(), 740054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower params_.num_outputs()); 741054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_[split_num].ExtractFromProto(left_stats); 742054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower const auto& right_stats = 743054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower cand.right_stats().classification().sparse_counts(); 744054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_.emplace_back(params_.num_classes_to_track(), 745054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower params_.num_outputs()); 746054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_[split_num].ExtractFromProto(right_stats); 747054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower ++split_num; 748054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 749054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 750054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 751054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeSparseClassificationGrowStats::PackToProto( 752054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower FertileSlot* slot) const { 753054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* slot_stats = slot->mutable_post_init_leaf_stats(); 754054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower slot_stats->set_weight_sum(weight_sum_); 755054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 756054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower for (int split_num = 0; split_num < num_splits(); ++split_num) { 757054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* cand = slot->add_candidates(); 758054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower *cand->mutable_split() = splits_[split_num]; 759054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* left_stats = cand->mutable_left_stats() 760054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower ->mutable_classification() 761054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower ->mutable_sparse_counts(); 762054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_[split_num].PackToProto(left_stats); 763054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* right_stats = cand->mutable_right_stats() 764054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower ->mutable_classification() 765054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower ->mutable_sparse_counts(); 766054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_[split_num].PackToProto(right_stats); 767054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 768054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 769054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 770054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerfloat FixedSizeSparseClassificationGrowStats::GiniScore( 771054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int split, float* left_sum, float* right_sum) const { 772054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float left_square, right_square; 773054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_[split].set_sum_and_square(left_sum, &left_square); 774054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_[split].set_sum_and_square(right_sum, &right_square); 775054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower const int32 num_classes = params_.num_outputs(); 776054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower const float left_score = 777054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower WeightedSmoothedGini(*left_sum, left_square, num_classes); 778054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower const float right_score = 779054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower WeightedSmoothedGini(*right_sum, right_square, num_classes); 780054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return left_score + right_score; 781054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower} 782054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 783054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeSparseClassificationGrowStats::InitLeafClassStats( 784054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const { 785054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* left_class_stats = left_stats->mutable_classification(); 786054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* left_counts = left_class_stats->mutable_sparse_counts(); 787054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_[best_split_index].PackToProto(left_counts); 788054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 789054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* right_class_stats = right_stats->mutable_classification(); 790054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower auto* right_counts = right_class_stats->mutable_sparse_counts(); 791054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_[best_split_index].PackToProto(right_counts); 7921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 7931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 7941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// --------------------- Least Squares Regression --------------------------- // 7951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::ExtractFromProto( 7961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const FertileSlot& slot) { 7971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_outputs = params_.num_outputs(); 7981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower Initialize(); 7991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (!slot.has_post_init_leaf_stats()) { 8001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return; 8011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 8031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& total_sums = 8041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot.post_init_leaf_stats().regression().mean_output(); 8051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& total_squares = 8061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot.post_init_leaf_stats().regression().mean_output_squares(); 8071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Total counts. 8091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 8101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_[i] = total_sums.value(i).float_value(); 8111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_squares_[i] = total_squares.value(i).float_value(); 8121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Candidate counts and splits. 8151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int split_num = 0; 8161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (const auto& cand : slot.candidates()) { 81775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower AddSplit(cand.split(), nullptr, nullptr, -1); 8181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& sums = cand.left_stats().regression().mean_output(); 8191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const auto& squares = cand.left_stats().regression().mean_output_squares(); 8201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 8211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_sum(split_num, i) = sums.value(i).float_value(); 8221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_square(split_num, i) = squares.value(i).float_value(); 8231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_[split_num] = cand.left_stats().weight_sum(); 8251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++split_num; 8261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 8281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const { 8301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_outputs = params_.num_outputs(); 8311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* slot_stats = slot->mutable_post_init_leaf_stats(); 8321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower slot_stats->set_weight_sum(weight_sum_); 8331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* total_sums = slot->mutable_post_init_leaf_stats() 8351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_regression() 8361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_mean_output(); 8371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* total_squares = slot->mutable_post_init_leaf_stats() 8381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_regression() 8391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_mean_output_squares(); 8401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < total_sum_.size(); ++i) { 8421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sums->add_value()->set_float_value(total_sum_[i]); 8431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_squares->add_value()->set_float_value(total_sum_squares_[i]); 8441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8464463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower for (int split_num = 0; split_num < num_splits(); ++split_num) { 8471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* cand = slot->add_candidates(); 8481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *cand->mutable_split() = splits_[split_num]; 8494463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower auto* sums = 8504463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower cand->mutable_left_stats()->mutable_regression()->mutable_mean_output(); 8511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* squares = cand->mutable_left_stats() 8521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_regression() 8531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ->mutable_mean_output_squares(); 8541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 8551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower sums->add_value()->set_float_value(left_sum(split_num, i)); 8561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower squares->add_value()->set_float_value(left_square(split_num, i)); 8571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]); 8591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 8611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::AddExample( 8631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, 8641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int example) { 8651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_outputs = params_.num_outputs(); 8661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Update splits. 8671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 8681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto& eval = evaluators_[i]; 8691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (eval->Decide(input_data, example) == LEFT_INDEX) { 8701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int j = 0; j < num_outputs; ++j) { 8711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float output = target->GetTargetAsContinuous(example, j); 8721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_sum(i, j) += output; 8731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_square(i, j) += output * output; 8741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ++left_counts_[i]; 8761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Update totals. 8801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 8811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float output = target->GetTargetAsContinuous(example, i); 8821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_[i] += output; 8831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_squares_[i] += output * output; 8841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 8851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ += 1.0; 8861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 8871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat LeastSquaresRegressionGrowStats::SplitVariance(int split) const { 8891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float total_variance = 0; 8901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < params_.num_outputs(); ++i) { 8911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Left side 8924463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower const float le_x = left_sum(split, i) / left_counts_[split]; 8931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8944463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower const float le_x2 = left_square(split, i) / left_counts_[split]; 8951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_variance += le_x2 - le_x * le_x; 8961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Right side 8981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float re_x = (total_sum_[i] - left_sum(split, i)) / 8991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower (weight_sum_ - left_counts_[split]); 9001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 9014463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) / 9024463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower (weight_sum_ - left_counts_[split]); 9031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_variance += re_x2 - re_x * re_x; 9041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 9051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return total_variance; 9061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 9071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 9081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const { 9091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float min_score = FLT_MAX; 9101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int best_index = -1; 9111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_outputs = params_.num_outputs(); 9121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_splits(); ++i) { 9131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) { 9141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float split_score = SplitVariance(i); 9151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (split_score < min_score) { 9161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower min_score = split_score; 9171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower best_index = i; 9181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 9191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 9201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 9211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 9221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // This could happen if all the splits are useless. 9231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (best_index < 0) { 9241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return false; 9251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 9261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 9271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Fill in right stats to be used for leaf model. 9281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower *best->mutable_split() = splits_[best_index]; 9291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Left 9301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left = best->mutable_left_stats(); 9311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_reg_stats = left->mutable_regression(); 9321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left->set_weight_sum(left_counts_[best_index]); 9331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* left_output_sum = left_reg_stats->mutable_mean_output(); 9341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 9354463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower left_output_sum->add_value()->set_float_value(left_sum(best_index, i)); 9361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 9371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 9381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Right 9391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right = best->mutable_right_stats(); 9401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_reg_stats = right->mutable_regression(); 9411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right->set_weight_sum(weight_sum_ - left_counts_[best_index]); 9421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower auto* right_output_sum = right_reg_stats->mutable_mean_output(); 9431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower for (int i = 0; i < num_outputs; ++i) { 9444463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower right_output_sum->add_value()->set_float_value(total_sum_[i] - 9454463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower left_sum(best_index, i)); 9461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 9471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return true; 9481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 9491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 9501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool LeastSquaresRegressionGrowStats::IsFinished() const { 9511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return weight_sum_ >= split_after_samples_; 9521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} 9531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 9541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorforest 9551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorflow 956