11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License");
41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License.
51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at
61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//     http://www.apache.org/licenses/LICENSE-2.0
81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower//
91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software
101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS,
111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and
131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License.
141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// =============================================================================
15f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <unordered_map>
181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <vector>
191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h"
221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h"
271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/lib/random/philox_random.h"
281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/lib/random/simple_philox.h"
291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow {
311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest {
321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Base class for tracking stats necessary to split a leaf.
341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Holds and tracks stats for every candidate split.
351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass GrowStats {
361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public:
371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual ~GrowStats() {}
381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Perform any initialization.
391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void Initialize() = 0;
401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Add an example to any stats being collected.
421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          const InputTarget* target, int example) = 0;
441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Fill in the best split, return false if none were valid.
461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual bool BestSplit(SplitCandidate* best) const = 0;
471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Return true if this leaf is finished splitting.
491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual bool IsFinished() const = 0;
501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Get the split_num BinaryNode.
521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const decision_trees::BinaryNode& Split(int split_num) const {
531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return splits_[split_num];
541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Clear all state.
571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void Clear() {
581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    weight_sum_ = 0;
591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    splits_.clear();
601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    evaluators_.clear();
611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ClearInternal();
621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void ExtractFromProto(const FertileSlot& slot) = 0;
651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void PackToProto(FertileSlot* slot) const = 0;
661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Add split to the list of candidate splits.
6875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  void AddSplit(const decision_trees::BinaryNode& split,
6975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower                const std::unique_ptr<TensorDataSet>& input_data,
7075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower                const InputTarget* target, int example);
7175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  virtual void AdditionalInitializationExample(
7275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      const std::unique_ptr<TensorDataSet>& input_data,
7375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      const InputTarget* target, int example) {}
741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void RemoveSplit(int split_num);
751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
764463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower  int num_splits() const { return splits_.size(); }
771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
784463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower  float weight_sum() const { return weight_sum_; }
791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
8075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  virtual bool IsInitialized() const {
811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_;
821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
844463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower  int32 depth() const { return depth_; }
851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected:
871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  GrowStats(const TensorForestParams& params, int32 depth);
881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Function called by AddSplit for subclasses to initialize stats for a split.
9075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  virtual void AddSplitStats(const InputTarget* target, int example) = 0;
911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void RemoveSplitStats(int split_num) = 0;
931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Function called by Clear for subclasses to clear their state.
951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void ClearInternal() = 0;
961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<decision_trees::BinaryNode> splits_;
981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<std::unique_ptr<DecisionNodeEvaluator>> evaluators_;
991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float weight_sum_;
1011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 depth_;
1031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const TensorForestParams& params_;
1051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
10690d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai  // We cache these because they're used often.
1071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int split_after_samples_;
1081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int num_splits_to_consider_;
1091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  const int32 num_outputs_;
1111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
1121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Don't track anything, useful for systems that want to track split
1141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// candidates but train the model in some other way.
1151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass SimpleStats : public GrowStats {
1161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public:
1171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  SimpleStats(const TensorForestParams& params, int32 depth)
1181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      : GrowStats(params, depth) {}
1191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void Initialize() override {}
1201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ExtractFromProto(const FertileSlot& slot) override {}
1221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void PackToProto(FertileSlot* slot) const override {}
1231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
1251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                  const InputTarget* target, int example) override {
1261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    weight_sum_ += target->GetTargetWeight(example);
1271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  bool BestSplit(SplitCandidate* best) const override { return false; }
1301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  bool IsFinished() const override {
1321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return weight_sum_ >= split_after_samples_;
1331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected:
13675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  void AddSplitStats(const InputTarget* target, int example) override {}
1371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void RemoveSplitStats(int split_num) override {}
1381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClearInternal() override {}
1391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
1401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Tracks the sum and square of one side of a split for each Gini calculation.
1421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass RunningGiniScores {
1431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public:
1441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float sum(int split) const { return sum_[split]; }
1451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float square(int split) const { return square_[split]; }
1461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void update(int split, float old_val, float weight) {
1481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    sum_[split] += weight;
1491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    const float new_val = old_val + weight;
1501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    square_[split] = square_[split] - old_val * old_val + new_val * new_val;
1511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void add_split() {
1541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    sum_.push_back(0);
1551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    square_.push_back(0);
1561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void remove_split(int i) {
1591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    sum_.erase(sum_.begin() + i);
1601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    square_.erase(square_.begin() + i);
1611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
1621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private:
1641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> sum_;
1651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> square_;
1661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
1671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass ClassificationStats : public GrowStats {
1691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public:
1701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  ClassificationStats(const TensorForestParams& params, int32 depth);
1711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  bool IsFinished() const override;
1731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
1741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
1751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                  const InputTarget* target, int example) override;
1761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
17775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  void AdditionalInitializationExample(
17875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      const std::unique_ptr<TensorDataSet>& input_data,
17975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      const InputTarget* target, int example) override;
18075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower
18175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  bool IsInitialized() const override {
18275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    return weight_sum_ > 0 || (splits_.size() == num_splits_to_consider_ &&
18375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower                               half_initialized_splits_.empty());
18475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  }
18575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower
186054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  bool BestSplit(SplitCandidate* best) const override;
187054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // When best_split_index has been chosen as the best split,
188054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // InitLeafClassStats is used to initialize the LeafStat's of the two
189054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // new leaves.
190054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  virtual void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
191054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                                  LeafStat* right_stats) const = 0;
192054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
1931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected:
1941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual float GiniScore(int split, float* left_sum,
1951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                          float* right_sum) const = 0;
196054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
197054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // is_pure should return true if at most one class label has been seen
198054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // at the node, and false if two or more have been seen.
199054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  virtual bool is_pure() const = 0;
2001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual float left_count(int split, int class_num) const = 0;
2011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual float right_count(int split, int class_num) const = 0;
2021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2034463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower  virtual void ClassificationAddLeftExample(int split, int64 int_label,
2044463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower                                            float weight) = 0;
205054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  virtual void ClassificationAddRightExample(int split, int64 int_label,
206054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                                             float weight) {
207054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    // Does nothing by default, but sub-classes can override.
208054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
2091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void ClassificationAddTotalExample(int64 int_label, float weight) = 0;
2101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void ClassificationAddSplitStats() = 0;
2121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void ClassificationRemoveSplitStats(int split) = 0;
2131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
21475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  void AddSplitStats(const InputTarget* target, int example) override {
2151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (left_gini_ != nullptr) {
2161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_gini_->add_split();
2171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right_gini_->add_split();
2181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
21975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    if (params_.initialize_average_splits()) {
22075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      if (splits_[splits_.size() - 1].has_inequality_left_child_test()) {
22175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower        half_initialized_splits_[splits_.size() - 1] =
22275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower            target->GetTargetAsClassIndex(example, 0);
22375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower      }
22475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower    }
2251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ClassificationAddSplitStats();
2261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void RemoveSplitStats(int split) override {
2281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (left_gini_ != nullptr) {
2291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_gini_->remove_split(split);
2301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right_gini_->remove_split(split);
2311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
2321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    ClassificationRemoveSplitStats(split);
2331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Virtual so we can override these to test.
2361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void CheckFinishEarly();
2371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void CheckFinishEarlyHoeffding();
2381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void CheckFinishEarlyBootstrap();
2391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual void CheckPrune();
2411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Implement SplitPruningStrategyType::SPLIT_PRUNE_HOEFFDING.
2431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void CheckPruneHoeffding();
2441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Return the gini score, possibly being calculated from sums and squares
2461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // saved in left_gini_ and right_gini_, otherwise calculated from raw counts.
2471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float MaybeCachedGiniScore(int split, float* left_sum,
2481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                             float* right_sum) const;
2491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Initialize the sum and squares of left_gini_ and right_gini_ for given
2511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // split and value (being extracted from a proto), if left_gini_ isn't null.
2521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void MaybeInitializeRunningCount(int split, float val) {
2531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    if (left_gini_ != nullptr) {
2541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      left_gini_->update(split, 0, val);
2551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      right_gini_->update(split, 0, val);
2561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    }
2571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int NumBootstrapSamples() const;
2601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Populate *weights with the smoothed per-class frequencies needed to
2621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // initialize a DistributionSampler.
2631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void MakeBootstrapWeights(int index, std::vector<float>* weights);
2641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Accessors for RunningGiniScores objects, for testing.
2661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual const std::unique_ptr<RunningGiniScores>& get_left_gini() const {
2671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left_gini_;
2681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  virtual const std::unique_ptr<RunningGiniScores>& get_right_gini() const {
2701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return right_gini_;
2711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
2721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private:
2741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Tracks how many check_every_samples epochs we've seen go by in weight_sum.
2751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 finish_sample_epoch_;
2761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 finish_check_every_;
2771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 prune_sample_epoch_;
2781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 prune_check_every_;
2791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  bool finish_early_;
2801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int32 min_split_samples_;
2811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float dominate_fraction_;
2821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float prune_fraction_;
2831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // When using SPLIT_PRUNE_HOEFFDING, we precompute and store
2851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // 0.5 * ln(1 / (1.0 - dominate_fraction_)).
2861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float half_ln_dominate_frac_;
2871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::unique_ptr<random::PhiloxRandom> single_rand_;
2891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::unique_ptr<random::SimplePhilox> rng_;
2901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::unique_ptr<RunningGiniScores> left_gini_;
2921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::unique_ptr<RunningGiniScores> right_gini_;
29375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower
29475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  // Stores split number -> class that was first seen.
29575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  std::unordered_map<int, int32> half_initialized_splits_;
2961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
2971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
2981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Tracks classification stats by storing class counts densely.
2991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass DenseClassificationGrowStats : public ClassificationStats {
3001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public:
3011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  DenseClassificationGrowStats(const TensorForestParams& params, int32 depth)
3021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      : ClassificationStats(params, depth) {}
3031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void Initialize() override {
3051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    Clear();
3061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_counts_.resize(num_outputs_);
3071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ExtractFromProto(const FertileSlot& slot) override;
3101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void PackToProto(FertileSlot* slot) const override;
3111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
312054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
3137578785dff668c63ba6b5423a6bf2a5984c7b409A. Unique TensorFlower                          LeafStat* right_stats) const override;
3141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected:
3161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClassificationAddSplitStats() override {
3171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_.resize(num_outputs_ * num_splits());
3181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClassificationRemoveSplitStats(int split_num) override {
3201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_.erase(left_counts_.begin() + num_outputs_ * split_num,
3211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                       left_counts_.begin() + num_outputs_ * (split_num + 1));
3221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClearInternal() override {
3241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_counts_.clear();
3251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_.clear();
3261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    num_outputs_seen_ = 0;
3271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
329054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  bool is_pure() const override { return num_outputs_seen_ <= 1; }
3301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClassificationAddLeftExample(int split, int64 int_label,
3321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                    float weight) override {
3331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    mutable_left_count(split, int_label) += weight;
3341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClassificationAddTotalExample(int64 int_label, float weight) override {
3361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    num_outputs_seen_ += total_counts_[int_label] == 0 && weight > 0;
3371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_counts_[int_label] += weight;
3381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float GiniScore(int split, float* left_sum, float* right_sum) const override;
3411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float left_count(int split, int class_num) const override {
3431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left_counts_[split * num_outputs_ + class_num];
3441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float right_count(int split, int class_num) const override {
3461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return total_counts_[class_num] -
3471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower           left_counts_[split * num_outputs_ + class_num];
3481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private:
3511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  inline float& mutable_left_count(int split, int class_num) {
3521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left_counts_[split * num_outputs_ + class_num];
3531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Total class counts seen at this leaf
3551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> total_counts_;
3561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Also track the number of classes seen for not splitting pure leaves.
3581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  int num_outputs_seen_;
3591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Left-branch taken class counts at this leaf for each split.
3611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // This is a flat vector for memory-performance reasons.
3621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // left_counts_[i * num_outputs_ + j] has the j-th class count for split i.
3631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> left_counts_;
3641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
3651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Tracks classification stats by storing class counts sparsely.
3671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass SparseClassificationGrowStats : public ClassificationStats {
3681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public:
3691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  SparseClassificationGrowStats(const TensorForestParams& params, int32 depth)
3701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      : ClassificationStats(params, depth) {}
3711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3724463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower  void Initialize() override { Clear(); }
3731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ExtractFromProto(const FertileSlot& slot) override;
3751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void PackToProto(FertileSlot* slot) const override;
3761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
377054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
3787578785dff668c63ba6b5423a6bf2a5984c7b409A. Unique TensorFlower                          LeafStat* right_stats) const override;
3791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected:
3811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClassificationAddSplitStats() override {
3821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_.resize(num_splits());
3831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClassificationRemoveSplitStats(int split_num) override {
3851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_.erase(left_counts_.begin() + split_num,
3861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                       left_counts_.begin() + (split_num + 1));
3871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClearInternal() override {
3891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_counts_.clear();
3901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_.clear();
3911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
393054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  bool is_pure() const override { return total_counts_.size() <= 1; }
3941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
3951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClassificationAddLeftExample(int split, int64 int_label,
3961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                                    float weight) override {
3971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_[split][int_label] += weight;
3981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
3991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClassificationAddTotalExample(int64 int_label, float weight) override {
4001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_counts_[int_label] += weight;
4011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float GiniScore(int split, float* left_sum, float* right_sum) const override;
4041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float left_count(int split, int class_num) const override {
4061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left_counts_[split].at(class_num);
4071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float right_count(int split, int class_num) const override {
4091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return total_counts_.at(class_num) - left_counts_[split].at(class_num);
4101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
4111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private:
4131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Total class counts seen at this leaf
4141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::unordered_map<int, float> total_counts_;
4151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
4161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Left-branch taken class counts at this leaf for each split.
4171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // left_counts_[i][j] has the j-th class count for split i.
4181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<std::unordered_map<int, float>> left_counts_;
4191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
4201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
421054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// Accumulates weights for the most popular classes while only using a
422054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// fixed amount of space.
423054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerclass FixedSizeClassStats {
424054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower public:
425054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // n specifies how many classes are tracked.
426054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  FixedSizeClassStats(int n, int num_classes)
427054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      : n_(n), num_classes_(num_classes), smallest_weight_class_(-1) {}
428054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
429054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // Add weight w to the class c.
430054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void accumulate(int c, float w);
431054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
432054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // Return the approximate accumulated weight for class c.  If c isn't one
433054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // of the n-most popular classes, this can be 0 even if c has accumulated
434054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // some weight.
435054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float get_weight(int c) const;
436054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
437054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // Put the sum of all weights seen into *sum, and
438054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // \sum_c get_weight(c)^2
439054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // into *square.  *sum will be exact, but *square will be approximate.
440054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void set_sum_and_square(float* sum, float* square) const;
441054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
442054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void ExtractFromProto(const decision_trees::SparseVector& sparse_vector);
443054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void PackToProto(decision_trees::SparseVector* sparse_vector) const;
444054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
445054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower private:
446054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // For our typical use cases, n_ is between 10 and 100, so there's no
447054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // need to track the smallest weight with a min_heap or the like.
448054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  int n_;
449054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  int num_classes_;
450054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
451054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // This tracks the class of the smallest weight, but isn't set until
452054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // class_weights_.size() == n_.
453054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  int smallest_weight_class_;
454054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
455054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  std::unordered_map<int, float> class_weights_;
456054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower};
457054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
458054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// Tracks classification stats sparsely in a fixed amount of space.
459054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerclass FixedSizeSparseClassificationGrowStats : public ClassificationStats {
460054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower public:
461054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  FixedSizeSparseClassificationGrowStats(const TensorForestParams& params,
462054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                                         int32 depth)
463054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      : ClassificationStats(params, depth) {}
464054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
465054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void Initialize() override { Clear(); }
466054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
467054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void ExtractFromProto(const FertileSlot& slot) override;
468054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void PackToProto(FertileSlot* slot) const override;
469054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
470054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void InitLeafClassStats(int best_split_index, LeafStat* left_stats,
471d100729c309cb22baf1630d9f39cf60516c58cdfDaniel Trebbien                          LeafStat* right_stats) const override;
472054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
473054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower protected:
474054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void ClassificationAddSplitStats() override {
475054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    FixedSizeClassStats stats(params_.num_classes_to_track(),
476054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                              params_.num_outputs());
477054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    left_counts_.resize(num_splits(), stats);
478054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    right_counts_.resize(num_splits(), stats);
479054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
480054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void ClassificationRemoveSplitStats(int split_num) override {
481054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    left_counts_.erase(left_counts_.begin() + split_num,
482054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                       left_counts_.begin() + (split_num + 1));
483054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    right_counts_.erase(right_counts_.begin() + split_num,
484054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                        right_counts_.begin() + (split_num + 1));
485054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
486054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void ClearInternal() override {
487054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    left_counts_.clear();
488054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    right_counts_.clear();
489054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
490054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
491054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  bool is_pure() const override { return first_two_classes_seen_.size() <= 1; }
492054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
493054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void ClassificationAddLeftExample(int split, int64 int_label,
494054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                                    float weight) override {
495054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    left_counts_[split].accumulate(int_label, weight);
496054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
497054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void ClassificationAddRightExample(int split, int64 int_label,
498054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower                                     float weight) override {
499054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    right_counts_[split].accumulate(int_label, weight);
500054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
501054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  void ClassificationAddTotalExample(int64 int_label, float weight) override {
502054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    if (is_pure()) {
503054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower      first_two_classes_seen_.insert(int_label);
504054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    }
505054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
506054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
507054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float GiniScore(int split, float* left_sum, float* right_sum) const override;
508054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
509054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float left_count(int split, int class_num) const override {
510054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    return left_counts_[split].get_weight(class_num);
511054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
512054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
513054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  float right_count(int split, int class_num) const override {
514054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower    return right_counts_[split].get_weight(class_num);
515054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  }
516054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
517054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower private:
518054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  std::vector<FixedSizeClassStats> left_counts_;
519054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  std::vector<FixedSizeClassStats> right_counts_;
520054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
521054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // We keep track of the first two class labels seen, so we can tell if
522054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  // the node is pure (= all of one class) or not.
523054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower  std::set<int> first_two_classes_seen_;
524054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower};
525054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower
5261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Tracks regression stats using least-squares minimization.
5271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass LeastSquaresRegressionGrowStats : public GrowStats {
5281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public:
5291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  LeastSquaresRegressionGrowStats(const TensorForestParams& params, int32 depth)
5301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower      : GrowStats(params, depth) {}
5311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void Initialize() override {
5331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    Clear();
5341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_.resize(num_outputs_);
5351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_squares_.resize(num_outputs_);
5361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ExtractFromProto(const FertileSlot& slot) override;
5391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void PackToProto(FertileSlot* slot) const override;
5401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void AddExample(const std::unique_ptr<TensorDataSet>& input_data,
5421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                  const InputTarget* target, int example) override;
5431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  bool BestSplit(SplitCandidate* best) const override;
5441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  bool IsFinished() const override;
5451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected:
5471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Returns the variance of split.
5481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  float SplitVariance(int split) const;
5491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
55075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower  void AddSplitStats(const InputTarget* target, int example) override {
5511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_sums_.resize(num_outputs_ * num_splits());
5521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_squares_.resize(num_outputs_ * num_splits());
5531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_.push_back(0);
5541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void RemoveSplitStats(int split_num) override {
5561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_sums_.erase(left_sums_.begin() + num_outputs_ * split_num,
5574463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower                     left_sums_.begin() + num_outputs_ * (split_num + 1));
5581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_squares_.erase(left_squares_.begin() + num_outputs_ * split_num,
5594463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower                        left_squares_.begin() + num_outputs_ * (split_num + 1));
5601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_counts_.erase(left_counts_.begin() + split_num,
5611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower                       left_counts_.begin() + (split_num + 1));
5621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  void ClearInternal() override {
5651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_.clear();
5661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    total_sum_squares_.clear();
5671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_sums_.clear();
5681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    left_squares_.clear();
5691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private:
5721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Convenience methods for accessing the flat count vectors.
5731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  inline const float& left_sum(int split, int output_num) const {
5741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left_sums_[split * num_outputs_ + output_num];
5751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  inline float& left_sum(int split, int output_num) {
5771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left_sums_[split * num_outputs_ + output_num];
5781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  inline const float& left_square(int split, int output_num) const {
5801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left_squares_[split * num_outputs_ + output_num];
5811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  inline float& left_square(int split, int output_num) {
5831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower    return left_squares_[split * num_outputs_ + output_num];
5841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  }
5851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Total sums and squares seen at this leaf.
5871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // sum[i] is the sum of the i-th output.
5881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> total_sum_;
5891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> total_sum_squares_;
5901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // Per-split sums and squares, stored flat for performance.
5921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // left_sums_[i * num_outputs_ + j] has the j-th sum for split i.
5931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> left_sums_;
5941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<float> left_squares_;
5951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
5961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  // The number of example seen at each split.
5971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower  std::vector<int64> left_counts_;
5981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower};
5991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
6001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorforest
6011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}  // namespace tensorflow
6021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower
603f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_
604