range_sampler.h revision 3ede5506acf6a026f09eda33277d46e34ac7ed10
19c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur/* Copyright 2015 Google Inc. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <vector>
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
213ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/lib/core/status.h"
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/array_slice.h"
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/random/distribution_sampler.h"
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/random/random_distributions.h"
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/random/weighted_picker.h"
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/logging.h"
27af9b56881008f8b804cf93c0d99cb96357b19748Josh Levenberg#include "tensorflow/core/platform/mutex.h"
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/thread_annotations.h"
293ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h"
30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass Env;
34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Abstract subclass for sampling from the set of non-negative integers
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// [0, range)
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass RangeSampler {
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
39f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit RangeSampler(int range) : range_(range) { CHECK_GT(range_, 0); }
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual ~RangeSampler();
41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Sample a single value
43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual int64 Sample(random::SimplePhilox* rnd) const = 0;
44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The probability that a single call to Sample() returns the given value.
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Assumes that value is in [0, range).  No range checking is done.
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual float Probability(int64 value) const = 0;
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Fill "batch" with samples from the distribution.
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // If unique=true, then we re-pick each element until we get a
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // value distinct from all previously picked values in the batch.
52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SampleBatch(random::SimplePhilox* rnd, bool unique,
53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                   gtl::MutableArraySlice<int64> batch) const;
54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Fill "batch" with samples from the distribution, and report
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "expected counts".
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The "expected count" of a value is an estimate of the expected
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // number of occurrences of the value in the batch returned by a
60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // call to this function with the given parameters.  If unique=true,
61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // the expected count is an inclusion probability.  For details on
62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // this estimation, see the comment to "ExpectedCountHelper" in the
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // .cc file.
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Expected counts for the elements of the returned "batch" are reported
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // in the aligned array "batch_expected_count".
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
68f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The user can optionally provide "extras", containg values in the range.
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The expected counts for the extras are reported in the aligned array
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "extras_expected_count".
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "batch_expected_count" must have size equal to 0 or to the size of "batch".
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "extras" and "extras_expected_count" must have equal size.
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SampleBatchGetExpectedCount(
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      random::SimplePhilox* rnd, bool unique,
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<int64> batch,
77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> batch_expected_count,
78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> extras,
79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> extras_expected_count) const;
80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Same as SampleBatchGetExpectedCount (see above), but with avoided values.
82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // We repick to avoid all of the values in "avoided_values".
83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "avoided_values" is only supported with unique=true.  If
84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // unique=false, then avoided_values must be empty.
85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual void SampleBatchGetExpectedCountAvoid(
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      random::SimplePhilox* rnd, bool unique,
87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<int64> batch,
88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> batch_expected_count,
89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> extras,
90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> extras_expected_count,
91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> avoided_values) const;
92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Does this sampler need to be updated with values, e.g. UnigramSampler
94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual bool NeedsUpdates() const { return false; }
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Updates the underlying distribution
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual void Update(gtl::ArraySlice<int64> values) {
98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    LOG(FATAL) << "Update not supported for this sampler type.";
99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 range() { return range_; }
102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur protected:
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const int64 range_;
105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// An AllSampler only samples batches of size equal to range.
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// It returns the entire range.
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// It cannot sample single values.
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass AllSampler : public RangeSampler {
111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit AllSampler(int64 range);
113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~AllSampler() override {}
115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override {
117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    LOG(FATAL) << "Should not be called";
118f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override {
121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    LOG(FATAL) << "Should not be called";
122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
123f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SampleBatchGetExpectedCountAvoid(
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      random::SimplePhilox* rnd, bool unique,
126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<int64> batch,
127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> batch_expected_count,
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> extras,
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> extras_expected_count,
130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> avoided_values) const override;
131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const float inv_range_;
134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass UniformSampler : public RangeSampler {
137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit UniformSampler(int64 range);
139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~UniformSampler() override {}
141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const float inv_range_;
148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass LogUniformSampler : public RangeSampler {
151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit LogUniformSampler(int64 range);
153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~LogUniformSampler() override {}
155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
159f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const double log_range_;
162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
164f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-unsafe unigram sampler
165f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass ThreadUnsafeUnigramSampler : public RangeSampler {
166f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
167f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit ThreadUnsafeUnigramSampler(int64 range);
168f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~ThreadUnsafeUnigramSampler() override {}
169f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
170f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
171f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  bool NeedsUpdates() const override { return true; }
175f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Update(gtl::ArraySlice<int64> values) override;
176f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
177f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
178f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  random::WeightedPicker picker_;
179f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
180f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
181f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-safe unigram sampler
182f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass UnigramSampler : public RangeSampler {
183f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
184f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit UnigramSampler(int64 range);
185f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~UnigramSampler() override {}
186f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
187f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
190f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Overriding at a high level results in far fewer lock aquisitions.
192f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SampleBatchGetExpectedCountAvoid(
193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      random::SimplePhilox* rnd, bool unique,
194f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<int64> batch,
195f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> batch_expected_count,
196f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> extras,
197f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> extras_expected_count,
198f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> avoided_values) const override;
199f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
200f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  bool NeedsUpdates() const override { return true; }
201f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Update(gtl::ArraySlice<int64> values) override;
202f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
203f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
204f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_);
205f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  mutable mutex mu_;
206f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
207f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
208f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// A unigram sampler that uses a fixed unigram distribution read from a
209f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// file or passed in as an in-memory array instead of building up the
210f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// distribution from data on the fly. There is also an option to skew the
211f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// distribution by applying a distortion power to the weights.
212f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass FixedUnigramSampler : public RangeSampler {
213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
214f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The vocab_file is assumed to be a CSV, with the last entry of each row a
215f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // value representing the counts or probabilities for the corresponding ID.
216f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  FixedUnigramSampler(Env* env, int64 range, const string& vocab_file,
217f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                      float distortion, int32 num_reserved_ids,
218f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                      int32 num_shards, int32 shard);
219f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
220f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  FixedUnigramSampler(int64 range, const std::vector<float>& unigrams,
221f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                      float distortion, int32 num_reserved_ids,
222f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                      int32 num_shards, int32 shard);
223f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
224f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
225f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
226f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
227f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
228f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
229f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Underlying distribution sampler.
230f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::unique_ptr<random::DistributionSampler> dist_sampler_;
231f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Weights for individual samples. The probability of a sample i is defined
232f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // as weights_.at(i) / total_weight_.
233f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::vector<float> weights_;
234f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The total weights of all samples.
235f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float total_weight_;
236f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Sharding information of the sampler. The whole vocabulary is sharded
237f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // into num_shards_ smaller ranges and each sampler is responsible for one
238f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // such smaller range, identified by the shard number.
239f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int32 num_shards_;
240f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int32 shard_;
241f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
242f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Fill the sampler with the appropriate number of reserved IDs.
243f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void FillReservedIds(int32 num_reserved_ids);
244f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Load IDs to sample from a CSV file. It is assumed that the last item of
245f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // each row contains a count or probability for the corresponding ID.
246f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  Status LoadFromFile(Env* env, const string& vocab_file, float distortion);
247f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Load from an in-memory array.
248f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion);
249f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
250f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
251f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace tensorflow
252f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
253f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
254