grow_stats.cc revision 75f03e2d509d016021f8508555f9ab96af2c7cfe
11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License");
41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License.
51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at
61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//     http://www.apache.org/licenses/LICENSE-2.0
81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software
101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS,
111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and
131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License.
141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// =============================================================================
151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h"
161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <cfloat>
181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <queue>
191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/lib/random/distribution_sampler.h"
221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow {
251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest {
261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// When creating evaluators for the split candidates, use these
281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// for the left and right return values.
291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerstatic const int32 LEFT_INDEX = 0;
301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerstatic const int32 RIGHT_INDEX = 1;
311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerGrowStats::GrowStats(const TensorForestParams& params, int32 depth)
3308a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    : weight_sum_(0),
3408a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower      depth_(depth),
351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      params_(params),
361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      split_after_samples_(ResolveParam(params.split_after_samples(), depth)),
371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      num_splits_to_consider_(
381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower          ResolveParam(params.num_splits_to_consider(), depth)),
391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      num_outputs_(params.num_outputs()) {}
401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlowervoid GrowStats::AddSplit(const decision_trees::BinaryNode& split,
4275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower                         const std::unique_ptr<TensorDataSet>& input_data,
4375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower                         const InputTarget* target, int example) {
4475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  // It's possible that the split collection calls AddSplit, but we actually
4575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  // have all the splits we need and are just waiting for them to be fully
4675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  // initialized.
4775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  if (splits_.size() < num_splits_to_consider_) {
4875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    splits_.push_back(split);
4975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    evaluators_.emplace_back(
5075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
5175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AddSplitStats(target, example);
5275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  }
5375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower
5475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  if (input_data != nullptr && target != nullptr &&
5575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      params_.initialize_average_splits()) {
5675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AdditionalInitializationExample(input_data, target, example);
5775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  }
581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid GrowStats::RemoveSplit(int split_num) {
611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  splits_.erase(splits_.begin() + split_num);
621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  evaluators_.erase(evaluators_.begin() + split_num);
631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  RemoveSplitStats(split_num);
641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Classification --------------------------- //
671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerClassificationStats::ClassificationStats(const TensorForestParams& params,
691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                         int32 depth)
701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    : GrowStats(params, depth), finish_early_(false) {
711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Early splitting params.
721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params.finish_type().type() == SPLIT_FINISH_BASIC) {
731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    min_split_samples_ = split_after_samples_;
7408a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    finish_sample_epoch_ = 1;
7508a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    finish_check_every_ = split_after_samples_ * 2;
761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  } else {
771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (!params.has_dominate_fraction() || !params.has_min_split_samples()) {
781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      LOG(FATAL) << "dominate_fraction and min_split_samples "
791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                 << "required for early-finish strategy.";
801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    } else {
811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      min_split_samples_ = ResolveParam(params.min_split_samples(), depth);
821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      finish_check_every_ =
831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower          ResolveParam(params.finish_type().check_every_steps(), depth);
841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      finish_sample_epoch_ = min_split_samples_ / finish_check_every_;
851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) {
881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_;
891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Pruning params.
941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params.pruning_type().type() != SPLIT_PRUNE_NONE) {
951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    prune_check_every_ =
961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        ResolveParam(params.pruning_type().prune_every_samples(), depth);
971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    prune_sample_epoch_ = 1;
981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    prune_fraction_ = 0.0;
991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    switch (params_.pruning_type().type()) {
1001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      case SPLIT_PRUNE_HALF:
1011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        prune_fraction_ = 0.5;
1021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        break;
1031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      case SPLIT_PRUNE_QUARTER:
1041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        prune_fraction_ = 0.25;
1051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        break;
1061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      case SPLIT_PRUNE_10_PERCENT:
1071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        prune_fraction_ = 0.10;
1081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        break;
1091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      case SPLIT_PRUNE_HOEFFDING:
1101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
1111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_));
1121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        break;
1131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      default:
1141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        LOG(WARNING) << "Unknown pruning type";
1151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
11608a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower  } else {
11708a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    prune_check_every_ = split_after_samples_ * 2;
11808a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    prune_sample_epoch_ = 1;
1191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params.use_running_stats_method()) {
1221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_gini_.reset(new RunningGiniScores());
1231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    right_gini_.reset(new RunningGiniScores());
1241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  uint64 time_seed = static_cast<uint64>(std::clock());
1271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  single_rand_ = std::unique_ptr<random::PhiloxRandom>(
1281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      new random::PhiloxRandom(time_seed));
1291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  rng_ = std::unique_ptr<random::SimplePhilox>(
1301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      new random::SimplePhilox(single_rand_.get()));
1311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
1321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
13375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlowervoid ClassificationStats::AdditionalInitializationExample(
13475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
13575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    int example) {
13675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  const int32 new_target = target->GetTargetAsClassIndex(example, 0);
13775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  std::unordered_set<int> to_erase;
13875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  for (auto it = half_initialized_splits_.begin();
13975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower       it != half_initialized_splits_.end(); ++it) {
14075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    if (it->second != new_target) {
14175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      auto& split = splits_[it->first];
14275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      if (split.has_inequality_left_child_test()) {
14375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        auto& test = split.inequality_left_child_test();
14475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        auto* thresh =
14575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower            split.mutable_inequality_left_child_test()->mutable_threshold();
14675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        if (test.has_feature_id()) {
14775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower          const float val =
14875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower              input_data->GetExampleValue(example, test.feature_id());
14975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower          thresh->set_float_value((thresh->float_value() + val) / 2);
15075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        }
15175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      }
15275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      to_erase.insert(it->first);
15375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    }
15475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  }
15575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower
15675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  for (const int split_id : to_erase) {
15775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    half_initialized_splits_.erase(split_id);
15875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  }
15975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower}
16075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower
1611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool ClassificationStats::IsFinished() const {
1621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1;
1631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return basic || finish_early_;
1641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
1651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum,
1671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                                float* right_sum) const {
1681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (left_gini_ == nullptr) {
1691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return GiniScore(split, left_sum, right_sum);
1701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  } else {
1711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *left_sum = left_gini_->sum(split);
1721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float left = WeightedSmoothedGini(
1731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        *left_sum, left_gini_->square(split), num_outputs_);
1741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *right_sum = right_gini_->sum(split);
1761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float right = WeightedSmoothedGini(
1771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        *right_sum, right_gini_->square(split), num_outputs_);
1781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left + right;
1801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
1821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::AddExample(
1841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
1851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    int example) {
1861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int64 int_label = target->GetTargetAsClassIndex(example, 0);
1871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float weight = target->GetTargetWeight(example);
1881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
1901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto& eval = evaluators_[i];
1911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (eval->Decide(input_data, example) == LEFT_INDEX) {
1921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      if (left_gini_ != nullptr) {
1931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_gini_->update(i, left_count(i, int_label), weight);
1941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
1951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      ClassificationAddLeftExample(i, int_label, weight);
1961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    } else if (right_gini_ != nullptr) {
1971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right_gini_->update(i, right_count(i, int_label), weight);
1981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
1991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  ClassificationAddTotalExample(int_label, weight);
2021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ += weight;
2041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  CheckFinishEarly();
2061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  CheckPrune();
2071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
2081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckPrune() {
2101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (IsFinished() || weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
2111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
2121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  ++prune_sample_epoch_;
2141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) {
2161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    CheckPruneHoeffding();
2171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
2181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int to_remove = num_splits() * prune_fraction_;
2211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (to_remove <= 0) {
2221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
2231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // pair ordering is first-then-second by default, no need for custom
2261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // comparison.  Use std::greater to make it a min-heap.
2271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
2281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                      std::greater<std::pair<float, int>>>
2291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      worst;
2301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Track indices that are in the heap so we can iterate over them
2321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // by largest-first later.
2331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::set<int> indices;
2341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
2361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float left, right;
2371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float split_score = MaybeCachedGiniScore(i, &left, &right);
2381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (worst.size() < to_remove) {
2391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      worst.push(std::pair<float, int>(split_score, i));
2401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      indices.insert(i);
2411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    } else if (worst.top().first < split_score) {
2421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      indices.erase(worst.top().second);
2431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      worst.pop();
2441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      worst.push(std::pair<float, int>(split_score, i));
2451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      indices.insert(i);
2461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
2471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // traverse indices from the back so that they are removed correctly.
2501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (auto it = indices.rbegin(); it != indices.rend(); ++it) {
2511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    RemoveSplit(*it);
2521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
2541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckPruneHoeffding() {
2561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> split_scores(num_splits());
2571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Find best split score
2581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float best_split_score = FLT_MAX;
2591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
2601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float left, right;
2611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    split_scores[i] = MaybeCachedGiniScore(i, &left, &right);
2621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (split_scores[i] < best_split_score) {
2631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      best_split_score = split_scores[i];
2641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
2651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // We apply the Hoeffding bound to the difference between the best split
2681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // score and the i-th split score.
2691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted.
2701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float num_classes = params_.num_outputs();
2711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes);
2721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_);
2731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = num_splits() - 1; i >= 0; i--) {
2741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (split_scores[i] - best_split_score > epsilon) {
2751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      RemoveSplit(i);
2761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
2771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
2791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarly() {
2811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (weight_sum_ < min_split_samples_ ||
2821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      weight_sum_ < finish_sample_epoch_ * finish_check_every_) {
2831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
2841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  ++finish_sample_epoch_;
2861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) {
2881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    CheckFinishEarlyHoeffding();
2891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) {
2901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    CheckFinishEarlyBootstrap();
2911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
2931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarlyHoeffding() {
2951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Each term in the Gini impurity can range from 0 to 0.5 * 0.5.
2961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float range = 0.25 * static_cast<float>(params_.num_outputs()) * weight_sum_;
2971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float hoeffding_bound =
2991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_));
3001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float unused_left_sum, unused_right_sum;
3021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::function<float(int)> score_fn =
3031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
3041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                std::placeholders::_1, &unused_left_sum, &unused_right_sum);
3051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float best_score;
3071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 best_index;
3081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float second_best_score;
3091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 second_best_index;
3101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
3111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower             &second_best_score, &second_best_index);
3121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  finish_early_ = (second_best_score - best_score) > hoeffding_bound;
3141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
3151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::MakeBootstrapWeights(int index,
3171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                               std::vector<float>* weights) {
3181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int n = weight_sum_;
3191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float denom = static_cast<float>(n) + static_cast<float>(num_outputs_);
3201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs_; ++i) {
3211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    // Use the Laplace smoothed per-class probabilities when generating the
3221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    // bootstrap samples.
3231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    (*weights)[i] = (left_count(index, i) + 1.0) / denom;
3241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom;
3251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
3271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerint ClassificationStats::NumBootstrapSamples() const {
3291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float p = 1.0 - dominate_fraction_;
3301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int bootstrap_samples = 1;
3311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  while (p < 1.0) {
3321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ++bootstrap_samples;
3331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    p = p * 2;
3341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return bootstrap_samples;
3361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
3371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarlyBootstrap() {
3391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float unused_left_sum, unused_right_sum;
3401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::function<float(int)> score_fn =
3411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
3421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                std::placeholders::_1, &unused_left_sum, &unused_right_sum);
3431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float best_score;
3451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 best_index;
3461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float second_best_score;
3471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 second_best_index;
3481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
3491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower             &second_best_score, &second_best_index);
3501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> weights1(num_outputs_ * 2);
3521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  MakeBootstrapWeights(best_index, &weights1);
3531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  random::DistributionSampler ds1(weights1);
3541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> weights2(num_outputs_ * 2);
3561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  MakeBootstrapWeights(second_best_index, &weights2);
3571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  random::DistributionSampler ds2(weights2);
3581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int bootstrap_samples = NumBootstrapSamples();
3601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int worst_g1 = 0;
3621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < bootstrap_samples; i++) {
3631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get());
3641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    worst_g1 = std::max(worst_g1, g1);
3651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int best_g2 = 99;
3681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < bootstrap_samples; i++) {
3691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get());
3701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    best_g2 = std::min(best_g2, g2);
3711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  finish_early_ = worst_g1 < best_g2;
3741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
3751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Dense Classification --------------------------- //
3771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
3781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  Initialize();
3791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (!slot.has_post_init_leaf_stats()) {
3801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
3811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_classes = params_.num_outputs();
3831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
3841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const auto& class_stats =
3851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      slot.post_init_leaf_stats().classification().dense_counts();
3861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Total counts.
3881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_classes; ++i) {
3891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_counts_[i] = class_stats.value(i).float_value();
3901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    num_outputs_seen_ += total_counts_[i] != 0;
3911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Candidate counts and splits.
3941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int split_num = 0;
3951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& cand : slot.candidates()) {
39675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AddSplit(cand.split(), nullptr, nullptr, -1);
3971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const auto& left_stats = cand.left_stats().classification().dense_counts();
3981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (int i = 0; i < num_classes; ++i) {
3991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float val = left_stats.value(i).float_value();
4001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      mutable_left_count(split_num, i) = val;
4011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      MaybeInitializeRunningCount(split_num, val);
4021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
4031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ++split_num;
4041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
4061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
4081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* slot_stats = slot->mutable_post_init_leaf_stats();
4091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  slot_stats->set_weight_sum(weight_sum_);
4101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* class_stats = slot->mutable_post_init_leaf_stats()
4121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_classification()
4131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_dense_counts();
4141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs_; ++i) {
4151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    class_stats->add_value()->set_float_value(total_counts_[i]);
4161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int split_num = 0;  split_num < num_splits(); ++split_num) {
4191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* cand = slot->add_candidates();
4201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *cand->mutable_split() = splits_[split_num];
4211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* left_stats = cand->mutable_left_stats()
4221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_classification()
4231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_dense_counts();
4241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (int i = 0; i < num_outputs_; ++i) {
4251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower       left_stats->add_value()->set_float_value(left_count(split_num, i));
4261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
4271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
4291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat DenseClassificationGrowStats::GiniScore(int split, float* left_sum,
4311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                              float* right_sum) const {
4321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float left_square = 0, right_square = 0;
4331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *left_sum = 0;
4341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *right_sum = 0;
4351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int j = 0; j < num_outputs_; ++j) {
4361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float left = left_count(split, j);
4371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *left_sum += left;
4381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_square += left * left;
4391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float right = right_count(split, j);
4401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *right_sum += right;
4411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    right_square += right * right;
4421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float left_score =
4451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      WeightedSmoothedGini(*left_sum, left_square, num_outputs_);
4461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float right_score =
4471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      WeightedSmoothedGini(*right_sum, right_square, num_outputs_);
4481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return left_score + right_score;
4491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
4501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool DenseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
4521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float min_score = FLT_MAX;
4531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int best_index = -1;
4541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float best_left_sum, best_right_sum;
4551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Calculate sums.
4571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
4581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float left_sum, right_sum;
4591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
4601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    // Find the lowest gini.
4611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (left_sum > 0 && right_sum > 0 &&
4621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        split_score < min_score) {  // useless check
4631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      min_score = split_score;
4641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      best_index = i;
4651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      best_left_sum = left_sum;
4661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      best_right_sum = right_sum;
4671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
4681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // This could happen if all the splits are useless.
4711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (best_index < 0) {
4721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return false;
4731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Fill in stats to be used for leaf model.
4761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *best->mutable_split() = splits_[best_index];
4771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Left
4781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left = best->mutable_left_stats();
4791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_class_stats = left->mutable_classification();
4801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  left->set_weight_sum(best_left_sum);
4811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_counts = left_class_stats->mutable_dense_counts();
4821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < params_.num_outputs(); ++i) {
4831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts->add_value()->set_float_value(
4841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_count(best_index, i));
4851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Right
4881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right = best->mutable_right_stats();
4891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_class_stats = right->mutable_classification();
4901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  right->set_weight_sum(best_right_sum);
4911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_counts = right_class_stats->mutable_dense_counts();
4921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < params_.num_outputs(); ++i) {
4931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    right_counts->add_value()->set_float_value(
4941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        total_counts_[i] - left_count(best_index, i));
4951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return true;
4971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
4981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Sparse Classification --------------------------- //
5001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
5011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  Initialize();
5021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (!slot.has_post_init_leaf_stats()) {
5031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
5041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
5061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const auto& class_stats =
5071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      slot.post_init_leaf_stats().classification().sparse_counts();
5081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Total counts.
5101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (auto const& entry : class_stats.sparse_value()) {
5111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_counts_[entry.first] = entry.second.float_value();
5121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Candidate counts and splits.
5151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int split_num = 0;
5161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& cand : slot.candidates()) {
51775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AddSplit(cand.split(), nullptr, nullptr, -1);
5181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const auto& left_stats = cand.left_stats().classification().sparse_counts();
5191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (auto const& entry : left_stats.sparse_value()) {
5201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float val = entry.second.float_value();
5211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_counts_[split_num][entry.first] = val;
5221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      MaybeInitializeRunningCount(split_num, val);
5231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
5241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ++split_num;
5251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
5271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
5291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* slot_stats = slot->mutable_post_init_leaf_stats();
5301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  slot_stats->set_weight_sum(weight_sum_);
5311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* class_stats = slot->mutable_post_init_leaf_stats()
5331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_classification()
5341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_sparse_counts()
5351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_sparse_value();
5361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& entry : total_counts_) {
5371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    decision_trees::Value val;
5381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    val.set_float_value(entry.second);
5391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    (*class_stats)[entry.first] = val;
5401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int split_num = 0;  split_num < num_splits(); ++split_num) {
5431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* cand = slot->add_candidates();
5441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *cand->mutable_split() = splits_[split_num];
5451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* left_stats = cand->mutable_left_stats()
5461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_classification()
5471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_sparse_counts()
5481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_sparse_value();
5491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (const auto& entry : left_counts_[split_num]) {
5501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      decision_trees::Value val;
5511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      val.set_float_value(entry.second);
5521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      (*left_stats)[entry.first] = val;
5531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
5541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
5561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat SparseClassificationGrowStats::GiniScore(
5581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    int split, float* left_sum, float* right_sum) const {
5591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float left_square = 0, right_square = 0;
5601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *left_sum = 0;
5611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *right_sum = 0;
5621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& entry : total_counts_) {
5631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const int label = entry.first;
5641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float left = 0;
5651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float right = 0;
5661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto it = left_counts_[split].find(label);
5671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (it == left_counts_[split].end()) {
5681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right = entry.second;
5691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    } else {
5701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left = it->second;
5711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right = entry.second - it->second;
5721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
5731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *left_sum += left;
5741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_square += left * left;
5751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *right_sum += right;
5761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    right_square += right * right;
5771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_classes = params_.num_outputs();
5791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float left_score =
5801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      WeightedSmoothedGini(*left_sum, left_square, num_classes);
5811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float right_score =
5821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      WeightedSmoothedGini(*right_sum, right_square, num_classes);
5831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return left_score + right_score;
5841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
5851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const {
5871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float min_score = FLT_MAX;
5881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int best_index = -1;
5891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float best_left_sum = -1;
5901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float best_right_sum = -1;
5911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Find the lowest gini.
5931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
5941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float left_sum, right_sum;
5951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
5961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (left_sum > 0 && right_sum > 0 &&
5971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        split_score < min_score) {  // useless check
5981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      min_score = split_score;
5991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      best_index = i;
6001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      best_left_sum = left_sum;
6011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      best_right_sum = right_sum;
6021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
6031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
6041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // This could happen if all the splits are useless.
6061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (best_index < 0) {
6071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return false;
6081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
6091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Fill in stats to be used for leaf model.
6111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *best->mutable_split() = splits_[best_index];
6121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Left
6131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left = best->mutable_left_stats();
6141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_class_stats = left->mutable_classification();
6151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  left->set_weight_sum(best_left_sum);
6161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_counts =
6171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_class_stats->mutable_sparse_counts()->mutable_sparse_value();
6181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Right
6201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right = best->mutable_right_stats();
6211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_class_stats = right->mutable_classification();
6221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  right->set_weight_sum(best_right_sum);
6231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_counts =
6241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right_class_stats->mutable_sparse_counts()->mutable_sparse_value();
6251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& entry : total_counts_) {
6271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto it = left_counts_[best_index].find(entry.first);
6281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (it == left_counts_[best_index].end()) {
6291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      (*right_counts)[entry.first].set_float_value(entry.second);
6301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    } else {
6311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float left = it->second;
6321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float right = entry.second - it->second;
6331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      (*left_counts)[entry.first].set_float_value(left);
6341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      if (right > 0) {
6351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        (*right_counts)[entry.first].set_float_value(right);
6361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
6371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
6381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
6391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return true;
6401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
6411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// --------------------- Least Squares Regression --------------------------- //
6431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::ExtractFromProto(
6441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const FertileSlot& slot) {
6451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_outputs = params_.num_outputs();
6461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  Initialize();
6471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (!slot.has_post_init_leaf_stats()) {
6481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
6491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
6501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
6511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const auto& total_sums =
6521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      slot.post_init_leaf_stats().regression().mean_output();
6531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const auto& total_squares =
6541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      slot.post_init_leaf_stats().regression().mean_output_squares();
6551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Total counts.
6571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs; ++i) {
6581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_[i] = total_sums.value(i).float_value();
6591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_squares_[i] = total_squares.value(i).float_value();
6601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
6611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Candidate counts and splits.
6631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int split_num = 0;
6641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& cand : slot.candidates()) {
66575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AddSplit(cand.split(), nullptr, nullptr, -1);
6661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const auto& sums = cand.left_stats().regression().mean_output();
6671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const auto& squares = cand.left_stats().regression().mean_output_squares();
6681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (int i = 0; i < num_outputs; ++i) {
6691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_sum(split_num, i) = sums.value(i).float_value();
6701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_square(split_num, i) = squares.value(i).float_value();
6711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
6721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_[split_num] = cand.left_stats().weight_sum();
6731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ++split_num;
6741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
6751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
6761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const {
6781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_outputs = params_.num_outputs();
6791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* slot_stats = slot->mutable_post_init_leaf_stats();
6801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  slot_stats->set_weight_sum(weight_sum_);
6811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* total_sums = slot->mutable_post_init_leaf_stats()
6831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                         ->mutable_regression()
6841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                         ->mutable_mean_output();
6851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* total_squares = slot->mutable_post_init_leaf_stats()
6861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                            ->mutable_regression()
6871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                            ->mutable_mean_output_squares();
6881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < total_sum_.size(); ++i) {
6901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sums->add_value()->set_float_value(total_sum_[i]);
6911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_squares->add_value()->set_float_value(total_sum_squares_[i]);
6921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
6931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int split_num = 0;  split_num < num_splits(); ++split_num) {
6951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* cand = slot->add_candidates();
6961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *cand->mutable_split() = splits_[split_num];
6971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* sums = cand->mutable_left_stats()
6981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_regression()
6991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_mean_output();
7001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* squares = cand->mutable_left_stats()
7011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                        ->mutable_regression()
7021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                        ->mutable_mean_output_squares();
7031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (int i = 0; i < num_outputs; ++i) {
7041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      sums->add_value()->set_float_value(left_sum(split_num, i));
7051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      squares->add_value()->set_float_value(left_square(split_num, i));
7061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
7071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]);
7081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
7091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
7101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::AddExample(
7121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
7131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    int example) {
7141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_outputs = params_.num_outputs();
7151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Update splits.
7161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
7171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto& eval = evaluators_[i];
7181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (eval->Decide(input_data, example) == LEFT_INDEX) {
7191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      for (int j = 0; j < num_outputs; ++j) {
7201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        const float output = target->GetTargetAsContinuous(example, j);
7211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_sum(i, j) += output;
7221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_square(i, j) += output * output;
7231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
7241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      ++left_counts_[i];
7251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
7261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
7271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Update totals.
7291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs; ++i) {
7301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float output = target->GetTargetAsContinuous(example, i);
7311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_[i] += output;
7321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_squares_[i] += output * output;
7331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
7341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ += 1.0;
7351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
7361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat LeastSquaresRegressionGrowStats::SplitVariance(int split) const {
7381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float total_variance = 0;
7391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < params_.num_outputs(); ++i) {
7401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    // Left side
7411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float le_x =
7421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_sum(split, i) / left_counts_[split];
7431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float le_x2 =
7451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_square(split, i) / left_counts_[split];
7461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_variance += le_x2 - le_x * le_x;
7471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    // Right side
7491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float re_x = (total_sum_[i] - left_sum(split, i)) /
7501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                       (weight_sum_ - left_counts_[split]);
7511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float re_x2 =
7531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        (total_sum_squares_[i] - left_square(split, i)) /
7541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        (weight_sum_ - left_counts_[split]);
7551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_variance += re_x2 - re_x * re_x;
7561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
7571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return total_variance;
7581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
7591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const {
7611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float min_score = FLT_MAX;
7621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int best_index = -1;
7631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_outputs = params_.num_outputs();
7641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
7651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) {
7661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float split_score = SplitVariance(i);
7671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      if (split_score < min_score) {
7681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        min_score = split_score;
7691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        best_index = i;
7701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
7711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
7721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
7731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // This could happen if all the splits are useless.
7751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (best_index < 0) {
7761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return false;
7771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
7781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Fill in right stats to be used for leaf model.
7801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *best->mutable_split() = splits_[best_index];
7811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Left
7821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left = best->mutable_left_stats();
7831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_reg_stats = left->mutable_regression();
7841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  left->set_weight_sum(left_counts_[best_index]);
7851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_output_sum = left_reg_stats->mutable_mean_output();
7861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs; ++i) {
7871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_output_sum->add_value()->set_float_value(
7881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_sum(best_index, i));
7891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
7901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Right
7921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right = best->mutable_right_stats();
7931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_reg_stats = right->mutable_regression();
7941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  right->set_weight_sum(weight_sum_ - left_counts_[best_index]);
7951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_output_sum = right_reg_stats->mutable_mean_output();
7961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs; ++i) {
7971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    right_output_sum->add_value()->set_float_value(
7981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        total_sum_[i] - left_sum(best_index, i));
7991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
8001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return true;
8011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
8021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool LeastSquaresRegressionGrowStats::IsFinished() const {
8041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return weight_sum_ >= split_after_samples_;
8051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
8061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorforest
8081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorflow
809