1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
15f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// An abstraction to pick from one of N elements with a specified
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// weight per element.
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur//
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// The weight for a given element can be changed in O(lg N) time
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// An element can be picked in O(lg N) time.
21f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur//
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Uses O(N) bytes of memory.
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur//
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Alternative: distribution-sampler.h allows O(1) time picking, but no weight
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// adjustment after construction.
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
27f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
29f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <assert.h>
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/logging.h"
33252bb90ca2dd412ca2fd2908faf1a25d6ef618cfJosh Levenberg#include "tensorflow/core/platform/macros.h"
345a24d3a2514698b0ae11563b2ea21e368de48a4fJosh Levenberg#include "tensorflow/core/platform/types.h"
35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace random {
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass SimplePhilox;
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass WeightedPicker {
42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // REQUIRES   N >= 0
44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Initializes the elements with a weight of one per element
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit WeightedPicker(int N);
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Releases all resources
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~WeightedPicker();
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Pick a random element with probability proportional to its weight.
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // If total weight is zero, returns -1.
52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int Pick(SimplePhilox* rnd) const;
53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Deterministically pick element x whose weight covers the
55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // specified weight_index.
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Returns -1 if weight_index is not in the range [ 0 .. total_weight()-1 ]
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int PickAt(int32 weight_index) const;
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Get the weight associated with an element
60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // REQUIRES 0 <= index < N
61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int32 get_weight(int index) const;
62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Set the weight associated with an element
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // REQUIRES weight >= 0.0f
65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // REQUIRES 0 <= index < N
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void set_weight(int index, int32 weight);
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Get the total combined weight of all elements
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int32 total_weight() const;
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Get the number of elements in the picker
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int num_elements() const;
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Set weight of each element to "weight"
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SetAllWeights(int32 weight);
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Resizes the picker to N and
78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // sets the weight of each element i to weight[i].
79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The sum of the weights should not exceed 2^31 - 2
80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Complexity O(N).
81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SetWeightsFromArray(int N, const int32* weights);
82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // REQUIRES   N >= 0
84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Resize the weighted picker so that it has "N" elements.
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Any newly added entries have zero weight.
87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Note: Resizing to a smaller size than num_elements() will
89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // not reclaim any memory.  If you wish to reduce memory usage,
90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // allocate a new WeightedPicker of the appropriate size.
91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // It is efficient to use repeated calls to Resize(num_elements() + 1)
93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // to grow the picker to size X (takes total time O(X)).
94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Resize(int N);
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Grow the picker by one and set the weight of the new entry to "weight".
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Repeated calls to Append() in order to grow the
99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // picker to size X takes a total time of O(X lg(X)).
100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Consider using SetWeightsFromArray instead.
101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Append(int32 weight);
102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // We keep a binary tree with N leaves.  The "i"th leaf contains
105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // the weight of the "i"th element.  An internal node contains
106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // the sum of the weights of its children.
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int N_;           // Number of elements
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int num_levels_;  // Number of levels in tree (level-0 is root)
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int32** level_;   // Array that holds nodes per level
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Size of each level
112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  static int LevelSize(int level) { return 1 << level; }
113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Rebuild the tree weights using the leaf weights
115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void RebuildTreeWeights();
116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  TF_DISALLOW_COPY_AND_ASSIGN(WeightedPicker);
118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurinline int32 WeightedPicker::get_weight(int index) const {
121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DCHECK_GE(index, 0);
122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  DCHECK_LT(index, N_);
123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  return level_[num_levels_ - 1][index];
124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurinline int32 WeightedPicker::total_weight() const { return level_[0][0]; }
127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurinline int WeightedPicker::num_elements() const { return N_; }
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace random
131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace tensorflow
132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // TENSORFLOW_LIB_RANDOM_WEIGHTED_PICKER_H_
134