1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License");
49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License.
59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at
69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur    http://www.apache.org/licenses/LICENSE-2.0
89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software
109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS,
119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and
139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License.
149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/
159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur
16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <vector>
20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
213ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/lib/core/status.h"
22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/array_slice.h"
23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/random/distribution_sampler.h"
24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/random/random_distributions.h"
25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/random/weighted_picker.h"
26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/logging.h"
27af9b56881008f8b804cf93c0d99cb96357b19748Josh Levenberg#include "tensorflow/core/platform/mutex.h"
28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/thread_annotations.h"
293ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h"
30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow {
32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass Env;
34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Abstract subclass for sampling from the set of non-negative integers
36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// [0, range)
37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass RangeSampler {
38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
3920a8d1ff2c320dfa46df2bab05334b080aff2444David G. Andersen  explicit RangeSampler(int64 range) : range_(range) { CHECK_GT(range_, 0); }
40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual ~RangeSampler();
41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Sample a single value
43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual int64 Sample(random::SimplePhilox* rnd) const = 0;
44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The probability that a single call to Sample() returns the given value.
46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Assumes that value is in [0, range).  No range checking is done.
47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual float Probability(int64 value) const = 0;
48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Fill "batch" with samples from the distribution.
50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // If unique=true, then we re-pick each element until we get a
51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // value distinct from all previously picked values in the batch.
52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SampleBatch(random::SimplePhilox* rnd, bool unique,
53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                   gtl::MutableArraySlice<int64> batch) const;
54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Fill "batch" with samples from the distribution, and report
56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "expected counts".
57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The "expected count" of a value is an estimate of the expected
59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // number of occurrences of the value in the batch returned by a
60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // call to this function with the given parameters.  If unique=true,
61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // the expected count is an inclusion probability.  For details on
62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // this estimation, see the comment to "ExpectedCountHelper" in the
63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // .cc file.
64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Expected counts for the elements of the returned "batch" are reported
66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // in the aligned array "batch_expected_count".
67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
68fe056f0b5e52db86766761f5e6446a89c1aa3938Vijay Vasudevan  // The user can optionally provide "extras", containing values in the range.
69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The expected counts for the extras are reported in the aligned array
70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "extras_expected_count".
71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  //
72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "batch_expected_count" must have size equal to 0 or to the size of "batch".
73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "extras" and "extras_expected_count" must have equal size.
74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SampleBatchGetExpectedCount(
75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      random::SimplePhilox* rnd, bool unique,
76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<int64> batch,
77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> batch_expected_count,
78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> extras,
79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> extras_expected_count) const;
80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Same as SampleBatchGetExpectedCount (see above), but with avoided values.
82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // We repick to avoid all of the values in "avoided_values".
83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // "avoided_values" is only supported with unique=true.  If
84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // unique=false, then avoided_values must be empty.
85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual void SampleBatchGetExpectedCountAvoid(
86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      random::SimplePhilox* rnd, bool unique,
87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<int64> batch,
88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> batch_expected_count,
89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> extras,
90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> extras_expected_count,
91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> avoided_values) const;
92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Does this sampler need to be updated with values, e.g. UnigramSampler
94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual bool NeedsUpdates() const { return false; }
95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Updates the underlying distribution
97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  virtual void Update(gtl::ArraySlice<int64> values) {
98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    LOG(FATAL) << "Update not supported for this sampler type.";
99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 range() { return range_; }
102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur protected:
104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const int64 range_;
105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// An AllSampler only samples batches of size equal to range.
108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// It returns the entire range.
109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// It cannot sample single values.
110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass AllSampler : public RangeSampler {
111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit AllSampler(int64 range);
113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~AllSampler() override {}
115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override {
117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    LOG(FATAL) << "Should not be called";
1187e8050ce34fa639af93af8ec239f826e8dd68077guschmue    return 0;
119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override {
122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur    LOG(FATAL) << "Should not be called";
1237e8050ce34fa639af93af8ec239f826e8dd68077guschmue    return 0;
124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  }
125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SampleBatchGetExpectedCountAvoid(
127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      random::SimplePhilox* rnd, bool unique,
128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<int64> batch,
129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> batch_expected_count,
130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> extras,
131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> extras_expected_count,
132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> avoided_values) const override;
133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass UniformSampler : public RangeSampler {
136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit UniformSampler(int64 range);
138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~UniformSampler() override {}
140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const float inv_range_;
147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass LogUniformSampler : public RangeSampler {
150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit LogUniformSampler(int64 range);
152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~LogUniformSampler() override {}
154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
159f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  const double log_range_;
161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-unsafe unigram sampler
164f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass ThreadUnsafeUnigramSampler : public RangeSampler {
165f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
166f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit ThreadUnsafeUnigramSampler(int64 range);
167f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~ThreadUnsafeUnigramSampler() override {}
168f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
169f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
170f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
171f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  bool NeedsUpdates() const override { return true; }
174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Update(gtl::ArraySlice<int64> values) override;
175f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
176f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
177f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  random::WeightedPicker picker_;
178f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
179f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
180f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-safe unigram sampler
181f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass UnigramSampler : public RangeSampler {
182f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
183f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  explicit UnigramSampler(int64 range);
184f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ~UnigramSampler() override {}
185f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
186f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
187f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
19059f1eba5fb94506a205fa2e81145667754739da5Martin Wicke  // Overriding at a high level results in far fewer lock acquisitions.
191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void SampleBatchGetExpectedCountAvoid(
192f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      random::SimplePhilox* rnd, bool unique,
193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<int64> batch,
194f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> batch_expected_count,
195f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> extras,
196f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::MutableArraySlice<float> extras_expected_count,
197f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur      gtl::ArraySlice<int64> avoided_values) const override;
198f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
199f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  bool NeedsUpdates() const override { return true; }
200f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void Update(gtl::ArraySlice<int64> values) override;
201f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
202f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
203f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_);
204f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  mutable mutex mu_;
205f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
206f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
207f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// A unigram sampler that uses a fixed unigram distribution read from a
208f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// file or passed in as an in-memory array instead of building up the
209f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// distribution from data on the fly. There is also an option to skew the
210f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// distribution by applying a distortion power to the weights.
211f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass FixedUnigramSampler : public RangeSampler {
212f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public:
213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The vocab_file is assumed to be a CSV, with the last entry of each row a
214f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // value representing the counts or probabilities for the corresponding ID.
215f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  FixedUnigramSampler(Env* env, int64 range, const string& vocab_file,
216f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                      float distortion, int32 num_reserved_ids,
217f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                      int32 num_shards, int32 shard);
218f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
219f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  FixedUnigramSampler(int64 range, const std::vector<float>& unigrams,
220f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                      float distortion, int32 num_reserved_ids,
221f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur                      int32 num_shards, int32 shard);
222f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
223f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float Probability(int64 value) const override;
224f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
225f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int64 Sample(random::SimplePhilox* rnd) const override;
226f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
227f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private:
228f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Underlying distribution sampler.
229f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::unique_ptr<random::DistributionSampler> dist_sampler_;
230f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Weights for individual samples. The probability of a sample i is defined
231f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // as weights_.at(i) / total_weight_.
232f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  std::vector<float> weights_;
233f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // The total weights of all samples.
234f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  float total_weight_;
235f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Sharding information of the sampler. The whole vocabulary is sharded
236f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // into num_shards_ smaller ranges and each sampler is responsible for one
237f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // such smaller range, identified by the shard number.
238f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int32 num_shards_;
239f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  int32 shard_;
240f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
241f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Fill the sampler with the appropriate number of reserved IDs.
242f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void FillReservedIds(int32 num_reserved_ids);
243f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Load IDs to sample from a CSV file. It is assumed that the last item of
244f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // each row contains a count or probability for the corresponding ID.
245f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  Status LoadFromFile(Env* env, const string& vocab_file, float distortion);
246f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  // Load from an in-memory array.
247f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur  void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion);
248f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur};
249f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
250f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}  // namespace tensorflow
251f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur
252f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif  // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
253