1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 15f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// An abstraction to pick from one of N elements with a specified 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// weight per element. 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// The weight for a given element can be changed in O(lg N) time 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// An element can be picked in O(lg N) time. 21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Uses O(N) bytes of memory. 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Alternative: distribution-sampler.h allows O(1) time picking, but no weight 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// adjustment after construction. 26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ 29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <assert.h> 31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/logging.h" 33252bb90ca2dd412ca2fd2908faf1a25d6ef618cfJosh Levenberg#include "tensorflow/core/platform/macros.h" 345a24d3a2514698b0ae11563b2ea21e368de48a4fJosh Levenberg#include "tensorflow/core/platform/types.h" 35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow { 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace random { 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass SimplePhilox; 40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass WeightedPicker { 42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // REQUIRES N >= 0 44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Initializes the elements with a weight of one per element 45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit WeightedPicker(int N); 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Releases all resources 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~WeightedPicker(); 49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Pick a random element with probability proportional to its weight. 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // If total weight is zero, returns -1. 52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int Pick(SimplePhilox* rnd) const; 53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Deterministically pick element x whose weight covers the 55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // specified weight_index. 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ] 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int PickAt(int32 weight_index) const; 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Get the weight associated with an element 60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // REQUIRES 0 <= index < N 61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int32 get_weight(int index) const; 62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Set the weight associated with an element 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // REQUIRES weight >= 0.0f 65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // REQUIRES 0 <= index < N 66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void set_weight(int index, int32 weight); 67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Get the total combined weight of all elements 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int32 total_weight() const; 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Get the number of elements in the picker 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int num_elements() const; 73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Set weight of each element to "weight" 75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void SetAllWeights(int32 weight); 76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Resizes the picker to N and 78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // sets the weight of each element i to weight[i]. 79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // The sum of the weights should not exceed 2^31 - 2 80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Complexity O(N). 81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void SetWeightsFromArray(int N, const int32* weights); 82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // REQUIRES N >= 0 84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Resize the weighted picker so that it has "N" elements. 86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Any newly added entries have zero weight. 87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Note: Resizing to a smaller size than num_elements() will 89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // not reclaim any memory. If you wish to reduce memory usage, 90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // allocate a new WeightedPicker of the appropriate size. 91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // It is efficient to use repeated calls to Resize(num_elements() + 1) 93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // to grow the picker to size X (takes total time O(X)). 94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Resize(int N); 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Grow the picker by one and set the weight of the new entry to "weight". 97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Repeated calls to Append() in order to grow the 99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // picker to size X takes a total time of O(X lg(X)). 100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Consider using SetWeightsFromArray instead. 101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Append(int32 weight); 102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // We keep a binary tree with N leaves. The "i"th leaf contains 105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // the weight of the "i"th element. An internal node contains 106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // the sum of the weights of its children. 107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int N_; // Number of elements 108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int num_levels_; // Number of levels in tree (level-0 is root) 109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int32** level_; // Array that holds nodes per level 110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Size of each level 112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur static int LevelSize(int level) { return 1 << level; } 113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Rebuild the tree weights using the leaf weights 115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void RebuildTreeWeights(); 116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker); 118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurinline int32 WeightedPicker::get_weight(int index) const { 121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DCHECK_GE(index, 0); 122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur DCHECK_LT(index, N_); 123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur return level_[num_levels_ - 1][index]; 124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurinline int32 WeightedPicker::total_weight() const { return level_[0][0]; } 127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurinline int WeightedPicker::num_elements() const { return N_; } 129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace random 131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace tensorflow 132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_ 134