weighted_picker.h revision c8b59c046895fa5b6d79f73e0b5817330fcfbfc1
1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16// An abstraction to pick from one of N elements with a specified 17// weight per element. 18// 19// The weight for a given element can be changed in O(lg N) time 20// An element can be picked in O(lg N) time. 21// 22// Uses O(N) bytes of memory. 23// 24// Alternative: distribution-sampler.h allows O(1) time picking, but no weight 25// adjustment after construction. 26 27#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ 28#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ 29 30#include <assert.h> 31 32#include "tensorflow/core/platform/logging.h" 33#include "tensorflow/core/platform/macros.h" 34#include "tensorflow/core/platform/types.h" 35 36namespace tensorflow { 37namespace random { 38 39class SimplePhilox; 40 41class WeightedPicker { 42 public: 43 // REQUIRES N >= 0 44 // Initializes the elements with a weight of one per element 45 explicit WeightedPicker(int N); 46 47 // Releases all resources 48 ~WeightedPicker(); 49 50 // Pick a random element with probability proportional to its weight. 51 // If total weight is zero, returns -1. 52 int Pick(SimplePhilox* rnd) const; 53 54 // Deterministically pick element x whose weight covers the 55 // specified weight_index. 56 // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ] 57 int PickAt(int32 weight_index) const; 58 59 // Get the weight associated with an element 60 // REQUIRES 0 <= index < N 61 int32 get_weight(int index) const; 62 63 // Set the weight associated with an element 64 // REQUIRES weight >= 0.0f 65 // REQUIRES 0 <= index < N 66 void set_weight(int index, int32 weight); 67 68 // Get the total combined weight of all elements 69 int32 total_weight() const; 70 71 // Get the number of elements in the picker 72 int num_elements() const; 73 74 // Set weight of each element to "weight" 75 void SetAllWeights(int32 weight); 76 77 // Resizes the picker to N and 78 // sets the weight of each element i to weight[i]. 79 // The sum of the weights should not exceed 2^31 - 2 80 // Complexity O(N). 81 void SetWeightsFromArray(int N, const int32* weights); 82 83 // REQUIRES N >= 0 84 // 85 // Resize the weighted picker so that it has "N" elements. 86 // Any newly added entries have zero weight. 87 // 88 // Note: Resizing to a smaller size than num_elements() will 89 // not reclaim any memory. If you wish to reduce memory usage, 90 // allocate a new WeightedPicker of the appropriate size. 91 // 92 // It is efficient to use repeated calls to Resize(num_elements() + 1) 93 // to grow the picker to size X (takes total time O(X)). 94 void Resize(int N); 95 96 // Grow the picker by one and set the weight of the new entry to "weight". 97 // 98 // Repeated calls to Append() in order to grow the 99 // picker to size X takes a total time of O(X lg(X)). 100 // Consider using SetWeightsFromArray instead. 101 void Append(int32 weight); 102 103 private: 104 // We keep a binary tree with N leaves. The "i"th leaf contains 105 // the weight of the "i"th element. An internal node contains 106 // the sum of the weights of its children. 107 int N_; // Number of elements 108 int num_levels_; // Number of levels in tree (level-0 is root) 109 int32** level_; // Array that holds nodes per level 110 111 // Size of each level 112 static int LevelSize(int level) { return 1 << level; } 113 114 // Rebuild the tree weights using the leaf weights 115 void RebuildTreeWeights(); 116 117 TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker); 118}; 119 120inline int32 WeightedPicker::get_weight(int index) const { 121 DCHECK_GE(index, 0); 122 DCHECK_LT(index, N_); 123 return level_[num_levels_ - 1][index]; 124} 125 126inline int32 WeightedPicker::total_weight() const { return level_[0][0]; } 127 128inline int WeightedPicker::num_elements() const { return N_; } 129 130} // namespace random 131} // namespace tensorflow 132 133#endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ 134