1// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h"
16
17#include <cfloat>
18#include <queue>
19#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
20#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h"
21#include "tensorflow/core/lib/random/distribution_sampler.h"
22
23namespace tensorflow {
24namespace tensorforest {
25
26// When creating evaluators for the split candidates, use these
27// for the left and right return values.
28static const int32 LEFT_INDEX = 0;
29static const int32 RIGHT_INDEX = 1;
30
31GrowStats::GrowStats(const TensorForestParams& params, int32 depth)
32    : weight_sum_(0),
33      depth_(depth),
34      params_(params),
35      split_after_samples_(ResolveParam(params.split_after_samples(), depth)),
36      num_splits_to_consider_(
37          ResolveParam(params.num_splits_to_consider(), depth)),
38      num_outputs_(params.num_outputs()) {}
39
40void GrowStats::AddSplit(const decision_trees::BinaryNode& split,
41                         const std::unique_ptr<TensorDataSet>& input_data,
42                         const InputTarget* target, int example) {
43  // It's possible that the split collection calls AddSplit, but we actually
44  // have all the splits we need and are just waiting for them to be fully
45  // initialized.
46  if (splits_.size() < num_splits_to_consider_) {
47    splits_.push_back(split);
48    evaluators_.emplace_back(
49        CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX));
50    AddSplitStats(target, example);
51  }
52
53  if (input_data != nullptr && target != nullptr &&
54      params_.initialize_average_splits()) {
55    AdditionalInitializationExample(input_data, target, example);
56  }
57}
58
59void GrowStats::RemoveSplit(int split_num) {
60  splits_.erase(splits_.begin() + split_num);
61  evaluators_.erase(evaluators_.begin() + split_num);
62  RemoveSplitStats(split_num);
63}
64
65// ------------------------ Classification --------------------------- //
66
67ClassificationStats::ClassificationStats(const TensorForestParams& params,
68                                         int32 depth)
69    : GrowStats(params, depth), finish_early_(false) {
70  // Early splitting params.
71  if (params.finish_type().type() == SPLIT_FINISH_BASIC) {
72    min_split_samples_ = split_after_samples_;
73    finish_sample_epoch_ = 1;
74    finish_check_every_ = split_after_samples_ * 2;
75  } else {
76    if (!params.has_dominate_fraction() || !params.has_min_split_samples()) {
77      LOG(FATAL) << "dominate_fraction and min_split_samples "
78                 << "required for early-finish strategy.";
79    } else {
80      min_split_samples_ = ResolveParam(params.min_split_samples(), depth);
81      finish_check_every_ =
82          ResolveParam(params.finish_type().check_every_steps(), depth);
83      finish_sample_epoch_ = min_split_samples_ / finish_check_every_;
84
85      dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
86      if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) {
87        LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_;
88      }
89    }
90  }
91
92  // Pruning params.
93  if (params.pruning_type().type() != SPLIT_PRUNE_NONE) {
94    prune_check_every_ =
95        ResolveParam(params.pruning_type().prune_every_samples(), depth);
96    prune_sample_epoch_ = 1;
97    prune_fraction_ = 0.0;
98    switch (params_.pruning_type().type()) {
99      case SPLIT_PRUNE_HALF:
100        prune_fraction_ = 0.5;
101        break;
102      case SPLIT_PRUNE_QUARTER:
103        prune_fraction_ = 0.25;
104        break;
105      case SPLIT_PRUNE_10_PERCENT:
106        prune_fraction_ = 0.10;
107        break;
108      case SPLIT_PRUNE_HOEFFDING:
109        dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_);
110        half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_));
111        break;
112      default:
113        LOG(WARNING) << "Unknown pruning type";
114    }
115  } else {
116    prune_check_every_ = split_after_samples_ * 2;
117    prune_sample_epoch_ = 1;
118  }
119
120  if (params.use_running_stats_method()) {
121    left_gini_.reset(new RunningGiniScores());
122    right_gini_.reset(new RunningGiniScores());
123  }
124
125  uint64 time_seed = static_cast<uint64>(std::clock());
126  single_rand_ = std::unique_ptr<random::PhiloxRandom>(
127      new random::PhiloxRandom(time_seed));
128  rng_ = std::unique_ptr<random::SimplePhilox>(
129      new random::SimplePhilox(single_rand_.get()));
130}
131
132void ClassificationStats::AdditionalInitializationExample(
133    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
134    int example) {
135  const int32 new_target = target->GetTargetAsClassIndex(example, 0);
136  std::unordered_set<int> to_erase;
137  for (auto it = half_initialized_splits_.begin();
138       it != half_initialized_splits_.end(); ++it) {
139    if (it->second != new_target) {
140      auto& split = splits_[it->first];
141      if (split.has_inequality_left_child_test()) {
142        auto& test = split.inequality_left_child_test();
143        auto* thresh =
144            split.mutable_inequality_left_child_test()->mutable_threshold();
145        if (test.has_feature_id()) {
146          const float val =
147              input_data->GetExampleValue(example, test.feature_id());
148          thresh->set_float_value((thresh->float_value() + val) / 2);
149        }
150      }
151      to_erase.insert(it->first);
152    }
153  }
154
155  for (const int split_id : to_erase) {
156    half_initialized_splits_.erase(split_id);
157  }
158}
159
160bool ClassificationStats::IsFinished() const {
161  bool basic = (weight_sum_ >= split_after_samples_) && !is_pure();
162  return basic || finish_early_;
163}
164
165float ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum,
166                                                float* right_sum) const {
167  if (left_gini_ == nullptr) {
168    return GiniScore(split, left_sum, right_sum);
169  } else {
170    *left_sum = left_gini_->sum(split);
171    const float left = WeightedSmoothedGini(
172        *left_sum, left_gini_->square(split), num_outputs_);
173
174    *right_sum = right_gini_->sum(split);
175    const float right = WeightedSmoothedGini(
176        *right_sum, right_gini_->square(split), num_outputs_);
177
178    return left + right;
179  }
180}
181
182void ClassificationStats::AddExample(
183    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
184    int example) {
185  const int64 int_label = target->GetTargetAsClassIndex(example, 0);
186  const float weight = target->GetTargetWeight(example);
187
188  for (int i = 0; i < num_splits(); ++i) {
189    auto& eval = evaluators_[i];
190    if (eval->Decide(input_data, example) == LEFT_INDEX) {
191      if (left_gini_ != nullptr) {
192        left_gini_->update(i, left_count(i, int_label), weight);
193      }
194      ClassificationAddLeftExample(i, int_label, weight);
195    } else {
196      if (right_gini_ != nullptr) {
197        right_gini_->update(i, right_count(i, int_label), weight);
198      }
199      ClassificationAddRightExample(i, int_label, weight);
200    }
201  }
202
203  ClassificationAddTotalExample(int_label, weight);
204
205  weight_sum_ += weight;
206
207  CheckFinishEarly();
208  CheckPrune();
209}
210
211void ClassificationStats::CheckPrune() {
212  if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() ||
213      weight_sum_ < prune_sample_epoch_ * prune_check_every_) {
214    return;
215  }
216  ++prune_sample_epoch_;
217
218  if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) {
219    CheckPruneHoeffding();
220    return;
221  }
222
223  const int to_remove = num_splits() * prune_fraction_;
224  if (to_remove <= 0) {
225    return;
226  }
227
228  // pair ordering is first-then-second by default, no need for custom
229  // comparison.  Use std::greater to make it a min-heap.
230  std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
231                      std::greater<std::pair<float, int>>>
232      worst;
233
234  // Track indices that are in the heap so we can iterate over them
235  // by largest-first later.
236  std::set<int> indices;
237
238  for (int i = 0; i < num_splits(); ++i) {
239    float left, right;
240    const float split_score = MaybeCachedGiniScore(i, &left, &right);
241    if (worst.size() < to_remove) {
242      worst.push(std::pair<float, int>(split_score, i));
243      indices.insert(i);
244    } else if (worst.top().first < split_score) {
245      indices.erase(worst.top().second);
246      worst.pop();
247      worst.push(std::pair<float, int>(split_score, i));
248      indices.insert(i);
249    }
250  }
251
252  // traverse indices from the back so that they are removed correctly.
253  for (auto it = indices.rbegin(); it != indices.rend(); ++it) {
254    RemoveSplit(*it);
255  }
256}
257
258void ClassificationStats::CheckPruneHoeffding() {
259  std::vector<float> split_scores(num_splits());
260  // Find best split score
261  float best_split_score = FLT_MAX;
262  for (int i = 0; i < num_splits(); ++i) {
263    float left, right;
264    split_scores[i] = MaybeCachedGiniScore(i, &left, &right);
265    if (split_scores[i] < best_split_score) {
266      best_split_score = split_scores[i];
267    }
268  }
269
270  // We apply the Hoeffding bound to the difference between the best split
271  // score and the i-th split score.
272  // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted.
273  const float num_classes = params_.num_outputs();
274  const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes);
275  float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_);
276  for (int i = num_splits() - 1; i >= 0; i--) {
277    if (split_scores[i] - best_split_score > epsilon) {
278      RemoveSplit(i);
279    }
280  }
281}
282
283void ClassificationStats::CheckFinishEarly() {
284  if (weight_sum_ < min_split_samples_ ||
285      weight_sum_ < finish_sample_epoch_ * finish_check_every_) {
286    return;
287  }
288  ++finish_sample_epoch_;
289
290  if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) {
291    CheckFinishEarlyHoeffding();
292  } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) {
293    CheckFinishEarlyBootstrap();
294  }
295}
296
297void ClassificationStats::CheckFinishEarlyHoeffding() {
298  // Each term in the Gini impurity can range from 0 to 0.5 * 0.5.
299  float range = 0.25 * static_cast<float>(params_.num_outputs()) * weight_sum_;
300
301  float hoeffding_bound =
302      range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_));
303
304  float unused_left_sum, unused_right_sum;
305  std::function<float(int)> score_fn =
306      std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
307                std::placeholders::_1, &unused_left_sum, &unused_right_sum);
308
309  float best_score;
310  int32 best_index;
311  float second_best_score;
312  int32 second_best_index;
313  GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
314             &second_best_score, &second_best_index);
315
316  finish_early_ = (second_best_score - best_score) > hoeffding_bound;
317}
318
319void ClassificationStats::MakeBootstrapWeights(int index,
320                                               std::vector<float>* weights) {
321  int n = weight_sum_;
322  float denom = static_cast<float>(n) + static_cast<float>(num_outputs_);
323  for (int i = 0; i < num_outputs_; ++i) {
324    // Use the Laplace smoothed per-class probabilities when generating the
325    // bootstrap samples.
326    (*weights)[i] = (left_count(index, i) + 1.0) / denom;
327    (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom;
328  }
329}
330
331int ClassificationStats::NumBootstrapSamples() const {
332  float p = 1.0 - dominate_fraction_;
333  int bootstrap_samples = 1;
334  while (p < 1.0) {
335    ++bootstrap_samples;
336    p = p * 2;
337  }
338  return bootstrap_samples;
339}
340
341void ClassificationStats::CheckFinishEarlyBootstrap() {
342  float unused_left_sum, unused_right_sum;
343  std::function<float(int)> score_fn =
344      std::bind(&ClassificationStats::MaybeCachedGiniScore, this,
345                std::placeholders::_1, &unused_left_sum, &unused_right_sum);
346
347  float best_score;
348  int32 best_index;
349  float second_best_score;
350  int32 second_best_index;
351  GetTwoBest(num_splits(), score_fn, &best_score, &best_index,
352             &second_best_score, &second_best_index);
353
354  std::vector<float> weights1(num_outputs_ * 2);
355  MakeBootstrapWeights(best_index, &weights1);
356  random::DistributionSampler ds1(weights1);
357
358  std::vector<float> weights2(num_outputs_ * 2);
359  MakeBootstrapWeights(second_best_index, &weights2);
360  random::DistributionSampler ds2(weights2);
361
362  const int bootstrap_samples = NumBootstrapSamples();
363
364  int worst_g1 = 0;
365  for (int i = 0; i < bootstrap_samples; i++) {
366    int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get());
367    worst_g1 = std::max(worst_g1, g1);
368  }
369
370  int best_g2 = 99;
371  for (int i = 0; i < bootstrap_samples; i++) {
372    int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get());
373    best_g2 = std::min(best_g2, g2);
374  }
375
376  finish_early_ = worst_g1 < best_g2;
377}
378
379bool ClassificationStats::BestSplit(SplitCandidate* best) const {
380  float min_score = FLT_MAX;
381  int best_index = -1;
382  float best_left_sum, best_right_sum;
383
384  // Calculate sums.
385  for (int i = 0; i < num_splits(); ++i) {
386    float left_sum, right_sum;
387    const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum);
388    // Find the lowest gini.
389    if (left_sum > 0 && right_sum > 0 &&
390        split_score < min_score) {  // useless check
391      min_score = split_score;
392      best_index = i;
393      best_left_sum = left_sum;
394      best_right_sum = right_sum;
395    }
396  }
397
398  // This could happen if all the splits are useless.
399  if (best_index < 0) {
400    return false;
401  }
402
403  // Fill in stats to be used for leaf model.
404  *best->mutable_split() = splits_[best_index];
405  auto* left = best->mutable_left_stats();
406  left->set_weight_sum(best_left_sum);
407  auto* right = best->mutable_right_stats();
408  right->set_weight_sum(best_right_sum);
409  InitLeafClassStats(best_index, left, right);
410
411  return true;
412}
413
414// ------------------------ Dense Classification --------------------------- //
415void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
416  Initialize();
417  if (!slot.has_post_init_leaf_stats()) {
418    return;
419  }
420  const int32 num_classes = params_.num_outputs();
421  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
422  const auto& class_stats =
423      slot.post_init_leaf_stats().classification().dense_counts();
424
425  // Total counts.
426  for (int i = 0; i < num_classes; ++i) {
427    total_counts_[i] = class_stats.value(i).float_value();
428    num_outputs_seen_ += total_counts_[i] != 0;
429  }
430
431  // Candidate counts and splits.
432  int split_num = 0;
433  for (const auto& cand : slot.candidates()) {
434    AddSplit(cand.split(), nullptr, nullptr, -1);
435    const auto& left_stats = cand.left_stats().classification().dense_counts();
436    for (int i = 0; i < num_classes; ++i) {
437      const float val = left_stats.value(i).float_value();
438      mutable_left_count(split_num, i) = val;
439      MaybeInitializeRunningCount(split_num, val);
440    }
441    ++split_num;
442  }
443}
444
445void DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
446  auto* slot_stats = slot->mutable_post_init_leaf_stats();
447  slot_stats->set_weight_sum(weight_sum_);
448
449  auto* class_stats = slot->mutable_post_init_leaf_stats()
450                          ->mutable_classification()
451                          ->mutable_dense_counts();
452  for (int i = 0; i < num_outputs_; ++i) {
453    class_stats->add_value()->set_float_value(total_counts_[i]);
454  }
455
456  for (int split_num = 0; split_num < num_splits(); ++split_num) {
457    auto* cand = slot->add_candidates();
458    *cand->mutable_split() = splits_[split_num];
459    auto* left_stats = cand->mutable_left_stats()
460                           ->mutable_classification()
461                           ->mutable_dense_counts();
462    for (int i = 0; i < num_outputs_; ++i) {
463      left_stats->add_value()->set_float_value(left_count(split_num, i));
464    }
465  }
466}
467
468float DenseClassificationGrowStats::GiniScore(int split, float* left_sum,
469                                              float* right_sum) const {
470  float left_square = 0, right_square = 0;
471  *left_sum = 0;
472  *right_sum = 0;
473  for (int j = 0; j < num_outputs_; ++j) {
474    const float left = left_count(split, j);
475    *left_sum += left;
476    left_square += left * left;
477    const float right = right_count(split, j);
478    *right_sum += right;
479    right_square += right * right;
480  }
481
482  const float left_score =
483      WeightedSmoothedGini(*left_sum, left_square, num_outputs_);
484  const float right_score =
485      WeightedSmoothedGini(*right_sum, right_square, num_outputs_);
486  return left_score + right_score;
487}
488
489void DenseClassificationGrowStats::InitLeafClassStats(
490    int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
491  auto* left_class_stats = left_stats->mutable_classification();
492  auto* left_counts = left_class_stats->mutable_dense_counts();
493  for (int i = 0; i < params_.num_outputs(); ++i) {
494    left_counts->add_value()->set_float_value(left_count(best_split_index, i));
495  }
496
497  auto* right_class_stats = right_stats->mutable_classification();
498  auto* right_counts = right_class_stats->mutable_dense_counts();
499  for (int i = 0; i < params_.num_outputs(); ++i) {
500    right_counts->add_value()->set_float_value(total_counts_[i] -
501                                               left_count(best_split_index, i));
502  }
503}
504
505// ------------------------ Sparse Classification --------------------------- //
506void SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) {
507  Initialize();
508  if (!slot.has_post_init_leaf_stats()) {
509    return;
510  }
511  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
512  const auto& class_stats =
513      slot.post_init_leaf_stats().classification().sparse_counts();
514
515  // Total counts.
516  for (auto const& entry : class_stats.sparse_value()) {
517    total_counts_[entry.first] = entry.second.float_value();
518  }
519
520  // Candidate counts and splits.
521  int split_num = 0;
522  for (const auto& cand : slot.candidates()) {
523    AddSplit(cand.split(), nullptr, nullptr, -1);
524    const auto& left_stats = cand.left_stats().classification().sparse_counts();
525    for (auto const& entry : left_stats.sparse_value()) {
526      const float val = entry.second.float_value();
527      left_counts_[split_num][entry.first] = val;
528      MaybeInitializeRunningCount(split_num, val);
529    }
530    ++split_num;
531  }
532}
533
534void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const {
535  auto* slot_stats = slot->mutable_post_init_leaf_stats();
536  slot_stats->set_weight_sum(weight_sum_);
537
538  auto* class_stats = slot->mutable_post_init_leaf_stats()
539                          ->mutable_classification()
540                          ->mutable_sparse_counts()
541                          ->mutable_sparse_value();
542  for (const auto& entry : total_counts_) {
543    decision_trees::Value val;
544    val.set_float_value(entry.second);
545    (*class_stats)[entry.first] = val;
546  }
547
548  for (int split_num = 0; split_num < num_splits(); ++split_num) {
549    auto* cand = slot->add_candidates();
550    *cand->mutable_split() = splits_[split_num];
551    auto* left_stats = cand->mutable_left_stats()
552                           ->mutable_classification()
553                           ->mutable_sparse_counts()
554                           ->mutable_sparse_value();
555    for (const auto& entry : left_counts_[split_num]) {
556      decision_trees::Value val;
557      val.set_float_value(entry.second);
558      (*left_stats)[entry.first] = val;
559    }
560  }
561}
562
563float SparseClassificationGrowStats::GiniScore(int split, float* left_sum,
564                                               float* right_sum) const {
565  float left_square = 0, right_square = 0;
566  *left_sum = 0;
567  *right_sum = 0;
568  for (const auto& entry : total_counts_) {
569    const int label = entry.first;
570    float left = 0;
571    float right = 0;
572    auto it = left_counts_[split].find(label);
573    if (it == left_counts_[split].end()) {
574      right = entry.second;
575    } else {
576      left = it->second;
577      right = entry.second - it->second;
578    }
579    *left_sum += left;
580    left_square += left * left;
581    *right_sum += right;
582    right_square += right * right;
583  }
584  const int32 num_classes = params_.num_outputs();
585  const float left_score =
586      WeightedSmoothedGini(*left_sum, left_square, num_classes);
587  const float right_score =
588      WeightedSmoothedGini(*right_sum, right_square, num_classes);
589  return left_score + right_score;
590}
591
592void SparseClassificationGrowStats::InitLeafClassStats(
593    int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
594  auto* left_class_stats = left_stats->mutable_classification();
595  auto* left_counts =
596      left_class_stats->mutable_sparse_counts()->mutable_sparse_value();
597  auto* right_class_stats = right_stats->mutable_classification();
598  auto* right_counts =
599      right_class_stats->mutable_sparse_counts()->mutable_sparse_value();
600
601  for (const auto& entry : total_counts_) {
602    auto it = left_counts_[best_split_index].find(entry.first);
603    if (it == left_counts_[best_split_index].end()) {
604      (*right_counts)[entry.first].set_float_value(entry.second);
605    } else {
606      const float left = it->second;
607      const float right = entry.second - it->second;
608      (*left_counts)[entry.first].set_float_value(left);
609      if (right > 0) {
610        (*right_counts)[entry.first].set_float_value(right);
611      }
612    }
613  }
614}
615
616// -------------------- FixedSizeClassStats --------------------------------- //
617
618// FixedSizeClassStats implements the "SpaceSaving" algorithm by
619// Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi.  See for example
620// https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf
621
622int argmin(const std::unordered_map<int, float>& m) {
623  int c = -1;
624  float f = FLT_MAX;
625  for (const auto it : m) {
626    if (it.second < f) {
627      f = it.second;
628      c = it.first;
629    }
630  }
631  return c;
632}
633
634void FixedSizeClassStats::accumulate(int c, float w) {
635  auto it = class_weights_.find(c);
636  if (it != class_weights_.end()) {
637    it->second += w;
638    if (c == smallest_weight_class_) {
639      smallest_weight_class_ = argmin(class_weights_);
640    }
641    return;
642  }
643
644  if (class_weights_.size() < n_) {
645    class_weights_.insert(it, std::pair<int, float>(c, w));
646    if (class_weights_.size() == n_) {
647      // Can't assume last added has the smallest weight, because the
648      // w's might be all different.
649      smallest_weight_class_ = argmin(class_weights_);
650    }
651    return;
652  }
653
654  // This is the slightly unintuitive heart of the SpaceSaving algorithm:
655  // if the map is full and we see a new class, we find the entry with the
656  // smallest weight and "take it over":  we add our weight to its weight,
657  // and assign it all to the new seen class.
658  it = class_weights_.find(smallest_weight_class_);
659  float new_weight = it->second + w;
660  class_weights_.erase(it);
661  class_weights_[c] = new_weight;
662  smallest_weight_class_ = argmin(class_weights_);
663}
664
665float FixedSizeClassStats::get_weight(int c) const {
666  // Every entry in class_weights_ might be overstated by as much as the
667  // smallest_weight.  We therefore assume that each has been overstated
668  // by smallest_weight / 2.0, and we re-distribute that mass over all
669  // num_classes_ classes.
670  float smallest_weight = 0.0;
671  auto it = class_weights_.find(smallest_weight_class_);
672  if (it != class_weights_.end()) {
673    smallest_weight = it->second;
674  }
675  float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
676  it = class_weights_.find(c);
677  if (it != class_weights_.end()) {
678    w += it->second - smallest_weight / 2.0;
679  }
680  return w;
681}
682
683void FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const {
684  *sum = 0.0;
685  *square = 0.0;
686
687  float smallest_weight = 0.0;
688  auto it = class_weights_.find(smallest_weight_class_);
689  if (it != class_weights_.end()) {
690    smallest_weight = it->second;
691  }
692
693  float w;
694  for (const auto it : class_weights_) {
695    *sum += it.second;
696    w = get_weight(it.first);
697    *square += w * w;
698  }
699
700  w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_);
701  *square += (num_classes_ - n_) * w * w;
702}
703
704void FixedSizeClassStats::ExtractFromProto(
705    const decision_trees::SparseVector& sparse_vector) {
706  for (const auto& it : sparse_vector.sparse_value()) {
707    class_weights_[it.first] = it.second.float_value();
708  }
709  if (class_weights_.size() == n_) {
710    smallest_weight_class_ = argmin(class_weights_);
711  }
712}
713
714void FixedSizeClassStats::PackToProto(
715    decision_trees::SparseVector* sparse_vector) const {
716  for (const auto it : class_weights_) {
717    (*sparse_vector->mutable_sparse_value())[it.first].set_float_value(
718        it.second);
719  }
720}
721
722// --------------------- FixedSizeSparseClassificationGrowStats ------------- //
723
724void FixedSizeSparseClassificationGrowStats::ExtractFromProto(
725    const FertileSlot& slot) {
726  Initialize();
727  if (!slot.has_post_init_leaf_stats()) {
728    return;
729  }
730  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
731
732  // Candidate counts and splits.
733  int split_num = 0;
734  left_counts_.clear();
735  right_counts_.clear();
736  for (const auto& cand : slot.candidates()) {
737    AddSplit(cand.split(), nullptr, nullptr, -1);
738    const auto& left_stats = cand.left_stats().classification().sparse_counts();
739    left_counts_.emplace_back(params_.num_classes_to_track(),
740                              params_.num_outputs());
741    left_counts_[split_num].ExtractFromProto(left_stats);
742    const auto& right_stats =
743        cand.right_stats().classification().sparse_counts();
744    right_counts_.emplace_back(params_.num_classes_to_track(),
745                               params_.num_outputs());
746    right_counts_[split_num].ExtractFromProto(right_stats);
747    ++split_num;
748  }
749}
750
751void FixedSizeSparseClassificationGrowStats::PackToProto(
752    FertileSlot* slot) const {
753  auto* slot_stats = slot->mutable_post_init_leaf_stats();
754  slot_stats->set_weight_sum(weight_sum_);
755
756  for (int split_num = 0; split_num < num_splits(); ++split_num) {
757    auto* cand = slot->add_candidates();
758    *cand->mutable_split() = splits_[split_num];
759    auto* left_stats = cand->mutable_left_stats()
760                           ->mutable_classification()
761                           ->mutable_sparse_counts();
762    left_counts_[split_num].PackToProto(left_stats);
763    auto* right_stats = cand->mutable_right_stats()
764                            ->mutable_classification()
765                            ->mutable_sparse_counts();
766    right_counts_[split_num].PackToProto(right_stats);
767  }
768}
769
770float FixedSizeSparseClassificationGrowStats::GiniScore(
771    int split, float* left_sum, float* right_sum) const {
772  float left_square, right_square;
773  left_counts_[split].set_sum_and_square(left_sum, &left_square);
774  right_counts_[split].set_sum_and_square(right_sum, &right_square);
775  const int32 num_classes = params_.num_outputs();
776  const float left_score =
777      WeightedSmoothedGini(*left_sum, left_square, num_classes);
778  const float right_score =
779      WeightedSmoothedGini(*right_sum, right_square, num_classes);
780  return left_score + right_score;
781}
782
783void FixedSizeSparseClassificationGrowStats::InitLeafClassStats(
784    int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const {
785  auto* left_class_stats = left_stats->mutable_classification();
786  auto* left_counts = left_class_stats->mutable_sparse_counts();
787  left_counts_[best_split_index].PackToProto(left_counts);
788
789  auto* right_class_stats = right_stats->mutable_classification();
790  auto* right_counts = right_class_stats->mutable_sparse_counts();
791  right_counts_[best_split_index].PackToProto(right_counts);
792}
793
794// --------------------- Least Squares Regression --------------------------- //
795void LeastSquaresRegressionGrowStats::ExtractFromProto(
796    const FertileSlot& slot) {
797  const int32 num_outputs = params_.num_outputs();
798  Initialize();
799  if (!slot.has_post_init_leaf_stats()) {
800    return;
801  }
802  weight_sum_ = slot.post_init_leaf_stats().weight_sum();
803  const auto& total_sums =
804      slot.post_init_leaf_stats().regression().mean_output();
805  const auto& total_squares =
806      slot.post_init_leaf_stats().regression().mean_output_squares();
807
808  // Total counts.
809  for (int i = 0; i < num_outputs; ++i) {
810    total_sum_[i] = total_sums.value(i).float_value();
811    total_sum_squares_[i] = total_squares.value(i).float_value();
812  }
813
814  // Candidate counts and splits.
815  int split_num = 0;
816  for (const auto& cand : slot.candidates()) {
817    AddSplit(cand.split(), nullptr, nullptr, -1);
818    const auto& sums = cand.left_stats().regression().mean_output();
819    const auto& squares = cand.left_stats().regression().mean_output_squares();
820    for (int i = 0; i < num_outputs; ++i) {
821      left_sum(split_num, i) = sums.value(i).float_value();
822      left_square(split_num, i) = squares.value(i).float_value();
823    }
824    left_counts_[split_num] = cand.left_stats().weight_sum();
825    ++split_num;
826  }
827}
828
829void LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const {
830  const int32 num_outputs = params_.num_outputs();
831  auto* slot_stats = slot->mutable_post_init_leaf_stats();
832  slot_stats->set_weight_sum(weight_sum_);
833
834  auto* total_sums = slot->mutable_post_init_leaf_stats()
835                         ->mutable_regression()
836                         ->mutable_mean_output();
837  auto* total_squares = slot->mutable_post_init_leaf_stats()
838                            ->mutable_regression()
839                            ->mutable_mean_output_squares();
840
841  for (int i = 0; i < total_sum_.size(); ++i) {
842    total_sums->add_value()->set_float_value(total_sum_[i]);
843    total_squares->add_value()->set_float_value(total_sum_squares_[i]);
844  }
845
846  for (int split_num = 0; split_num < num_splits(); ++split_num) {
847    auto* cand = slot->add_candidates();
848    *cand->mutable_split() = splits_[split_num];
849    auto* sums =
850        cand->mutable_left_stats()->mutable_regression()->mutable_mean_output();
851    auto* squares = cand->mutable_left_stats()
852                        ->mutable_regression()
853                        ->mutable_mean_output_squares();
854    for (int i = 0; i < num_outputs; ++i) {
855      sums->add_value()->set_float_value(left_sum(split_num, i));
856      squares->add_value()->set_float_value(left_square(split_num, i));
857    }
858    cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]);
859  }
860}
861
862void LeastSquaresRegressionGrowStats::AddExample(
863    const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
864    int example) {
865  const int32 num_outputs = params_.num_outputs();
866  // Update splits.
867  for (int i = 0; i < num_splits(); ++i) {
868    auto& eval = evaluators_[i];
869    if (eval->Decide(input_data, example) == LEFT_INDEX) {
870      for (int j = 0; j < num_outputs; ++j) {
871        const float output = target->GetTargetAsContinuous(example, j);
872        left_sum(i, j) += output;
873        left_square(i, j) += output * output;
874      }
875      ++left_counts_[i];
876    }
877  }
878
879  // Update totals.
880  for (int i = 0; i < num_outputs; ++i) {
881    const float output = target->GetTargetAsContinuous(example, i);
882    total_sum_[i] += output;
883    total_sum_squares_[i] += output * output;
884  }
885  weight_sum_ += 1.0;
886}
887
888float LeastSquaresRegressionGrowStats::SplitVariance(int split) const {
889  float total_variance = 0;
890  for (int i = 0; i < params_.num_outputs(); ++i) {
891    // Left side
892    const float le_x = left_sum(split, i) / left_counts_[split];
893
894    const float le_x2 = left_square(split, i) / left_counts_[split];
895    total_variance += le_x2 - le_x * le_x;
896
897    // Right side
898    const float re_x = (total_sum_[i] - left_sum(split, i)) /
899                       (weight_sum_ - left_counts_[split]);
900
901    const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) /
902                        (weight_sum_ - left_counts_[split]);
903    total_variance += re_x2 - re_x * re_x;
904  }
905  return total_variance;
906}
907
908bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const {
909  float min_score = FLT_MAX;
910  int best_index = -1;
911  const int32 num_outputs = params_.num_outputs();
912  for (int i = 0; i < num_splits(); ++i) {
913    if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) {
914      const float split_score = SplitVariance(i);
915      if (split_score < min_score) {
916        min_score = split_score;
917        best_index = i;
918      }
919    }
920  }
921
922  // This could happen if all the splits are useless.
923  if (best_index < 0) {
924    return false;
925  }
926
927  // Fill in right stats to be used for leaf model.
928  *best->mutable_split() = splits_[best_index];
929  // Left
930  auto* left = best->mutable_left_stats();
931  auto* left_reg_stats = left->mutable_regression();
932  left->set_weight_sum(left_counts_[best_index]);
933  auto* left_output_sum = left_reg_stats->mutable_mean_output();
934  for (int i = 0; i < num_outputs; ++i) {
935    left_output_sum->add_value()->set_float_value(left_sum(best_index, i));
936  }
937
938  // Right
939  auto* right = best->mutable_right_stats();
940  auto* right_reg_stats = right->mutable_regression();
941  right->set_weight_sum(weight_sum_ - left_counts_[best_index]);
942  auto* right_output_sum = right_reg_stats->mutable_mean_output();
943  for (int i = 0; i < num_outputs; ++i) {
944    right_output_sum->add_value()->set_float_value(total_sum_[i] -
945                                                   left_sum(best_index, i));
946  }
947  return true;
948}
949
950bool LeastSquaresRegressionGrowStats::IsFinished() const {
951  return weight_sum_ >= split_after_samples_;
952}
953
954}  // namespace tensorforest
955}  // namespace tensorflow
956