1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for ops.gmm.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.contrib.factorization.python.ops import gmm as gmm_lib 24from tensorflow.contrib.learn.python.learn.estimators import kmeans 25from tensorflow.contrib.learn.python.learn.estimators import run_config 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import random_seed as random_seed_lib 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import data_flow_ops 31from tensorflow.python.ops import random_ops 32from tensorflow.python.platform import test 33from tensorflow.python.training import queue_runner 34 35 36class GMMTest(test.TestCase): 37 38 def input_fn(self, batch_size=None, points=None): 39 batch_size = batch_size or self.batch_size 40 points = points if points is not None else self.points 41 num_points = points.shape[0] 42 43 def _fn(): 44 x = constant_op.constant(points) 45 if batch_size == num_points: 46 return x, None 47 indices = random_ops.random_uniform(constant_op.constant([batch_size]), 48 minval=0, maxval=num_points-1, 49 dtype=dtypes.int32, 50 seed=10) 51 return array_ops.gather(x, indices), None 52 return _fn 53 54 def setUp(self): 55 np.random.seed(3) 56 random_seed_lib.set_random_seed(2) 57 self.num_centers = 2 58 self.num_dims = 2 59 self.num_points = 4000 60 self.batch_size = self.num_points 61 self.true_centers = self.make_random_centers(self.num_centers, 62 self.num_dims) 63 self.points, self.assignments = self.make_random_points( 64 self.true_centers, self.num_points) 65 66 # Use initial means from kmeans (just like scikit-learn does). 67 clusterer = kmeans.KMeansClustering(num_clusters=self.num_centers) 68 clusterer.fit(input_fn=lambda: (constant_op.constant(self.points), None), 69 steps=30) 70 self.initial_means = clusterer.clusters() 71 72 @staticmethod 73 def make_random_centers(num_centers, num_dims): 74 return np.round( 75 np.random.rand(num_centers, num_dims).astype(np.float32) * 500) 76 77 @staticmethod 78 def make_random_points(centers, num_points): 79 num_centers, num_dims = centers.shape 80 assignments = np.random.choice(num_centers, num_points) 81 offsets = np.round( 82 np.random.randn(num_points, num_dims).astype(np.float32) * 20) 83 points = centers[assignments] + offsets 84 return (points, assignments) 85 86 def test_weights(self): 87 """Tests the shape of the weights.""" 88 gmm = gmm_lib.GMM(self.num_centers, 89 initial_clusters=self.initial_means, 90 random_seed=4, 91 config=run_config.RunConfig(tf_random_seed=2)) 92 gmm.fit(input_fn=self.input_fn(), steps=0) 93 weights = gmm.weights() 94 self.assertAllEqual(list(weights.shape), [self.num_centers]) 95 96 def test_clusters(self): 97 """Tests the shape of the clusters.""" 98 gmm = gmm_lib.GMM(self.num_centers, 99 initial_clusters=self.initial_means, 100 random_seed=4, 101 config=run_config.RunConfig(tf_random_seed=2)) 102 gmm.fit(input_fn=self.input_fn(), steps=0) 103 clusters = gmm.clusters() 104 self.assertAllEqual(list(clusters.shape), [self.num_centers, self.num_dims]) 105 106 def test_fit(self): 107 gmm = gmm_lib.GMM(self.num_centers, 108 initial_clusters='random', 109 random_seed=4, 110 config=run_config.RunConfig(tf_random_seed=2)) 111 gmm.fit(input_fn=self.input_fn(), steps=1) 112 score1 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points), 113 steps=1) 114 gmm.fit(input_fn=self.input_fn(), steps=10) 115 score2 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points), 116 steps=1) 117 self.assertLess(score1, score2) 118 119 def test_infer(self): 120 gmm = gmm_lib.GMM(self.num_centers, 121 initial_clusters=self.initial_means, 122 random_seed=4, 123 config=run_config.RunConfig(tf_random_seed=2)) 124 gmm.fit(input_fn=self.input_fn(), steps=60) 125 clusters = gmm.clusters() 126 127 # Make a small test set 128 num_points = 40 129 points, true_assignments = self.make_random_points(clusters, num_points) 130 131 assignments = [] 132 for item in gmm.predict_assignments( 133 input_fn=self.input_fn(points=points, batch_size=num_points)): 134 assignments.append(item) 135 assignments = np.ravel(assignments) 136 self.assertAllEqual(true_assignments, assignments) 137 138 def _compare_with_sklearn(self, cov_type): 139 # sklearn version. 140 iterations = 40 141 np.random.seed(5) 142 sklearn_assignments = np.asarray([0, 0, 1, 0, 0, 0, 1, 0, 0, 1]) 143 sklearn_means = np.asarray([[144.83417719, 254.20130341], 144 [274.38754816, 353.16074346]]) 145 sklearn_covs = np.asarray([[[395.0081194, -4.50389512], 146 [-4.50389512, 408.27543989]], 147 [[385.17484203, -31.27834935], 148 [-31.27834935, 391.74249925]]]) 149 150 # skflow version. 151 gmm = gmm_lib.GMM(self.num_centers, 152 initial_clusters=self.initial_means, 153 covariance_type=cov_type, 154 config=run_config.RunConfig(tf_random_seed=2)) 155 gmm.fit(input_fn=self.input_fn(), steps=iterations) 156 points = self.points[:10, :] 157 skflow_assignments = [] 158 for item in gmm.predict_assignments( 159 input_fn=self.input_fn(points=points, batch_size=10)): 160 skflow_assignments.append(item) 161 self.assertAllClose(sklearn_assignments, 162 np.ravel(skflow_assignments).astype(int)) 163 self.assertAllClose(sklearn_means, gmm.clusters()) 164 if cov_type == 'full': 165 self.assertAllClose(sklearn_covs, gmm.covariances(), rtol=0.01) 166 else: 167 for d in [0, 1]: 168 self.assertAllClose( 169 np.diag(sklearn_covs[d]), gmm.covariances()[d, :], rtol=0.01) 170 171 def test_compare_full(self): 172 self._compare_with_sklearn('full') 173 174 def test_compare_diag(self): 175 self._compare_with_sklearn('diag') 176 177 def test_random_input_large(self): 178 # sklearn version. 179 iterations = 5 # that should be enough to know whether this diverges 180 np.random.seed(5) 181 num_classes = 20 182 x = np.array([[np.random.random() for _ in range(100)] 183 for _ in range(num_classes)], dtype=np.float32) 184 185 # skflow version. 186 gmm = gmm_lib.GMM(num_classes, 187 covariance_type='full', 188 config=run_config.RunConfig(tf_random_seed=2)) 189 190 def get_input_fn(x): 191 def input_fn(): 192 return constant_op.constant(x.astype(np.float32)), None 193 return input_fn 194 195 gmm.fit(input_fn=get_input_fn(x), steps=iterations) 196 self.assertFalse(np.isnan(gmm.clusters()).any()) 197 198 199class GMMTestQueues(test.TestCase): 200 201 def input_fn(self): 202 def _fn(): 203 queue = data_flow_ops.FIFOQueue(capacity=10, 204 dtypes=dtypes.float32, 205 shapes=[10, 3]) 206 enqueue_op = queue.enqueue(array_ops.zeros([10, 3], dtype=dtypes.float32)) 207 queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, 208 [enqueue_op])) 209 return queue.dequeue(), None 210 return _fn 211 212 # This test makes sure that there are no deadlocks when using a QueueRunner. 213 # Note that since cluster initialization is dependendent on inputs, if input 214 # is generated using a QueueRunner, one has to make sure that these runners 215 # are started before the initialization. 216 def test_queues(self): 217 gmm = gmm_lib.GMM(2, covariance_type='diag') 218 gmm.fit(input_fn=self.input_fn(), steps=1) 219 220 221if __name__ == '__main__': 222 test.main() 223