1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h" 16 17#include "tensorflow/core/lib/random/philox_random.h" 18#include "tensorflow/core/lib/random/simple_philox.h" 19#include "tensorflow/core/platform/test.h" 20#include "tensorflow/core/platform/test_benchmark.h" 21 22namespace tensorflow { 23namespace { 24 25using Buffer = boosted_trees::quantiles::WeightedQuantilesBuffer<float, float>; 26using BufferEntry = 27 boosted_trees::quantiles::WeightedQuantilesBuffer<float, 28 float>::BufferEntry; 29using Summary = 30 boosted_trees::quantiles::WeightedQuantilesSummary<float, float>; 31using SummaryEntry = 32 boosted_trees::quantiles::WeightedQuantilesSummary<float, 33 float>::SummaryEntry; 34 35class WeightedQuantilesSummaryTest : public ::testing::Test { 36 protected: 37 void SetUp() override { 38 // Constructs a buffer of 10 weighted unique entries. 39 buffer1_.reset(new Buffer(10, 1000)); 40 buffer1_->PushEntry(5, 9); 41 buffer1_->PushEntry(2, 3); 42 buffer1_->PushEntry(-1, 7); 43 buffer1_->PushEntry(-7, 1); 44 buffer1_->PushEntry(3, 2); 45 buffer1_->PushEntry(-2, 3); 46 buffer1_->PushEntry(21, 8); 47 buffer1_->PushEntry(-13, 4); 48 buffer1_->PushEntry(8, 2); 49 buffer1_->PushEntry(-5, 6); 50 51 // Constructs a buffer of 7 weighted unique entries. 52 buffer2_.reset(new Buffer(7, 1000)); 53 buffer2_->PushEntry(9, 2); 54 buffer2_->PushEntry(-7, 3); 55 buffer2_->PushEntry(2, 1); 56 buffer2_->PushEntry(4, 13); 57 buffer2_->PushEntry(0, 5); 58 buffer2_->PushEntry(-5, 3); 59 buffer2_->PushEntry(11, 3); 60 } 61 62 void TearDown() override { buffer1_->Clear(); } 63 64 std::unique_ptr<Buffer> buffer1_; 65 std::unique_ptr<Buffer> buffer2_; 66 const double buffer1_min_value_ = -13; 67 const double buffer1_max_value_ = 21; 68 const double buffer1_total_weight_ = 45; 69 const double buffer2_min_value_ = -7; 70 const double buffer2_max_value_ = 11; 71 const double buffer2_total_weight_ = 30; 72}; 73 74TEST_F(WeightedQuantilesSummaryTest, BuildFromBuffer) { 75 Summary summary; 76 summary.BuildFromBufferEntries(buffer1_->GenerateEntryList()); 77 78 // We expect no approximation error because no compress operation occurred. 79 EXPECT_EQ(summary.ApproximationError(), 0); 80 81 // Check first and last elements in the summary. 82 const auto& entries = summary.GetEntryList(); 83 // First element's rmin should be zero. 84 EXPECT_EQ(summary.MinValue(), buffer1_min_value_); 85 EXPECT_EQ(entries.front(), SummaryEntry(-13, 4, 0, 4)); 86 // Last element's rmax should be cumulative weight. 87 EXPECT_EQ(summary.MaxValue(), buffer1_max_value_); 88 EXPECT_EQ(entries.back(), SummaryEntry(21, 8, 37, 45)); 89 // Check total weight. 90 EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_); 91} 92 93TEST_F(WeightedQuantilesSummaryTest, CompressSeparately) { 94 const auto entry_list = buffer1_->GenerateEntryList(); 95 for (int new_size = 9; new_size >= 2; --new_size) { 96 Summary summary; 97 summary.BuildFromBufferEntries(entry_list); 98 summary.Compress(new_size); 99 100 // Expect a max approximation error of 1 / n 101 // ie. eps0 + 1/n but eps0 = 0. 102 EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2); 103 EXPECT_LE(summary.ApproximationError(), 1.0 / new_size); 104 105 // Min/Max elements and total weight should not change. 106 EXPECT_EQ(summary.MinValue(), buffer1_min_value_); 107 EXPECT_EQ(summary.MaxValue(), buffer1_max_value_); 108 EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_); 109 } 110} 111 112TEST_F(WeightedQuantilesSummaryTest, CompressSequentially) { 113 Summary summary; 114 summary.BuildFromBufferEntries(buffer1_->GenerateEntryList()); 115 for (int new_size = 9; new_size >= 2; new_size -= 2) { 116 double prev_eps = summary.ApproximationError(); 117 summary.Compress(new_size); 118 119 // Expect a max approximation error of prev_eps + 1 / n. 120 EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2); 121 EXPECT_LE(summary.ApproximationError(), prev_eps + 1.0 / new_size); 122 123 // Min/Max elements and total weight should not change. 124 EXPECT_EQ(summary.MinValue(), buffer1_min_value_); 125 EXPECT_EQ(summary.MaxValue(), buffer1_max_value_); 126 EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_); 127 } 128} 129 130TEST_F(WeightedQuantilesSummaryTest, CompressRandomized) { 131 // Check multiple size compressions and ensure approximation bounds 132 // are always respected. 133 int prev_size = 1; 134 int size = 2; 135 float max_value = 1 << 20; 136 while (size < (1 << 16)) { 137 // Create buffer of size from uniform random elements. 138 Buffer buffer(size, size << 4); 139 random::PhiloxRandom philox(13); 140 random::SimplePhilox rand(&philox); 141 for (int i = 0; i < size; ++i) { 142 buffer.PushEntry(rand.RandFloat() * max_value, 143 rand.RandFloat() * max_value); 144 } 145 146 // Create summary and compress. 147 Summary summary; 148 summary.BuildFromBufferEntries(buffer.GenerateEntryList()); 149 int new_size = std::max(rand.Uniform(size), 2u); 150 summary.Compress(new_size); 151 152 // Ensure approximation error is acceptable. 153 EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2); 154 EXPECT_LE(summary.ApproximationError(), 1.0 / new_size); 155 156 // Update size to next fib number. 157 size_t last_size = size; 158 size += prev_size; 159 prev_size = last_size; 160 } 161} 162 163TEST_F(WeightedQuantilesSummaryTest, MergeSymmetry) { 164 // Create two separate summaries and merge. 165 const auto list_1 = buffer1_->GenerateEntryList(); 166 const auto list_2 = buffer2_->GenerateEntryList(); 167 Summary summary1; 168 summary1.BuildFromBufferEntries(list_1); 169 Summary summary2; 170 summary2.BuildFromBufferEntries(list_2); 171 172 // Merge summary 2 into 1 and verify. 173 summary1.Merge(summary2); 174 EXPECT_EQ(summary1.ApproximationError(), 0.0); 175 EXPECT_EQ(summary1.MinValue(), 176 std::min(buffer1_min_value_, buffer2_min_value_)); 177 EXPECT_EQ(summary1.MaxValue(), 178 std::max(buffer1_max_value_, buffer2_max_value_)); 179 EXPECT_EQ(summary1.TotalWeight(), 180 buffer1_total_weight_ + buffer2_total_weight_); 181 EXPECT_EQ(summary1.Size(), 14); // 14 unique values. 182 183 // Merge summary 1 into 2 and verify same result. 184 summary1.BuildFromBufferEntries(list_1); 185 summary2.Merge(summary1); 186 EXPECT_EQ(summary2.ApproximationError(), 0.0); 187 EXPECT_EQ(summary2.MinValue(), 188 std::min(buffer1_min_value_, buffer2_min_value_)); 189 EXPECT_EQ(summary2.MaxValue(), 190 std::max(buffer1_max_value_, buffer2_max_value_)); 191 EXPECT_EQ(summary2.TotalWeight(), 192 buffer1_total_weight_ + buffer2_total_weight_); 193 EXPECT_EQ(summary2.Size(), 14); // 14 unique values. 194} 195 196TEST_F(WeightedQuantilesSummaryTest, CompressThenMerge) { 197 // Create two separate summaries and merge. 198 Summary summary1; 199 summary1.BuildFromBufferEntries(buffer1_->GenerateEntryList()); 200 Summary summary2; 201 summary2.BuildFromBufferEntries(buffer2_->GenerateEntryList()); 202 203 // Compress summaries. 204 summary1.Compress(5); // max error is 1/5. 205 const auto eps1 = 1.0 / 5; 206 EXPECT_LE(summary1.ApproximationError(), eps1); 207 summary2.Compress(3); // max error is 1/3. 208 const auto eps2 = 1.0 / 3; 209 EXPECT_LE(summary2.ApproximationError(), eps2); 210 211 // Merge guarantees an approximation error of max(eps1, eps2). 212 // Merge summary 2 into 1 and verify. 213 summary1.Merge(summary2); 214 EXPECT_LE(summary1.ApproximationError(), std::max(eps1, eps2)); 215 EXPECT_EQ(summary1.MinValue(), 216 std::min(buffer1_min_value_, buffer2_min_value_)); 217 EXPECT_EQ(summary1.MaxValue(), 218 std::max(buffer1_max_value_, buffer2_max_value_)); 219 EXPECT_EQ(summary1.TotalWeight(), 220 buffer1_total_weight_ + buffer2_total_weight_); 221} 222 223} // namespace 224} // namespace tensorflow 225