1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
16
17#include "tensorflow/core/lib/random/philox_random.h"
18#include "tensorflow/core/lib/random/simple_philox.h"
19#include "tensorflow/core/platform/test.h"
20#include "tensorflow/core/platform/test_benchmark.h"
21
22namespace tensorflow {
23namespace {
24
25using Buffer =
26    boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>;
27using BufferEntry =
28    boosted_trees::quantiles::WeightedQuantilesBuffer<double,
29                                                      double>::BufferEntry;
30
31class WeightedQuantilesBufferTest : public ::testing::Test {};
32
33TEST_F(WeightedQuantilesBufferTest, Invalid) {
34  EXPECT_DEATH(
35      ({
36        boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
37            buffer(2, 0);
38      }),
39      "Invalid buffer specification");
40  EXPECT_DEATH(
41      ({
42        boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
43            buffer(0, 2);
44      }),
45      "Invalid buffer specification");
46}
47
48TEST_F(WeightedQuantilesBufferTest, PushEntryNotFull) {
49  Buffer buffer(20, 100);
50  buffer.PushEntry(5, 9);
51  buffer.PushEntry(2, 3);
52  buffer.PushEntry(-1, 7);
53  buffer.PushEntry(3, 0);  // This entry will be ignored.
54
55  EXPECT_FALSE(buffer.IsFull());
56  EXPECT_EQ(buffer.Size(), 3);
57}
58
59TEST_F(WeightedQuantilesBufferTest, PushEntryFull) {
60  // buffer capacity is 4.
61  Buffer buffer(2, 100);
62  buffer.PushEntry(5, 9);
63  buffer.PushEntry(2, 3);
64  buffer.PushEntry(-1, 7);
65  buffer.PushEntry(2, 1);
66
67  std::vector<BufferEntry> expected;
68  expected.emplace_back(-1, 7);
69  expected.emplace_back(2, 4);
70  expected.emplace_back(5, 9);
71
72  // At this point, we have pushed 4 entries and we expect the buffer to be
73  // full.
74  EXPECT_TRUE(buffer.IsFull());
75  EXPECT_EQ(buffer.GenerateEntryList(), expected);
76  EXPECT_FALSE(buffer.IsFull());
77}
78
79TEST_F(WeightedQuantilesBufferTest, PushEntryFullDeath) {
80  // buffer capacity is 4.
81  Buffer buffer(2, 100);
82  buffer.PushEntry(5, 9);
83  buffer.PushEntry(2, 3);
84  buffer.PushEntry(-1, 7);
85  buffer.PushEntry(2, 1);
86
87  std::vector<BufferEntry> expected;
88  expected.emplace_back(-1, 7);
89  expected.emplace_back(2, 4);
90  expected.emplace_back(5, 9);
91
92  // At this point, we have pushed 4 entries and we expect the buffer to be
93  // full.
94  EXPECT_TRUE(buffer.IsFull());
95  // Can't push any more entries before clearing.
96  EXPECT_DEATH(({ buffer.PushEntry(6, 6); }), "Buffer already full");
97}
98
99}  // namespace
100}  // namespace tensorflow
101