1150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower//
3150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// Licensed under the Apache License, Version 2.0 (the "License");
4150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// you may not use this file except in compliance with the License.
5150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// You may obtain a copy of the License at
6150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower//
7150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower//     http://www.apache.org/licenses/LICENSE-2.0
8150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower//
9150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// Unless required by applicable law or agreed to in writing, software
10150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// distributed under the License is distributed on an "AS IS" BASIS,
11150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// See the License for the specific language governing permissions and
13150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// limitations under the License.
14150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// =============================================================================
15f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
16f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
17150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
18150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include <algorithm>
19150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include <unordered_map>
20150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include <vector>
21150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
22150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include "tensorflow/core/platform/logging.h"
23150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower#include "tensorflow/core/platform/types.h"
24150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
25150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowernamespace tensorflow {
26150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowernamespace boosted_trees {
27150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowernamespace quantiles {
28150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
29150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// Buffering container ideally suited for scenarios where we need
30150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower// to sort and dedupe/compact fixed chunks of a stream of weighted elements.
31150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowertemplate <typename ValueType, typename WeightType,
32150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower          typename CompareFn = std::less<ValueType>>
33150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowerclass WeightedQuantilesBuffer {
34150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower public:
35150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  struct BufferEntry {
360dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower    BufferEntry(ValueType v, WeightType w)
370dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower        : value(std::move(v)), weight(std::move(w)) {}
380dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower    BufferEntry() : value(), weight(0) {}
39150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
40150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    bool operator<(const BufferEntry& other) const {
41150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower      return kCompFn(value, other.value);
42150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    }
43150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    bool operator==(const BufferEntry& other) const {
44150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower      return value == other.value && weight == other.weight;
45150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    }
46150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    friend std::ostream& operator<<(std::ostream& strm,
47150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower                                    const BufferEntry& entry) {
48150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower      return strm << "{" << entry.value << ", " << entry.weight << "}";
49150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    }
50150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    ValueType value;
51150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    WeightType weight;
52150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  };
53150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
54150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements)
55150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower      : max_size_(std::min(block_size << 1, max_elements)) {
56150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size
57150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower                          << ", " << max_elements << ")";
583e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    vec_.reserve(max_size_);
59150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  }
60150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
61150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  // Disallow copying as it's semantically non-sensical in the Squawd algorithm
62150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  // but enable move semantics.
63150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete;
64150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete;
65150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default;
66150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default;
67150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
68150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  // Push entry to buffer and maintain a compact representation within
69150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  // pre-defined size limit.
700dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower  void PushEntry(ValueType value, WeightType weight) {
71150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    // Callers are expected to act on a full compacted buffer after the
72150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    // PushEntry call returns.
73150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    QCHECK(!IsFull()) << "Buffer already full: " << max_size_;
74150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
75150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    // Ignore zero and negative weight entries.
76150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    if (weight <= 0) {
77150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower      return;
78150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    }
79150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
803e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    // Push back the entry to the buffer.
810dbb1ad1d53976050180fc2e2289d768e78e300fA. Unique TensorFlower    vec_.push_back(BufferEntry(std::move(value), std::move(weight)));
82150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  }
83150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
843e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower  // Returns a sorted vector view of the base buffer and clears the buffer.
853e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower  // Callers should minimize how often this is called, ideally only right after
863e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower  // the buffer becomes full.
873e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower  std::vector<BufferEntry> GenerateEntryList() {
88150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    std::vector<BufferEntry> ret;
893e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    if (vec_.size() == 0) {
903e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower      return ret;
913e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    }
923e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    ret.swap(vec_);
933e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    vec_.reserve(max_size_);
94150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    std::sort(ret.begin(), ret.end());
953e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    size_t num_entries = 0;
963e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    for (size_t i = 1; i < ret.size(); ++i) {
973e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower      if (ret[i].value != ret[i - 1].value) {
983e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower        BufferEntry tmp = ret[i];
993e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower        ++num_entries;
1003e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower        ret[num_entries] = tmp;
1013e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower      } else {
1023e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower        ret[num_entries].weight += ret[i].weight;
1033e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower      }
1043e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    }
1053e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower    ret.resize(num_entries + 1);
106150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    return ret;
107150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  }
108150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
1093e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower  int64 Size() const { return vec_.size(); }
1103e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower  bool IsFull() const { return vec_.size() >= max_size_; }
1113e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower  void Clear() { vec_.clear(); }
112150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
113150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower private:
1143e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower  using BufferVector = typename std::vector<BufferEntry>;
115150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
116150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  // Comparison function.
117150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  static constexpr decltype(CompareFn()) kCompFn = CompareFn();
118150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
119150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  // Base buffer.
120150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower  size_t max_size_;
1213e8f868135b67c7e532003b22ebb7c90f67ca3d9A. Unique TensorFlower  BufferVector vec_;
122150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower};
123150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
124150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowertemplate <typename ValueType, typename WeightType, typename CompareFn>
125150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlowerconstexpr decltype(CompareFn())
126150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower    WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn;
127150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
128150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower}  // namespace quantiles
129150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower}  // namespace boosted_trees
130150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower}  // namespace tensorflow
131150a0aec64acc75d710fda67f802acebfe496cb8A. Unique TensorFlower
132f8347ceebbad0e06552633fcdf8e63f52246ba62Sanjoy Das#endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
133