1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" 16 17#include "tensorflow/core/lib/random/philox_random.h" 18#include "tensorflow/core/lib/random/simple_philox.h" 19#include "tensorflow/core/platform/test.h" 20#include "tensorflow/core/platform/test_benchmark.h" 21 22namespace tensorflow { 23namespace { 24 25using Buffer = 26 boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>; 27using BufferEntry = 28 boosted_trees::quantiles::WeightedQuantilesBuffer<double, 29 double>::BufferEntry; 30 31class WeightedQuantilesBufferTest : public ::testing::Test {}; 32 33TEST_F(WeightedQuantilesBufferTest, Invalid) { 34 EXPECT_DEATH( 35 ({ 36 boosted_trees::quantiles::WeightedQuantilesBuffer<double, double> 37 buffer(2, 0); 38 }), 39 "Invalid buffer specification"); 40 EXPECT_DEATH( 41 ({ 42 boosted_trees::quantiles::WeightedQuantilesBuffer<double, double> 43 buffer(0, 2); 44 }), 45 "Invalid buffer specification"); 46} 47 48TEST_F(WeightedQuantilesBufferTest, PushEntryNotFull) { 49 Buffer buffer(20, 100); 50 buffer.PushEntry(5, 9); 51 buffer.PushEntry(2, 3); 52 buffer.PushEntry(-1, 7); 53 buffer.PushEntry(3, 0); // This entry will be ignored. 54 55 EXPECT_FALSE(buffer.IsFull()); 56 EXPECT_EQ(buffer.Size(), 3); 57} 58 59TEST_F(WeightedQuantilesBufferTest, PushEntryFull) { 60 // buffer capacity is 4. 61 Buffer buffer(2, 100); 62 buffer.PushEntry(5, 9); 63 buffer.PushEntry(2, 3); 64 buffer.PushEntry(-1, 7); 65 buffer.PushEntry(2, 1); 66 67 std::vector<BufferEntry> expected; 68 expected.emplace_back(-1, 7); 69 expected.emplace_back(2, 4); 70 expected.emplace_back(5, 9); 71 72 // At this point, we have pushed 4 entries and we expect the buffer to be 73 // full. 74 EXPECT_TRUE(buffer.IsFull()); 75 EXPECT_EQ(buffer.GenerateEntryList(), expected); 76 EXPECT_FALSE(buffer.IsFull()); 77} 78 79TEST_F(WeightedQuantilesBufferTest, PushEntryFullDeath) { 80 // buffer capacity is 4. 81 Buffer buffer(2, 100); 82 buffer.PushEntry(5, 9); 83 buffer.PushEntry(2, 3); 84 buffer.PushEntry(-1, 7); 85 buffer.PushEntry(2, 1); 86 87 std::vector<BufferEntry> expected; 88 expected.emplace_back(-1, 7); 89 expected.emplace_back(2, 4); 90 expected.emplace_back(5, 9); 91 92 // At this point, we have pushed 4 entries and we expect the buffer to be 93 // full. 94 EXPECT_TRUE(buffer.IsFull()); 95 // Can't push any more entries before clearing. 96 EXPECT_DEATH(({ buffer.PushEntry(6, 6); }), "Buffer already full"); 97} 98 99} // namespace 100} // namespace tensorflow 101