1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// An abstraction to pick from one of N elements with a specified
17// weight per element.
18//
19// The weight for a given element can be changed in O(lg N) time
20// An element can be picked in O(lg N) time.
21//
22// Uses O(N) bytes of memory.
23//
24// Alternative: distribution-sampler.h allows O(1) time picking, but no weight
25// adjustment after construction.
26
27#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
28#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
29
30#include <assert.h>
31
32#include "tensorflow/core/platform/logging.h"
33#include "tensorflow/core/platform/macros.h"
34#include "tensorflow/core/platform/types.h"
35
36namespace tensorflow {
37namespace random {
38
39class SimplePhilox;
40
41class WeightedPicker {
42 public:
43  // REQUIRES   N >= 0
44  // Initializes the elements with a weight of one per element
45  explicit WeightedPicker(int N);
46
47  // Releases all resources
48  ~WeightedPicker();
49
50  // Pick a random element with probability proportional to its weight.
51  // If total weight is zero, returns -1.
52  int Pick(SimplePhilox* rnd) const;
53
54  // Deterministically pick element x whose weight covers the
55  // specified weight_index.
56  // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ]
57  int PickAt(int32 weight_index) const;
58
59  // Get the weight associated with an element
60  // REQUIRES 0 <= index < N
61  int32 get_weight(int index) const;
62
63  // Set the weight associated with an element
64  // REQUIRES weight >= 0.0f
65  // REQUIRES 0 <= index < N
66  void set_weight(int index, int32 weight);
67
68  // Get the total combined weight of all elements
69  int32 total_weight() const;
70
71  // Get the number of elements in the picker
72  int num_elements() const;
73
74  // Set weight of each element to "weight"
75  void SetAllWeights(int32 weight);
76
77  // Resizes the picker to N and
78  // sets the weight of each element i to weight[i].
79  // The sum of the weights should not exceed 2^31 - 2
80  // Complexity O(N).
81  void SetWeightsFromArray(int N, const int32* weights);
82
83  // REQUIRES   N >= 0
84  //
85  // Resize the weighted picker so that it has "N" elements.
86  // Any newly added entries have zero weight.
87  //
88  // Note: Resizing to a smaller size than num_elements() will
89  // not reclaim any memory.  If you wish to reduce memory usage,
90  // allocate a new WeightedPicker of the appropriate size.
91  //
92  // It is efficient to use repeated calls to Resize(num_elements() + 1)
93  // to grow the picker to size X (takes total time O(X)).
94  void Resize(int N);
95
96  // Grow the picker by one and set the weight of the new entry to "weight".
97  //
98  // Repeated calls to Append() in order to grow the
99  // picker to size X takes a total time of O(X lg(X)).
100  // Consider using SetWeightsFromArray instead.
101  void Append(int32 weight);
102
103 private:
104  // We keep a binary tree with N leaves.  The "i"th leaf contains
105  // the weight of the "i"th element.  An internal node contains
106  // the sum of the weights of its children.
107  int N_;           // Number of elements
108  int num_levels_;  // Number of levels in tree (level-0 is root)
109  int32** level_;   // Array that holds nodes per level
110
111  // Size of each level
112  static int LevelSize(int level) { return 1 << level; }
113
114  // Rebuild the tree weights using the leaf weights
115  void RebuildTreeWeights();
116
117  TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker);
118};
119
120inline int32 WeightedPicker::get_weight(int index) const {
121  DCHECK_GE(index, 0);
122  DCHECK_LT(index, N_);
123  return level_[num_levels_ - 1][index];
124}
125
126inline int32 WeightedPicker::total_weight() const { return level_[0][0]; }
127
128inline int WeightedPicker::num_elements() const { return N_; }
129
130}  // namespace random
131}  // namespace tensorflow
132
133#endif  // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
134