11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License");
41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License.
51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at
61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//     http://www.apache.org/licenses/LICENSE-2.0
81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software
101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS,
111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and
131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License.
141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// =============================================================================
151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h"
161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <cfloat>
181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <queue>
191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/lib/random/distribution_sampler.h"
221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow {
241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest {
251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// When creating evaluators for the split candidates, use these
271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// for the left and right return values.
281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerstatic const int32 LEFT_INDEX = 0;
291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerstatic const int32 RIGHT_INDEX = 1;
301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerGrowStats::GrowStats(const TensorForestParams& params, int32 depth)
3208a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    : weight_sum_(0),
3308a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower      depth_(depth),
341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      params_(params),
351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      split_after_samples_(ResolveParam(params.split_after_samples(), depth)),
361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      num_splits_to_consider_(
371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower          ResolveParam(params.num_splits_to_consider(), depth)),
381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      num_outputs_(params.num_outputs()) {}
391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlowervoid GrowStats::AddSplit(const decision_trees::BinaryNode& split,
4175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower                         const std::unique_ptr<TensorDataSet>& input_data,
4275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower                         const InputTarget* target, int example) {
4375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  // It's possible that the split collection calls AddSplit, but we actually
4475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  // have all the splits we need and are just waiting for them to be fully
4575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  // initialized.
4675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  if (splits_.size() < num_splits_to_consider_) {
4775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    splits_.push_back(split);
4875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    evaluators_.emplace_back(
4975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
5075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AddSplitStats(target, example);
5175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  }
5275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower
5375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  if (input_data != nullptr && target != nullptr &&
5475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      params_.initialize_average_splits()) {
5575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AdditionalInitializationExample(input_data, target, example);
5675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  }
571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid GrowStats::RemoveSplit(int split_num) {
601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  splits_.erase(splits_.begin() + split_num);
611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  evaluators_.erase(evaluators_.begin() + split_num);
621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  RemoveSplitStats(split_num);
631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Classification --------------------------- //
661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerClassificationStats::ClassificationStats(const TensorForestParams& params,
681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                         int32 depth)
691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    : GrowStats(params, depth), finish_early_(false) {
701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Early splitting params.
711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params.finish_type().type() == SPLIT_FINISH_BASIC) {
721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    min_split_samples_ = split_after_samples_;
7308a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    finish_sample_epoch_ = 1;
7408a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    finish_check_every_ = split_after_samples_ * 2;
751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  } else {
761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (!params.has_dominate_fraction() || !params.has_min_split_samples()) {
771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      LOG(FATAL) << "dominate_fraction and min_split_samples "
781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                 << "required for early-finish strategy.";
791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    } else {
801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      min_split_samples_ = ResolveParam(params.min_split_samples(), depth);
811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      finish_check_every_ =
821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower          ResolveParam(params.finish_type().check_every_steps(), depth);
831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      finish_sample_epoch_ = min_split_samples_ / finish_check_every_;
841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) {
871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_;
881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Pruning params.
931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params.pruning_type().type() != SPLIT_PRUNE_NONE) {
941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    prune_check_every_ =
951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        ResolveParam(params.pruning_type().prune_every_samples(), depth);
961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    prune_sample_epoch_ = 1;
971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    prune_fraction_ = 0.0;
981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    switch (params_.pruning_type().type()) {
991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      case SPLIT_PRUNE_HALF:
1001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        prune_fraction_ = 0.5;
1011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        break;
1021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      case SPLIT_PRUNE_QUARTER:
1031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        prune_fraction_ = 0.25;
1041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        break;
1051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      case SPLIT_PRUNE_10_PERCENT:
1061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        prune_fraction_ = 0.10;
1071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        break;
1081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      case SPLIT_PRUNE_HOEFFDING:
1091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
1101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_));
1111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        break;
1121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      default:
1131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        LOG(WARNING) << "Unknown pruning type";
1141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
11508a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower  } else {
11608a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    prune_check_every_ = split_after_samples_ * 2;
11708a23d412b7d18a4d8ee41e24b03f4c616e32eb0A. Unique TensorFlower    prune_sample_epoch_ = 1;
1181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params.use_running_stats_method()) {
1211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_gini_.reset(new RunningGiniScores());
1221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    right_gini_.reset(new RunningGiniScores());
1231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  uint64 time_seed = static_cast<uint64>(std::clock());
1261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  single_rand_ = std::unique_ptr<random::PhiloxRandom>(
1271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      new random::PhiloxRandom(time_seed));
1281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  rng_ = std::unique_ptr<random::SimplePhilox>(
1291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      new random::SimplePhilox(single_rand_.get()));
1301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
1311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
13275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlowervoid ClassificationStats::AdditionalInitializationExample(
13375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
13475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    int example) {
13575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  const int32 new_target = target->GetTargetAsClassIndex(example, 0);
13675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  std::unordered_set<int> to_erase;
13775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  for (auto it = half_initialized_splits_.begin();
13875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower       it != half_initialized_splits_.end(); ++it) {
13975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    if (it->second != new_target) {
14075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      auto& split = splits_[it->first];
14175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      if (split.has_inequality_left_child_test()) {
14275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        auto& test = split.inequality_left_child_test();
14375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        auto* thresh =
14475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower            split.mutable_inequality_left_child_test()->mutable_threshold();
14575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        if (test.has_feature_id()) {
14675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower          const float val =
14775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower              input_data->GetExampleValue(example, test.feature_id());
14875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower          thresh->set_float_value((thresh->float_value() + val) / 2);
14975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        }
15075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      }
15175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      to_erase.insert(it->first);
15275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    }
15375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  }
15475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower
15575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  for (const int split_id : to_erase) {
15675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    half_initialized_splits_.erase(split_id);
15775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  }
15875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower}
15975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower
1601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool ClassificationStats::IsFinished() const {
161054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  bool basic = (weight_sum_ >= split_after_samples_) && !is_pure();
1621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return basic || finish_early_;
1631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
1641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum,
1661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                                float* right_sum) const {
1671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (left_gini_ == nullptr) {
1681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return GiniScore(split, left_sum, right_sum);
1691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  } else {
1701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *left_sum = left_gini_->sum(split);
1711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float left = WeightedSmoothedGini(
1721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        *left_sum, left_gini_->square(split), num_outputs_);
1731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *right_sum = right_gini_->sum(split);
1751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float right = WeightedSmoothedGini(
1761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        *right_sum, right_gini_->square(split), num_outputs_);
1771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left + right;
1791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
1811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::AddExample(
1831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
1841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    int example) {
1851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int64 int_label = target->GetTargetAsClassIndex(example, 0);
1861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float weight = target->GetTargetWeight(example);
1871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
1891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto& eval = evaluators_[i];
1901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (eval->Decide(input_data, example) == LEFT_INDEX) {
1911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      if (left_gini_ != nullptr) {
1921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_gini_->update(i, left_count(i, int_label), weight);
1931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
1941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      ClassificationAddLeftExample(i, int_label, weight);
195054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    } else {
196054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      if (right_gini_ != nullptr) {
197054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower        right_gini_->update(i, right_count(i, int_label), weight);
198054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      }
199054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      ClassificationAddRightExample(i, int_label, weight);
2001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
2011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  ClassificationAddTotalExample(int_label, weight);
2041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ += weight;
2061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  CheckFinishEarly();
2081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  CheckPrune();
2091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
2101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckPrune() {
2123d72bc69ee838e8c6b0b801e274aac4c31647b22A. Unique TensorFlower  if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() ||
2133d72bc69ee838e8c6b0b801e274aac4c31647b22A. Unique TensorFlower      weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
2141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
2151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  ++prune_sample_epoch_;
2171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) {
2191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    CheckPruneHoeffding();
2201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
2211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int to_remove = num_splits() * prune_fraction_;
2241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (to_remove <= 0) {
2251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
2261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // pair ordering is first-then-second by default, no need for custom
2291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // comparison.  Use std::greater to make it a min-heap.
2301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
2311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                      std::greater<std::pair<float, int>>>
2321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      worst;
2331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Track indices that are in the heap so we can iterate over them
2351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // by largest-first later.
2361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::set<int> indices;
2371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
2391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float left, right;
2401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float split_score = MaybeCachedGiniScore(i, &left, &right);
2411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (worst.size() < to_remove) {
2421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      worst.push(std::pair<float, int>(split_score, i));
2431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      indices.insert(i);
2441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    } else if (worst.top().first < split_score) {
2451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      indices.erase(worst.top().second);
2461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      worst.pop();
2471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      worst.push(std::pair<float, int>(split_score, i));
2481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      indices.insert(i);
2491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
2501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // traverse indices from the back so that they are removed correctly.
2531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (auto it = indices.rbegin(); it != indices.rend(); ++it) {
2541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    RemoveSplit(*it);
2551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
2571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckPruneHoeffding() {
2591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> split_scores(num_splits());
2601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Find best split score
2611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float best_split_score = FLT_MAX;
2621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
2631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float left, right;
2641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    split_scores[i] = MaybeCachedGiniScore(i, &left, &right);
2651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (split_scores[i] < best_split_score) {
2661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      best_split_score = split_scores[i];
2671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
2681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // We apply the Hoeffding bound to the difference between the best split
2711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // score and the i-th split score.
2721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted.
2731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float num_classes = params_.num_outputs();
2741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes);
2751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_);
2761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = num_splits() - 1; i >= 0; i--) {
2771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (split_scores[i] - best_split_score > epsilon) {
2781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      RemoveSplit(i);
2791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
2801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
2821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarly() {
2841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (weight_sum_ < min_split_samples_ ||
2851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      weight_sum_ < finish_sample_epoch_ * finish_check_every_) {
2861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
2871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  ++finish_sample_epoch_;
2891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) {
2911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    CheckFinishEarlyHoeffding();
2921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) {
2931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    CheckFinishEarlyBootstrap();
2941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
2961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarlyHoeffding() {
2981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Each term in the Gini impurity can range from 0 to 0.5 * 0.5.
2991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float range = 0.25 * static_cast<float>(params_.num_outputs()) * weight_sum_;
3001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float hoeffding_bound =
3021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_));
3031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float unused_left_sum, unused_right_sum;
3051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::function<float(int)> score_fn =
3061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
3071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                std::placeholders::_1, &unused_left_sum, &unused_right_sum);
3081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float best_score;
3101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 best_index;
3111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float second_best_score;
3121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 second_best_index;
3131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
3141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower             &second_best_score, &second_best_index);
3151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  finish_early_ = (second_best_score - best_score) > hoeffding_bound;
3171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
3181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::MakeBootstrapWeights(int index,
3201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                               std::vector<float>* weights) {
3211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int n = weight_sum_;
3221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float denom = static_cast<float>(n) + static_cast<float>(num_outputs_);
3231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs_; ++i) {
3241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    // Use the Laplace smoothed per-class probabilities when generating the
3251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    // bootstrap samples.
3261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    (*weights)[i] = (left_count(index, i) + 1.0) / denom;
3271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom;
3281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
3301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerint ClassificationStats::NumBootstrapSamples() const {
3321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float p = 1.0 - dominate_fraction_;
3331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int bootstrap_samples = 1;
3341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  while (p < 1.0) {
3351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ++bootstrap_samples;
3361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    p = p * 2;
3371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return bootstrap_samples;
3391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
3401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid ClassificationStats::CheckFinishEarlyBootstrap() {
3421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float unused_left_sum, unused_right_sum;
3431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::function<float(int)> score_fn =
3441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
3451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                std::placeholders::_1, &unused_left_sum, &unused_right_sum);
3461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float best_score;
3481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 best_index;
3491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float second_best_score;
3501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 second_best_index;
3511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
3521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower             &second_best_score, &second_best_index);
3531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> weights1(num_outputs_ * 2);
3551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  MakeBootstrapWeights(best_index, &weights1);
3561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  random::DistributionSampler ds1(weights1);
3571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> weights2(num_outputs_ * 2);
3591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  MakeBootstrapWeights(second_best_index, &weights2);
3601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  random::DistributionSampler ds2(weights2);
3611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int bootstrap_samples = NumBootstrapSamples();
3631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int worst_g1 = 0;
3651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < bootstrap_samples; i++) {
3661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get());
3671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    worst_g1 = std::max(worst_g1, g1);
3681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int best_g2 = 99;
3711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < bootstrap_samples; i++) {
3721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get());
3731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    best_g2 = std::min(best_g2, g2);
3741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  finish_early_ = worst_g1 < best_g2;
3771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
3781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
379054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerbool ClassificationStats::BestSplit(SplitCandidate* best) const {
380054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float min_score = FLT_MAX;
381054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  int best_index = -1;
382054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float best_left_sum, best_right_sum;
383054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
384054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // Calculate sums.
385054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
386054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    float left_sum, right_sum;
387054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
388054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    // Find the lowest gini.
389054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    if (left_sum > 0 && right_sum > 0 &&
390054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower        split_score < min_score) {  // useless check
391054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      min_score = split_score;
392054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      best_index = i;
393054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      best_left_sum = left_sum;
394054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      best_right_sum = right_sum;
395054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    }
396054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
397054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
398054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // This could happen if all the splits are useless.
399054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  if (best_index < 0) {
400054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    return false;
401054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
402054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
403054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // Fill in stats to be used for leaf model.
404054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  *best->mutable_split() = splits_[best_index];
405054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* left = best->mutable_left_stats();
406054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  left->set_weight_sum(best_left_sum);
407054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* right = best->mutable_right_stats();
408054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  right->set_weight_sum(best_right_sum);
409054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  InitLeafClassStats(best_index, left, right);
410054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
411054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  return true;
412054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
413054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
4141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Dense Classification --------------------------- //
4151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
4161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  Initialize();
4171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (!slot.has_post_init_leaf_stats()) {
4181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
4191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_classes = params_.num_outputs();
4211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
4221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const auto& class_stats =
4231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      slot.post_init_leaf_stats().classification().dense_counts();
4241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Total counts.
4261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_classes; ++i) {
4271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_counts_[i] = class_stats.value(i).float_value();
4281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    num_outputs_seen_ += total_counts_[i] != 0;
4291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Candidate counts and splits.
4321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int split_num = 0;
4331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& cand : slot.candidates()) {
43475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AddSplit(cand.split(), nullptr, nullptr, -1);
4351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const auto& left_stats = cand.left_stats().classification().dense_counts();
4361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (int i = 0; i < num_classes; ++i) {
4371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float val = left_stats.value(i).float_value();
4381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      mutable_left_count(split_num, i) = val;
4391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      MaybeInitializeRunningCount(split_num, val);
4401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
4411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ++split_num;
4421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
4441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
4461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* slot_stats = slot->mutable_post_init_leaf_stats();
4471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  slot_stats->set_weight_sum(weight_sum_);
4481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* class_stats = slot->mutable_post_init_leaf_stats()
4501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_classification()
4511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_dense_counts();
4521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs_; ++i) {
4531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    class_stats->add_value()->set_float_value(total_counts_[i]);
4541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4564463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower  for (int split_num = 0; split_num < num_splits(); ++split_num) {
4571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* cand = slot->add_candidates();
4581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *cand->mutable_split() = splits_[split_num];
4591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* left_stats = cand->mutable_left_stats()
4601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_classification()
4611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_dense_counts();
4621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (int i = 0; i < num_outputs_; ++i) {
4634463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower      left_stats->add_value()->set_float_value(left_count(split_num, i));
4641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
4651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
4671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat DenseClassificationGrowStats::GiniScore(int split, float* left_sum,
4691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                              float* right_sum) const {
4701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float left_square = 0, right_square = 0;
4711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *left_sum = 0;
4721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *right_sum = 0;
4731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int j = 0; j < num_outputs_; ++j) {
4741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float left = left_count(split, j);
4751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *left_sum += left;
4761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_square += left * left;
4771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float right = right_count(split, j);
4781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *right_sum += right;
4791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    right_square += right * right;
4801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float left_score =
4831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      WeightedSmoothedGini(*left_sum, left_square, num_outputs_);
4841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float right_score =
4851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      WeightedSmoothedGini(*right_sum, right_square, num_outputs_);
4861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return left_score + right_score;
4871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
4881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
489054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid DenseClassificationGrowStats::InitLeafClassStats(
490054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
491054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* left_class_stats = left_stats->mutable_classification();
4921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_counts = left_class_stats->mutable_dense_counts();
4931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < params_.num_outputs(); ++i) {
494054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    left_counts->add_value()->set_float_value(left_count(best_split_index, i));
4951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
497054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* right_class_stats = right_stats->mutable_classification();
4981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_counts = right_class_stats->mutable_dense_counts();
4991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < params_.num_outputs(); ++i) {
500054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    right_counts->add_value()->set_float_value(total_counts_[i] -
501054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                                               left_count(best_split_index, i));
5021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
5041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ------------------------ Sparse Classification --------------------------- //
5061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
5071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  Initialize();
5081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (!slot.has_post_init_leaf_stats()) {
5091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
5101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
5121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const auto& class_stats =
5131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      slot.post_init_leaf_stats().classification().sparse_counts();
5141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Total counts.
5161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (auto const& entry : class_stats.sparse_value()) {
5171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_counts_[entry.first] = entry.second.float_value();
5181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Candidate counts and splits.
5211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int split_num = 0;
5221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& cand : slot.candidates()) {
52375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AddSplit(cand.split(), nullptr, nullptr, -1);
5241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const auto& left_stats = cand.left_stats().classification().sparse_counts();
5251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (auto const& entry : left_stats.sparse_value()) {
5261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float val = entry.second.float_value();
5271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_counts_[split_num][entry.first] = val;
5281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      MaybeInitializeRunningCount(split_num, val);
5291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
5301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ++split_num;
5311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
5331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
5351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* slot_stats = slot->mutable_post_init_leaf_stats();
5361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  slot_stats->set_weight_sum(weight_sum_);
5371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* class_stats = slot->mutable_post_init_leaf_stats()
5391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_classification()
5401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_sparse_counts()
5411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          ->mutable_sparse_value();
5421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& entry : total_counts_) {
5431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    decision_trees::Value val;
5441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    val.set_float_value(entry.second);
5451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    (*class_stats)[entry.first] = val;
5461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5484463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower  for (int split_num = 0; split_num < num_splits(); ++split_num) {
5491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* cand = slot->add_candidates();
5501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *cand->mutable_split() = splits_[split_num];
5511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* left_stats = cand->mutable_left_stats()
5521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_classification()
5531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_sparse_counts()
5541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                           ->mutable_sparse_value();
5551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (const auto& entry : left_counts_[split_num]) {
5561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      decision_trees::Value val;
5571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      val.set_float_value(entry.second);
5581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      (*left_stats)[entry.first] = val;
5591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
5601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
5621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5634463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlowerfloat SparseClassificationGrowStats::GiniScore(int split, float* left_sum,
5644463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower                                               float* right_sum) const {
5651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float left_square = 0, right_square = 0;
5661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *left_sum = 0;
5671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *right_sum = 0;
5681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& entry : total_counts_) {
5691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const int label = entry.first;
5701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float left = 0;
5711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    float right = 0;
5721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto it = left_counts_[split].find(label);
5731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (it == left_counts_[split].end()) {
5741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right = entry.second;
5751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    } else {
5761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left = it->second;
5771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right = entry.second - it->second;
5781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
5791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *left_sum += left;
5801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_square += left * left;
5811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *right_sum += right;
5821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    right_square += right * right;
5831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_classes = params_.num_outputs();
5851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float left_score =
5861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      WeightedSmoothedGini(*left_sum, left_square, num_classes);
5871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const float right_score =
5881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      WeightedSmoothedGini(*right_sum, right_square, num_classes);
5891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return left_score + right_score;
5901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
5911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
592054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid SparseClassificationGrowStats::InitLeafClassStats(
593054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
594054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* left_class_stats = left_stats->mutable_classification();
5951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_counts =
5961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_class_stats->mutable_sparse_counts()->mutable_sparse_value();
597054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* right_class_stats = right_stats->mutable_classification();
5981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_counts =
5991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right_class_stats->mutable_sparse_counts()->mutable_sparse_value();
6001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& entry : total_counts_) {
602054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    auto it = left_counts_[best_split_index].find(entry.first);
603054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    if (it == left_counts_[best_split_index].end()) {
6041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      (*right_counts)[entry.first].set_float_value(entry.second);
6051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    } else {
6061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float left = it->second;
6071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float right = entry.second - it->second;
6081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      (*left_counts)[entry.first].set_float_value(left);
6091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      if (right > 0) {
6101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        (*right_counts)[entry.first].set_float_value(right);
6111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
6121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
6131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
614054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
615054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
616054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// -------------------- FixedSizeClassStats --------------------------------- //
617054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
618054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// FixedSizeClassStats implements the "SpaceSaving" algorithm by
619054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi.  See for example
620054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf
621054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
622054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerint argmin(const std::unordered_map<int, float>& m) {
623054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  int c = -1;
624054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float f = FLT_MAX;
625054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  for (const auto it : m) {
626054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    if (it.second < f) {
627054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      f = it.second;
628054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      c = it.first;
629054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    }
630054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
631054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  return c;
632054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
633054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
634054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeClassStats::accumulate(int c, float w) {
635054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto it = class_weights_.find(c);
636054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  if (it != class_weights_.end()) {
637054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    it->second += w;
638054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    if (c == smallest_weight_class_) {
639054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      smallest_weight_class_ = argmin(class_weights_);
640054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    }
641054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    return;
642054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
643054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
644054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  if (class_weights_.size() < n_) {
645054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    class_weights_.insert(it, std::pair<int, float>(c, w));
646054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    if (class_weights_.size() == n_) {
647054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      // Can't assume last added has the smallest weight, because the
648054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      // w's might be all different.
649054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      smallest_weight_class_ = argmin(class_weights_);
650054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    }
651054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    return;
652054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
653054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
654054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // This is the slightly unintuitive heart of the SpaceSaving algorithm:
655054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // if the map is full and we see a new class, we find the entry with the
656054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // smallest weight and "take it over":  we add our weight to its weight,
657054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // and assign it all to the new seen class.
658054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  it = class_weights_.find(smallest_weight_class_);
659054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float new_weight = it->second + w;
660054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  class_weights_.erase(it);
661054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  class_weights_[c] = new_weight;
662054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  smallest_weight_class_ = argmin(class_weights_);
663054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
664054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
665054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerfloat FixedSizeClassStats::get_weight(int c) const {
666054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // Every entry in class_weights_ might be overstated by as much as the
667054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // smallest_weight.  We therefore assume that each has been overstated
668054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // by smallest_weight / 2.0, and we re-distribute that mass over all
669054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // num_classes_ classes.
670054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float smallest_weight = 0.0;
671054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto it = class_weights_.find(smallest_weight_class_);
672054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  if (it != class_weights_.end()) {
673054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    smallest_weight = it->second;
674054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
675054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
676054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  it = class_weights_.find(c);
677054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  if (it != class_weights_.end()) {
678054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    w += it->second - smallest_weight / 2.0;
679054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
680054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  return w;
681054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
682054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
683054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const {
684054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  *sum = 0.0;
685054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  *square = 0.0;
686054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
687054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float smallest_weight = 0.0;
688054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto it = class_weights_.find(smallest_weight_class_);
689054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  if (it != class_weights_.end()) {
690054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    smallest_weight = it->second;
691054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
692054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
693054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float w;
694054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  for (const auto it : class_weights_) {
695054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    *sum += it.second;
696054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    w = get_weight(it.first);
697054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    *square += w * w;
698054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
699054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
700054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
701054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  *square += (num_classes_ - n_) * w * w;
702054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
703054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
704054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeClassStats::ExtractFromProto(
705054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    const decision_trees::SparseVector& sparse_vector) {
706054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  for (const auto& it : sparse_vector.sparse_value()) {
707054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    class_weights_[it.first] = it.second.float_value();
708054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
709054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  if (class_weights_.size() == n_) {
710054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    smallest_weight_class_ = argmin(class_weights_);
711054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
712054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
713054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
714054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeClassStats::PackToProto(
715054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    decision_trees::SparseVector* sparse_vector) const {
716054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  for (const auto it : class_weights_) {
717054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    (*sparse_vector->mutable_sparse_value())[it.first].set_float_value(
718054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower        it.second);
719054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
720054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
721054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
722054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// --------------------- FixedSizeSparseClassificationGrowStats ------------- //
723054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
724054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeSparseClassificationGrowStats::ExtractFromProto(
725054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    const FertileSlot& slot) {
726054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  Initialize();
727054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  if (!slot.has_post_init_leaf_stats()) {
728054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    return;
729054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
730054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
731054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
732054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // Candidate counts and splits.
733054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  int split_num = 0;
734054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  left_counts_.clear();
735054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  right_counts_.clear();
736054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  for (const auto& cand : slot.candidates()) {
737054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    AddSplit(cand.split(), nullptr, nullptr, -1);
738054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    const auto& left_stats = cand.left_stats().classification().sparse_counts();
739054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    left_counts_.emplace_back(params_.num_classes_to_track(),
740054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                              params_.num_outputs());
741054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    left_counts_[split_num].ExtractFromProto(left_stats);
742054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    const auto& right_stats =
743054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower        cand.right_stats().classification().sparse_counts();
744054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    right_counts_.emplace_back(params_.num_classes_to_track(),
745054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                               params_.num_outputs());
746054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    right_counts_[split_num].ExtractFromProto(right_stats);
747054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    ++split_num;
748054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
749054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
750054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
751054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeSparseClassificationGrowStats::PackToProto(
752054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    FertileSlot* slot) const {
753054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* slot_stats = slot->mutable_post_init_leaf_stats();
754054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  slot_stats->set_weight_sum(weight_sum_);
755054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
756054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  for (int split_num = 0; split_num < num_splits(); ++split_num) {
757054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    auto* cand = slot->add_candidates();
758054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    *cand->mutable_split() = splits_[split_num];
759054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    auto* left_stats = cand->mutable_left_stats()
760054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                           ->mutable_classification()
761054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                           ->mutable_sparse_counts();
762054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    left_counts_[split_num].PackToProto(left_stats);
763054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    auto* right_stats = cand->mutable_right_stats()
764054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                            ->mutable_classification()
765054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                            ->mutable_sparse_counts();
766054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    right_counts_[split_num].PackToProto(right_stats);
767054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
768054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
769054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
770054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerfloat FixedSizeSparseClassificationGrowStats::GiniScore(
771054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    int split, float* left_sum, float* right_sum) const {
772054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float left_square, right_square;
773054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  left_counts_[split].set_sum_and_square(left_sum, &left_square);
774054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  right_counts_[split].set_sum_and_square(right_sum, &right_square);
775054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  const int32 num_classes = params_.num_outputs();
776054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  const float left_score =
777054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      WeightedSmoothedGini(*left_sum, left_square, num_classes);
778054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  const float right_score =
779054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      WeightedSmoothedGini(*right_sum, right_square, num_classes);
780054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  return left_score + right_score;
781054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}
782054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
783054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowervoid FixedSizeSparseClassificationGrowStats::InitLeafClassStats(
784054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
785054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* left_class_stats = left_stats->mutable_classification();
786054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* left_counts = left_class_stats->mutable_sparse_counts();
787054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  left_counts_[best_split_index].PackToProto(left_counts);
788054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
789054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* right_class_stats = right_stats->mutable_classification();
790054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  auto* right_counts = right_class_stats->mutable_sparse_counts();
791054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  right_counts_[best_split_index].PackToProto(right_counts);
7921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
7931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
7941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// --------------------- Least Squares Regression --------------------------- //
7951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::ExtractFromProto(
7961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const FertileSlot& slot) {
7971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_outputs = params_.num_outputs();
7981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  Initialize();
7991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (!slot.has_post_init_leaf_stats()) {
8001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return;
8011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
8021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
8031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const auto& total_sums =
8041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      slot.post_init_leaf_stats().regression().mean_output();
8051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const auto& total_squares =
8061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      slot.post_init_leaf_stats().regression().mean_output_squares();
8071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Total counts.
8091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs; ++i) {
8101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_[i] = total_sums.value(i).float_value();
8111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_squares_[i] = total_squares.value(i).float_value();
8121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
8131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Candidate counts and splits.
8151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int split_num = 0;
8161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (const auto& cand : slot.candidates()) {
81775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    AddSplit(cand.split(), nullptr, nullptr, -1);
8181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const auto& sums = cand.left_stats().regression().mean_output();
8191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const auto& squares = cand.left_stats().regression().mean_output_squares();
8201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (int i = 0; i < num_outputs; ++i) {
8211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_sum(split_num, i) = sums.value(i).float_value();
8221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_square(split_num, i) = squares.value(i).float_value();
8231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
8241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_[split_num] = cand.left_stats().weight_sum();
8251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ++split_num;
8261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
8271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
8281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const {
8301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_outputs = params_.num_outputs();
8311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* slot_stats = slot->mutable_post_init_leaf_stats();
8321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  slot_stats->set_weight_sum(weight_sum_);
8331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* total_sums = slot->mutable_post_init_leaf_stats()
8351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                         ->mutable_regression()
8361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                         ->mutable_mean_output();
8371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* total_squares = slot->mutable_post_init_leaf_stats()
8381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                            ->mutable_regression()
8391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                            ->mutable_mean_output_squares();
8401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < total_sum_.size(); ++i) {
8421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sums->add_value()->set_float_value(total_sum_[i]);
8431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_squares->add_value()->set_float_value(total_sum_squares_[i]);
8441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
8451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8464463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower  for (int split_num = 0; split_num < num_splits(); ++split_num) {
8471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* cand = slot->add_candidates();
8481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    *cand->mutable_split() = splits_[split_num];
8494463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower    auto* sums =
8504463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower        cand->mutable_left_stats()->mutable_regression()->mutable_mean_output();
8511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto* squares = cand->mutable_left_stats()
8521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                        ->mutable_regression()
8531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                        ->mutable_mean_output_squares();
8541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    for (int i = 0; i < num_outputs; ++i) {
8551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      sums->add_value()->set_float_value(left_sum(split_num, i));
8561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      squares->add_value()->set_float_value(left_square(split_num, i));
8571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
8581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]);
8591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
8601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
8611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowervoid LeastSquaresRegressionGrowStats::AddExample(
8631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
8641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    int example) {
8651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_outputs = params_.num_outputs();
8661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Update splits.
8671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
8681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    auto& eval = evaluators_[i];
8691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (eval->Decide(input_data, example) == LEFT_INDEX) {
8701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      for (int j = 0; j < num_outputs; ++j) {
8711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        const float output = target->GetTargetAsContinuous(example, j);
8721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_sum(i, j) += output;
8731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        left_square(i, j) += output * output;
8741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
8751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      ++left_counts_[i];
8761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
8771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
8781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Update totals.
8801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs; ++i) {
8811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float output = target->GetTargetAsContinuous(example, i);
8821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_[i] += output;
8831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_squares_[i] += output * output;
8841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
8851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  weight_sum_ += 1.0;
8861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
8871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerfloat LeastSquaresRegressionGrowStats::SplitVariance(int split) const {
8891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float total_variance = 0;
8901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < params_.num_outputs(); ++i) {
8911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    // Left side
8924463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower    const float le_x = left_sum(split, i) / left_counts_[split];
8931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8944463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower    const float le_x2 = left_square(split, i) / left_counts_[split];
8951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_variance += le_x2 - le_x * le_x;
8961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    // Right side
8981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float re_x = (total_sum_[i] - left_sum(split, i)) /
8991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                       (weight_sum_ - left_counts_[split]);
9001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
9014463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower    const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) /
9024463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower                        (weight_sum_ - left_counts_[split]);
9031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_variance += re_x2 - re_x * re_x;
9041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
9051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return total_variance;
9061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
9071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
9081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const {
9091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float min_score = FLT_MAX;
9101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int best_index = -1;
9111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_outputs = params_.num_outputs();
9121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_splits(); ++i) {
9131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) {
9141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      const float split_score = SplitVariance(i);
9151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      if (split_score < min_score) {
9161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        min_score = split_score;
9171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower        best_index = i;
9181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      }
9191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
9201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
9211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
9221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // This could happen if all the splits are useless.
9231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  if (best_index < 0) {
9241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return false;
9251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
9261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
9271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Fill in right stats to be used for leaf model.
9281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  *best->mutable_split() = splits_[best_index];
9291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Left
9301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left = best->mutable_left_stats();
9311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_reg_stats = left->mutable_regression();
9321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  left->set_weight_sum(left_counts_[best_index]);
9331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* left_output_sum = left_reg_stats->mutable_mean_output();
9341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs; ++i) {
9354463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower    left_output_sum->add_value()->set_float_value(left_sum(best_index, i));
9361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
9371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
9381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Right
9391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right = best->mutable_right_stats();
9401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_reg_stats = right->mutable_regression();
9411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  right->set_weight_sum(weight_sum_ - left_counts_[best_index]);
9421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  auto* right_output_sum = right_reg_stats->mutable_mean_output();
9431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  for (int i = 0; i < num_outputs; ++i) {
9444463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower    right_output_sum->add_value()->set_float_value(total_sum_[i] -
9454463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower                                                   left_sum(best_index, i));
9461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
9471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return true;
9481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
9491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
9501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerbool LeastSquaresRegressionGrowStats::IsFinished() const {
9511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  return weight_sum_ >= split_after_samples_;
9521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}
9531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
9541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorforest
9551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorflow
956