1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for CandidateSamplerOp."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import candidate_sampling_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import test
29
30
31class RangeSamplerOpsTest(test.TestCase):
32
33  BATCH_SIZE = 3
34  NUM_TRUE = 2
35  RANGE = 5
36  NUM_SAMPLED = RANGE
37
38  TRUE_LABELS = [[1, 2], [0, 4], [3, 3]]
39
40  def testTrueCandidates(self):
41    with self.test_session() as sess:
42      indices = constant_op.constant([0, 0, 1, 1, 2, 2])
43      true_candidates_vec = constant_op.constant([1, 2, 0, 4, 3, 3])
44      true_candidates_matrix = array_ops.reshape(
45          true_candidates_vec, [self.BATCH_SIZE, self.NUM_TRUE])
46      indices_val, true_candidates_val = sess.run(
47          [indices, true_candidates_matrix])
48
49    self.assertAllEqual(indices_val, [0, 0, 1, 1, 2, 2])
50    self.assertAllEqual(true_candidates_val, self.TRUE_LABELS)
51
52  def testSampledCandidates(self):
53    with self.test_session():
54      true_classes = constant_op.constant(
55          [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
56      sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
57          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
58      result = sampled_candidates.eval()
59
60    expected_ids = [0, 1, 2, 3, 4]
61    self.assertAllEqual(result, expected_ids)
62    self.assertEqual(sampled_candidates.get_shape(), [self.NUM_SAMPLED])
63
64  def testTrueLogExpectedCount(self):
65    with self.test_session():
66      true_classes = constant_op.constant(
67          [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
68      _, true_expected_count, _ = candidate_sampling_ops.all_candidate_sampler(
69          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
70      true_log_expected_count = math_ops.log(true_expected_count)
71      result = true_log_expected_count.eval()
72
73    self.assertAllEqual(result, [[0.0] * self.NUM_TRUE] * self.BATCH_SIZE)
74    self.assertEqual(true_expected_count.get_shape(),
75                     [self.BATCH_SIZE, self.NUM_TRUE])
76    self.assertEqual(true_log_expected_count.get_shape(),
77                     [self.BATCH_SIZE, self.NUM_TRUE])
78
79  def testSampledLogExpectedCount(self):
80    with self.test_session():
81      true_classes = constant_op.constant(
82          [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
83      _, _, sampled_expected_count = candidate_sampling_ops.all_candidate_sampler(  # pylint: disable=line-too-long
84          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
85      sampled_log_expected_count = math_ops.log(sampled_expected_count)
86      result = sampled_log_expected_count.eval()
87
88    self.assertAllEqual(result, [0.0] * self.NUM_SAMPLED)
89    self.assertEqual(sampled_expected_count.get_shape(), [self.NUM_SAMPLED])
90    self.assertEqual(sampled_log_expected_count.get_shape(), [self.NUM_SAMPLED])
91
92  def testAccidentalHits(self):
93    with self.test_session() as sess:
94      true_classes = constant_op.constant(
95          [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
96      sampled_candidates, _, _ = candidate_sampling_ops.all_candidate_sampler(
97          true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True)
98      accidental_hits = candidate_sampling_ops.compute_accidental_hits(
99          true_classes, sampled_candidates, self.NUM_TRUE)
100      indices, ids, weights = sess.run(accidental_hits)
101
102    self.assertEqual(1, accidental_hits[0].get_shape().ndims)
103    self.assertEqual(1, accidental_hits[1].get_shape().ndims)
104    self.assertEqual(1, accidental_hits[2].get_shape().ndims)
105    for index, id_, weight in zip(indices, ids, weights):
106      self.assertTrue(id_ in self.TRUE_LABELS[index])
107      self.assertLess(weight, -1.0e37)
108
109  def testSeed(self):
110
111    def draw(seed):
112      with self.test_session():
113        true_classes = constant_op.constant(
114            [[1, 2], [0, 4], [3, 3]], dtype=dtypes.int64)
115        sampled, _, _ = candidate_sampling_ops.log_uniform_candidate_sampler(
116            true_classes, self.NUM_TRUE, self.NUM_SAMPLED, True, 5, seed=seed)
117        return sampled.eval()
118
119    # Non-zero seed. Repeatable.
120    for seed in [1, 12, 123, 1234]:
121      self.assertAllEqual(draw(seed), draw(seed))
122    # Seed=0 means random seeds.
123    num_same = 0
124    for _ in range(10):
125      if np.allclose(draw(None), draw(None)):
126        num_same += 1
127    # Accounts for the fact that the same random seed may be picked
128    # twice very rarely.
129    self.assertLessEqual(num_same, 2)
130
131
132if __name__ == "__main__":
133  test.main()
134