1150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// 3150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License"); 4150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// you may not use this file except in compliance with the License. 5150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// You may obtain a copy of the License at 6150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// 7150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// http://www.apache.org/licenses/LICENSE-2.0 8150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// 9150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software 10150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS, 11150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// See the License for the specific language governing permissions and 13150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// limitations under the License. 14150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// ============================================================================= 15f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ 16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ 17150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 18150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include <algorithm> 19150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include <unordered_map> 20150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include <vector> 21150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 22150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include "tensorflow/core/platform/logging.h" 23150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include "tensorflow/core/platform/types.h" 24150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 25150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowernamespace tensorflow { 26150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowernamespace boosted_trees { 27150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowernamespace quantiles { 28150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 29150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// Buffering container ideally suited for scenarios where we need 30150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// to sort and dedupe/compact fixed chunks of a stream of weighted elements. 31150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowertemplate <typename ValueType, typename WeightType, 32150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower typename CompareFn = std::less<ValueType>> 33150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowerclass WeightedQuantilesBuffer { 34150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower public: 35150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower struct BufferEntry { 360dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower BufferEntry(ValueType v, WeightType w) 370dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower : value(std::move(v)), weight(std::move(w)) {} 380dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower BufferEntry() : value(), weight(0) {} 39150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 40150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower bool operator<(const BufferEntry& other) const { 41150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower return kCompFn(value, other.value); 42150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower } 43150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower bool operator==(const BufferEntry& other) const { 44150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower return value == other.value && weight == other.weight; 45150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower } 46150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower friend std::ostream& operator<<(std::ostream& strm, 47150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower const BufferEntry& entry) { 48150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower return strm << "{" << entry.value << ", " << entry.weight << "}"; 49150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower } 50150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower ValueType value; 51150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower WeightType weight; 52150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower }; 53150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 54150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements) 55150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower : max_size_(std::min(block_size << 1, max_elements)) { 56150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size 57150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower << ", " << max_elements << ")"; 583e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower vec_.reserve(max_size_); 59150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower } 60150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 61150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower // Disallow copying as it's semantically non-sensical in the Squawd algorithm 62150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower // but enable move semantics. 63150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete; 64150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete; 65150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default; 66150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default; 67150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 68150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower // Push entry to buffer and maintain a compact representation within 69150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower // pre-defined size limit. 700dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower void PushEntry(ValueType value, WeightType weight) { 71150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower // Callers are expected to act on a full compacted buffer after the 72150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower // PushEntry call returns. 73150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower QCHECK(!IsFull()) << "Buffer already full: " << max_size_; 74150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 75150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower // Ignore zero and negative weight entries. 76150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower if (weight <= 0) { 77150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower return; 78150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower } 79150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 803e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower // Push back the entry to the buffer. 810dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower vec_.push_back(BufferEntry(std::move(value), std::move(weight))); 82150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower } 83150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 843e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower // Returns a sorted vector view of the base buffer and clears the buffer. 853e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower // Callers should minimize how often this is called, ideally only right after 863e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower // the buffer becomes full. 873e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower std::vector<BufferEntry> GenerateEntryList() { 88150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower std::vector<BufferEntry> ret; 893e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower if (vec_.size() == 0) { 903e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower return ret; 913e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower } 923e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower ret.swap(vec_); 933e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower vec_.reserve(max_size_); 94150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower std::sort(ret.begin(), ret.end()); 953e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower size_t num_entries = 0; 963e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower for (size_t i = 1; i < ret.size(); ++i) { 973e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower if (ret[i].value != ret[i - 1].value) { 983e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower BufferEntry tmp = ret[i]; 993e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower ++num_entries; 1003e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower ret[num_entries] = tmp; 1013e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower } else { 1023e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower ret[num_entries].weight += ret[i].weight; 1033e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower } 1043e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower } 1053e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower ret.resize(num_entries + 1); 106150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower return ret; 107150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower } 108150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 1093e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower int64 Size() const { return vec_.size(); } 1103e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower bool IsFull() const { return vec_.size() >= max_size_; } 1113e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower void Clear() { vec_.clear(); } 112150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 113150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower private: 1143e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower using BufferVector = typename std::vector<BufferEntry>; 115150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 116150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower // Comparison function. 117150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower static constexpr decltype(CompareFn()) kCompFn = CompareFn(); 118150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 119150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower // Base buffer. 120150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower size_t max_size_; 1213e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower BufferVector vec_; 122150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower}; 123150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 124150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowertemplate <typename ValueType, typename WeightType, typename CompareFn> 125150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowerconstexpr decltype(CompareFn()) 126150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn; 127150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 128150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower} // namespace quantiles 129150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower} // namespace boosted_trees 130150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower} // namespace tensorflow 131150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower 132f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ 133