1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <vector>
17
18#include "tensorflow/core/kernels/range_sampler.h"
19#include "tensorflow/core/lib/core/status_test_util.h"
20#include "tensorflow/core/lib/io/path.h"
21#include "tensorflow/core/lib/random/simple_philox.h"
22#include "tensorflow/core/platform/env.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/platform/test.h"
25
26namespace tensorflow {
27namespace {
28
29using gtl::ArraySlice;
30using gtl::MutableArraySlice;
31
32class RangeSamplerTest : public ::testing::Test {
33 protected:
34  void CheckProbabilitiesSumToOne() {
35    double sum = 0;
36    for (int i = 0; i < sampler_->range(); i++) {
37      sum += sampler_->Probability(i);
38    }
39    EXPECT_NEAR(sum, 1.0, 1e-4);
40  }
41  void CheckHistogram(int num_samples, float tolerance) {
42    const int range = sampler_->range();
43    std::vector<int> h(range);
44    std::vector<int64> a(num_samples);
45    // Using a fixed random seed to make the test deterministic.
46    random::PhiloxRandom philox(123, 17);
47    random::SimplePhilox rnd(&philox);
48    sampler_->SampleBatch(&rnd, false, &a);
49    for (int i = 0; i < num_samples; i++) {
50      int64 val = a[i];
51      ASSERT_GE(val, 0);
52      ASSERT_LT(val, range);
53      h[val]++;
54    }
55    for (int val = 0; val < range; val++) {
56      EXPECT_NEAR((h[val] + 0.0) / num_samples, sampler_->Probability(val),
57                  tolerance);
58    }
59  }
60  void Update1() {
61    // Add the value 3 ten times.
62    std::vector<int64> a(10);
63    for (int i = 0; i < 10; i++) {
64      a[i] = 3;
65    }
66    sampler_->Update(a);
67  }
68  void Update2() {
69    // Add the value n times.
70    int64 a[10];
71    for (int i = 0; i < 10; i++) {
72      a[i] = i;
73    }
74    for (int64 i = 1; i < 10; i++) {
75      sampler_->Update(ArraySlice<int64>(a + i, 10 - i));
76    }
77  }
78  std::unique_ptr<RangeSampler> sampler_;
79};
80
81TEST_F(RangeSamplerTest, UniformProbabilities) {
82  sampler_.reset(new UniformSampler(10));
83  for (int i = 0; i < 10; i++) {
84    CHECK_EQ(sampler_->Probability(i), sampler_->Probability(0));
85  }
86}
87
88TEST_F(RangeSamplerTest, UniformChecksum) {
89  sampler_.reset(new UniformSampler(10));
90  CheckProbabilitiesSumToOne();
91}
92
93TEST_F(RangeSamplerTest, UniformHistogram) {
94  sampler_.reset(new UniformSampler(10));
95  CheckHistogram(1000, 0.05);
96}
97
98TEST_F(RangeSamplerTest, LogUniformProbabilities) {
99  int range = 1000000;
100  sampler_.reset(new LogUniformSampler(range));
101  for (int i = 100; i < range; i *= 2) {
102    float ratio = sampler_->Probability(i) / sampler_->Probability(i / 2);
103    EXPECT_NEAR(ratio, 0.5, 0.1);
104  }
105}
106
107TEST_F(RangeSamplerTest, LogUniformChecksum) {
108  sampler_.reset(new LogUniformSampler(10));
109  CheckProbabilitiesSumToOne();
110}
111
112TEST_F(RangeSamplerTest, LogUniformHistogram) {
113  sampler_.reset(new LogUniformSampler(10));
114  CheckHistogram(1000, 0.05);
115}
116
117TEST_F(RangeSamplerTest, UnigramProbabilities1) {
118  sampler_.reset(new UnigramSampler(10));
119  Update1();
120  EXPECT_NEAR(sampler_->Probability(3), 0.55, 1e-4);
121  for (int i = 0; i < 10; i++) {
122    if (i != 3) {
123      ASSERT_NEAR(sampler_->Probability(i), 0.05, 1e-4);
124    }
125  }
126}
127TEST_F(RangeSamplerTest, UnigramProbabilities2) {
128  sampler_.reset(new UnigramSampler(10));
129  Update2();
130  for (int i = 0; i < 10; i++) {
131    ASSERT_NEAR(sampler_->Probability(i), (i + 1) / 55.0, 1e-4);
132  }
133}
134TEST_F(RangeSamplerTest, UnigramChecksum) {
135  sampler_.reset(new UnigramSampler(10));
136  Update1();
137  CheckProbabilitiesSumToOne();
138}
139
140TEST_F(RangeSamplerTest, UnigramHistogram) {
141  sampler_.reset(new UnigramSampler(10));
142  Update1();
143  CheckHistogram(1000, 0.05);
144}
145
146static const char kVocabContent[] =
147    "w1,1\n"
148    "w2,2\n"
149    "w3,4\n"
150    "w4,8\n"
151    "w5,16\n"
152    "w6,32\n"
153    "w7,64\n"
154    "w8,128\n"
155    "w9,256";
156TEST_F(RangeSamplerTest, FixedUnigramProbabilities) {
157  Env* env = Env::Default();
158  string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
159  TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
160  sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
161  // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
162  for (int i = 0; i < 9; i++) {
163    ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
164  }
165}
166TEST_F(RangeSamplerTest, FixedUnigramChecksum) {
167  Env* env = Env::Default();
168  string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
169  TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
170  sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
171  CheckProbabilitiesSumToOne();
172}
173
174TEST_F(RangeSamplerTest, FixedUnigramHistogram) {
175  Env* env = Env::Default();
176  string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
177  TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
178  sampler_.reset(new FixedUnigramSampler(env, 9, fname, 0.8, 0, 1, 0));
179  CheckHistogram(1000, 0.05);
180}
181TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1) {
182  Env* env = Env::Default();
183  string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
184  TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
185  sampler_.reset(new FixedUnigramSampler(env, 10, fname, 0.8, 1, 1, 0));
186  ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
187  // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
188  for (int i = 1; i < 10; i++) {
189    ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
190  }
191}
192TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2) {
193  Env* env = Env::Default();
194  string fname = io::JoinPath(testing::TmpDir(), "vocab_file");
195  TF_CHECK_OK(WriteStringToFile(env, fname, kVocabContent));
196  sampler_.reset(new FixedUnigramSampler(env, 11, fname, 0.8, 2, 1, 0));
197  ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
198  ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
199  // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
200  for (int i = 2; i < 11; i++) {
201    ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
202  }
203}
204TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesFromVector) {
205  std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
206  sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
207  // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
208  for (int i = 0; i < 9; i++) {
209    ASSERT_NEAR(sampler_->Probability(i), pow(2, i * 0.8) / 197.05, 1e-4);
210  }
211}
212TEST_F(RangeSamplerTest, FixedUnigramChecksumFromVector) {
213  std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
214  sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
215  CheckProbabilitiesSumToOne();
216}
217TEST_F(RangeSamplerTest, FixedUnigramHistogramFromVector) {
218  std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
219  sampler_.reset(new FixedUnigramSampler(9, weights, 0.8, 0, 1, 0));
220  CheckHistogram(1000, 0.05);
221}
222TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve1FromVector) {
223  std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
224  sampler_.reset(new FixedUnigramSampler(10, weights, 0.8, 1, 1, 0));
225  ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
226  // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
227  for (int i = 1; i < 10; i++) {
228    ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 1) * 0.8) / 197.05, 1e-4);
229  }
230}
231TEST_F(RangeSamplerTest, FixedUnigramProbabilitiesReserve2FromVector) {
232  std::vector<float> weights = {1, 2, 4, 8, 16, 32, 64, 128, 256};
233  sampler_.reset(new FixedUnigramSampler(11, weights, 0.8, 2, 1, 0));
234  ASSERT_NEAR(sampler_->Probability(0), 0, 1e-4);
235  ASSERT_NEAR(sampler_->Probability(1), 0, 1e-4);
236  // 1^0.8+2^0.8+4^0.8+...+256^0.8=197.05
237  for (int i = 2; i < 11; i++) {
238    ASSERT_NEAR(sampler_->Probability(i), pow(2, (i - 2) * 0.8) / 197.05, 1e-4);
239  }
240}
241
242// AllSampler cannot call Sample or Probability directly.
243// We will test SampleBatchGetExpectedCount instead.
244TEST_F(RangeSamplerTest, All) {
245  int batch_size = 10;
246  sampler_.reset(new AllSampler(10));
247  std::vector<int64> batch(batch_size);
248  std::vector<float> batch_expected(batch_size);
249  std::vector<int64> extras(2);
250  std::vector<float> extras_expected(2);
251  extras[0] = 0;
252  extras[1] = batch_size - 1;
253  sampler_->SampleBatchGetExpectedCount(nullptr,  // no random numbers needed
254                                        false, &batch, &batch_expected, extras,
255                                        &extras_expected);
256  for (int i = 0; i < batch_size; i++) {
257    EXPECT_EQ(i, batch[i]);
258    EXPECT_EQ(1, batch_expected[i]);
259  }
260  EXPECT_EQ(1, extras_expected[0]);
261  EXPECT_EQ(1, extras_expected[1]);
262}
263
264TEST_F(RangeSamplerTest, Unique) {
265  // We sample num_batches batches, each without replacement.
266  //
267  // We check that the returned expected counts roughly agree with each other
268  // and with the average observed frequencies over the set of batches.
269  random::PhiloxRandom philox(123, 17);
270  random::SimplePhilox rnd(&philox);
271  const int range = 100;
272  const int batch_size = 50;
273  const int num_batches = 100;
274  sampler_.reset(new LogUniformSampler(range));
275  std::vector<int> histogram(range);
276  std::vector<int64> batch(batch_size);
277  std::vector<int64> all_values(range);
278  for (int i = 0; i < range; i++) {
279    all_values[i] = i;
280  }
281  std::vector<float> expected(range);
282
283  // Sample one batch and get the expected counts of all values
284  sampler_->SampleBatchGetExpectedCount(
285      &rnd, true, &batch, MutableArraySlice<float>(), all_values, &expected);
286  // Check that all elements are unique
287  std::set<int64> s(batch.begin(), batch.end());
288  CHECK_EQ(batch_size, s.size());
289
290  for (int trial = 0; trial < num_batches; trial++) {
291    std::vector<float> trial_expected(range);
292    sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch,
293                                          MutableArraySlice<float>(),
294                                          all_values, &trial_expected);
295    for (int i = 0; i < range; i++) {
296      EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
297    }
298    for (int i = 0; i < batch_size; i++) {
299      histogram[batch[i]]++;
300    }
301  }
302  for (int i = 0; i < range; i++) {
303    // Check that the computed expected count agrees with the average observed
304    // count.
305    const float average_count = static_cast<float>(histogram[i]) / num_batches;
306    EXPECT_NEAR(expected[i], average_count, 0.2);
307  }
308}
309
310TEST_F(RangeSamplerTest, Avoid) {
311  random::PhiloxRandom philox(123, 17);
312  random::SimplePhilox rnd(&philox);
313  sampler_.reset(new LogUniformSampler(100));
314  std::vector<int64> avoided(2);
315  avoided[0] = 17;
316  avoided[1] = 23;
317  std::vector<int64> batch(98);
318
319  // We expect to pick all elements of [0, 100) except the avoided two.
320  sampler_->SampleBatchGetExpectedCountAvoid(
321      &rnd, true, &batch, MutableArraySlice<float>(), ArraySlice<int64>(),
322      MutableArraySlice<float>(), avoided);
323
324  int sum = 0;
325  for (auto val : batch) {
326    sum += val;
327  }
328  const int expected_sum = 100 * 99 / 2 - avoided[0] - avoided[1];
329  EXPECT_EQ(expected_sum, sum);
330}
331
332}  // namespace
333
334}  // namespace tensorflow
335