1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ 16#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ 17 18#include <algorithm> 19#include <unordered_map> 20#include <vector> 21 22#include "tensorflow/core/platform/logging.h" 23#include "tensorflow/core/platform/types.h" 24 25namespace tensorflow { 26namespace boosted_trees { 27namespace quantiles { 28 29// Buffering container ideally suited for scenarios where we need 30// to sort and dedupe/compact fixed chunks of a stream of weighted elements. 31template <typename ValueType, typename WeightType, 32 typename CompareFn = std::less<ValueType>> 33class WeightedQuantilesBuffer { 34 public: 35 struct BufferEntry { 36 BufferEntry(ValueType v, WeightType w) 37 : value(std::move(v)), weight(std::move(w)) {} 38 BufferEntry() : value(), weight(0) {} 39 40 bool operator<(const BufferEntry& other) const { 41 return kCompFn(value, other.value); 42 } 43 bool operator==(const BufferEntry& other) const { 44 return value == other.value && weight == other.weight; 45 } 46 friend std::ostream& operator<<(std::ostream& strm, 47 const BufferEntry& entry) { 48 return strm << "{" << entry.value << ", " << entry.weight << "}"; 49 } 50 ValueType value; 51 WeightType weight; 52 }; 53 54 explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements) 55 : max_size_(std::min(block_size << 1, max_elements)) { 56 QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size 57 << ", " << max_elements << ")"; 58 vec_.reserve(max_size_); 59 } 60 61 // Disallow copying as it's semantically non-sensical in the Squawd algorithm 62 // but enable move semantics. 63 WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete; 64 WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete; 65 WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default; 66 WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default; 67 68 // Push entry to buffer and maintain a compact representation within 69 // pre-defined size limit. 70 void PushEntry(ValueType value, WeightType weight) { 71 // Callers are expected to act on a full compacted buffer after the 72 // PushEntry call returns. 73 QCHECK(!IsFull()) << "Buffer already full: " << max_size_; 74 75 // Ignore zero and negative weight entries. 76 if (weight <= 0) { 77 return; 78 } 79 80 // Push back the entry to the buffer. 81 vec_.push_back(BufferEntry(std::move(value), std::move(weight))); 82 } 83 84 // Returns a sorted vector view of the base buffer and clears the buffer. 85 // Callers should minimize how often this is called, ideally only right after 86 // the buffer becomes full. 87 std::vector<BufferEntry> GenerateEntryList() { 88 std::vector<BufferEntry> ret; 89 if (vec_.size() == 0) { 90 return ret; 91 } 92 ret.swap(vec_); 93 vec_.reserve(max_size_); 94 std::sort(ret.begin(), ret.end()); 95 size_t num_entries = 0; 96 for (size_t i = 1; i < ret.size(); ++i) { 97 if (ret[i].value != ret[i - 1].value) { 98 BufferEntry tmp = ret[i]; 99 ++num_entries; 100 ret[num_entries] = tmp; 101 } else { 102 ret[num_entries].weight += ret[i].weight; 103 } 104 } 105 ret.resize(num_entries + 1); 106 return ret; 107 } 108 109 int64 Size() const { return vec_.size(); } 110 bool IsFull() const { return vec_.size() >= max_size_; } 111 void Clear() { vec_.clear(); } 112 113 private: 114 using BufferVector = typename std::vector<BufferEntry>; 115 116 // Comparison function. 117 static constexpr decltype(CompareFn()) kCompFn = CompareFn(); 118 119 // Base buffer. 120 size_t max_size_; 121 BufferVector vec_; 122}; 123 124template <typename ValueType, typename WeightType, typename CompareFn> 125constexpr decltype(CompareFn()) 126 WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn; 127 128} // namespace quantiles 129} // namespace boosted_trees 130} // namespace tensorflow 131 132#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ 133