1// Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14// ============================================================================= 15#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 16#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 17 18#include <cstring> 19#include <vector> 20 21#include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" 22 23namespace tensorflow { 24namespace boosted_trees { 25namespace quantiles { 26 27// Summary holding a sorted block of entries with upper bound guarantees 28// over the approximation error. 29template <typename ValueType, typename WeightType, 30 typename CompareFn = std::less<ValueType>> 31class WeightedQuantilesSummary { 32 public: 33 using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>; 34 using BufferEntry = typename Buffer::BufferEntry; 35 36 struct SummaryEntry { 37 SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min, 38 const WeightType& max) { 39 // Explicitly initialize all of memory (including padding from memory 40 // alignment) to allow the struct to be msan-resistant "plain old data". 41 // 42 // POD = http://en.cppreference.com/w/cpp/concept/PODType 43 memset(this, 0, sizeof(*this)); 44 45 value = v; 46 weight = w; 47 min_rank = min; 48 max_rank = max; 49 } 50 51 SummaryEntry() { 52 memset(this, 0, sizeof(*this)); 53 54 value = 0; 55 weight = 0; 56 min_rank = 0; 57 max_rank = 0; 58 } 59 60 bool operator==(const SummaryEntry& other) const { 61 return value == other.value && weight == other.weight && 62 min_rank == other.min_rank && max_rank == other.max_rank; 63 } 64 friend std::ostream& operator<<(std::ostream& strm, 65 const SummaryEntry& entry) { 66 return strm << "{" << entry.value << ", " << entry.weight << ", " 67 << entry.min_rank << ", " << entry.max_rank << "}"; 68 } 69 70 // Max rank estimate for previous smaller value. 71 WeightType PrevMaxRank() const { return max_rank - weight; } 72 73 // Min rank estimate for next larger value. 74 WeightType NextMinRank() const { return min_rank + weight; } 75 76 ValueType value; 77 WeightType weight; 78 WeightType min_rank; 79 WeightType max_rank; 80 }; 81 82 // Re-construct summary from the specified buffer. 83 void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) { 84 entries_.clear(); 85 entries_.reserve(buffer_entries.size()); 86 WeightType cumulative_weight = 0; 87 for (const auto& entry : buffer_entries) { 88 WeightType current_weight = entry.weight; 89 entries_.emplace_back(entry.value, entry.weight, cumulative_weight, 90 cumulative_weight + current_weight); 91 cumulative_weight += current_weight; 92 } 93 } 94 95 // Re-construct summary from the specified summary entries. 96 void BuildFromSummaryEntries( 97 const std::vector<SummaryEntry>& summary_entries) { 98 entries_.clear(); 99 entries_.reserve(summary_entries.size()); 100 entries_.insert(entries_.begin(), summary_entries.begin(), 101 summary_entries.end()); 102 } 103 104 // Merges two summaries through an algorithm that's derived from MergeSort 105 // for summary entries while guaranteeing that the max approximation error 106 // of the final merged summary is no greater than the approximation errors 107 // of each individual summary. 108 // For example consider summaries where each entry is of the form 109 // (element, weight, min rank, max rank): 110 // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5) 111 // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2) 112 // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7). 113 void Merge(const WeightedQuantilesSummary& other_summary) { 114 // Make sure we have something to merge. 115 const auto& other_entries = other_summary.entries_; 116 if (other_entries.empty()) { 117 return; 118 } 119 if (entries_.empty()) { 120 BuildFromSummaryEntries(other_summary.entries_); 121 return; 122 } 123 124 // Move current entries to make room for a new buffer. 125 std::vector<SummaryEntry> base_entries(std::move(entries_)); 126 entries_.clear(); 127 entries_.reserve(base_entries.size() + other_entries.size()); 128 129 // Merge entries maintaining ranks. The idea is to stack values 130 // in order which we can do in linear time as the two summaries are 131 // already sorted. We keep track of the next lower rank from either 132 // summary and update it as we pop elements from the summaries. 133 // We handle the special case when the next two elements from either 134 // summary are equal, in which case we just merge the two elements 135 // and simultaneously update both ranks. 136 auto it1 = base_entries.cbegin(); 137 auto it2 = other_entries.cbegin(); 138 WeightType next_min_rank1 = 0; 139 WeightType next_min_rank2 = 0; 140 while (it1 != base_entries.cend() && it2 != other_entries.cend()) { 141 if (kCompFn(it1->value, it2->value)) { // value1 < value2 142 // Take value1 and use the last added value2 to compute 143 // the min rank and the current value2 to compute the max rank. 144 entries_.emplace_back(it1->value, it1->weight, 145 it1->min_rank + next_min_rank2, 146 it1->max_rank + it2->PrevMaxRank()); 147 // Update next min rank 1. 148 next_min_rank1 = it1->NextMinRank(); 149 ++it1; 150 } else if (kCompFn(it2->value, it1->value)) { // value1 > value2 151 // Take value2 and use the last added value1 to compute 152 // the min rank and the current value1 to compute the max rank. 153 entries_.emplace_back(it2->value, it2->weight, 154 it2->min_rank + next_min_rank1, 155 it2->max_rank + it1->PrevMaxRank()); 156 // Update next min rank 2. 157 next_min_rank2 = it2->NextMinRank(); 158 ++it2; 159 } else { // value1 == value2 160 // Straight additive merger of the two entries into one. 161 entries_.emplace_back(it1->value, it1->weight + it2->weight, 162 it1->min_rank + it2->min_rank, 163 it1->max_rank + it2->max_rank); 164 // Update next min ranks for both. 165 next_min_rank1 = it1->NextMinRank(); 166 next_min_rank2 = it2->NextMinRank(); 167 ++it1; 168 ++it2; 169 } 170 } 171 172 // Fill in any residual. 173 while (it1 != base_entries.cend()) { 174 entries_.emplace_back(it1->value, it1->weight, 175 it1->min_rank + next_min_rank2, 176 it1->max_rank + other_entries.back().max_rank); 177 ++it1; 178 } 179 while (it2 != other_entries.cend()) { 180 entries_.emplace_back(it2->value, it2->weight, 181 it2->min_rank + next_min_rank1, 182 it2->max_rank + base_entries.back().max_rank); 183 ++it2; 184 } 185 } 186 187 // Compresses buffer into desired size. The size specification is 188 // considered a hint as we always keep the first and last elements and 189 // maintain strict approximation error bounds. 190 // The approximation error delta is taken as the max of either the requested 191 // min error or 1 / size_hint. 192 // After compression, the approximation error is guaranteed to increase 193 // by no more than that error delta. 194 // This algorithm is linear in the original size of the summary and is 195 // designed to be cache-friendly. 196 void Compress(int64 size_hint, double min_eps = 0) { 197 // No-op if we're already within the size requirement. 198 size_hint = std::max(size_hint, 2LL); 199 if (entries_.size() <= size_hint) { 200 return; 201 } 202 203 // First compute the max error bound delta resulting from this compression. 204 double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps); 205 206 // Compress elements ensuring approximation bounds and elements diversity 207 // are both maintained. 208 int64 add_accumulator = 0, add_step = entries_.size(); 209 auto write_it = entries_.begin() + 1, last_it = write_it; 210 for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) { 211 auto next_it = read_it + 1; 212 while (next_it != entries_.end() && add_accumulator < add_step && 213 next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) { 214 add_accumulator += size_hint; 215 ++next_it; 216 } 217 if (read_it == next_it - 1) { 218 ++read_it; 219 } else { 220 read_it = next_it - 1; 221 } 222 (*write_it++) = (*read_it); 223 last_it = read_it; 224 add_accumulator -= add_step; 225 } 226 // Write last element and resize. 227 if (last_it + 1 != entries_.end()) { 228 (*write_it++) = entries_.back(); 229 } 230 entries_.resize(write_it - entries_.begin()); 231 } 232 233 // To construct the boundaries we first run a soft compress over a copy 234 // of the summary and retrieve the values. 235 // The resulting boundaries are guaranteed to both contain at least 236 // num_boundaries unique elements and maintain approximation bounds. 237 std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const { 238 // Generate soft compressed summary. 239 WeightedQuantilesSummary<ValueType, WeightType, CompareFn> 240 compressed_summary; 241 compressed_summary.BuildFromSummaryEntries(entries_); 242 // Set an epsilon for compression that's at most 1.0 / num_boundaries 243 // more than epsilon of original our summary since the compression operation 244 // adds ~1.0/num_boundaries to final approximation error. 245 float compression_eps = ApproximationError() + (1.0 / num_boundaries); 246 compressed_summary.Compress(num_boundaries, compression_eps); 247 248 // Return boundaries. 249 std::vector<ValueType> output; 250 output.reserve(compressed_summary.entries_.size()); 251 for (const auto& entry : compressed_summary.entries_) { 252 output.push_back(entry.value); 253 } 254 return output; 255 } 256 257 // To construct the desired n-quantiles we repetitively query n ranks from the 258 // original summary. The following algorithm is an efficient cache-friendly 259 // O(n) implementation of that idea which avoids the cost of the repetitive 260 // full rank queries O(nlogn). 261 std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const { 262 std::vector<ValueType> output; 263 num_quantiles = std::max(num_quantiles, 2LL); 264 output.reserve(num_quantiles + 1); 265 266 // Make successive rank queries to get boundaries. 267 // We always keep the first (min) and last (max) entries. 268 for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) { 269 // This step boils down to finding the next element sub-range defined by 270 // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r. 271 WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles); 272 size_t next_idx = cur_idx + 1; 273 while (next_idx < entries_.size() && 274 d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) { 275 ++next_idx; 276 } 277 cur_idx = next_idx - 1; 278 279 // Determine insertion order. 280 if (next_idx == entries_.size() || 281 d_2 < entries_[cur_idx].NextMinRank() + 282 entries_[next_idx].PrevMaxRank()) { 283 output.push_back(entries_[cur_idx].value); 284 } else { 285 output.push_back(entries_[next_idx].value); 286 } 287 } 288 return output; 289 } 290 291 // Calculates current approximation error which should always be <= eps. 292 double ApproximationError() const { 293 if (entries_.empty()) { 294 return 0; 295 } 296 297 WeightType max_gap = 0; 298 for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) { 299 max_gap = std::max(max_gap, 300 std::max(it->max_rank - it->min_rank - it->weight, 301 it->PrevMaxRank() - (it - 1)->NextMinRank())); 302 } 303 return static_cast<double>(max_gap) / TotalWeight(); 304 } 305 306 ValueType MinValue() const { 307 return !entries_.empty() ? entries_.front().value 308 : std::numeric_limits<ValueType>::max(); 309 } 310 ValueType MaxValue() const { 311 return !entries_.empty() ? entries_.back().value 312 : std::numeric_limits<ValueType>::max(); 313 } 314 WeightType TotalWeight() const { 315 return !entries_.empty() ? entries_.back().max_rank : 0; 316 } 317 int64 Size() const { return entries_.size(); } 318 void Clear() { entries_.clear(); } 319 const std::vector<SummaryEntry>& GetEntryList() const { return entries_; } 320 321 private: 322 // Comparison function. 323 static constexpr decltype(CompareFn()) kCompFn = CompareFn(); 324 325 // Summary entries. 326 std::vector<SummaryEntry> entries_; 327}; 328 329template <typename ValueType, typename WeightType, typename CompareFn> 330constexpr decltype(CompareFn()) 331 WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn; 332 333} // namespace quantiles 334} // namespace boosted_trees 335} // namespace tensorflow 336 337#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 338