1// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" 16 17#include <cfloat> 18#include <queue> 19#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 20#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" 21#include "tensorflow/core/lib/random/distribution_sampler.h" 22 23namespace tensorflow { 24namespace tensorforest { 25 26// When creating evaluators for the split candidates, use these 27// for the left and right return values. 28static const int32 LEFT_INDEX = 0; 29static const int32 RIGHT_INDEX = 1; 30 31GrowStats::GrowStats(const TensorForestParams& params, int32 depth) 32 : weight_sum_(0), 33 depth_(depth), 34 params_(params), 35 split_after_samples_(ResolveParam(params.split_after_samples(), depth)), 36 num_splits_to_consider_( 37 ResolveParam(params.num_splits_to_consider(), depth)), 38 num_outputs_(params.num_outputs()) {} 39 40void GrowStats::AddSplit(const decision_trees::BinaryNode& split, 41 const std::unique_ptr<TensorDataSet>& input_data, 42 const InputTarget* target, int example) { 43 // It's possible that the split collection calls AddSplit, but we actually 44 // have all the splits we need and are just waiting for them to be fully 45 // initialized. 46 if (splits_.size() < num_splits_to_consider_) { 47 splits_.push_back(split); 48 evaluators_.emplace_back( 49 CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX)); 50 AddSplitStats(target, example); 51 } 52 53 if (input_data != nullptr && target != nullptr && 54 params_.initialize_average_splits()) { 55 AdditionalInitializationExample(input_data, target, example); 56 } 57} 58 59void GrowStats::RemoveSplit(int split_num) { 60 splits_.erase(splits_.begin() + split_num); 61 evaluators_.erase(evaluators_.begin() + split_num); 62 RemoveSplitStats(split_num); 63} 64 65// ------------------------ Classification --------------------------- // 66 67ClassificationStats::ClassificationStats(const TensorForestParams& params, 68 int32 depth) 69 : GrowStats(params, depth), finish_early_(false) { 70 // Early splitting params. 71 if (params.finish_type().type() == SPLIT_FINISH_BASIC) { 72 min_split_samples_ = split_after_samples_; 73 finish_sample_epoch_ = 1; 74 finish_check_every_ = split_after_samples_ * 2; 75 } else { 76 if (!params.has_dominate_fraction() || !params.has_min_split_samples()) { 77 LOG(FATAL) << "dominate_fraction and min_split_samples " 78 << "required for early-finish strategy."; 79 } else { 80 min_split_samples_ = ResolveParam(params.min_split_samples(), depth); 81 finish_check_every_ = 82 ResolveParam(params.finish_type().check_every_steps(), depth); 83 finish_sample_epoch_ = min_split_samples_ / finish_check_every_; 84 85 dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_); 86 if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) { 87 LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_; 88 } 89 } 90 } 91 92 // Pruning params. 93 if (params.pruning_type().type() != SPLIT_PRUNE_NONE) { 94 prune_check_every_ = 95 ResolveParam(params.pruning_type().prune_every_samples(), depth); 96 prune_sample_epoch_ = 1; 97 prune_fraction_ = 0.0; 98 switch (params_.pruning_type().type()) { 99 case SPLIT_PRUNE_HALF: 100 prune_fraction_ = 0.5; 101 break; 102 case SPLIT_PRUNE_QUARTER: 103 prune_fraction_ = 0.25; 104 break; 105 case SPLIT_PRUNE_10_PERCENT: 106 prune_fraction_ = 0.10; 107 break; 108 case SPLIT_PRUNE_HOEFFDING: 109 dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_); 110 half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_)); 111 break; 112 default: 113 LOG(WARNING) << "Unknown pruning type"; 114 } 115 } else { 116 prune_check_every_ = split_after_samples_ * 2; 117 prune_sample_epoch_ = 1; 118 } 119 120 if (params.use_running_stats_method()) { 121 left_gini_.reset(new RunningGiniScores()); 122 right_gini_.reset(new RunningGiniScores()); 123 } 124 125 uint64 time_seed = static_cast<uint64>(std::clock()); 126 single_rand_ = std::unique_ptr<random::PhiloxRandom>( 127 new random::PhiloxRandom(time_seed)); 128 rng_ = std::unique_ptr<random::SimplePhilox>( 129 new random::SimplePhilox(single_rand_.get())); 130} 131 132void ClassificationStats::AdditionalInitializationExample( 133 const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, 134 int example) { 135 const int32 new_target = target->GetTargetAsClassIndex(example, 0); 136 std::unordered_set<int> to_erase; 137 for (auto it = half_initialized_splits_.begin(); 138 it != half_initialized_splits_.end(); ++it) { 139 if (it->second != new_target) { 140 auto& split = splits_[it->first]; 141 if (split.has_inequality_left_child_test()) { 142 auto& test = split.inequality_left_child_test(); 143 auto* thresh = 144 split.mutable_inequality_left_child_test()->mutable_threshold(); 145 if (test.has_feature_id()) { 146 const float val = 147 input_data->GetExampleValue(example, test.feature_id()); 148 thresh->set_float_value((thresh->float_value() + val) / 2); 149 } 150 } 151 to_erase.insert(it->first); 152 } 153 } 154 155 for (const int split_id : to_erase) { 156 half_initialized_splits_.erase(split_id); 157 } 158} 159 160bool ClassificationStats::IsFinished() const { 161 bool basic = (weight_sum_ >= split_after_samples_) && !is_pure(); 162 return basic || finish_early_; 163} 164 165float ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum, 166 float* right_sum) const { 167 if (left_gini_ == nullptr) { 168 return GiniScore(split, left_sum, right_sum); 169 } else { 170 *left_sum = left_gini_->sum(split); 171 const float left = WeightedSmoothedGini( 172 *left_sum, left_gini_->square(split), num_outputs_); 173 174 *right_sum = right_gini_->sum(split); 175 const float right = WeightedSmoothedGini( 176 *right_sum, right_gini_->square(split), num_outputs_); 177 178 return left + right; 179 } 180} 181 182void ClassificationStats::AddExample( 183 const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, 184 int example) { 185 const int64 int_label = target->GetTargetAsClassIndex(example, 0); 186 const float weight = target->GetTargetWeight(example); 187 188 for (int i = 0; i < num_splits(); ++i) { 189 auto& eval = evaluators_[i]; 190 if (eval->Decide(input_data, example) == LEFT_INDEX) { 191 if (left_gini_ != nullptr) { 192 left_gini_->update(i, left_count(i, int_label), weight); 193 } 194 ClassificationAddLeftExample(i, int_label, weight); 195 } else { 196 if (right_gini_ != nullptr) { 197 right_gini_->update(i, right_count(i, int_label), weight); 198 } 199 ClassificationAddRightExample(i, int_label, weight); 200 } 201 } 202 203 ClassificationAddTotalExample(int_label, weight); 204 205 weight_sum_ += weight; 206 207 CheckFinishEarly(); 208 CheckPrune(); 209} 210 211void ClassificationStats::CheckPrune() { 212 if (params_.pruning_type().type() == SPLIT_PRUNE_NONE || IsFinished() || 213 weight_sum_ < prune_sample_epoch_ * prune_check_every_) { 214 return; 215 } 216 ++prune_sample_epoch_; 217 218 if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) { 219 CheckPruneHoeffding(); 220 return; 221 } 222 223 const int to_remove = num_splits() * prune_fraction_; 224 if (to_remove <= 0) { 225 return; 226 } 227 228 // pair ordering is first-then-second by default, no need for custom 229 // comparison. Use std::greater to make it a min-heap. 230 std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>, 231 std::greater<std::pair<float, int>>> 232 worst; 233 234 // Track indices that are in the heap so we can iterate over them 235 // by largest-first later. 236 std::set<int> indices; 237 238 for (int i = 0; i < num_splits(); ++i) { 239 float left, right; 240 const float split_score = MaybeCachedGiniScore(i, &left, &right); 241 if (worst.size() < to_remove) { 242 worst.push(std::pair<float, int>(split_score, i)); 243 indices.insert(i); 244 } else if (worst.top().first < split_score) { 245 indices.erase(worst.top().second); 246 worst.pop(); 247 worst.push(std::pair<float, int>(split_score, i)); 248 indices.insert(i); 249 } 250 } 251 252 // traverse indices from the back so that they are removed correctly. 253 for (auto it = indices.rbegin(); it != indices.rend(); ++it) { 254 RemoveSplit(*it); 255 } 256} 257 258void ClassificationStats::CheckPruneHoeffding() { 259 std::vector<float> split_scores(num_splits()); 260 // Find best split score 261 float best_split_score = FLT_MAX; 262 for (int i = 0; i < num_splits(); ++i) { 263 float left, right; 264 split_scores[i] = MaybeCachedGiniScore(i, &left, &right); 265 if (split_scores[i] < best_split_score) { 266 best_split_score = split_scores[i]; 267 } 268 } 269 270 // We apply the Hoeffding bound to the difference between the best split 271 // score and the i-th split score. 272 // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted. 273 const float num_classes = params_.num_outputs(); 274 const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes); 275 float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_); 276 for (int i = num_splits() - 1; i >= 0; i--) { 277 if (split_scores[i] - best_split_score > epsilon) { 278 RemoveSplit(i); 279 } 280 } 281} 282 283void ClassificationStats::CheckFinishEarly() { 284 if (weight_sum_ < min_split_samples_ || 285 weight_sum_ < finish_sample_epoch_ * finish_check_every_) { 286 return; 287 } 288 ++finish_sample_epoch_; 289 290 if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) { 291 CheckFinishEarlyHoeffding(); 292 } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) { 293 CheckFinishEarlyBootstrap(); 294 } 295} 296 297void ClassificationStats::CheckFinishEarlyHoeffding() { 298 // Each term in the Gini impurity can range from 0 to 0.5 * 0.5. 299 float range = 0.25 * static_cast<float>(params_.num_outputs()) * weight_sum_; 300 301 float hoeffding_bound = 302 range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_)); 303 304 float unused_left_sum, unused_right_sum; 305 std::function<float(int)> score_fn = 306 std::bind(&ClassificationStats::MaybeCachedGiniScore, this, 307 std::placeholders::_1, &unused_left_sum, &unused_right_sum); 308 309 float best_score; 310 int32 best_index; 311 float second_best_score; 312 int32 second_best_index; 313 GetTwoBest(num_splits(), score_fn, &best_score, &best_index, 314 &second_best_score, &second_best_index); 315 316 finish_early_ = (second_best_score - best_score) > hoeffding_bound; 317} 318 319void ClassificationStats::MakeBootstrapWeights(int index, 320 std::vector<float>* weights) { 321 int n = weight_sum_; 322 float denom = static_cast<float>(n) + static_cast<float>(num_outputs_); 323 for (int i = 0; i < num_outputs_; ++i) { 324 // Use the Laplace smoothed per-class probabilities when generating the 325 // bootstrap samples. 326 (*weights)[i] = (left_count(index, i) + 1.0) / denom; 327 (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom; 328 } 329} 330 331int ClassificationStats::NumBootstrapSamples() const { 332 float p = 1.0 - dominate_fraction_; 333 int bootstrap_samples = 1; 334 while (p < 1.0) { 335 ++bootstrap_samples; 336 p = p * 2; 337 } 338 return bootstrap_samples; 339} 340 341void ClassificationStats::CheckFinishEarlyBootstrap() { 342 float unused_left_sum, unused_right_sum; 343 std::function<float(int)> score_fn = 344 std::bind(&ClassificationStats::MaybeCachedGiniScore, this, 345 std::placeholders::_1, &unused_left_sum, &unused_right_sum); 346 347 float best_score; 348 int32 best_index; 349 float second_best_score; 350 int32 second_best_index; 351 GetTwoBest(num_splits(), score_fn, &best_score, &best_index, 352 &second_best_score, &second_best_index); 353 354 std::vector<float> weights1(num_outputs_ * 2); 355 MakeBootstrapWeights(best_index, &weights1); 356 random::DistributionSampler ds1(weights1); 357 358 std::vector<float> weights2(num_outputs_ * 2); 359 MakeBootstrapWeights(second_best_index, &weights2); 360 random::DistributionSampler ds2(weights2); 361 362 const int bootstrap_samples = NumBootstrapSamples(); 363 364 int worst_g1 = 0; 365 for (int i = 0; i < bootstrap_samples; i++) { 366 int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get()); 367 worst_g1 = std::max(worst_g1, g1); 368 } 369 370 int best_g2 = 99; 371 for (int i = 0; i < bootstrap_samples; i++) { 372 int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get()); 373 best_g2 = std::min(best_g2, g2); 374 } 375 376 finish_early_ = worst_g1 < best_g2; 377} 378 379bool ClassificationStats::BestSplit(SplitCandidate* best) const { 380 float min_score = FLT_MAX; 381 int best_index = -1; 382 float best_left_sum, best_right_sum; 383 384 // Calculate sums. 385 for (int i = 0; i < num_splits(); ++i) { 386 float left_sum, right_sum; 387 const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum); 388 // Find the lowest gini. 389 if (left_sum > 0 && right_sum > 0 && 390 split_score < min_score) { // useless check 391 min_score = split_score; 392 best_index = i; 393 best_left_sum = left_sum; 394 best_right_sum = right_sum; 395 } 396 } 397 398 // This could happen if all the splits are useless. 399 if (best_index < 0) { 400 return false; 401 } 402 403 // Fill in stats to be used for leaf model. 404 *best->mutable_split() = splits_[best_index]; 405 auto* left = best->mutable_left_stats(); 406 left->set_weight_sum(best_left_sum); 407 auto* right = best->mutable_right_stats(); 408 right->set_weight_sum(best_right_sum); 409 InitLeafClassStats(best_index, left, right); 410 411 return true; 412} 413 414// ------------------------ Dense Classification --------------------------- // 415void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { 416 Initialize(); 417 if (!slot.has_post_init_leaf_stats()) { 418 return; 419 } 420 const int32 num_classes = params_.num_outputs(); 421 weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 422 const auto& class_stats = 423 slot.post_init_leaf_stats().classification().dense_counts(); 424 425 // Total counts. 426 for (int i = 0; i < num_classes; ++i) { 427 total_counts_[i] = class_stats.value(i).float_value(); 428 num_outputs_seen_ += total_counts_[i] != 0; 429 } 430 431 // Candidate counts and splits. 432 int split_num = 0; 433 for (const auto& cand : slot.candidates()) { 434 AddSplit(cand.split(), nullptr, nullptr, -1); 435 const auto& left_stats = cand.left_stats().classification().dense_counts(); 436 for (int i = 0; i < num_classes; ++i) { 437 const float val = left_stats.value(i).float_value(); 438 mutable_left_count(split_num, i) = val; 439 MaybeInitializeRunningCount(split_num, val); 440 } 441 ++split_num; 442 } 443} 444 445void DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const { 446 auto* slot_stats = slot->mutable_post_init_leaf_stats(); 447 slot_stats->set_weight_sum(weight_sum_); 448 449 auto* class_stats = slot->mutable_post_init_leaf_stats() 450 ->mutable_classification() 451 ->mutable_dense_counts(); 452 for (int i = 0; i < num_outputs_; ++i) { 453 class_stats->add_value()->set_float_value(total_counts_[i]); 454 } 455 456 for (int split_num = 0; split_num < num_splits(); ++split_num) { 457 auto* cand = slot->add_candidates(); 458 *cand->mutable_split() = splits_[split_num]; 459 auto* left_stats = cand->mutable_left_stats() 460 ->mutable_classification() 461 ->mutable_dense_counts(); 462 for (int i = 0; i < num_outputs_; ++i) { 463 left_stats->add_value()->set_float_value(left_count(split_num, i)); 464 } 465 } 466} 467 468float DenseClassificationGrowStats::GiniScore(int split, float* left_sum, 469 float* right_sum) const { 470 float left_square = 0, right_square = 0; 471 *left_sum = 0; 472 *right_sum = 0; 473 for (int j = 0; j < num_outputs_; ++j) { 474 const float left = left_count(split, j); 475 *left_sum += left; 476 left_square += left * left; 477 const float right = right_count(split, j); 478 *right_sum += right; 479 right_square += right * right; 480 } 481 482 const float left_score = 483 WeightedSmoothedGini(*left_sum, left_square, num_outputs_); 484 const float right_score = 485 WeightedSmoothedGini(*right_sum, right_square, num_outputs_); 486 return left_score + right_score; 487} 488 489void DenseClassificationGrowStats::InitLeafClassStats( 490 int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const { 491 auto* left_class_stats = left_stats->mutable_classification(); 492 auto* left_counts = left_class_stats->mutable_dense_counts(); 493 for (int i = 0; i < params_.num_outputs(); ++i) { 494 left_counts->add_value()->set_float_value(left_count(best_split_index, i)); 495 } 496 497 auto* right_class_stats = right_stats->mutable_classification(); 498 auto* right_counts = right_class_stats->mutable_dense_counts(); 499 for (int i = 0; i < params_.num_outputs(); ++i) { 500 right_counts->add_value()->set_float_value(total_counts_[i] - 501 left_count(best_split_index, i)); 502 } 503} 504 505// ------------------------ Sparse Classification --------------------------- // 506void SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { 507 Initialize(); 508 if (!slot.has_post_init_leaf_stats()) { 509 return; 510 } 511 weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 512 const auto& class_stats = 513 slot.post_init_leaf_stats().classification().sparse_counts(); 514 515 // Total counts. 516 for (auto const& entry : class_stats.sparse_value()) { 517 total_counts_[entry.first] = entry.second.float_value(); 518 } 519 520 // Candidate counts and splits. 521 int split_num = 0; 522 for (const auto& cand : slot.candidates()) { 523 AddSplit(cand.split(), nullptr, nullptr, -1); 524 const auto& left_stats = cand.left_stats().classification().sparse_counts(); 525 for (auto const& entry : left_stats.sparse_value()) { 526 const float val = entry.second.float_value(); 527 left_counts_[split_num][entry.first] = val; 528 MaybeInitializeRunningCount(split_num, val); 529 } 530 ++split_num; 531 } 532} 533 534void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const { 535 auto* slot_stats = slot->mutable_post_init_leaf_stats(); 536 slot_stats->set_weight_sum(weight_sum_); 537 538 auto* class_stats = slot->mutable_post_init_leaf_stats() 539 ->mutable_classification() 540 ->mutable_sparse_counts() 541 ->mutable_sparse_value(); 542 for (const auto& entry : total_counts_) { 543 decision_trees::Value val; 544 val.set_float_value(entry.second); 545 (*class_stats)[entry.first] = val; 546 } 547 548 for (int split_num = 0; split_num < num_splits(); ++split_num) { 549 auto* cand = slot->add_candidates(); 550 *cand->mutable_split() = splits_[split_num]; 551 auto* left_stats = cand->mutable_left_stats() 552 ->mutable_classification() 553 ->mutable_sparse_counts() 554 ->mutable_sparse_value(); 555 for (const auto& entry : left_counts_[split_num]) { 556 decision_trees::Value val; 557 val.set_float_value(entry.second); 558 (*left_stats)[entry.first] = val; 559 } 560 } 561} 562 563float SparseClassificationGrowStats::GiniScore(int split, float* left_sum, 564 float* right_sum) const { 565 float left_square = 0, right_square = 0; 566 *left_sum = 0; 567 *right_sum = 0; 568 for (const auto& entry : total_counts_) { 569 const int label = entry.first; 570 float left = 0; 571 float right = 0; 572 auto it = left_counts_[split].find(label); 573 if (it == left_counts_[split].end()) { 574 right = entry.second; 575 } else { 576 left = it->second; 577 right = entry.second - it->second; 578 } 579 *left_sum += left; 580 left_square += left * left; 581 *right_sum += right; 582 right_square += right * right; 583 } 584 const int32 num_classes = params_.num_outputs(); 585 const float left_score = 586 WeightedSmoothedGini(*left_sum, left_square, num_classes); 587 const float right_score = 588 WeightedSmoothedGini(*right_sum, right_square, num_classes); 589 return left_score + right_score; 590} 591 592void SparseClassificationGrowStats::InitLeafClassStats( 593 int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const { 594 auto* left_class_stats = left_stats->mutable_classification(); 595 auto* left_counts = 596 left_class_stats->mutable_sparse_counts()->mutable_sparse_value(); 597 auto* right_class_stats = right_stats->mutable_classification(); 598 auto* right_counts = 599 right_class_stats->mutable_sparse_counts()->mutable_sparse_value(); 600 601 for (const auto& entry : total_counts_) { 602 auto it = left_counts_[best_split_index].find(entry.first); 603 if (it == left_counts_[best_split_index].end()) { 604 (*right_counts)[entry.first].set_float_value(entry.second); 605 } else { 606 const float left = it->second; 607 const float right = entry.second - it->second; 608 (*left_counts)[entry.first].set_float_value(left); 609 if (right > 0) { 610 (*right_counts)[entry.first].set_float_value(right); 611 } 612 } 613 } 614} 615 616// -------------------- FixedSizeClassStats --------------------------------- // 617 618// FixedSizeClassStats implements the "SpaceSaving" algorithm by 619// Ahmed Metwally, Divyakant Agrawal and Amr El Abbadi. See for example 620// https://pdfs.semanticscholar.org/72f1/5aba2e67b1cc9cd1fb12c99e101c4c1aae4b.pdf 621 622int argmin(const std::unordered_map<int, float>& m) { 623 int c = -1; 624 float f = FLT_MAX; 625 for (const auto it : m) { 626 if (it.second < f) { 627 f = it.second; 628 c = it.first; 629 } 630 } 631 return c; 632} 633 634void FixedSizeClassStats::accumulate(int c, float w) { 635 auto it = class_weights_.find(c); 636 if (it != class_weights_.end()) { 637 it->second += w; 638 if (c == smallest_weight_class_) { 639 smallest_weight_class_ = argmin(class_weights_); 640 } 641 return; 642 } 643 644 if (class_weights_.size() < n_) { 645 class_weights_.insert(it, std::pair<int, float>(c, w)); 646 if (class_weights_.size() == n_) { 647 // Can't assume last added has the smallest weight, because the 648 // w's might be all different. 649 smallest_weight_class_ = argmin(class_weights_); 650 } 651 return; 652 } 653 654 // This is the slightly unintuitive heart of the SpaceSaving algorithm: 655 // if the map is full and we see a new class, we find the entry with the 656 // smallest weight and "take it over": we add our weight to its weight, 657 // and assign it all to the new seen class. 658 it = class_weights_.find(smallest_weight_class_); 659 float new_weight = it->second + w; 660 class_weights_.erase(it); 661 class_weights_[c] = new_weight; 662 smallest_weight_class_ = argmin(class_weights_); 663} 664 665float FixedSizeClassStats::get_weight(int c) const { 666 // Every entry in class_weights_ might be overstated by as much as the 667 // smallest_weight. We therefore assume that each has been overstated 668 // by smallest_weight / 2.0, and we re-distribute that mass over all 669 // num_classes_ classes. 670 float smallest_weight = 0.0; 671 auto it = class_weights_.find(smallest_weight_class_); 672 if (it != class_weights_.end()) { 673 smallest_weight = it->second; 674 } 675 float w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_); 676 it = class_weights_.find(c); 677 if (it != class_weights_.end()) { 678 w += it->second - smallest_weight / 2.0; 679 } 680 return w; 681} 682 683void FixedSizeClassStats::set_sum_and_square(float* sum, float* square) const { 684 *sum = 0.0; 685 *square = 0.0; 686 687 float smallest_weight = 0.0; 688 auto it = class_weights_.find(smallest_weight_class_); 689 if (it != class_weights_.end()) { 690 smallest_weight = it->second; 691 } 692 693 float w; 694 for (const auto it : class_weights_) { 695 *sum += it.second; 696 w = get_weight(it.first); 697 *square += w * w; 698 } 699 700 w = (smallest_weight / 2.0) * n_ / static_cast<float>(num_classes_); 701 *square += (num_classes_ - n_) * w * w; 702} 703 704void FixedSizeClassStats::ExtractFromProto( 705 const decision_trees::SparseVector& sparse_vector) { 706 for (const auto& it : sparse_vector.sparse_value()) { 707 class_weights_[it.first] = it.second.float_value(); 708 } 709 if (class_weights_.size() == n_) { 710 smallest_weight_class_ = argmin(class_weights_); 711 } 712} 713 714void FixedSizeClassStats::PackToProto( 715 decision_trees::SparseVector* sparse_vector) const { 716 for (const auto it : class_weights_) { 717 (*sparse_vector->mutable_sparse_value())[it.first].set_float_value( 718 it.second); 719 } 720} 721 722// --------------------- FixedSizeSparseClassificationGrowStats ------------- // 723 724void FixedSizeSparseClassificationGrowStats::ExtractFromProto( 725 const FertileSlot& slot) { 726 Initialize(); 727 if (!slot.has_post_init_leaf_stats()) { 728 return; 729 } 730 weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 731 732 // Candidate counts and splits. 733 int split_num = 0; 734 left_counts_.clear(); 735 right_counts_.clear(); 736 for (const auto& cand : slot.candidates()) { 737 AddSplit(cand.split(), nullptr, nullptr, -1); 738 const auto& left_stats = cand.left_stats().classification().sparse_counts(); 739 left_counts_.emplace_back(params_.num_classes_to_track(), 740 params_.num_outputs()); 741 left_counts_[split_num].ExtractFromProto(left_stats); 742 const auto& right_stats = 743 cand.right_stats().classification().sparse_counts(); 744 right_counts_.emplace_back(params_.num_classes_to_track(), 745 params_.num_outputs()); 746 right_counts_[split_num].ExtractFromProto(right_stats); 747 ++split_num; 748 } 749} 750 751void FixedSizeSparseClassificationGrowStats::PackToProto( 752 FertileSlot* slot) const { 753 auto* slot_stats = slot->mutable_post_init_leaf_stats(); 754 slot_stats->set_weight_sum(weight_sum_); 755 756 for (int split_num = 0; split_num < num_splits(); ++split_num) { 757 auto* cand = slot->add_candidates(); 758 *cand->mutable_split() = splits_[split_num]; 759 auto* left_stats = cand->mutable_left_stats() 760 ->mutable_classification() 761 ->mutable_sparse_counts(); 762 left_counts_[split_num].PackToProto(left_stats); 763 auto* right_stats = cand->mutable_right_stats() 764 ->mutable_classification() 765 ->mutable_sparse_counts(); 766 right_counts_[split_num].PackToProto(right_stats); 767 } 768} 769 770float FixedSizeSparseClassificationGrowStats::GiniScore( 771 int split, float* left_sum, float* right_sum) const { 772 float left_square, right_square; 773 left_counts_[split].set_sum_and_square(left_sum, &left_square); 774 right_counts_[split].set_sum_and_square(right_sum, &right_square); 775 const int32 num_classes = params_.num_outputs(); 776 const float left_score = 777 WeightedSmoothedGini(*left_sum, left_square, num_classes); 778 const float right_score = 779 WeightedSmoothedGini(*right_sum, right_square, num_classes); 780 return left_score + right_score; 781} 782 783void FixedSizeSparseClassificationGrowStats::InitLeafClassStats( 784 int best_split_index, LeafStat* left_stats, LeafStat* right_stats) const { 785 auto* left_class_stats = left_stats->mutable_classification(); 786 auto* left_counts = left_class_stats->mutable_sparse_counts(); 787 left_counts_[best_split_index].PackToProto(left_counts); 788 789 auto* right_class_stats = right_stats->mutable_classification(); 790 auto* right_counts = right_class_stats->mutable_sparse_counts(); 791 right_counts_[best_split_index].PackToProto(right_counts); 792} 793 794// --------------------- Least Squares Regression --------------------------- // 795void LeastSquaresRegressionGrowStats::ExtractFromProto( 796 const FertileSlot& slot) { 797 const int32 num_outputs = params_.num_outputs(); 798 Initialize(); 799 if (!slot.has_post_init_leaf_stats()) { 800 return; 801 } 802 weight_sum_ = slot.post_init_leaf_stats().weight_sum(); 803 const auto& total_sums = 804 slot.post_init_leaf_stats().regression().mean_output(); 805 const auto& total_squares = 806 slot.post_init_leaf_stats().regression().mean_output_squares(); 807 808 // Total counts. 809 for (int i = 0; i < num_outputs; ++i) { 810 total_sum_[i] = total_sums.value(i).float_value(); 811 total_sum_squares_[i] = total_squares.value(i).float_value(); 812 } 813 814 // Candidate counts and splits. 815 int split_num = 0; 816 for (const auto& cand : slot.candidates()) { 817 AddSplit(cand.split(), nullptr, nullptr, -1); 818 const auto& sums = cand.left_stats().regression().mean_output(); 819 const auto& squares = cand.left_stats().regression().mean_output_squares(); 820 for (int i = 0; i < num_outputs; ++i) { 821 left_sum(split_num, i) = sums.value(i).float_value(); 822 left_square(split_num, i) = squares.value(i).float_value(); 823 } 824 left_counts_[split_num] = cand.left_stats().weight_sum(); 825 ++split_num; 826 } 827} 828 829void LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const { 830 const int32 num_outputs = params_.num_outputs(); 831 auto* slot_stats = slot->mutable_post_init_leaf_stats(); 832 slot_stats->set_weight_sum(weight_sum_); 833 834 auto* total_sums = slot->mutable_post_init_leaf_stats() 835 ->mutable_regression() 836 ->mutable_mean_output(); 837 auto* total_squares = slot->mutable_post_init_leaf_stats() 838 ->mutable_regression() 839 ->mutable_mean_output_squares(); 840 841 for (int i = 0; i < total_sum_.size(); ++i) { 842 total_sums->add_value()->set_float_value(total_sum_[i]); 843 total_squares->add_value()->set_float_value(total_sum_squares_[i]); 844 } 845 846 for (int split_num = 0; split_num < num_splits(); ++split_num) { 847 auto* cand = slot->add_candidates(); 848 *cand->mutable_split() = splits_[split_num]; 849 auto* sums = 850 cand->mutable_left_stats()->mutable_regression()->mutable_mean_output(); 851 auto* squares = cand->mutable_left_stats() 852 ->mutable_regression() 853 ->mutable_mean_output_squares(); 854 for (int i = 0; i < num_outputs; ++i) { 855 sums->add_value()->set_float_value(left_sum(split_num, i)); 856 squares->add_value()->set_float_value(left_square(split_num, i)); 857 } 858 cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]); 859 } 860} 861 862void LeastSquaresRegressionGrowStats::AddExample( 863 const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target, 864 int example) { 865 const int32 num_outputs = params_.num_outputs(); 866 // Update splits. 867 for (int i = 0; i < num_splits(); ++i) { 868 auto& eval = evaluators_[i]; 869 if (eval->Decide(input_data, example) == LEFT_INDEX) { 870 for (int j = 0; j < num_outputs; ++j) { 871 const float output = target->GetTargetAsContinuous(example, j); 872 left_sum(i, j) += output; 873 left_square(i, j) += output * output; 874 } 875 ++left_counts_[i]; 876 } 877 } 878 879 // Update totals. 880 for (int i = 0; i < num_outputs; ++i) { 881 const float output = target->GetTargetAsContinuous(example, i); 882 total_sum_[i] += output; 883 total_sum_squares_[i] += output * output; 884 } 885 weight_sum_ += 1.0; 886} 887 888float LeastSquaresRegressionGrowStats::SplitVariance(int split) const { 889 float total_variance = 0; 890 for (int i = 0; i < params_.num_outputs(); ++i) { 891 // Left side 892 const float le_x = left_sum(split, i) / left_counts_[split]; 893 894 const float le_x2 = left_square(split, i) / left_counts_[split]; 895 total_variance += le_x2 - le_x * le_x; 896 897 // Right side 898 const float re_x = (total_sum_[i] - left_sum(split, i)) / 899 (weight_sum_ - left_counts_[split]); 900 901 const float re_x2 = (total_sum_squares_[i] - left_square(split, i)) / 902 (weight_sum_ - left_counts_[split]); 903 total_variance += re_x2 - re_x * re_x; 904 } 905 return total_variance; 906} 907 908bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const { 909 float min_score = FLT_MAX; 910 int best_index = -1; 911 const int32 num_outputs = params_.num_outputs(); 912 for (int i = 0; i < num_splits(); ++i) { 913 if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) { 914 const float split_score = SplitVariance(i); 915 if (split_score < min_score) { 916 min_score = split_score; 917 best_index = i; 918 } 919 } 920 } 921 922 // This could happen if all the splits are useless. 923 if (best_index < 0) { 924 return false; 925 } 926 927 // Fill in right stats to be used for leaf model. 928 *best->mutable_split() = splits_[best_index]; 929 // Left 930 auto* left = best->mutable_left_stats(); 931 auto* left_reg_stats = left->mutable_regression(); 932 left->set_weight_sum(left_counts_[best_index]); 933 auto* left_output_sum = left_reg_stats->mutable_mean_output(); 934 for (int i = 0; i < num_outputs; ++i) { 935 left_output_sum->add_value()->set_float_value(left_sum(best_index, i)); 936 } 937 938 // Right 939 auto* right = best->mutable_right_stats(); 940 auto* right_reg_stats = right->mutable_regression(); 941 right->set_weight_sum(weight_sum_ - left_counts_[best_index]); 942 auto* right_output_sum = right_reg_stats->mutable_mean_output(); 943 for (int i = 0; i < num_outputs; ++i) { 944 right_output_sum->add_value()->set_float_value(total_sum_[i] - 945 left_sum(best_index, i)); 946 } 947 return true; 948} 949 950bool LeastSquaresRegressionGrowStats::IsFinished() const { 951 return weight_sum_ >= split_after_samples_; 952} 953 954} // namespace tensorforest 955} // namespace tensorflow 956