1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
16#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
17
18#include <cstring>
19#include <vector>
20
21#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
22
23namespace tensorflow {
24namespace boosted_trees {
25namespace quantiles {
26
27// Summary holding a sorted block of entries with upper bound guarantees
28// over the approximation error.
29template <typename ValueType, typename WeightType,
30          typename CompareFn = std::less<ValueType>>
31class WeightedQuantilesSummary {
32 public:
33  using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
34  using BufferEntry = typename Buffer::BufferEntry;
35
36  struct SummaryEntry {
37    SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
38                 const WeightType& max) {
39      // Explicitly initialize all of memory (including padding from memory
40      // alignment) to allow the struct to be msan-resistant "plain old data".
41      //
42      // POD = http://en.cppreference.com/w/cpp/concept/PODType
43      memset(this, 0, sizeof(*this));
44
45      value = v;
46      weight = w;
47      min_rank = min;
48      max_rank = max;
49    }
50
51    SummaryEntry() {
52      memset(this, 0, sizeof(*this));
53
54      value = 0;
55      weight = 0;
56      min_rank = 0;
57      max_rank = 0;
58    }
59
60    bool operator==(const SummaryEntry& other) const {
61      return value == other.value && weight == other.weight &&
62             min_rank == other.min_rank && max_rank == other.max_rank;
63    }
64    friend std::ostream& operator<<(std::ostream& strm,
65                                    const SummaryEntry& entry) {
66      return strm << "{" << entry.value << ", " << entry.weight << ", "
67                  << entry.min_rank << ", " << entry.max_rank << "}";
68    }
69
70    // Max rank estimate for previous smaller value.
71    WeightType PrevMaxRank() const { return max_rank - weight; }
72
73    // Min rank estimate for next larger value.
74    WeightType NextMinRank() const { return min_rank + weight; }
75
76    ValueType value;
77    WeightType weight;
78    WeightType min_rank;
79    WeightType max_rank;
80  };
81
82  // Re-construct summary from the specified buffer.
83  void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) {
84    entries_.clear();
85    entries_.reserve(buffer_entries.size());
86    WeightType cumulative_weight = 0;
87    for (const auto& entry : buffer_entries) {
88      WeightType current_weight = entry.weight;
89      entries_.emplace_back(entry.value, entry.weight, cumulative_weight,
90                            cumulative_weight + current_weight);
91      cumulative_weight += current_weight;
92    }
93  }
94
95  // Re-construct summary from the specified summary entries.
96  void BuildFromSummaryEntries(
97      const std::vector<SummaryEntry>& summary_entries) {
98    entries_.clear();
99    entries_.reserve(summary_entries.size());
100    entries_.insert(entries_.begin(), summary_entries.begin(),
101                    summary_entries.end());
102  }
103
104  // Merges two summaries through an algorithm that's derived from MergeSort
105  // for summary entries while guaranteeing that the max approximation error
106  // of the final merged summary is no greater than the approximation errors
107  // of each individual summary.
108  // For example consider summaries where each entry is of the form
109  // (element, weight, min rank, max rank):
110  // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5)
111  // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2)
112  // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7).
113  void Merge(const WeightedQuantilesSummary& other_summary) {
114    // Make sure we have something to merge.
115    const auto& other_entries = other_summary.entries_;
116    if (other_entries.empty()) {
117      return;
118    }
119    if (entries_.empty()) {
120      BuildFromSummaryEntries(other_summary.entries_);
121      return;
122    }
123
124    // Move current entries to make room for a new buffer.
125    std::vector<SummaryEntry> base_entries(std::move(entries_));
126    entries_.clear();
127    entries_.reserve(base_entries.size() + other_entries.size());
128
129    // Merge entries maintaining ranks. The idea is to stack values
130    // in order which we can do in linear time as the two summaries are
131    // already sorted. We keep track of the next lower rank from either
132    // summary and update it as we pop elements from the summaries.
133    // We handle the special case when the next two elements from either
134    // summary are equal, in which case we just merge the two elements
135    // and simultaneously update both ranks.
136    auto it1 = base_entries.cbegin();
137    auto it2 = other_entries.cbegin();
138    WeightType next_min_rank1 = 0;
139    WeightType next_min_rank2 = 0;
140    while (it1 != base_entries.cend() && it2 != other_entries.cend()) {
141      if (kCompFn(it1->value, it2->value)) {  // value1 < value2
142        // Take value1 and use the last added value2 to compute
143        // the min rank and the current value2 to compute the max rank.
144        entries_.emplace_back(it1->value, it1->weight,
145                              it1->min_rank + next_min_rank2,
146                              it1->max_rank + it2->PrevMaxRank());
147        // Update next min rank 1.
148        next_min_rank1 = it1->NextMinRank();
149        ++it1;
150      } else if (kCompFn(it2->value, it1->value)) {  // value1 > value2
151        // Take value2 and use the last added value1 to compute
152        // the min rank and the current value1 to compute the max rank.
153        entries_.emplace_back(it2->value, it2->weight,
154                              it2->min_rank + next_min_rank1,
155                              it2->max_rank + it1->PrevMaxRank());
156        // Update next min rank 2.
157        next_min_rank2 = it2->NextMinRank();
158        ++it2;
159      } else {  // value1 == value2
160        // Straight additive merger of the two entries into one.
161        entries_.emplace_back(it1->value, it1->weight + it2->weight,
162                              it1->min_rank + it2->min_rank,
163                              it1->max_rank + it2->max_rank);
164        // Update next min ranks for both.
165        next_min_rank1 = it1->NextMinRank();
166        next_min_rank2 = it2->NextMinRank();
167        ++it1;
168        ++it2;
169      }
170    }
171
172    // Fill in any residual.
173    while (it1 != base_entries.cend()) {
174      entries_.emplace_back(it1->value, it1->weight,
175                            it1->min_rank + next_min_rank2,
176                            it1->max_rank + other_entries.back().max_rank);
177      ++it1;
178    }
179    while (it2 != other_entries.cend()) {
180      entries_.emplace_back(it2->value, it2->weight,
181                            it2->min_rank + next_min_rank1,
182                            it2->max_rank + base_entries.back().max_rank);
183      ++it2;
184    }
185  }
186
187  // Compresses buffer into desired size. The size specification is
188  // considered a hint as we always keep the first and last elements and
189  // maintain strict approximation error bounds.
190  // The approximation error delta is taken as the max of either the requested
191  // min error or 1 / size_hint.
192  // After compression, the approximation error is guaranteed to increase
193  // by no more than that error delta.
194  // This algorithm is linear in the original size of the summary and is
195  // designed to be cache-friendly.
196  void Compress(int64 size_hint, double min_eps = 0) {
197    // No-op if we're already within the size requirement.
198    size_hint = std::max(size_hint, 2LL);
199    if (entries_.size() <= size_hint) {
200      return;
201    }
202
203    // First compute the max error bound delta resulting from this compression.
204    double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps);
205
206    // Compress elements ensuring approximation bounds and elements diversity
207    // are both maintained.
208    int64 add_accumulator = 0, add_step = entries_.size();
209    auto write_it = entries_.begin() + 1, last_it = write_it;
210    for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) {
211      auto next_it = read_it + 1;
212      while (next_it != entries_.end() && add_accumulator < add_step &&
213             next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) {
214        add_accumulator += size_hint;
215        ++next_it;
216      }
217      if (read_it == next_it - 1) {
218        ++read_it;
219      } else {
220        read_it = next_it - 1;
221      }
222      (*write_it++) = (*read_it);
223      last_it = read_it;
224      add_accumulator -= add_step;
225    }
226    // Write last element and resize.
227    if (last_it + 1 != entries_.end()) {
228      (*write_it++) = entries_.back();
229    }
230    entries_.resize(write_it - entries_.begin());
231  }
232
233  // To construct the boundaries we first run a soft compress over a copy
234  // of the summary and retrieve the values.
235  // The resulting boundaries are guaranteed to both contain at least
236  // num_boundaries unique elements and maintain approximation bounds.
237  std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
238    // Generate soft compressed summary.
239    WeightedQuantilesSummary<ValueType, WeightType, CompareFn>
240        compressed_summary;
241    compressed_summary.BuildFromSummaryEntries(entries_);
242    // Set an epsilon for compression that's at most 1.0 / num_boundaries
243    // more than epsilon of original our summary since the compression operation
244    // adds ~1.0/num_boundaries to final approximation error.
245    float compression_eps = ApproximationError() + (1.0 / num_boundaries);
246    compressed_summary.Compress(num_boundaries, compression_eps);
247
248    // Return boundaries.
249    std::vector<ValueType> output;
250    output.reserve(compressed_summary.entries_.size());
251    for (const auto& entry : compressed_summary.entries_) {
252      output.push_back(entry.value);
253    }
254    return output;
255  }
256
257  // To construct the desired n-quantiles we repetitively query n ranks from the
258  // original summary. The following algorithm is an efficient cache-friendly
259  // O(n) implementation of that idea which avoids the cost of the repetitive
260  // full rank queries O(nlogn).
261  std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
262    std::vector<ValueType> output;
263    num_quantiles = std::max(num_quantiles, 2LL);
264    output.reserve(num_quantiles + 1);
265
266    // Make successive rank queries to get boundaries.
267    // We always keep the first (min) and last (max) entries.
268    for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) {
269      // This step boils down to finding the next element sub-range defined by
270      // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r.
271      WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles);
272      size_t next_idx = cur_idx + 1;
273      while (next_idx < entries_.size() &&
274             d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) {
275        ++next_idx;
276      }
277      cur_idx = next_idx - 1;
278
279      // Determine insertion order.
280      if (next_idx == entries_.size() ||
281          d_2 < entries_[cur_idx].NextMinRank() +
282                    entries_[next_idx].PrevMaxRank()) {
283        output.push_back(entries_[cur_idx].value);
284      } else {
285        output.push_back(entries_[next_idx].value);
286      }
287    }
288    return output;
289  }
290
291  // Calculates current approximation error which should always be <= eps.
292  double ApproximationError() const {
293    if (entries_.empty()) {
294      return 0;
295    }
296
297    WeightType max_gap = 0;
298    for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) {
299      max_gap = std::max(max_gap,
300                         std::max(it->max_rank - it->min_rank - it->weight,
301                                  it->PrevMaxRank() - (it - 1)->NextMinRank()));
302    }
303    return static_cast<double>(max_gap) / TotalWeight();
304  }
305
306  ValueType MinValue() const {
307    return !entries_.empty() ? entries_.front().value
308                             : std::numeric_limits<ValueType>::max();
309  }
310  ValueType MaxValue() const {
311    return !entries_.empty() ? entries_.back().value
312                             : std::numeric_limits<ValueType>::max();
313  }
314  WeightType TotalWeight() const {
315    return !entries_.empty() ? entries_.back().max_rank : 0;
316  }
317  int64 Size() const { return entries_.size(); }
318  void Clear() { entries_.clear(); }
319  const std::vector<SummaryEntry>& GetEntryList() const { return entries_; }
320
321 private:
322  // Comparison function.
323  static constexpr decltype(CompareFn()) kCompFn = CompareFn();
324
325  // Summary entries.
326  std::vector<SummaryEntry> entries_;
327};
328
329template <typename ValueType, typename WeightType, typename CompareFn>
330constexpr decltype(CompareFn())
331    WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn;
332
333}  // namespace quantiles
334}  // namespace boosted_trees
335}  // namespace tensorflow
336
337#endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
338