1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ops.gmm."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.contrib.factorization.python.ops import gmm as gmm_lib
24from tensorflow.contrib.learn.python.learn.estimators import kmeans
25from tensorflow.contrib.learn.python.learn.estimators import run_config
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import random_seed as random_seed_lib
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import data_flow_ops
31from tensorflow.python.ops import random_ops
32from tensorflow.python.platform import test
33from tensorflow.python.training import queue_runner
34
35
36class GMMTest(test.TestCase):
37
38  def input_fn(self, batch_size=None, points=None):
39    batch_size = batch_size or self.batch_size
40    points = points if points is not None else self.points
41    num_points = points.shape[0]
42
43    def _fn():
44      x = constant_op.constant(points)
45      if batch_size == num_points:
46        return x, None
47      indices = random_ops.random_uniform(constant_op.constant([batch_size]),
48                                          minval=0, maxval=num_points-1,
49                                          dtype=dtypes.int32,
50                                          seed=10)
51      return array_ops.gather(x, indices), None
52    return _fn
53
54  def setUp(self):
55    np.random.seed(3)
56    random_seed_lib.set_random_seed(2)
57    self.num_centers = 2
58    self.num_dims = 2
59    self.num_points = 4000
60    self.batch_size = self.num_points
61    self.true_centers = self.make_random_centers(self.num_centers,
62                                                 self.num_dims)
63    self.points, self.assignments = self.make_random_points(
64        self.true_centers, self.num_points)
65
66    # Use initial means from kmeans (just like scikit-learn does).
67    clusterer = kmeans.KMeansClustering(num_clusters=self.num_centers)
68    clusterer.fit(input_fn=lambda: (constant_op.constant(self.points), None),
69                  steps=30)
70    self.initial_means = clusterer.clusters()
71
72  @staticmethod
73  def make_random_centers(num_centers, num_dims):
74    return np.round(
75        np.random.rand(num_centers, num_dims).astype(np.float32) * 500)
76
77  @staticmethod
78  def make_random_points(centers, num_points):
79    num_centers, num_dims = centers.shape
80    assignments = np.random.choice(num_centers, num_points)
81    offsets = np.round(
82        np.random.randn(num_points, num_dims).astype(np.float32) * 20)
83    points = centers[assignments] + offsets
84    return (points, assignments)
85
86  def test_weights(self):
87    """Tests the shape of the weights."""
88    gmm = gmm_lib.GMM(self.num_centers,
89                      initial_clusters=self.initial_means,
90                      random_seed=4,
91                      config=run_config.RunConfig(tf_random_seed=2))
92    gmm.fit(input_fn=self.input_fn(), steps=0)
93    weights = gmm.weights()
94    self.assertAllEqual(list(weights.shape), [self.num_centers])
95
96  def test_clusters(self):
97    """Tests the shape of the clusters."""
98    gmm = gmm_lib.GMM(self.num_centers,
99                      initial_clusters=self.initial_means,
100                      random_seed=4,
101                      config=run_config.RunConfig(tf_random_seed=2))
102    gmm.fit(input_fn=self.input_fn(), steps=0)
103    clusters = gmm.clusters()
104    self.assertAllEqual(list(clusters.shape), [self.num_centers, self.num_dims])
105
106  def test_fit(self):
107    gmm = gmm_lib.GMM(self.num_centers,
108                      initial_clusters='random',
109                      random_seed=4,
110                      config=run_config.RunConfig(tf_random_seed=2))
111    gmm.fit(input_fn=self.input_fn(), steps=1)
112    score1 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points),
113                       steps=1)
114    gmm.fit(input_fn=self.input_fn(), steps=10)
115    score2 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points),
116                       steps=1)
117    self.assertLess(score1, score2)
118
119  def test_infer(self):
120    gmm = gmm_lib.GMM(self.num_centers,
121                      initial_clusters=self.initial_means,
122                      random_seed=4,
123                      config=run_config.RunConfig(tf_random_seed=2))
124    gmm.fit(input_fn=self.input_fn(), steps=60)
125    clusters = gmm.clusters()
126
127    # Make a small test set
128    num_points = 40
129    points, true_assignments = self.make_random_points(clusters, num_points)
130
131    assignments = []
132    for item in gmm.predict_assignments(
133        input_fn=self.input_fn(points=points, batch_size=num_points)):
134      assignments.append(item)
135    assignments = np.ravel(assignments)
136    self.assertAllEqual(true_assignments, assignments)
137
138  def _compare_with_sklearn(self, cov_type):
139    # sklearn version.
140    iterations = 40
141    np.random.seed(5)
142    sklearn_assignments = np.asarray([0, 0, 1, 0, 0, 0, 1, 0, 0, 1])
143    sklearn_means = np.asarray([[144.83417719, 254.20130341],
144                                [274.38754816, 353.16074346]])
145    sklearn_covs = np.asarray([[[395.0081194, -4.50389512],
146                                [-4.50389512, 408.27543989]],
147                               [[385.17484203, -31.27834935],
148                                [-31.27834935, 391.74249925]]])
149
150    # skflow version.
151    gmm = gmm_lib.GMM(self.num_centers,
152                      initial_clusters=self.initial_means,
153                      covariance_type=cov_type,
154                      config=run_config.RunConfig(tf_random_seed=2))
155    gmm.fit(input_fn=self.input_fn(), steps=iterations)
156    points = self.points[:10, :]
157    skflow_assignments = []
158    for item in gmm.predict_assignments(
159        input_fn=self.input_fn(points=points, batch_size=10)):
160      skflow_assignments.append(item)
161    self.assertAllClose(sklearn_assignments,
162                        np.ravel(skflow_assignments).astype(int))
163    self.assertAllClose(sklearn_means, gmm.clusters())
164    if cov_type == 'full':
165      self.assertAllClose(sklearn_covs, gmm.covariances(), rtol=0.01)
166    else:
167      for d in [0, 1]:
168        self.assertAllClose(
169            np.diag(sklearn_covs[d]), gmm.covariances()[d, :], rtol=0.01)
170
171  def test_compare_full(self):
172    self._compare_with_sklearn('full')
173
174  def test_compare_diag(self):
175    self._compare_with_sklearn('diag')
176
177  def test_random_input_large(self):
178    # sklearn version.
179    iterations = 5  # that should be enough to know whether this diverges
180    np.random.seed(5)
181    num_classes = 20
182    x = np.array([[np.random.random() for _ in range(100)]
183                  for _ in range(num_classes)], dtype=np.float32)
184
185    # skflow version.
186    gmm = gmm_lib.GMM(num_classes,
187                      covariance_type='full',
188                      config=run_config.RunConfig(tf_random_seed=2))
189
190    def get_input_fn(x):
191      def input_fn():
192        return constant_op.constant(x.astype(np.float32)), None
193      return input_fn
194
195    gmm.fit(input_fn=get_input_fn(x), steps=iterations)
196    self.assertFalse(np.isnan(gmm.clusters()).any())
197
198
199class GMMTestQueues(test.TestCase):
200
201  def input_fn(self):
202    def _fn():
203      queue = data_flow_ops.FIFOQueue(capacity=10,
204                                      dtypes=dtypes.float32,
205                                      shapes=[10, 3])
206      enqueue_op = queue.enqueue(array_ops.zeros([10, 3], dtype=dtypes.float32))
207      queue_runner.add_queue_runner(queue_runner.QueueRunner(queue,
208                                                             [enqueue_op]))
209      return queue.dequeue(), None
210    return _fn
211
212  # This test makes sure that there are no deadlocks when using a QueueRunner.
213  # Note that since cluster initialization is dependendent on inputs, if input
214  # is generated using a QueueRunner, one has to make sure that these runners
215  # are started before the initialization.
216  def test_queues(self):
217    gmm = gmm_lib.GMM(2, covariance_type='diag')
218    gmm.fit(input_fn=self.input_fn(), steps=1)
219
220
221if __name__ == '__main__':
222  test.main()
223