1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ 16#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ 17 18#include <cmath> 19#include <memory> 20#include <vector> 21 22#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" 23#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h" 24#include "tensorflow/core/platform/types.h" 25 26namespace tensorflow { 27namespace boosted_trees { 28namespace quantiles { 29 30// Class to compute approximate quantiles with error bound guarantees for 31// weighted data sets. 32// This implementation is an adaptation of techniques from the following papers: 33// * (2001) Space-efficient online computation of quantile summaries. 34// * (2004) Power-conserving computation of order-statistics over 35// sensor networks. 36// * (2007) A fast algorithm for approximate quantiles in high speed 37// data streams. 38// * (2016) XGBoost: A Scalable Tree Boosting System. 39// 40// The key ideas at play are the following: 41// - Maintain an in-memory multi-level quantile summary in a way to guarantee 42// a maximum approximation error of eps * W per bucket where W is the total 43// weight across all points in the input dataset. 44// - Two base operations are defined: MERGE and COMPRESS. MERGE combines two 45// summaries guaranteeing a epsNew = max(eps1, eps2). COMPRESS compresses 46// a summary to b + 1 elements guaranteeing epsNew = epsOld + 1/b. 47// - b * sizeof(summary entry) must ideally be small enough to fit in an 48// average CPU L2 cache. 49// - To distribute this algorithm with maintaining error bounds, we need 50// the worker-computed summaries to have no more than eps / h error 51// where h is the height of the distributed computation graph which 52// is 2 for an MR with no combiner. 53// 54// We mainly want to max out IO bw by ensuring we're not compute-bound and 55// using a reasonable amount of RAM. 56// 57// Complexity: 58// Compute: O(n * log(1/eps * log(eps * n))). 59// Memory: O(1/eps * log^2(eps * n)) <- for one worker streaming through the 60// entire dataset. 61template <typename ValueType, typename WeightType, 62 typename CompareFn = std::less<ValueType>> 63class WeightedQuantilesStream { 64 public: 65 using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>; 66 using BufferEntry = typename Buffer::BufferEntry; 67 using Summary = WeightedQuantilesSummary<ValueType, WeightType, CompareFn>; 68 using SummaryEntry = typename Summary::SummaryEntry; 69 70 explicit WeightedQuantilesStream(double eps, int64 max_elements) 71 : eps_(eps), buffer_(1LL, 2LL), finalized_(false) { 72 std::tie(max_levels_, block_size_) = GetQuantileSpecs(eps, max_elements); 73 buffer_ = Buffer(block_size_, max_elements); 74 summary_levels_.reserve(max_levels_); 75 } 76 77 // Disallow copy and assign but enable move semantics for the stream. 78 WeightedQuantilesStream(const WeightedQuantilesStream& other) = delete; 79 WeightedQuantilesStream& operator=(const WeightedQuantilesStream&) = delete; 80 WeightedQuantilesStream(WeightedQuantilesStream&& other) = default; 81 WeightedQuantilesStream& operator=(WeightedQuantilesStream&& other) = default; 82 83 // Pushes one entry while maintaining approximation error invariants. 84 void PushEntry(const ValueType& value, const WeightType& weight) { 85 // Validate state. 86 QCHECK(!finalized_) << "Finalize() already called."; 87 88 // Push element to base buffer. 89 buffer_.PushEntry(value, weight); 90 91 // When compacted buffer is full we need to compress 92 // and push weighted quantile summary up the level chain. 93 if (buffer_.IsFull()) { 94 PushBuffer(buffer_); 95 } 96 } 97 98 // Pushes full buffer while maintaining approximation error invariants. 99 void PushBuffer(Buffer& buffer) { 100 // Validate state. 101 QCHECK(!finalized_) << "Finalize() already called."; 102 103 // Create local compressed summary and propagate. 104 local_summary_.BuildFromBufferEntries(buffer.GenerateEntryList()); 105 local_summary_.Compress(block_size_, eps_); 106 PropagateLocalSummary(); 107 } 108 109 // Pushes full summary while maintaining approximation error invariants. 110 void PushSummary(const std::vector<SummaryEntry>& summary) { 111 // Validate state. 112 QCHECK(!finalized_) << "Finalize() already called."; 113 114 // Create local compressed summary and propagate. 115 local_summary_.BuildFromSummaryEntries(summary); 116 local_summary_.Compress(block_size_, eps_); 117 PropagateLocalSummary(); 118 } 119 120 // Flushes approximator and finalizes state. 121 void Finalize() { 122 // Validate state. 123 QCHECK(!finalized_) << "Finalize() may only be called once."; 124 125 // Flush any remaining buffer elements. 126 PushBuffer(buffer_); 127 128 // Create final merged summary. 129 local_summary_.Clear(); 130 for (auto& summary : summary_levels_) { 131 local_summary_.Merge(summary); 132 summary.Clear(); 133 } 134 summary_levels_.clear(); 135 summary_levels_.shrink_to_fit(); 136 finalized_ = true; 137 } 138 139 // Generates requested number of quantiles after finalizing stream. 140 // The returned quantiles can be queried using std::lower_bound to get 141 // the bucket for a given value. 142 std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const { 143 // Validate state. 144 QCHECK(finalized_) 145 << "Finalize() must be called before generating quantiles."; 146 return local_summary_.GenerateQuantiles(num_quantiles); 147 } 148 149 // Generates requested number of boundaries after finalizing stream. 150 // The returned boundaries can be queried using std::lower_bound to get 151 // the bucket for a given value. 152 // The boundaries, while still guaranteeing approximation bounds, don't 153 // necessarily represent the actual quantiles of the distribution. 154 // Boundaries are preferable over quantiles when the caller is less 155 // interested in the actual quantiles distribution and more interested in 156 // getting a representative sample of boundary values. 157 std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const { 158 // Validate state. 159 QCHECK(finalized_) 160 << "Finalize() must be called before generating boundaries."; 161 return local_summary_.GenerateBoundaries(num_boundaries); 162 } 163 164 // Calculates approximation error for the specified level. 165 // If the passed level is negative, the approximation error for the entire 166 // summary is returned. Note that after Finalize is called, only the overall 167 // error is available. 168 WeightType ApproximationError(int64 level = -1) const { 169 if (finalized_) { 170 QCHECK(level <= 0) << "Only overall error is available after Finalize()"; 171 return local_summary_.ApproximationError(); 172 } 173 174 if (summary_levels_.empty()) { 175 // No error even if base buffer isn't empty. 176 return 0; 177 } 178 179 // If level is negative, we get the approximation error 180 // for the top-most level which is the max approximation error 181 // in all summaries by construction. 182 if (level < 0) { 183 level = summary_levels_.size() - 1; 184 } 185 QCHECK(level < summary_levels_.size()) << "Invalid level."; 186 return summary_levels_[level].ApproximationError(); 187 } 188 189 size_t MaxDepth() const { return summary_levels_.size(); } 190 191 // Generates requested number of quantiles after finalizing stream. 192 const Summary& GetFinalSummary() const { 193 // Validate state. 194 QCHECK(finalized_) 195 << "Finalize() must be called before requesting final summary."; 196 return local_summary_; 197 } 198 199 // Helper method which, given the desired approximation error 200 // and an upper bound on the number of elements, computes the optimal 201 // number of levels and block size and returns them in the tuple. 202 static std::tuple<int64, int64> GetQuantileSpecs(double eps, 203 int64 max_elements); 204 205 // Serializes the internal state of the stream. 206 std::vector<Summary> SerializeInternalSummaries() const { 207 // The buffer should be empty for serialize to work. 208 QCHECK_EQ(buffer_.Size(), 0); 209 std::vector<Summary> result; 210 result.reserve(summary_levels_.size() + 1); 211 for (const Summary& summary : summary_levels_) { 212 result.push_back(summary); 213 } 214 result.push_back(local_summary_); 215 return result; 216 } 217 218 // Resets the state of the stream with a serialized state. 219 void DeserializeInternalSummaries(const std::vector<Summary>& summaries) { 220 // Clear the state before deserializing. 221 buffer_.Clear(); 222 summary_levels_.clear(); 223 local_summary_.Clear(); 224 QCHECK_GT(max_levels_, summaries.size() - 1); 225 for (int i = 0; i < summaries.size() - 1; ++i) { 226 summary_levels_.push_back(summaries[i]); 227 } 228 local_summary_ = summaries[summaries.size() - 1]; 229 } 230 231 private: 232 // Propagates local summary through summary levels while maintaining 233 // approximation error invariants. 234 void PropagateLocalSummary() { 235 // Validate state. 236 QCHECK(!finalized_) << "Finalize() already called."; 237 238 // No-op if there's nothing to add. 239 if (local_summary_.Size() <= 0) { 240 return; 241 } 242 243 // Propagate summary through levels. 244 size_t level = 0; 245 for (bool settled = false; !settled; ++level) { 246 // Ensure we have enough depth. 247 if (summary_levels_.size() <= level) { 248 summary_levels_.emplace_back(); 249 } 250 251 // Merge summaries. 252 Summary& current_summary = summary_levels_[level]; 253 local_summary_.Merge(current_summary); 254 255 // Check if we need to compress and propagate summary higher. 256 if (current_summary.Size() == 0 || 257 local_summary_.Size() <= block_size_ + 1) { 258 current_summary = std::move(local_summary_); 259 settled = true; 260 } else { 261 // Compress, empty current level and propagate. 262 local_summary_.Compress(block_size_, eps_); 263 current_summary.Clear(); 264 } 265 } 266 } 267 268 // Desired approximation precision. 269 double eps_; 270 // Maximum number of levels. 271 int64 max_levels_; 272 // Max block size per level. 273 int64 block_size_; 274 // Base buffer. 275 Buffer buffer_; 276 // Local summary used to minimize memory allocation and cache misses. 277 // After the stream is finalized, this summary holds the final quantile 278 // estimates. 279 Summary local_summary_; 280 // Summary levels; 281 std::vector<Summary> summary_levels_; 282 // Flag indicating whether the stream is finalized. 283 bool finalized_; 284}; 285 286template <typename ValueType, typename WeightType, typename CompareFn> 287inline std::tuple<int64, int64> 288WeightedQuantilesStream<ValueType, WeightType, CompareFn>::GetQuantileSpecs( 289 double eps, int64 max_elements) { 290 int64 max_level = 1LL; 291 int64 block_size = 2LL; 292 QCHECK(eps >= 0 && eps < 1); 293 QCHECK_GT(max_elements, 0); 294 295 if (eps <= std::numeric_limits<double>::epsilon()) { 296 // Exact quantile computation at the expense of RAM. 297 max_level = 1; 298 block_size = std::max(max_elements, 2LL); 299 } else { 300 // The bottom-most level will become full at most 301 // (max_elements / block_size) times, the level above will become full 302 // (max_elements / 2 * block_size) times and generally level l becomes 303 // full (max_elements / 2^l * block_size) times until the last 304 // level max_level becomes full at most once meaning when the inequality 305 // (2^max_level * block_size >= max_elements) is satisfied. 306 // In what follows, we jointly solve for max_level and block_size by 307 // gradually increasing the level until the inequality above is satisfied. 308 // We could alternatively set max_level = ceil(log2(eps * max_elements)); 309 // and block_size = ceil(max_level / eps) + 1 but that tends to give more 310 // pessimistic bounds and wastes RAM needlessly. 311 for (max_level = 1, block_size = 2; 312 (1LL << max_level) * block_size < max_elements; ++max_level) { 313 // Update upper bound on block size at current level, we always 314 // increase the estimate by 2 to hold the min/max elements seen so far. 315 block_size = static_cast<size_t>(ceil(max_level / eps)) + 1; 316 } 317 } 318 return std::make_tuple(max_level, std::max(block_size, 2LL)); 319} 320 321} // namespace quantiles 322} // namespace boosted_trees 323} // namespace tensorflow 324 325#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ 326