1c8b59c046895fa5b6d79f73e0b5817330fcfbfc1A. Unique TensorFlower/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 29c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 39c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurLicensed under the Apache License, Version 2.0 (the "License"); 49c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudluryou may not use this file except in compliance with the License. 59c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurYou may obtain a copy of the License at 69c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 79c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur http://www.apache.org/licenses/LICENSE-2.0 89c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 99c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurUnless required by applicable law or agreed to in writing, software 109c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurdistributed under the License is distributed on an "AS IS" BASIS, 119c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 129c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath KudlurSee the License for the specific language governing permissions and 139c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlurlimitations under the License. 149c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur==============================================================================*/ 159c3043ff3bf31a6a81810b4ce9e87ef936f1f529Manjunath Kudlur 16f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_ 17f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_ 18f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 19f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include <vector> 20f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 213ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/lib/core/status.h" 22f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/gtl/array_slice.h" 23f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/random/distribution_sampler.h" 24f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/random/random_distributions.h" 25f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/lib/random/weighted_picker.h" 26f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/logging.h" 27af9b56881008f8b804cf93c0d99cb96357b19748Josh Levenberg#include "tensorflow/core/platform/mutex.h" 28f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#include "tensorflow/core/platform/thread_annotations.h" 293ede5506acf6a026f09eda33277d46e34ac7ed10Josh Levenberg#include "tensorflow/core/platform/types.h" 30f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 31f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurnamespace tensorflow { 32f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 33f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass Env; 34f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 35f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Abstract subclass for sampling from the set of non-negative integers 36f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// [0, range) 37f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass RangeSampler { 38f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 3920a8d1ff2c320dfa46df2bab05334b080aff2444David G. Andersen explicit RangeSampler(int64 range) : range_(range) { CHECK_GT(range_, 0); } 40f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual ~RangeSampler(); 41f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 42f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Sample a single value 43f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual int64 Sample(random::SimplePhilox* rnd) const = 0; 44f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 45f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // The probability that a single call to Sample() returns the given value. 46f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Assumes that value is in [0, range). No range checking is done. 47f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual float Probability(int64 value) const = 0; 48f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 49f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Fill "batch" with samples from the distribution. 50f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // If unique=true, then we re-pick each element until we get a 51f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // value distinct from all previously picked values in the batch. 52f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void SampleBatch(random::SimplePhilox* rnd, bool unique, 53f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<int64> batch) const; 54f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 55f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Fill "batch" with samples from the distribution, and report 56f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // "expected counts". 57f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 58f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // The "expected count" of a value is an estimate of the expected 59f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // number of occurrences of the value in the batch returned by a 60f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // call to this function with the given parameters. If unique=true, 61f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // the expected count is an inclusion probability. For details on 62f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // this estimation, see the comment to "ExpectedCountHelper" in the 63f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // .cc file. 64f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 65f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Expected counts for the elements of the returned "batch" are reported 66f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // in the aligned array "batch_expected_count". 67f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 68fe056f0b5e52db86766761f5e6446a89c1aa3938Vijay Vasudevan // The user can optionally provide "extras", containing values in the range. 69f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // The expected counts for the extras are reported in the aligned array 70f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // "extras_expected_count". 71f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // 72f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // "batch_expected_count" must have size equal to 0 or to the size of "batch". 73f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // "extras" and "extras_expected_count" must have equal size. 74f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void SampleBatchGetExpectedCount( 75f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur random::SimplePhilox* rnd, bool unique, 76f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<int64> batch, 77f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<float> batch_expected_count, 78f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::ArraySlice<int64> extras, 79f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<float> extras_expected_count) const; 80f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 81f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Same as SampleBatchGetExpectedCount (see above), but with avoided values. 82f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // We repick to avoid all of the values in "avoided_values". 83f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // "avoided_values" is only supported with unique=true. If 84f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // unique=false, then avoided_values must be empty. 85f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual void SampleBatchGetExpectedCountAvoid( 86f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur random::SimplePhilox* rnd, bool unique, 87f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<int64> batch, 88f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<float> batch_expected_count, 89f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::ArraySlice<int64> extras, 90f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<float> extras_expected_count, 91f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::ArraySlice<int64> avoided_values) const; 92f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 93f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Does this sampler need to be updated with values, e.g. UnigramSampler 94f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual bool NeedsUpdates() const { return false; } 95f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 96f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Updates the underlying distribution 97f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur virtual void Update(gtl::ArraySlice<int64> values) { 98f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur LOG(FATAL) << "Update not supported for this sampler type."; 99f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 100f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 101f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int64 range() { return range_; } 102f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 103f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur protected: 104f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const int64 range_; 105f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 106f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 107f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// An AllSampler only samples batches of size equal to range. 108f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// It returns the entire range. 109f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// It cannot sample single values. 110f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass AllSampler : public RangeSampler { 111f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 112f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit AllSampler(int64 range); 113f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 114f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~AllSampler() override {} 115f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 116f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int64 Sample(random::SimplePhilox* rnd) const override { 117f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur LOG(FATAL) << "Should not be called"; 1187e8050ce34fa639af93af8ec239f826e8dd68077guschmue return 0; 119f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 120f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 121f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float Probability(int64 value) const override { 122f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur LOG(FATAL) << "Should not be called"; 1237e8050ce34fa639af93af8ec239f826e8dd68077guschmue return 0; 124f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur } 125f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 126f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void SampleBatchGetExpectedCountAvoid( 127f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur random::SimplePhilox* rnd, bool unique, 128f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<int64> batch, 129f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<float> batch_expected_count, 130f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::ArraySlice<int64> extras, 131f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<float> extras_expected_count, 132f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::ArraySlice<int64> avoided_values) const override; 133f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 134f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 135f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass UniformSampler : public RangeSampler { 136f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 137f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit UniformSampler(int64 range); 138f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 139f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~UniformSampler() override {} 140f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 141f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int64 Sample(random::SimplePhilox* rnd) const override; 142f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 143f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float Probability(int64 value) const override; 144f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 145f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 146f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const float inv_range_; 147f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 148f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 149f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass LogUniformSampler : public RangeSampler { 150f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 151f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit LogUniformSampler(int64 range); 152f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 153f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~LogUniformSampler() override {} 154f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 155f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int64 Sample(random::SimplePhilox* rnd) const override; 156f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 157f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float Probability(int64 value) const override; 158f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 159f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 160f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur const double log_range_; 161f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 162f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 163f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-unsafe unigram sampler 164f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass ThreadUnsafeUnigramSampler : public RangeSampler { 165f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 166f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit ThreadUnsafeUnigramSampler(int64 range); 167f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~ThreadUnsafeUnigramSampler() override {} 168f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 169f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int64 Sample(random::SimplePhilox* rnd) const override; 170f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 171f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float Probability(int64 value) const override; 172f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 173f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur bool NeedsUpdates() const override { return true; } 174f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Update(gtl::ArraySlice<int64> values) override; 175f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 176f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 177f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur random::WeightedPicker picker_; 178f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 179f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 180f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// Thread-safe unigram sampler 181f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass UnigramSampler : public RangeSampler { 182f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 183f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur explicit UnigramSampler(int64 range); 184f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ~UnigramSampler() override {} 185f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 186f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int64 Sample(random::SimplePhilox* rnd) const override; 187f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 188f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float Probability(int64 value) const override; 189f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 19059f1eba5fb94506a205fa2e81145667754739da5Martin Wicke // Overriding at a high level results in far fewer lock acquisitions. 191f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void SampleBatchGetExpectedCountAvoid( 192f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur random::SimplePhilox* rnd, bool unique, 193f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<int64> batch, 194f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<float> batch_expected_count, 195f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::ArraySlice<int64> extras, 196f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::MutableArraySlice<float> extras_expected_count, 197f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur gtl::ArraySlice<int64> avoided_values) const override; 198f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 199f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur bool NeedsUpdates() const override { return true; } 200f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void Update(gtl::ArraySlice<int64> values) override; 201f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 202f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 203f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur ThreadUnsafeUnigramSampler unsafe_sampler_ GUARDED_BY(mu_); 204f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur mutable mutex mu_; 205f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 206f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 207f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// A unigram sampler that uses a fixed unigram distribution read from a 208f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// file or passed in as an in-memory array instead of building up the 209f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// distribution from data on the fly. There is also an option to skew the 210f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur// distribution by applying a distortion power to the weights. 211f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlurclass FixedUnigramSampler : public RangeSampler { 212f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur public: 213f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // The vocab_file is assumed to be a CSV, with the last entry of each row a 214f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // value representing the counts or probabilities for the corresponding ID. 215f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur FixedUnigramSampler(Env* env, int64 range, const string& vocab_file, 216f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float distortion, int32 num_reserved_ids, 217f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int32 num_shards, int32 shard); 218f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 219f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur FixedUnigramSampler(int64 range, const std::vector<float>& unigrams, 220f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float distortion, int32 num_reserved_ids, 221f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int32 num_shards, int32 shard); 222f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 223f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float Probability(int64 value) const override; 224f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 225f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int64 Sample(random::SimplePhilox* rnd) const override; 226f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 227f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur private: 228f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Underlying distribution sampler. 229f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::unique_ptr<random::DistributionSampler> dist_sampler_; 230f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Weights for individual samples. The probability of a sample i is defined 231f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // as weights_.at(i) / total_weight_. 232f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur std::vector<float> weights_; 233f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // The total weights of all samples. 234f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur float total_weight_; 235f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Sharding information of the sampler. The whole vocabulary is sharded 236f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // into num_shards_ smaller ranges and each sampler is responsible for one 237f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // such smaller range, identified by the shard number. 238f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int32 num_shards_; 239f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur int32 shard_; 240f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 241f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Fill the sampler with the appropriate number of reserved IDs. 242f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void FillReservedIds(int32 num_reserved_ids); 243f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Load IDs to sample from a CSV file. It is assumed that the last item of 244f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // each row contains a count or probability for the corresponding ID. 245f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur Status LoadFromFile(Env* env, const string& vocab_file, float distortion); 246f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur // Load from an in-memory array. 247f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur void LoadFromUnigrams(const std::vector<float>& unigrams, float distortion); 248f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur}; 249f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 250f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur} // namespace tensorflow 251f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur 252f41959ccb2d9d4c722fe8fc3351401d53bcf490Manjunath Kudlur#endif // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_ 253