12bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
22bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower//
32bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License");
42bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// you may not use this file except in compliance with the License.
52bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// You may obtain a copy of the License at
62bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower//
72bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower//     http://www.apache.org/licenses/LICENSE-2.0
82bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower//
92bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software
102bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS,
112bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
122bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// See the License for the specific language governing permissions and
132bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// limitations under the License.
142bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// =============================================================================
152bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
162bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
172bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <sys/types.h>
182bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <algorithm>
192bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <cstdlib>
202bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <functional>
212bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <iterator>
222bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <unordered_set>
232bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <utility>
242bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
252bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"  // NOLINT
262bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include "tensorflow/core/lib/core/status_test_util.h"
27f92eef788dff0e629cb4c408ce4b00530f152d4fA. Unique TensorFlower#include "tensorflow/core/platform/env.h"
282bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
294463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlowerusing std::unordered_set;
302bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerusing tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
312bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerusing tensorflow::boosted_trees::trees::DecisionTreeEnsembleConfig;
322bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
332bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace tensorflow {
342bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace boosted_trees {
352bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace utils {
362bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace {
372bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
382bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerconst uint32 kSeed = 123;
392bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerconst int32 kNumTrees = 1000;
402bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
412bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerclass DropoutUtilsTest : public ::testing::Test {
422bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower public:
432bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  void SetUp() override {
442bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    // Fill an weights.
452bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    for (int i = 0; i < kNumTrees; ++i) {
462bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      weights_.push_back(1.1 + 0.4 * i);
472bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    }
482bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
492bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
502bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower protected:
512bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  std::vector<float> weights_;
522bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower};
532bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
542bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerTEST_F(DropoutUtilsTest, DropoutProbabilityTest) {
552bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  std::vector<int32> dropped_trees;
562bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  std::vector<float> original_weights;
579de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  std::unordered_set<int32> trees_not_to_drop;
582bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
592bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // Do not drop any trees
602bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
612bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
622bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_dropout_probability(0.0);
632bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_learning_rate(1.0);
642bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
659de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
669de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees,
679de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights));
682bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
692bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    // Nothing changed
702bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_TRUE(dropped_trees.empty());
712bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_TRUE(original_weights.empty());
722bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
732bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // Drop out all trees
742bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
752bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
762bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_dropout_probability(1.0);
772bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_learning_rate(1.0);
782bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
799de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
809de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees,
819de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights));
822bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
832bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    // No trees left
842bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_EQ(kNumTrees, dropped_trees.size());
852bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_EQ(kNumTrees, original_weights.size());
862bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_EQ(original_weights, weights_);
872bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
882bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // 50% probability of dropping a tree
892bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
902bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    const int32 kNumRuns = 1000;
912bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
922bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_dropout_probability(0.5);
932bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_learning_rate(1.0);
942bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
952bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    int32 total_num_trees = 0;
962bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    for (int i = 0; i < kNumRuns; ++i) {
972bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      // draw random seeds
98f92eef788dff0e629cb4c408ce4b00530f152d4fA. Unique TensorFlower      uint random_generator_seed =
99f92eef788dff0e629cb4c408ce4b00530f152d4fA. Unique TensorFlower          static_cast<uint>(Env::Default()->NowMicros());
1002bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      uint32 seed = rand_r(&random_generator_seed) % 100 + i;
1019de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower      TF_EXPECT_OK(DropoutUtils::DropOutTrees(seed, config, trees_not_to_drop,
1029de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                              weights_, &dropped_trees,
1039de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                              &original_weights));
1042bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
1052bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      // We would expect 400-600 trees left
1062bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      EXPECT_NEAR(500, kNumTrees - dropped_trees.size(), 100);
1072bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      total_num_trees += kNumTrees - dropped_trees.size();
1082bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
1092bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      // Trees dropped are unique
1102bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      unordered_set<int32> ids;
1112bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      for (const auto& tree : dropped_trees) {
1122bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower        ids.insert(tree);
1132bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      }
1142bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      EXPECT_EQ(ids.size(), dropped_trees.size());
1152bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    }
1162bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_NEAR(500, total_num_trees / kNumRuns, 5);
1172bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
1182bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}
1192bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
1209de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlowerTEST_F(DropoutUtilsTest, DropoutIgnoresNotToDropTest) {
1219de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  std::vector<int32> dropped_trees;
1229de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  std::vector<float> original_weights;
1239de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1249de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  // Empty do not drop set.
1259de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  {
1269de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::unordered_set<int32> trees_not_to_drop;
1279de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1289de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
1299de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    config.set_dropout_probability(1.0);
1309de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    config.set_learning_rate(1.0);
1319de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1329de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
1339de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees,
1349de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights));
1359de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1369de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    // No trees left
1379de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(kNumTrees, dropped_trees.size());
1389de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(kNumTrees, original_weights.size());
1399de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(original_weights, weights_);
1409de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  }
1419de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1429de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  // Do not drop any trees
1439de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  {
1449de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::unordered_set<int32> trees_not_to_drop;
1459de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    for (int i = 0; i < kNumTrees; ++i) {
1469de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower      trees_not_to_drop.insert(i);
1479de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    }
1489de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1499de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
1509de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    config.set_dropout_probability(1.0);
1519de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    config.set_learning_rate(1.0);
1529de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1539de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
1549de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees,
1559de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights));
1569de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1579de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    // No trees were dropped - they all were in do not drop set.
1589de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(0, dropped_trees.size());
1599de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(0, original_weights.size());
1609de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  }
1619de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  // Do not drop some trees
1629de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  {
1639de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::unordered_set<int32> trees_not_to_drop;
1649de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    trees_not_to_drop.insert(0);
1659de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    trees_not_to_drop.insert(34);
1669de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1679de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
1689de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    config.set_dropout_probability(1.0);
1699de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    config.set_learning_rate(1.0);
1709de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1719de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
1729de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees,
1739de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights));
1749de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1759de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    // No trees were dropped - they all were in do not drop set.
1769de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(kNumTrees - 2, dropped_trees.size());
1779de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(kNumTrees - 2, original_weights.size());
1789de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_TRUE(std::find(dropped_trees.begin(), dropped_trees.end(), 0) ==
1799de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                dropped_trees.end());
1809de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_TRUE(std::find(dropped_trees.begin(), dropped_trees.end(), 34) ==
1819de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                dropped_trees.end());
1829de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  }
1839de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower}
1849de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
1852bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerTEST_F(DropoutUtilsTest, DropoutSeedTest) {
1869de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  std::unordered_set<int32> trees_not_to_drop;
1872bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // Different seeds remove different trees
1882bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
1892bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
1902bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_dropout_probability(0.5);
1912bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_learning_rate(1.0);
1922bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
1932bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<int32> dropped_trees_1;
1942bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> original_weights_1;
1952bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<int32> dropped_trees_2;
1962bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> original_weights_2;
1972bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
1982bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    DecisionTreeEnsembleConfig new_ensemble_1;
1992bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    DecisionTreeEnsembleConfig new_ensemble_2;
2002bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2012bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    TF_EXPECT_OK(DropoutUtils::DropOutTrees(
2029de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        kSeed + 1, config, trees_not_to_drop, weights_, &dropped_trees_1,
2039de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        &original_weights_1));
2042bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    TF_EXPECT_OK(DropoutUtils::DropOutTrees(
2059de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        kSeed + 2, config, trees_not_to_drop, weights_, &dropped_trees_2,
2069de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        &original_weights_2));
2072bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2082bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_FALSE(dropped_trees_1 == dropped_trees_2);
2092bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_FALSE(original_weights_1 == original_weights_2);
2102bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
2112bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  //  The same seed produces the same result
2122bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
2132bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
2142bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_dropout_probability(0.5);
2152bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_learning_rate(1.0);
2162bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2172bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<int32> dropped_trees_1;
2182bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> original_weights_1;
2192bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<int32> dropped_trees_2;
2202bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> original_weights_2;
2212bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2222bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    DecisionTreeEnsembleConfig new_ensemble_1;
2232bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    DecisionTreeEnsembleConfig new_ensemble_2;
2242bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2259de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
2269de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees_1,
2279de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights_1));
2289de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
2299de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees_2,
2309de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights_2));
2312bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2322bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_TRUE(dropped_trees_1 == dropped_trees_2);
2332bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_TRUE(original_weights_1 == original_weights_2);
2342bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
2352bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}
2362bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2372bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerTEST_F(DropoutUtilsTest, InvalidConfigTest) {
2382bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  std::vector<int32> dropped_trees;
2392bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  std::vector<float> original_weights;
2409de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  std::unordered_set<int32> trees_not_to_drop;
2412bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // Negative prob
2422bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
2432bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
2442bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_dropout_probability(-1.34);
2452bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2469de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
2479de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees,
2489de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights)
2492bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower                     .ok());
2502bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
2512bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // Larger than 1 prob of dropping a tree.
2522bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
2532bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
2542bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_dropout_probability(1.34);
2552bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2569de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
2579de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees,
2589de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights)
2592bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower                     .ok());
2602bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
2612bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // Negative probability of skipping dropout.
2622bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
2632bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
2642bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_dropout_probability(0.5);
2652bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_probability_of_skipping_dropout(-10);
2662bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2672bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    DecisionTreeEnsembleConfig new_ensemble;
2689de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
2699de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees,
2709de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights)
2712bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower                     .ok());
2722bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
2732bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // Larger than 1 probability of skipping dropout.
2742bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
2752bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    LearningRateDropoutDrivenConfig config;
2762bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_dropout_probability(0.5);
2772bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    config.set_probability_of_skipping_dropout(1.2);
2782bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2792bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    DecisionTreeEnsembleConfig new_ensemble;
2809de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop,
2819de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            weights_, &dropped_trees,
2829de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower                                            &original_weights)
2832bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower                     .ok());
2842bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
2852bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}
2862bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace {
2872bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2882bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowervoid ExpectVecsEquiv(const std::vector<float>& vec1,
2892bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower                     const std::vector<float>& vec2) {
2902bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  EXPECT_EQ(vec1.size(), vec2.size());
2912bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  for (int i = 0; i < vec1.size(); ++i) {
2922bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_NEAR(vec1[i], vec2[i], 1e-3);
2932bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
2942bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}
2952bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
2962bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerstd::vector<float> GetWeightsByIndex(const std::vector<float>& weights,
2972bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower                                     const std::vector<int>& indices) {
2982bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  std::vector<float> res;
299eb10a4c494d95e7c17ddc44ef35197d08f2f6b33A. Unique TensorFlower  res.reserve(indices.size());
3002bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  for (const int index : indices) {
3012bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    res.push_back(weights[index]);
3022bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
3032bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  return res;
3042bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}
3052bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3062bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowervoid MergeLastElements(const int32 last_n, std::vector<float>* weights) {
3072bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  float sum = 0.0;
3082bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  for (int i = 0; i < last_n; ++i) {
3092bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    sum += weights->back();
3102bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    weights->pop_back();
3112bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
3122bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  weights->push_back(sum);
3132bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}
3142bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3152bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}  // namespace
3162bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3172bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerTEST_F(DropoutUtilsTest, GetTreesWeightsForAddingTreesTest) {
3182bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // Adding trees should give the same res in any order
3192bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
3202bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> weights = {1.0, 1.0, 1.0, 1.0, 1.0};
3212bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<int32> dropped_1 = {0, 3};
3222bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3232bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<int32> dropped_2 = {0};
3242bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3252bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> res_1;
3262bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> res_2;
3272bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    // Do one order
3282bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    {
3292bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      std::vector<float> current_weights = weights;
3302bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      std::vector<int32> num_updates =
3312bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower          std::vector<int32>(current_weights.size(), 1);
3322bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      DropoutUtils::GetTreesWeightsForAddingTrees(
3339de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          dropped_1, GetWeightsByIndex(current_weights, dropped_1),
3349de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          current_weights.size(), 1, &current_weights, &num_updates);
3352bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      DropoutUtils::GetTreesWeightsForAddingTrees(
3369de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          dropped_2, GetWeightsByIndex(current_weights, dropped_2),
3379de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          current_weights.size(), 1, &current_weights, &num_updates);
3382bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      res_1 = current_weights;
3392bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    }
3402bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    // Do another order
3412bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    {
3422bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      std::vector<float> current_weights = weights;
3432bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      std::vector<int32> num_updates =
3442bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower          std::vector<int32>(current_weights.size(), 1);
3452bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3462bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      DropoutUtils::GetTreesWeightsForAddingTrees(
3479de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          dropped_2, GetWeightsByIndex(current_weights, dropped_2),
3489de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          current_weights.size(), 1, &current_weights, &num_updates);
3492bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      DropoutUtils::GetTreesWeightsForAddingTrees(
3509de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          dropped_1, GetWeightsByIndex(current_weights, dropped_1),
3519de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          current_weights.size(), 1, &current_weights, &num_updates);
3522bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      res_2 = current_weights;
3532bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    }
3542bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    // The vectors are the same, but the last two elements have the same sum.
3552bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_EQ(res_1.size(), 7);
3562bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_EQ(res_2.size(), 7);
3572bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3582bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    MergeLastElements(2, &res_1);
3592bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    MergeLastElements(2, &res_2);
3602bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3612bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_EQ(res_1, res_2);
3622bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
3632bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  // Now when the weights are not all 1s
3642bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  {
3652bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> weights = {1.1, 2.1, 3.1, 4.1, 5.1};
3662bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<int32> dropped_1 = {0, 3};
3672bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3682bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<int32> dropped_2 = {0};
3692bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
3702bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> res_1;
3712bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    std::vector<float> res_2;
3722bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    // Do one order
3732bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    {
3742bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      std::vector<float> current_weights = weights;
3752bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      std::vector<int32> num_updates =
3762bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower          std::vector<int32>(current_weights.size(), 1);
3772bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      DropoutUtils::GetTreesWeightsForAddingTrees(
3789de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          dropped_1, GetWeightsByIndex(current_weights, dropped_1),
3799de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          current_weights.size(), 1, &current_weights, &num_updates);
3802bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      DropoutUtils::GetTreesWeightsForAddingTrees(
3819de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          dropped_2, GetWeightsByIndex(current_weights, dropped_2),
3829de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          current_weights.size(), 1, &current_weights, &num_updates);
3832bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      res_1 = current_weights;
3842bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    }
3852bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    // Do another order
3862bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    {
3872bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      std::vector<float> current_weights = weights;
3882bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      std::vector<int32> num_updates =
3892bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower          std::vector<int32>(current_weights.size(), 1);
3902bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      DropoutUtils::GetTreesWeightsForAddingTrees(
3919de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          dropped_2, GetWeightsByIndex(current_weights, dropped_2),
3929de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          current_weights.size(), 1, &current_weights, &num_updates);
3932bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      DropoutUtils::GetTreesWeightsForAddingTrees(
3949de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          dropped_1, GetWeightsByIndex(current_weights, dropped_1),
3959de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower          current_weights.size(), 1, &current_weights, &num_updates);
3962bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower      res_2 = current_weights;
3972bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    }
3982bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_EQ(res_1.size(), 7);
3992bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    EXPECT_EQ(res_2.size(), 7);
4002bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
4012bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    // The vectors are the same, but the last two elements have the same sum.
4022bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    MergeLastElements(2, &res_1);
4032bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    MergeLastElements(2, &res_2);
4042bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
4052bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower    ExpectVecsEquiv(res_1, res_2);
4062bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower  }
4072bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}
4082bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower
4099de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlowerTEST_F(DropoutUtilsTest, GetTreesWeightsForAddingTreesIndexTest) {
4109de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  std::vector<float> weights = {1.0, 1.0, 1.0, 1.0, 1.0};
4119de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  std::vector<int32> dropped = {0, 3};
4129de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
4139de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  std::vector<float> res;
4149de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  std::vector<float> res_2;
4159de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
4169de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  // The tree that is added does not yet have an entry in weights vector.
4179de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  {
4189de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::vector<float> current_weights = weights;
4199de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::vector<int32> num_updates =
4209de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        std::vector<int32>(current_weights.size(), 1);
4219de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    DropoutUtils::GetTreesWeightsForAddingTrees(
4229de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        dropped, GetWeightsByIndex(current_weights, dropped),
4239de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        current_weights.size(), 1, &current_weights, &num_updates);
4249de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(current_weights.size(), weights.size() + 1);
4259de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(num_updates.size(), weights.size() + 1);
4269de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
4279de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::vector<int32> expected_num_updates = {2, 1, 1, 2, 1, 1};
4289de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::vector<float> expected_weights = {2.0 / 3, 1, 1, 2.0 / 3, 1, 2.0 / 3};
4299de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(expected_weights, current_weights);
4309de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(expected_num_updates, num_updates);
4319de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  }
4329de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  // The tree that is added has already an entry in weights and updates (batch
4339de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  // mode).
4349de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  {
4359de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::vector<float> current_weights = weights;
4369de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::vector<int32> num_updates =
4379de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        std::vector<int32>(current_weights.size(), 1);
4389de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    DropoutUtils::GetTreesWeightsForAddingTrees(
4399de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        dropped, GetWeightsByIndex(current_weights, dropped),
4409de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower        current_weights.size() - 1, 1, &current_weights, &num_updates);
4419de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(current_weights.size(), weights.size());
4429de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(num_updates.size(), weights.size());
4439de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
4449de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::vector<int32> expected_num_updates = {2, 1, 1, 2, 2};
4459de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    std::vector<float> expected_weights = {2.0 / 3, 1, 1, 2.0 / 3, 2.0 / 3};
4469de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(expected_weights, current_weights);
4479de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower    EXPECT_EQ(expected_num_updates, num_updates);
4489de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower  }
4499de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower}
4509de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower
4512bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}  // namespace
4522bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}  // namespace utils
4532bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}  // namespace boosted_trees
4542bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}  // namespace tensorflow
455