1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h"
16
17#include "tensorflow/core/lib/random/philox_random.h"
18#include "tensorflow/core/lib/random/simple_philox.h"
19#include "tensorflow/core/platform/test.h"
20#include "tensorflow/core/platform/test_benchmark.h"
21
22namespace tensorflow {
23namespace {
24
25using Buffer = boosted_trees::quantiles::WeightedQuantilesBuffer<float, float>;
26using BufferEntry =
27    boosted_trees::quantiles::WeightedQuantilesBuffer<float,
28                                                      float>::BufferEntry;
29using Summary =
30    boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
31using SummaryEntry =
32    boosted_trees::quantiles::WeightedQuantilesSummary<float,
33                                                       float>::SummaryEntry;
34
35class WeightedQuantilesSummaryTest : public ::testing::Test {
36 protected:
37  void SetUp() override {
38    // Constructs a buffer of 10 weighted unique entries.
39    buffer1_.reset(new Buffer(10, 1000));
40    buffer1_->PushEntry(5, 9);
41    buffer1_->PushEntry(2, 3);
42    buffer1_->PushEntry(-1, 7);
43    buffer1_->PushEntry(-7, 1);
44    buffer1_->PushEntry(3, 2);
45    buffer1_->PushEntry(-2, 3);
46    buffer1_->PushEntry(21, 8);
47    buffer1_->PushEntry(-13, 4);
48    buffer1_->PushEntry(8, 2);
49    buffer1_->PushEntry(-5, 6);
50
51    // Constructs a buffer of 7 weighted unique entries.
52    buffer2_.reset(new Buffer(7, 1000));
53    buffer2_->PushEntry(9, 2);
54    buffer2_->PushEntry(-7, 3);
55    buffer2_->PushEntry(2, 1);
56    buffer2_->PushEntry(4, 13);
57    buffer2_->PushEntry(0, 5);
58    buffer2_->PushEntry(-5, 3);
59    buffer2_->PushEntry(11, 3);
60  }
61
62  void TearDown() override { buffer1_->Clear(); }
63
64  std::unique_ptr<Buffer> buffer1_;
65  std::unique_ptr<Buffer> buffer2_;
66  const double buffer1_min_value_ = -13;
67  const double buffer1_max_value_ = 21;
68  const double buffer1_total_weight_ = 45;
69  const double buffer2_min_value_ = -7;
70  const double buffer2_max_value_ = 11;
71  const double buffer2_total_weight_ = 30;
72};
73
74TEST_F(WeightedQuantilesSummaryTest, BuildFromBuffer) {
75  Summary summary;
76  summary.BuildFromBufferEntries(buffer1_->GenerateEntryList());
77
78  // We expect no approximation error because no compress operation occurred.
79  EXPECT_EQ(summary.ApproximationError(), 0);
80
81  // Check first and last elements in the summary.
82  const auto& entries = summary.GetEntryList();
83  // First element's rmin should be zero.
84  EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
85  EXPECT_EQ(entries.front(), SummaryEntry(-13, 4, 0, 4));
86  // Last element's rmax should be cumulative weight.
87  EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
88  EXPECT_EQ(entries.back(), SummaryEntry(21, 8, 37, 45));
89  // Check total weight.
90  EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
91}
92
93TEST_F(WeightedQuantilesSummaryTest, CompressSeparately) {
94  const auto entry_list = buffer1_->GenerateEntryList();
95  for (int new_size = 9; new_size >= 2; --new_size) {
96    Summary summary;
97    summary.BuildFromBufferEntries(entry_list);
98    summary.Compress(new_size);
99
100    // Expect a max approximation error of 1 / n
101    // ie. eps0 + 1/n but eps0 = 0.
102    EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
103    EXPECT_LE(summary.ApproximationError(), 1.0 / new_size);
104
105    // Min/Max elements and total weight should not change.
106    EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
107    EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
108    EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
109  }
110}
111
112TEST_F(WeightedQuantilesSummaryTest, CompressSequentially) {
113  Summary summary;
114  summary.BuildFromBufferEntries(buffer1_->GenerateEntryList());
115  for (int new_size = 9; new_size >= 2; new_size -= 2) {
116    double prev_eps = summary.ApproximationError();
117    summary.Compress(new_size);
118
119    // Expect a max approximation error of prev_eps + 1 / n.
120    EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
121    EXPECT_LE(summary.ApproximationError(), prev_eps + 1.0 / new_size);
122
123    // Min/Max elements and total weight should not change.
124    EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
125    EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
126    EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
127  }
128}
129
130TEST_F(WeightedQuantilesSummaryTest, CompressRandomized) {
131  // Check multiple size compressions and ensure approximation bounds
132  // are always respected.
133  int prev_size = 1;
134  int size = 2;
135  float max_value = 1 << 20;
136  while (size < (1 << 16)) {
137    // Create buffer of size from uniform random elements.
138    Buffer buffer(size, size << 4);
139    random::PhiloxRandom philox(13);
140    random::SimplePhilox rand(&philox);
141    for (int i = 0; i < size; ++i) {
142      buffer.PushEntry(rand.RandFloat() * max_value,
143                       rand.RandFloat() * max_value);
144    }
145
146    // Create summary and compress.
147    Summary summary;
148    summary.BuildFromBufferEntries(buffer.GenerateEntryList());
149    int new_size = std::max(rand.Uniform(size), 2u);
150    summary.Compress(new_size);
151
152    // Ensure approximation error is acceptable.
153    EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
154    EXPECT_LE(summary.ApproximationError(), 1.0 / new_size);
155
156    // Update size to next fib number.
157    size_t last_size = size;
158    size += prev_size;
159    prev_size = last_size;
160  }
161}
162
163TEST_F(WeightedQuantilesSummaryTest, MergeSymmetry) {
164  // Create two separate summaries and merge.
165  const auto list_1 = buffer1_->GenerateEntryList();
166  const auto list_2 = buffer2_->GenerateEntryList();
167  Summary summary1;
168  summary1.BuildFromBufferEntries(list_1);
169  Summary summary2;
170  summary2.BuildFromBufferEntries(list_2);
171
172  // Merge summary 2 into 1 and verify.
173  summary1.Merge(summary2);
174  EXPECT_EQ(summary1.ApproximationError(), 0.0);
175  EXPECT_EQ(summary1.MinValue(),
176            std::min(buffer1_min_value_, buffer2_min_value_));
177  EXPECT_EQ(summary1.MaxValue(),
178            std::max(buffer1_max_value_, buffer2_max_value_));
179  EXPECT_EQ(summary1.TotalWeight(),
180            buffer1_total_weight_ + buffer2_total_weight_);
181  EXPECT_EQ(summary1.Size(), 14);  // 14 unique values.
182
183  // Merge summary 1 into 2 and verify same result.
184  summary1.BuildFromBufferEntries(list_1);
185  summary2.Merge(summary1);
186  EXPECT_EQ(summary2.ApproximationError(), 0.0);
187  EXPECT_EQ(summary2.MinValue(),
188            std::min(buffer1_min_value_, buffer2_min_value_));
189  EXPECT_EQ(summary2.MaxValue(),
190            std::max(buffer1_max_value_, buffer2_max_value_));
191  EXPECT_EQ(summary2.TotalWeight(),
192            buffer1_total_weight_ + buffer2_total_weight_);
193  EXPECT_EQ(summary2.Size(), 14);  // 14 unique values.
194}
195
196TEST_F(WeightedQuantilesSummaryTest, CompressThenMerge) {
197  // Create two separate summaries and merge.
198  Summary summary1;
199  summary1.BuildFromBufferEntries(buffer1_->GenerateEntryList());
200  Summary summary2;
201  summary2.BuildFromBufferEntries(buffer2_->GenerateEntryList());
202
203  // Compress summaries.
204  summary1.Compress(5);  // max error is 1/5.
205  const auto eps1 = 1.0 / 5;
206  EXPECT_LE(summary1.ApproximationError(), eps1);
207  summary2.Compress(3);  // max error is 1/3.
208  const auto eps2 = 1.0 / 3;
209  EXPECT_LE(summary2.ApproximationError(), eps2);
210
211  // Merge guarantees an approximation error of max(eps1, eps2).
212  // Merge summary 2 into 1 and verify.
213  summary1.Merge(summary2);
214  EXPECT_LE(summary1.ApproximationError(), std::max(eps1, eps2));
215  EXPECT_EQ(summary1.MinValue(),
216            std::min(buffer1_min_value_, buffer2_min_value_));
217  EXPECT_EQ(summary1.MaxValue(),
218            std::max(buffer1_max_value_, buffer2_max_value_));
219  EXPECT_EQ(summary1.TotalWeight(),
220            buffer1_total_weight_ + buffer2_total_weight_);
221}
222
223}  // namespace
224}  // namespace tensorflow
225