1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
16#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
17
18#include <cmath>
19#include <memory>
20#include <vector>
21
22#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
23#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h"
24#include "tensorflow/core/platform/types.h"
25
26namespace tensorflow {
27namespace boosted_trees {
28namespace quantiles {
29
30// Class to compute approximate quantiles with error bound guarantees for
31// weighted data sets.
32// This implementation is an adaptation of techniques from the following papers:
33// * (2001) Space-efficient online computation of quantile summaries.
34// * (2004) Power-conserving computation of order-statistics over
35//          sensor networks.
36// * (2007) A fast algorithm for approximate quantiles in high speed
37//          data streams.
38// * (2016) XGBoost: A Scalable Tree Boosting System.
39//
40// The key ideas at play are the following:
41// - Maintain an in-memory multi-level quantile summary in a way to guarantee
42//   a maximum approximation error of eps * W per bucket where W is the total
43//   weight across all points in the input dataset.
44// - Two base operations are defined: MERGE and COMPRESS. MERGE combines two
45//   summaries guaranteeing a epsNew = max(eps1, eps2). COMPRESS compresses
46//   a summary to b + 1 elements guaranteeing epsNew = epsOld + 1/b.
47// - b * sizeof(summary entry) must ideally be small enough to fit in an
48//   average CPU L2 cache.
49// - To distribute this algorithm with maintaining error bounds, we need
50//   the worker-computed summaries to have no more than eps / h error
51//   where h is the height of the distributed computation graph which
52//   is 2 for an MR with no combiner.
53//
54// We mainly want to max out IO bw by ensuring we're not compute-bound and
55// using a reasonable amount of RAM.
56//
57// Complexity:
58// Compute: O(n * log(1/eps * log(eps * n))).
59// Memory: O(1/eps * log^2(eps * n)) <- for one worker streaming through the
60//                                      entire dataset.
61template <typename ValueType, typename WeightType,
62          typename CompareFn = std::less<ValueType>>
63class WeightedQuantilesStream {
64 public:
65  using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
66  using BufferEntry = typename Buffer::BufferEntry;
67  using Summary = WeightedQuantilesSummary<ValueType, WeightType, CompareFn>;
68  using SummaryEntry = typename Summary::SummaryEntry;
69
70  explicit WeightedQuantilesStream(double eps, int64 max_elements)
71      : eps_(eps), buffer_(1LL, 2LL), finalized_(false) {
72    std::tie(max_levels_, block_size_) = GetQuantileSpecs(eps, max_elements);
73    buffer_ = Buffer(block_size_, max_elements);
74    summary_levels_.reserve(max_levels_);
75  }
76
77  // Disallow copy and assign but enable move semantics for the stream.
78  WeightedQuantilesStream(const WeightedQuantilesStream& other) = delete;
79  WeightedQuantilesStream& operator=(const WeightedQuantilesStream&) = delete;
80  WeightedQuantilesStream(WeightedQuantilesStream&& other) = default;
81  WeightedQuantilesStream& operator=(WeightedQuantilesStream&& other) = default;
82
83  // Pushes one entry while maintaining approximation error invariants.
84  void PushEntry(const ValueType& value, const WeightType& weight) {
85    // Validate state.
86    QCHECK(!finalized_) << "Finalize() already called.";
87
88    // Push element to base buffer.
89    buffer_.PushEntry(value, weight);
90
91    // When compacted buffer is full we need to compress
92    // and push weighted quantile summary up the level chain.
93    if (buffer_.IsFull()) {
94      PushBuffer(buffer_);
95    }
96  }
97
98  // Pushes full buffer while maintaining approximation error invariants.
99  void PushBuffer(Buffer& buffer) {
100    // Validate state.
101    QCHECK(!finalized_) << "Finalize() already called.";
102
103    // Create local compressed summary and propagate.
104    local_summary_.BuildFromBufferEntries(buffer.GenerateEntryList());
105    local_summary_.Compress(block_size_, eps_);
106    PropagateLocalSummary();
107  }
108
109  // Pushes full summary while maintaining approximation error invariants.
110  void PushSummary(const std::vector<SummaryEntry>& summary) {
111    // Validate state.
112    QCHECK(!finalized_) << "Finalize() already called.";
113
114    // Create local compressed summary and propagate.
115    local_summary_.BuildFromSummaryEntries(summary);
116    local_summary_.Compress(block_size_, eps_);
117    PropagateLocalSummary();
118  }
119
120  // Flushes approximator and finalizes state.
121  void Finalize() {
122    // Validate state.
123    QCHECK(!finalized_) << "Finalize() may only be called once.";
124
125    // Flush any remaining buffer elements.
126    PushBuffer(buffer_);
127
128    // Create final merged summary.
129    local_summary_.Clear();
130    for (auto& summary : summary_levels_) {
131      local_summary_.Merge(summary);
132      summary.Clear();
133    }
134    summary_levels_.clear();
135    summary_levels_.shrink_to_fit();
136    finalized_ = true;
137  }
138
139  // Generates requested number of quantiles after finalizing stream.
140  // The returned quantiles can be queried using std::lower_bound to get
141  // the bucket for a given value.
142  std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
143    // Validate state.
144    QCHECK(finalized_)
145        << "Finalize() must be called before generating quantiles.";
146    return local_summary_.GenerateQuantiles(num_quantiles);
147  }
148
149  // Generates requested number of boundaries after finalizing stream.
150  // The returned boundaries can be queried using std::lower_bound to get
151  // the bucket for a given value.
152  // The boundaries, while still guaranteeing approximation bounds, don't
153  // necessarily represent the actual quantiles of the distribution.
154  // Boundaries are preferable over quantiles when the caller is less
155  // interested in the actual quantiles distribution and more interested in
156  // getting a representative sample of boundary values.
157  std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
158    // Validate state.
159    QCHECK(finalized_)
160        << "Finalize() must be called before generating boundaries.";
161    return local_summary_.GenerateBoundaries(num_boundaries);
162  }
163
164  // Calculates approximation error for the specified level.
165  // If the passed level is negative, the approximation error for the entire
166  // summary is returned. Note that after Finalize is called, only the overall
167  // error is available.
168  WeightType ApproximationError(int64 level = -1) const {
169    if (finalized_) {
170      QCHECK(level <= 0) << "Only overall error is available after Finalize()";
171      return local_summary_.ApproximationError();
172    }
173
174    if (summary_levels_.empty()) {
175      // No error even if base buffer isn't empty.
176      return 0;
177    }
178
179    // If level is negative, we get the approximation error
180    // for the top-most level which is the max approximation error
181    // in all summaries by construction.
182    if (level < 0) {
183      level = summary_levels_.size() - 1;
184    }
185    QCHECK(level < summary_levels_.size()) << "Invalid level.";
186    return summary_levels_[level].ApproximationError();
187  }
188
189  size_t MaxDepth() const { return summary_levels_.size(); }
190
191  // Generates requested number of quantiles after finalizing stream.
192  const Summary& GetFinalSummary() const {
193    // Validate state.
194    QCHECK(finalized_)
195        << "Finalize() must be called before requesting final summary.";
196    return local_summary_;
197  }
198
199  // Helper method which, given the desired approximation error
200  // and an upper bound on the number of elements, computes the optimal
201  // number of levels and block size and returns them in the tuple.
202  static std::tuple<int64, int64> GetQuantileSpecs(double eps,
203                                                   int64 max_elements);
204
205  // Serializes the internal state of the stream.
206  std::vector<Summary> SerializeInternalSummaries() const {
207    // The buffer should be empty for serialize to work.
208    QCHECK_EQ(buffer_.Size(), 0);
209    std::vector<Summary> result;
210    result.reserve(summary_levels_.size() + 1);
211    for (const Summary& summary : summary_levels_) {
212      result.push_back(summary);
213    }
214    result.push_back(local_summary_);
215    return result;
216  }
217
218  // Resets the state of the stream with a serialized state.
219  void DeserializeInternalSummaries(const std::vector<Summary>& summaries) {
220    // Clear the state before deserializing.
221    buffer_.Clear();
222    summary_levels_.clear();
223    local_summary_.Clear();
224    QCHECK_GT(max_levels_, summaries.size() - 1);
225    for (int i = 0; i < summaries.size() - 1; ++i) {
226      summary_levels_.push_back(summaries[i]);
227    }
228    local_summary_ = summaries[summaries.size() - 1];
229  }
230
231 private:
232  // Propagates local summary through summary levels while maintaining
233  // approximation error invariants.
234  void PropagateLocalSummary() {
235    // Validate state.
236    QCHECK(!finalized_) << "Finalize() already called.";
237
238    // No-op if there's nothing to add.
239    if (local_summary_.Size() <= 0) {
240      return;
241    }
242
243    // Propagate summary through levels.
244    size_t level = 0;
245    for (bool settled = false; !settled; ++level) {
246      // Ensure we have enough depth.
247      if (summary_levels_.size() <= level) {
248        summary_levels_.emplace_back();
249      }
250
251      // Merge summaries.
252      Summary& current_summary = summary_levels_[level];
253      local_summary_.Merge(current_summary);
254
255      // Check if we need to compress and propagate summary higher.
256      if (current_summary.Size() == 0 ||
257          local_summary_.Size() <= block_size_ + 1) {
258        current_summary = std::move(local_summary_);
259        settled = true;
260      } else {
261        // Compress, empty current level and propagate.
262        local_summary_.Compress(block_size_, eps_);
263        current_summary.Clear();
264      }
265    }
266  }
267
268  // Desired approximation precision.
269  double eps_;
270  // Maximum number of levels.
271  int64 max_levels_;
272  // Max block size per level.
273  int64 block_size_;
274  // Base buffer.
275  Buffer buffer_;
276  // Local summary used to minimize memory allocation and cache misses.
277  // After the stream is finalized, this summary holds the final quantile
278  // estimates.
279  Summary local_summary_;
280  // Summary levels;
281  std::vector<Summary> summary_levels_;
282  // Flag indicating whether the stream is finalized.
283  bool finalized_;
284};
285
286template <typename ValueType, typename WeightType, typename CompareFn>
287inline std::tuple<int64, int64>
288WeightedQuantilesStream<ValueType, WeightType, CompareFn>::GetQuantileSpecs(
289    double eps, int64 max_elements) {
290  int64 max_level = 1LL;
291  int64 block_size = 2LL;
292  QCHECK(eps >= 0 && eps < 1);
293  QCHECK_GT(max_elements, 0);
294
295  if (eps <= std::numeric_limits<double>::epsilon()) {
296    // Exact quantile computation at the expense of RAM.
297    max_level = 1;
298    block_size = std::max(max_elements, 2LL);
299  } else {
300    // The bottom-most level will become full at most
301    // (max_elements / block_size) times, the level above will become full
302    // (max_elements / 2 * block_size) times and generally level l becomes
303    // full (max_elements / 2^l * block_size) times until the last
304    // level max_level becomes full at most once meaning when the inequality
305    // (2^max_level * block_size >= max_elements) is satisfied.
306    // In what follows, we jointly solve for max_level and block_size by
307    // gradually increasing the level until the inequality above is satisfied.
308    // We could alternatively set max_level = ceil(log2(eps * max_elements));
309    // and block_size = ceil(max_level / eps) + 1 but that tends to give more
310    // pessimistic bounds and wastes RAM needlessly.
311    for (max_level = 1, block_size = 2;
312         (1LL << max_level) * block_size < max_elements; ++max_level) {
313      // Update upper bound on block size at current level, we always
314      // increase the estimate by 2 to hold the min/max elements seen so far.
315      block_size = static_cast<size_t>(ceil(max_level / eps)) + 1;
316    }
317  }
318  return std::make_tuple(max_level, std::max(block_size, 2LL));
319}
320
321}  // namespace quantiles
322}  // namespace boosted_trees
323}  // namespace tensorflow
324
325#endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
326