1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// =============================================================================
15#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
16#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
17
18#include <algorithm>
19#include <unordered_map>
20#include <vector>
21
22#include "tensorflow/core/platform/logging.h"
23#include "tensorflow/core/platform/types.h"
24
25namespace tensorflow {
26namespace boosted_trees {
27namespace quantiles {
28
29// Buffering container ideally suited for scenarios where we need
30// to sort and dedupe/compact fixed chunks of a stream of weighted elements.
31template <typename ValueType, typename WeightType,
32          typename CompareFn = std::less<ValueType>>
33class WeightedQuantilesBuffer {
34 public:
35  struct BufferEntry {
36    BufferEntry(ValueType v, WeightType w)
37        : value(std::move(v)), weight(std::move(w)) {}
38    BufferEntry() : value(), weight(0) {}
39
40    bool operator<(const BufferEntry& other) const {
41      return kCompFn(value, other.value);
42    }
43    bool operator==(const BufferEntry& other) const {
44      return value == other.value && weight == other.weight;
45    }
46    friend std::ostream& operator<<(std::ostream& strm,
47                                    const BufferEntry& entry) {
48      return strm << "{" << entry.value << ", " << entry.weight << "}";
49    }
50    ValueType value;
51    WeightType weight;
52  };
53
54  explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements)
55      : max_size_(std::min(block_size << 1, max_elements)) {
56    QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size
57                          << ", " << max_elements << ")";
58    vec_.reserve(max_size_);
59  }
60
61  // Disallow copying as it's semantically non-sensical in the Squawd algorithm
62  // but enable move semantics.
63  WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete;
64  WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete;
65  WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default;
66  WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default;
67
68  // Push entry to buffer and maintain a compact representation within
69  // pre-defined size limit.
70  void PushEntry(ValueType value, WeightType weight) {
71    // Callers are expected to act on a full compacted buffer after the
72    // PushEntry call returns.
73    QCHECK(!IsFull()) << "Buffer already full: " << max_size_;
74
75    // Ignore zero and negative weight entries.
76    if (weight <= 0) {
77      return;
78    }
79
80    // Push back the entry to the buffer.
81    vec_.push_back(BufferEntry(std::move(value), std::move(weight)));
82  }
83
84  // Returns a sorted vector view of the base buffer and clears the buffer.
85  // Callers should minimize how often this is called, ideally only right after
86  // the buffer becomes full.
87  std::vector<BufferEntry> GenerateEntryList() {
88    std::vector<BufferEntry> ret;
89    if (vec_.size() == 0) {
90      return ret;
91    }
92    ret.swap(vec_);
93    vec_.reserve(max_size_);
94    std::sort(ret.begin(), ret.end());
95    size_t num_entries = 0;
96    for (size_t i = 1; i < ret.size(); ++i) {
97      if (ret[i].value != ret[i - 1].value) {
98        BufferEntry tmp = ret[i];
99        ++num_entries;
100        ret[num_entries] = tmp;
101      } else {
102        ret[num_entries].weight += ret[i].weight;
103      }
104    }
105    ret.resize(num_entries + 1);
106    return ret;
107  }
108
109  int64 Size() const { return vec_.size(); }
110  bool IsFull() const { return vec_.size() >= max_size_; }
111  void Clear() { vec_.clear(); }
112
113 private:
114  using BufferVector = typename std::vector<BufferEntry>;
115
116  // Comparison function.
117  static constexpr decltype(CompareFn()) kCompFn = CompareFn();
118
119  // Base buffer.
120  size_t max_size_;
121  BufferVector vec_;
122};
123
124template <typename ValueType, typename WeightType, typename CompareFn>
125constexpr decltype(CompareFn())
126    WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn;
127
128}  // namespace quantiles
129}  // namespace boosted_trees
130}  // namespace tensorflow
131
132#endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
133