12bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// Copyright 2017 The TensorFlow Authors. All Rights Reserved. 22bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// 32bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License"); 42bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// you may not use this file except in compliance with the License. 52bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// You may obtain a copy of the License at 62bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// 72bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// http://www.apache.org/licenses/LICENSE-2.0 82bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// 92bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software 102bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS, 112bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 122bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// See the License for the specific language governing permissions and 132bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// limitations under the License. 142bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower// ============================================================================= 152bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h" 162bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 172bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <sys/types.h> 182bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <algorithm> 192bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <cstdlib> 202bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <functional> 212bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <iterator> 222bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <unordered_set> 232bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include <utility> 242bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 252bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT 262bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower#include "tensorflow/core/lib/core/status_test_util.h" 27f92eef788dff0e629cb4c408ce4b00530f152d4fA. Unique TensorFlower#include "tensorflow/core/platform/env.h" 282bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 294463d105a8a4a83642b9709ba79310e8f4ddf577A. Unique TensorFlowerusing std::unordered_set; 302bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerusing tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; 312bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerusing tensorflow::boosted_trees::trees::DecisionTreeEnsembleConfig; 322bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 332bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace tensorflow { 342bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace boosted_trees { 352bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace utils { 362bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace { 372bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 382bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerconst uint32 kSeed = 123; 392bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerconst int32 kNumTrees = 1000; 402bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 412bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerclass DropoutUtilsTest : public ::testing::Test { 422bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower public: 432bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower void SetUp() override { 442bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Fill an weights. 452bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower for (int i = 0; i < kNumTrees; ++i) { 462bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower weights_.push_back(1.1 + 0.4 * i); 472bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 482bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 492bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 502bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower protected: 512bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> weights_; 522bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower}; 532bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 542bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerTEST_F(DropoutUtilsTest, DropoutProbabilityTest) { 552bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_trees; 562bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> original_weights; 579de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::unordered_set<int32> trees_not_to_drop; 582bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 592bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Do not drop any trees 602bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 612bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 622bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_dropout_probability(0.0); 632bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_learning_rate(1.0); 642bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 659de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 669de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 679de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights)); 682bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 692bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Nothing changed 702bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_TRUE(dropped_trees.empty()); 712bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_TRUE(original_weights.empty()); 722bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 732bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Drop out all trees 742bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 752bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 762bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_dropout_probability(1.0); 772bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_learning_rate(1.0); 782bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 799de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 809de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 819de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights)); 822bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 832bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // No trees left 842bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(kNumTrees, dropped_trees.size()); 852bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(kNumTrees, original_weights.size()); 862bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(original_weights, weights_); 872bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 882bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // 50% probability of dropping a tree 892bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 902bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower const int32 kNumRuns = 1000; 912bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 922bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_dropout_probability(0.5); 932bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_learning_rate(1.0); 942bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 952bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower int32 total_num_trees = 0; 962bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower for (int i = 0; i < kNumRuns; ++i) { 972bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // draw random seeds 98f92eef788dff0e629cb4c408ce4b00530f152d4fA. Unique TensorFlower uint random_generator_seed = 99f92eef788dff0e629cb4c408ce4b00530f152d4fA. Unique TensorFlower static_cast<uint>(Env::Default()->NowMicros()); 1002bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower uint32 seed = rand_r(&random_generator_seed) % 100 + i; 1019de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees(seed, config, trees_not_to_drop, 1029de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 1039de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights)); 1042bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 1052bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // We would expect 400-600 trees left 1062bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_NEAR(500, kNumTrees - dropped_trees.size(), 100); 1072bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower total_num_trees += kNumTrees - dropped_trees.size(); 1082bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 1092bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Trees dropped are unique 1102bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower unordered_set<int32> ids; 1112bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower for (const auto& tree : dropped_trees) { 1122bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower ids.insert(tree); 1132bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 1142bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(ids.size(), dropped_trees.size()); 1152bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 1162bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_NEAR(500, total_num_trees / kNumRuns, 5); 1172bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 1182bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} 1192bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 1209de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlowerTEST_F(DropoutUtilsTest, DropoutIgnoresNotToDropTest) { 1219de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<int32> dropped_trees; 1229de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<float> original_weights; 1239de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1249de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower // Empty do not drop set. 1259de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower { 1269de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::unordered_set<int32> trees_not_to_drop; 1279de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1289de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 1299de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower config.set_dropout_probability(1.0); 1309de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower config.set_learning_rate(1.0); 1319de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1329de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 1339de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 1349de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights)); 1359de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1369de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower // No trees left 1379de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(kNumTrees, dropped_trees.size()); 1389de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(kNumTrees, original_weights.size()); 1399de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(original_weights, weights_); 1409de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower } 1419de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1429de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower // Do not drop any trees 1439de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower { 1449de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::unordered_set<int32> trees_not_to_drop; 1459de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower for (int i = 0; i < kNumTrees; ++i) { 1469de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower trees_not_to_drop.insert(i); 1479de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower } 1489de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1499de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 1509de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower config.set_dropout_probability(1.0); 1519de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower config.set_learning_rate(1.0); 1529de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1539de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 1549de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 1559de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights)); 1569de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1579de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower // No trees were dropped - they all were in do not drop set. 1589de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(0, dropped_trees.size()); 1599de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(0, original_weights.size()); 1609de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower } 1619de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower // Do not drop some trees 1629de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower { 1639de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::unordered_set<int32> trees_not_to_drop; 1649de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower trees_not_to_drop.insert(0); 1659de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower trees_not_to_drop.insert(34); 1669de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1679de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 1689de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower config.set_dropout_probability(1.0); 1699de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower config.set_learning_rate(1.0); 1709de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1719de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 1729de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 1739de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights)); 1749de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1759de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower // No trees were dropped - they all were in do not drop set. 1769de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(kNumTrees - 2, dropped_trees.size()); 1779de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(kNumTrees - 2, original_weights.size()); 1789de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_TRUE(std::find(dropped_trees.begin(), dropped_trees.end(), 0) == 1799de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_trees.end()); 1809de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_TRUE(std::find(dropped_trees.begin(), dropped_trees.end(), 34) == 1819de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_trees.end()); 1829de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower } 1839de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower} 1849de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 1852bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerTEST_F(DropoutUtilsTest, DropoutSeedTest) { 1869de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::unordered_set<int32> trees_not_to_drop; 1872bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Different seeds remove different trees 1882bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 1892bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 1902bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_dropout_probability(0.5); 1912bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_learning_rate(1.0); 1922bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 1932bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_trees_1; 1942bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> original_weights_1; 1952bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_trees_2; 1962bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> original_weights_2; 1972bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 1982bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DecisionTreeEnsembleConfig new_ensemble_1; 1992bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DecisionTreeEnsembleConfig new_ensemble_2; 2002bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2012bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees( 2029de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower kSeed + 1, config, trees_not_to_drop, weights_, &dropped_trees_1, 2039de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights_1)); 2042bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees( 2059de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower kSeed + 2, config, trees_not_to_drop, weights_, &dropped_trees_2, 2069de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights_2)); 2072bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2082bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_FALSE(dropped_trees_1 == dropped_trees_2); 2092bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_FALSE(original_weights_1 == original_weights_2); 2102bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 2112bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // The same seed produces the same result 2122bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 2132bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 2142bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_dropout_probability(0.5); 2152bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_learning_rate(1.0); 2162bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2172bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_trees_1; 2182bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> original_weights_1; 2192bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_trees_2; 2202bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> original_weights_2; 2212bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2222bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DecisionTreeEnsembleConfig new_ensemble_1; 2232bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DecisionTreeEnsembleConfig new_ensemble_2; 2242bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2259de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 2269de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees_1, 2279de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights_1)); 2289de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower TF_EXPECT_OK(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 2299de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees_2, 2309de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights_2)); 2312bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2322bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_TRUE(dropped_trees_1 == dropped_trees_2); 2332bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_TRUE(original_weights_1 == original_weights_2); 2342bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 2352bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} 2362bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2372bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerTEST_F(DropoutUtilsTest, InvalidConfigTest) { 2382bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_trees; 2392bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> original_weights; 2409de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::unordered_set<int32> trees_not_to_drop; 2412bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Negative prob 2422bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 2432bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 2442bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_dropout_probability(-1.34); 2452bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2469de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 2479de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 2489de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights) 2492bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower .ok()); 2502bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 2512bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Larger than 1 prob of dropping a tree. 2522bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 2532bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 2542bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_dropout_probability(1.34); 2552bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2569de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 2579de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 2589de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights) 2592bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower .ok()); 2602bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 2612bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Negative probability of skipping dropout. 2622bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 2632bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 2642bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_dropout_probability(0.5); 2652bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_probability_of_skipping_dropout(-10); 2662bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2672bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DecisionTreeEnsembleConfig new_ensemble; 2689de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 2699de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 2709de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights) 2712bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower .ok()); 2722bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 2732bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Larger than 1 probability of skipping dropout. 2742bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 2752bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower LearningRateDropoutDrivenConfig config; 2762bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_dropout_probability(0.5); 2772bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower config.set_probability_of_skipping_dropout(1.2); 2782bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2792bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DecisionTreeEnsembleConfig new_ensemble; 2809de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_FALSE(DropoutUtils::DropOutTrees(kSeed, config, trees_not_to_drop, 2819de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower weights_, &dropped_trees, 2829de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower &original_weights) 2832bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower .ok()); 2842bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 2852bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} 2862bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowernamespace { 2872bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2882bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowervoid ExpectVecsEquiv(const std::vector<float>& vec1, 2892bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower const std::vector<float>& vec2) { 2902bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(vec1.size(), vec2.size()); 2912bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower for (int i = 0; i < vec1.size(); ++i) { 2922bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_NEAR(vec1[i], vec2[i], 1e-3); 2932bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 2942bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} 2952bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 2962bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerstd::vector<float> GetWeightsByIndex(const std::vector<float>& weights, 2972bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower const std::vector<int>& indices) { 2982bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> res; 299eb10a4c494d95e7c17ddc44ef35197d08f2f6b33A. Unique TensorFlower res.reserve(indices.size()); 3002bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower for (const int index : indices) { 3012bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower res.push_back(weights[index]); 3022bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 3032bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower return res; 3042bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} 3052bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3062bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowervoid MergeLastElements(const int32 last_n, std::vector<float>* weights) { 3072bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower float sum = 0.0; 3082bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower for (int i = 0; i < last_n; ++i) { 3092bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower sum += weights->back(); 3102bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower weights->pop_back(); 3112bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 3122bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower weights->push_back(sum); 3132bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} 3142bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3152bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} // namespace 3162bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3172bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlowerTEST_F(DropoutUtilsTest, GetTreesWeightsForAddingTreesTest) { 3182bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Adding trees should give the same res in any order 3192bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 3202bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> weights = {1.0, 1.0, 1.0, 1.0, 1.0}; 3212bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_1 = {0, 3}; 3222bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3232bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_2 = {0}; 3242bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3252bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> res_1; 3262bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> res_2; 3272bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Do one order 3282bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 3292bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> current_weights = weights; 3302bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> num_updates = 3312bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32>(current_weights.size(), 1); 3322bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 3339de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_1, GetWeightsByIndex(current_weights, dropped_1), 3349de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size(), 1, ¤t_weights, &num_updates); 3352bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 3369de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_2, GetWeightsByIndex(current_weights, dropped_2), 3379de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size(), 1, ¤t_weights, &num_updates); 3382bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower res_1 = current_weights; 3392bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 3402bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Do another order 3412bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 3422bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> current_weights = weights; 3432bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> num_updates = 3442bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32>(current_weights.size(), 1); 3452bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3462bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 3479de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_2, GetWeightsByIndex(current_weights, dropped_2), 3489de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size(), 1, ¤t_weights, &num_updates); 3492bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 3509de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_1, GetWeightsByIndex(current_weights, dropped_1), 3519de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size(), 1, ¤t_weights, &num_updates); 3522bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower res_2 = current_weights; 3532bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 3542bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // The vectors are the same, but the last two elements have the same sum. 3552bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(res_1.size(), 7); 3562bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(res_2.size(), 7); 3572bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3582bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower MergeLastElements(2, &res_1); 3592bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower MergeLastElements(2, &res_2); 3602bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3612bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(res_1, res_2); 3622bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 3632bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Now when the weights are not all 1s 3642bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 3652bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> weights = {1.1, 2.1, 3.1, 4.1, 5.1}; 3662bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_1 = {0, 3}; 3672bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3682bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> dropped_2 = {0}; 3692bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 3702bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> res_1; 3712bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> res_2; 3722bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Do one order 3732bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 3742bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> current_weights = weights; 3752bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> num_updates = 3762bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32>(current_weights.size(), 1); 3772bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 3789de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_1, GetWeightsByIndex(current_weights, dropped_1), 3799de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size(), 1, ¤t_weights, &num_updates); 3802bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 3819de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_2, GetWeightsByIndex(current_weights, dropped_2), 3829de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size(), 1, ¤t_weights, &num_updates); 3832bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower res_1 = current_weights; 3842bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 3852bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // Do another order 3862bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower { 3872bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<float> current_weights = weights; 3882bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32> num_updates = 3892bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower std::vector<int32>(current_weights.size(), 1); 3902bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 3919de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_2, GetWeightsByIndex(current_weights, dropped_2), 3929de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size(), 1, ¤t_weights, &num_updates); 3932bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 3949de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped_1, GetWeightsByIndex(current_weights, dropped_1), 3959de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size(), 1, ¤t_weights, &num_updates); 3962bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower res_2 = current_weights; 3972bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 3982bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(res_1.size(), 7); 3992bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower EXPECT_EQ(res_2.size(), 7); 4002bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 4012bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower // The vectors are the same, but the last two elements have the same sum. 4022bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower MergeLastElements(2, &res_1); 4032bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower MergeLastElements(2, &res_2); 4042bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 4052bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower ExpectVecsEquiv(res_1, res_2); 4062bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower } 4072bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} 4082bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower 4099de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlowerTEST_F(DropoutUtilsTest, GetTreesWeightsForAddingTreesIndexTest) { 4109de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<float> weights = {1.0, 1.0, 1.0, 1.0, 1.0}; 4119de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<int32> dropped = {0, 3}; 4129de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 4139de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<float> res; 4149de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<float> res_2; 4159de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 4169de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower // The tree that is added does not yet have an entry in weights vector. 4179de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower { 4189de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<float> current_weights = weights; 4199de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<int32> num_updates = 4209de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<int32>(current_weights.size(), 1); 4219de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 4229de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped, GetWeightsByIndex(current_weights, dropped), 4239de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size(), 1, ¤t_weights, &num_updates); 4249de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(current_weights.size(), weights.size() + 1); 4259de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(num_updates.size(), weights.size() + 1); 4269de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 4279de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<int32> expected_num_updates = {2, 1, 1, 2, 1, 1}; 4289de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<float> expected_weights = {2.0 / 3, 1, 1, 2.0 / 3, 1, 2.0 / 3}; 4299de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(expected_weights, current_weights); 4309de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(expected_num_updates, num_updates); 4319de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower } 4329de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower // The tree that is added has already an entry in weights and updates (batch 4339de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower // mode). 4349de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower { 4359de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<float> current_weights = weights; 4369de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<int32> num_updates = 4379de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<int32>(current_weights.size(), 1); 4389de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower DropoutUtils::GetTreesWeightsForAddingTrees( 4399de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower dropped, GetWeightsByIndex(current_weights, dropped), 4409de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower current_weights.size() - 1, 1, ¤t_weights, &num_updates); 4419de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(current_weights.size(), weights.size()); 4429de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(num_updates.size(), weights.size()); 4439de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 4449de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<int32> expected_num_updates = {2, 1, 1, 2, 2}; 4459de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower std::vector<float> expected_weights = {2.0 / 3, 1, 1, 2.0 / 3, 2.0 / 3}; 4469de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(expected_weights, current_weights); 4479de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower EXPECT_EQ(expected_num_updates, num_updates); 4489de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower } 4499de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower} 4509de2794fca8d207f7054c024eba2bc54ae5d38a1A. Unique TensorFlower 4512bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} // namespace 4522bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} // namespace utils 4532bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} // namespace boosted_trees 4542bb0a625cd684f587daadae0252a78da3f14f4f9A. Unique TensorFlower} // namespace tensorflow 455