11588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 21588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 31588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License"); 41588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// you may not use this file except in compliance with the License. 51588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// You may obtain a copy of the License at 61588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 71588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// http://www.apache.org/licenses/LICENSE-2.0 81588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// 91588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software 101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS, 111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// See the License for the specific language governing permissions and 131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// limitations under the License. 141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// ============================================================================= 15f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ 16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ 171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <unordered_map> 181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include <vector> 191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" 221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" 231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" 241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" 251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" 261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" 271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/lib/random/philox_random.h" 281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower#include "tensorflow/core/lib/random/simple_philox.h" 291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorflow { 311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowernamespace tensorforest { 321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Base class for tracking stats necessary to split a leaf. 341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Holds and tracks stats for every candidate split. 351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass GrowStats { 361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public: 371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual ~GrowStats() {} 381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Perform any initialization. 391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void Initialize() = 0; 401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Add an example to any stats being collected. 421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void AddExample(const std::unique_ptr<TensorDataSet>& input_data, 431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const InputTarget* target, int example) = 0; 441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Fill in the best split, return false if none were valid. 461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual bool BestSplit(SplitCandidate* best) const = 0; 471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Return true if this leaf is finished splitting. 491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual bool IsFinished() const = 0; 501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Get the split_num BinaryNode. 521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const decision_trees::BinaryNode& Split(int split_num) const { 531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return splits_[split_num]; 541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Clear all state. 571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void Clear() { 581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ = 0; 591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower splits_.clear(); 601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower evaluators_.clear(); 611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ClearInternal(); 621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void ExtractFromProto(const FertileSlot& slot) = 0; 651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void PackToProto(FertileSlot* slot) const = 0; 661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Add split to the list of candidate splits. 6875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower void AddSplit(const decision_trees::BinaryNode& split, 6975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, 7075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const InputTarget* target, int example); 7175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower virtual void AdditionalInitializationExample( 7275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, 7375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const InputTarget* target, int example) {} 741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void RemoveSplit(int split_num); 751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 764463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower int num_splits() const { return splits_.size(); } 771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 784463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower float weight_sum() const { return weight_sum_; } 791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 8075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower virtual bool IsInitialized() const { 811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_; 821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 844463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower int32 depth() const { return depth_; } 851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected: 871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower GrowStats(const TensorForestParams& params, int32 depth); 881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Function called by AddSplit for subclasses to initialize stats for a split. 9075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower virtual void AddSplitStats(const InputTarget* target, int example) = 0; 911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void RemoveSplitStats(int split_num) = 0; 931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Function called by Clear for subclasses to clear their state. 951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void ClearInternal() = 0; 961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<decision_trees::BinaryNode> splits_; 981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<std::unique_ptr<DecisionNodeEvaluator>> evaluators_; 991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float weight_sum_; 1011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 depth_; 1031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const TensorForestParams& params_; 1051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 10690d6421c5e0898fb840197d9533c2f8ba1a7c651Shanqing Cai // We cache these because they're used often. 1071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int split_after_samples_; 1081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int num_splits_to_consider_; 1091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const int32 num_outputs_; 1111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 1121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Don't track anything, useful for systems that want to track split 1141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// candidates but train the model in some other way. 1151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass SimpleStats : public GrowStats { 1161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public: 1171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower SimpleStats(const TensorForestParams& params, int32 depth) 1181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower : GrowStats(params, depth) {} 1191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void Initialize() override {} 1201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ExtractFromProto(const FertileSlot& slot) override {} 1221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void PackToProto(FertileSlot* slot) const override {} 1231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void AddExample(const std::unique_ptr<TensorDataSet>& input_data, 1251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const InputTarget* target, int example) override { 1261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower weight_sum_ += target->GetTargetWeight(example); 1271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower bool BestSplit(SplitCandidate* best) const override { return false; } 1301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower bool IsFinished() const override { 1321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return weight_sum_ >= split_after_samples_; 1331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected: 13675f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower void AddSplitStats(const InputTarget* target, int example) override {} 1371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void RemoveSplitStats(int split_num) override {} 1381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClearInternal() override {} 1391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 1401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Tracks the sum and square of one side of a split for each Gini calculation. 1421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass RunningGiniScores { 1431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public: 1441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float sum(int split) const { return sum_[split]; } 1451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float square(int split) const { return square_[split]; } 1461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void update(int split, float old_val, float weight) { 1481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower sum_[split] += weight; 1491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const float new_val = old_val + weight; 1501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower square_[split] = square_[split] - old_val * old_val + new_val * new_val; 1511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void add_split() { 1541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower sum_.push_back(0); 1551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower square_.push_back(0); 1561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void remove_split(int i) { 1591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower sum_.erase(sum_.begin() + i); 1601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower square_.erase(square_.begin() + i); 1611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 1621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private: 1641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> sum_; 1651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> square_; 1661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 1671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass ClassificationStats : public GrowStats { 1691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public: 1701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ClassificationStats(const TensorForestParams& params, int32 depth); 1711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower bool IsFinished() const override; 1731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 1741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void AddExample(const std::unique_ptr<TensorDataSet>& input_data, 1751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const InputTarget* target, int example) override; 1761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 17775f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower void AdditionalInitializationExample( 17875f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const std::unique_ptr<TensorDataSet>& input_data, 17975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower const InputTarget* target, int example) override; 18075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower 18175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower bool IsInitialized() const override { 18275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower return weight_sum_ > 0 || (splits_.size() == num_splits_to_consider_ && 18375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower half_initialized_splits_.empty()); 18475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 18575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower 186054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower bool BestSplit(SplitCandidate* best) const override; 187054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // When best_split_index has been chosen as the best split, 188054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // InitLeafClassStats is used to initialize the LeafStat's of the two 189054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // new leaves. 190054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower virtual void InitLeafClassStats(int best_split_index, LeafStat* left_stats, 191054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower LeafStat* right_stats) const = 0; 192054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 1931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected: 1941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual float GiniScore(int split, float* left_sum, 1951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float* right_sum) const = 0; 196054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 197054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // is_pure should return true if at most one class label has been seen 198054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // at the node, and false if two or more have been seen. 199054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower virtual bool is_pure() const = 0; 2001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual float left_count(int split, int class_num) const = 0; 2011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual float right_count(int split, int class_num) const = 0; 2021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2034463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower virtual void ClassificationAddLeftExample(int split, int64 int_label, 2044463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower float weight) = 0; 205054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower virtual void ClassificationAddRightExample(int split, int64 int_label, 206054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float weight) { 207054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Does nothing by default, but sub-classes can override. 208054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 2091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void ClassificationAddTotalExample(int64 int_label, float weight) = 0; 2101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void ClassificationAddSplitStats() = 0; 2121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void ClassificationRemoveSplitStats(int split) = 0; 2131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 21475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower void AddSplitStats(const InputTarget* target, int example) override { 2151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_gini_ != nullptr) { 2161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_gini_->add_split(); 2171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_gini_->add_split(); 2181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 21975f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (params_.initialize_average_splits()) { 22075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower if (splits_[splits_.size() - 1].has_inequality_left_child_test()) { 22175f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower half_initialized_splits_[splits_.size() - 1] = 22275f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower target->GetTargetAsClassIndex(example, 0); 22375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 22475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower } 2251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ClassificationAddSplitStats(); 2261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void RemoveSplitStats(int split) override { 2281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_gini_ != nullptr) { 2291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_gini_->remove_split(split); 2301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_gini_->remove_split(split); 2311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower ClassificationRemoveSplitStats(split); 2331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Virtual so we can override these to test. 2361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void CheckFinishEarly(); 2371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void CheckFinishEarlyHoeffding(); 2381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void CheckFinishEarlyBootstrap(); 2391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual void CheckPrune(); 2411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Implement SplitPruningStrategyType::SPLIT_PRUNE_HOEFFDING. 2431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void CheckPruneHoeffding(); 2441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Return the gini score, possibly being calculated from sums and squares 2461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // saved in left_gini_ and right_gini_, otherwise calculated from raw counts. 2471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float MaybeCachedGiniScore(int split, float* left_sum, 2481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float* right_sum) const; 2491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Initialize the sum and squares of left_gini_ and right_gini_ for given 2511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // split and value (being extracted from a proto), if left_gini_ isn't null. 2521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void MaybeInitializeRunningCount(int split, float val) { 2531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower if (left_gini_ != nullptr) { 2541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_gini_->update(split, 0, val); 2551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower right_gini_->update(split, 0, val); 2561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int NumBootstrapSamples() const; 2601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Populate *weights with the smoothed per-class frequencies needed to 2621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // initialize a DistributionSampler. 2631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void MakeBootstrapWeights(int index, std::vector<float>* weights); 2641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Accessors for RunningGiniScores objects, for testing. 2661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual const std::unique_ptr<RunningGiniScores>& get_left_gini() const { 2671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_gini_; 2681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower virtual const std::unique_ptr<RunningGiniScores>& get_right_gini() const { 2701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return right_gini_; 2711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 2721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private: 2741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Tracks how many check_every_samples epochs we've seen go by in weight_sum. 2751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 finish_sample_epoch_; 2761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 finish_check_every_; 2771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 prune_sample_epoch_; 2781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 prune_check_every_; 2791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower bool finish_early_; 2801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int32 min_split_samples_; 2811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float dominate_fraction_; 2821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float prune_fraction_; 2831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // When using SPLIT_PRUNE_HOEFFDING, we precompute and store 2851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // 0.5 * ln(1 / (1.0 - dominate_fraction_)). 2861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float half_ln_dominate_frac_; 2871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::unique_ptr<random::PhiloxRandom> single_rand_; 2891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::unique_ptr<random::SimplePhilox> rng_; 2901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::unique_ptr<RunningGiniScores> left_gini_; 2921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::unique_ptr<RunningGiniScores> right_gini_; 29375f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower 29475f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower // Stores split number -> class that was first seen. 29575f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower std::unordered_map<int, int32> half_initialized_splits_; 2961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 2971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 2981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Tracks classification stats by storing class counts densely. 2991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass DenseClassificationGrowStats : public ClassificationStats { 3001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public: 3011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower DenseClassificationGrowStats(const TensorForestParams& params, int32 depth) 3021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower : ClassificationStats(params, depth) {} 3031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void Initialize() override { 3051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower Clear(); 3061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_.resize(num_outputs_); 3071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ExtractFromProto(const FertileSlot& slot) override; 3101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void PackToProto(FertileSlot* slot) const override; 3111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 312054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void InitLeafClassStats(int best_split_index, LeafStat* left_stats, 3137578785dff668c63ba6b5423a6bf2a5984c7b409A. Unique TensorFlower LeafStat* right_stats) const override; 3141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected: 3161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClassificationAddSplitStats() override { 3171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.resize(num_outputs_ * num_splits()); 3181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClassificationRemoveSplitStats(int split_num) override { 3201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.erase(left_counts_.begin() + num_outputs_ * split_num, 3211588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.begin() + num_outputs_ * (split_num + 1)); 3221588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3231588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClearInternal() override { 3241588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_.clear(); 3251588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.clear(); 3261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower num_outputs_seen_ = 0; 3271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 329054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower bool is_pure() const override { return num_outputs_seen_ <= 1; } 3301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClassificationAddLeftExample(int split, int64 int_label, 3321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float weight) override { 3331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower mutable_left_count(split, int_label) += weight; 3341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClassificationAddTotalExample(int64 int_label, float weight) override { 3361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower num_outputs_seen_ += total_counts_[int_label] == 0 && weight > 0; 3371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_[int_label] += weight; 3381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float GiniScore(int split, float* left_sum, float* right_sum) const override; 3411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left_count(int split, int class_num) const override { 3431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_counts_[split * num_outputs_ + class_num]; 3441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float right_count(int split, int class_num) const override { 3461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return total_counts_[class_num] - 3471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_[split * num_outputs_ + class_num]; 3481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3501588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private: 3511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower inline float& mutable_left_count(int split, int class_num) { 3521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_counts_[split * num_outputs_ + class_num]; 3531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Total class counts seen at this leaf 3551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> total_counts_; 3561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3571588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Also track the number of classes seen for not splitting pure leaves. 3581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower int num_outputs_seen_; 3591588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Left-branch taken class counts at this leaf for each split. 3611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // This is a flat vector for memory-performance reasons. 3621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // left_counts_[i * num_outputs_ + j] has the j-th class count for split i. 3631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> left_counts_; 3641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 3651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Tracks classification stats by storing class counts sparsely. 3671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass SparseClassificationGrowStats : public ClassificationStats { 3681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public: 3691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower SparseClassificationGrowStats(const TensorForestParams& params, int32 depth) 3701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower : ClassificationStats(params, depth) {} 3711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3724463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower void Initialize() override { Clear(); } 3731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ExtractFromProto(const FertileSlot& slot) override; 3751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void PackToProto(FertileSlot* slot) const override; 3761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 377054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void InitLeafClassStats(int best_split_index, LeafStat* left_stats, 3787578785dff668c63ba6b5423a6bf2a5984c7b409A. Unique TensorFlower LeafStat* right_stats) const override; 3791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected: 3811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClassificationAddSplitStats() override { 3821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.resize(num_splits()); 3831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClassificationRemoveSplitStats(int split_num) override { 3851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.erase(left_counts_.begin() + split_num, 3861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.begin() + (split_num + 1)); 3871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClearInternal() override { 3891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_.clear(); 3901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.clear(); 3911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 393054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower bool is_pure() const override { return total_counts_.size() <= 1; } 3941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 3951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClassificationAddLeftExample(int split, int64 int_label, 3961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float weight) override { 3971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_[split][int_label] += weight; 3981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 3991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClassificationAddTotalExample(int64 int_label, float weight) override { 4001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_counts_[int_label] += weight; 4011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4031588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float GiniScore(int split, float* left_sum, float* right_sum) const override; 4041588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4051588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float left_count(int split, int class_num) const override { 4061588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_counts_[split].at(class_num); 4071588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4081588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float right_count(int split, int class_num) const override { 4091588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return total_counts_.at(class_num) - left_counts_[split].at(class_num); 4101588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 4111588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4121588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private: 4131588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Total class counts seen at this leaf 4141588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::unordered_map<int, float> total_counts_; 4151588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 4161588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Left-branch taken class counts at this leaf for each split. 4171588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // left_counts_[i][j] has the j-th class count for split i. 4181588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<std::unordered_map<int, float>> left_counts_; 4191588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 4201588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 421054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// Accumulates weights for the most popular classes while only using a 422054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// fixed amount of space. 423054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerclass FixedSizeClassStats { 424054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower public: 425054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // n specifies how many classes are tracked. 426054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower FixedSizeClassStats(int n, int num_classes) 427054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower : n_(n), num_classes_(num_classes), smallest_weight_class_(-1) {} 428054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 429054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Add weight w to the class c. 430054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void accumulate(int c, float w); 431054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 432054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Return the approximate accumulated weight for class c. If c isn't one 433054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // of the n-most popular classes, this can be 0 even if c has accumulated 434054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // some weight. 435054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float get_weight(int c) const; 436054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 437054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // Put the sum of all weights seen into *sum, and 438054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // \sum_c get_weight(c)^2 439054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // into *square. *sum will be exact, but *square will be approximate. 440054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void set_sum_and_square(float* sum, float* square) const; 441054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 442054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void ExtractFromProto(const decision_trees::SparseVector& sparse_vector); 443054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void PackToProto(decision_trees::SparseVector* sparse_vector) const; 444054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 445054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower private: 446054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // For our typical use cases, n_ is between 10 and 100, so there's no 447054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // need to track the smallest weight with a min_heap or the like. 448054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int n_; 449054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int num_classes_; 450054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 451054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // This tracks the class of the smallest weight, but isn't set until 452054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // class_weights_.size() == n_. 453054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int smallest_weight_class_; 454054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 455054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower std::unordered_map<int, float> class_weights_; 456054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}; 457054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 458054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower// Tracks classification stats sparsely in a fixed amount of space. 459054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlowerclass FixedSizeSparseClassificationGrowStats : public ClassificationStats { 460054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower public: 461054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower FixedSizeSparseClassificationGrowStats(const TensorForestParams& params, 462054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower int32 depth) 463054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower : ClassificationStats(params, depth) {} 464054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 465054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void Initialize() override { Clear(); } 466054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 467054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void ExtractFromProto(const FertileSlot& slot) override; 468054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void PackToProto(FertileSlot* slot) const override; 469054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 470054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void InitLeafClassStats(int best_split_index, LeafStat* left_stats, 471d100729c309cb22baf1630d9f39cf60516c58cdfDaniel Trebbien LeafStat* right_stats) const override; 472054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 473054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower protected: 474054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void ClassificationAddSplitStats() override { 475054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower FixedSizeClassStats stats(params_.num_classes_to_track(), 476054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower params_.num_outputs()); 477054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_.resize(num_splits(), stats); 478054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_.resize(num_splits(), stats); 479054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 480054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void ClassificationRemoveSplitStats(int split_num) override { 481054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_.erase(left_counts_.begin() + split_num, 482054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_.begin() + (split_num + 1)); 483054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_.erase(right_counts_.begin() + split_num, 484054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_.begin() + (split_num + 1)); 485054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 486054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void ClearInternal() override { 487054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_.clear(); 488054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_.clear(); 489054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 490054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 491054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower bool is_pure() const override { return first_two_classes_seen_.size() <= 1; } 492054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 493054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void ClassificationAddLeftExample(int split, int64 int_label, 494054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float weight) override { 495054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower left_counts_[split].accumulate(int_label, weight); 496054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 497054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void ClassificationAddRightExample(int split, int64 int_label, 498054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float weight) override { 499054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower right_counts_[split].accumulate(int_label, weight); 500054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 501054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower void ClassificationAddTotalExample(int64 int_label, float weight) override { 502054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower if (is_pure()) { 503054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower first_two_classes_seen_.insert(int_label); 504054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 505054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 506054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 507054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float GiniScore(int split, float* left_sum, float* right_sum) const override; 508054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 509054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float left_count(int split, int class_num) const override { 510054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return left_counts_[split].get_weight(class_num); 511054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 512054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 513054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower float right_count(int split, int class_num) const override { 514054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower return right_counts_[split].get_weight(class_num); 515054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower } 516054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 517054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower private: 518054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower std::vector<FixedSizeClassStats> left_counts_; 519054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower std::vector<FixedSizeClassStats> right_counts_; 520054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 521054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // We keep track of the first two class labels seen, so we can tell if 522054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower // the node is pure (= all of one class) or not. 523054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower std::set<int> first_two_classes_seen_; 524054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower}; 525054b88233bf6d6bc5b953fca50dbb01d108b2d18A. Unique TensorFlower 5261588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower// Tracks regression stats using least-squares minimization. 5271588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlowerclass LeastSquaresRegressionGrowStats : public GrowStats { 5281588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower public: 5291588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower LeastSquaresRegressionGrowStats(const TensorForestParams& params, int32 depth) 5301588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower : GrowStats(params, depth) {} 5311588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5321588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void Initialize() override { 5331588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower Clear(); 5341588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_.resize(num_outputs_); 5351588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_squares_.resize(num_outputs_); 5361588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5371588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5381588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ExtractFromProto(const FertileSlot& slot) override; 5391588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void PackToProto(FertileSlot* slot) const override; 5401588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5411588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void AddExample(const std::unique_ptr<TensorDataSet>& input_data, 5421588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower const InputTarget* target, int example) override; 5431588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower bool BestSplit(SplitCandidate* best) const override; 5441588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower bool IsFinished() const override; 5451588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5461588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower protected: 5471588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Returns the variance of split. 5481588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower float SplitVariance(int split) const; 5491588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 55075f03e2d509d016021f8508555f9ab96af2c7cfeA. Unique TensorFlower void AddSplitStats(const InputTarget* target, int example) override { 5511588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_sums_.resize(num_outputs_ * num_splits()); 5521588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_squares_.resize(num_outputs_ * num_splits()); 5531588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.push_back(0); 5541588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5551588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void RemoveSplitStats(int split_num) override { 5561588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_sums_.erase(left_sums_.begin() + num_outputs_ * split_num, 5574463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower left_sums_.begin() + num_outputs_ * (split_num + 1)); 5581588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_squares_.erase(left_squares_.begin() + num_outputs_ * split_num, 5594463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlower left_squares_.begin() + num_outputs_ * (split_num + 1)); 5601588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.erase(left_counts_.begin() + split_num, 5611588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_counts_.begin() + (split_num + 1)); 5621588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5631588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5641588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower void ClearInternal() override { 5651588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_.clear(); 5661588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower total_sum_squares_.clear(); 5671588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_sums_.clear(); 5681588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower left_squares_.clear(); 5691588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5701588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5711588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower private: 5721588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Convenience methods for accessing the flat count vectors. 5731588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower inline const float& left_sum(int split, int output_num) const { 5741588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_sums_[split * num_outputs_ + output_num]; 5751588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5761588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower inline float& left_sum(int split, int output_num) { 5771588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_sums_[split * num_outputs_ + output_num]; 5781588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5791588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower inline const float& left_square(int split, int output_num) const { 5801588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_squares_[split * num_outputs_ + output_num]; 5811588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5821588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower inline float& left_square(int split, int output_num) { 5831588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower return left_squares_[split * num_outputs_ + output_num]; 5841588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower } 5851588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5861588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Total sums and squares seen at this leaf. 5871588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // sum[i] is the sum of the i-th output. 5881588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> total_sum_; 5891588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> total_sum_squares_; 5901588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5911588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // Per-split sums and squares, stored flat for performance. 5921588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // left_sums_[i * num_outputs_ + j] has the j-th sum for split i. 5931588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> left_sums_; 5941588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<float> left_squares_; 5951588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 5961588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower // The number of example seen at each split. 5971588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower std::vector<int64> left_counts_; 5981588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower}; 5991588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 6001588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorforest 6011588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower} // namespace tensorflow 6021588d3790d759ee8a50e2bcf17b1eeadca69ba30A. Unique TensorFlower 603f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ 604