1f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# 3f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License"); 4f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# you may not use this file except in compliance with the License. 5f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# You may obtain a copy of the License at 6f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# 7f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# http://www.apache.org/licenses/LICENSE-2.0 8f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# 9f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software 10f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS, 11f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# See the License for the specific language governing permissions and 13f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# limitations under the License. 14f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# ============================================================================== 15f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower"""Tests for gmm_ops.""" 16f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 17f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom __future__ import absolute_import 18f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom __future__ import division 19f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom __future__ import print_function 20f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 21f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerimport time 22f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 23f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerimport numpy as np 24f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom six.moves import xrange # pylint: disable=redefined-builtin 25f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 26f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom tensorflow.contrib.factorization.python.ops import gmm_ops 27e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import constant_op 28e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import dtypes 29e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import ops 30e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import random_seed as random_seed_lib 31e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.ops import variables 32e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import test 33f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging 34f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 35f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 36e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyclass GmmOpsTest(test.TestCase): 37f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 38f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower def setUp(self): 39f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.num_examples = 1000 40f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.iterations = 40 41f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.seed = 4 42e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney random_seed_lib.set_random_seed(self.seed) 43f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.random.seed(self.seed * 2) 44f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.data, self.true_assignments = self.make_data(self.num_examples) 45f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower # Generate more complicated data. 46f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.centers = [[1, 1], [-1, 0.5], [2, 1]] 47f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.more_data, self.more_true_assignments = self.make_data_from_centers( 48f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.num_examples, self.centers) 49f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 50f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower @staticmethod 51f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower def make_data(num_vectors): 52f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower """Generates 2-dimensional data centered on (2,2), (-1,-1). 53f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 54f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower Args: 55f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower num_vectors: number of training examples. 56f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 57f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower Returns: 58f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower A tuple containing the data as a numpy array and the cluster ids. 59f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower """ 60f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower vectors = [] 61f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower classes = [] 62f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower for _ in xrange(num_vectors): 63f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower if np.random.random() > 0.5: 64e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney vectors.append([np.random.normal(2.0, 0.6), np.random.normal(2.0, 0.9)]) 65f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower classes.append(0) 66f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower else: 67e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney vectors.append( 68e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [np.random.normal(-1.0, 0.4), np.random.normal(-1.0, 0.5)]) 69f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower classes.append(1) 70f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower return np.asarray(vectors), classes 71f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 72f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower @staticmethod 73f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower def make_data_from_centers(num_vectors, centers): 74f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower """Generates 2-dimensional data with random centers. 75f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 76f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower Args: 77f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower num_vectors: number of training examples. 78f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower centers: a list of random 2-dimensional centers. 79f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 80f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower Returns: 81f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower A tuple containing the data as a numpy array and the cluster ids. 82f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower """ 83f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower vectors = [] 84f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower classes = [] 85f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower for _ in xrange(num_vectors): 86f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower current_class = np.random.random_integers(0, len(centers) - 1) 87e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney vectors.append([ 88e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney np.random.normal(centers[current_class][0], 89e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney np.random.random_sample()), 90e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney np.random.normal(centers[current_class][1], np.random.random_sample()) 91e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney ]) 92f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower classes.append(current_class) 93f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower return np.asarray(vectors), len(centers) 94f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 95f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower def test_covariance(self): 96f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower start_time = time.time() 97f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower data = self.data.T 98f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np_cov = np.cov(data) 99f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower logging.info('Numpy took %f', time.time() - start_time) 100f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 101f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower start_time = time.time() 102f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower with self.test_session() as sess: 103f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower op = gmm_ops._covariance( 104e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney constant_op.constant( 105e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney data.T, dtype=dtypes.float32), False) 106f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower op_diag = gmm_ops._covariance( 107e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney constant_op.constant( 108e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney data.T, dtype=dtypes.float32), True) 109e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney variables.global_variables_initializer().run() 110f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower tf_cov = sess.run(op) 111f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_array_almost_equal(np_cov, tf_cov) 112f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower logging.info('Tensorflow took %f', time.time() - start_time) 113f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower tf_cov = sess.run(op_diag) 114f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_array_almost_equal( 115f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.diag(np_cov), np.ravel(tf_cov), decimal=5) 116f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 117f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower def test_simple_cluster(self): 118f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower """Tests that the clusters are correct.""" 119f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower num_classes = 2 120e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney graph = ops.Graph() 121f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower with graph.as_default() as g: 122f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower g.seed = 5 123f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower with self.test_session() as sess: 124e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney data = constant_op.constant(self.data, dtype=dtypes.float32) 125840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm( 126a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower data, 'random', num_classes, random_seed=self.seed) 127f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 128e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney variables.global_variables_initializer().run() 129a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower sess.run(init_op) 130840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower first_loss = sess.run(loss_op) 131f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower for _ in xrange(self.iterations): 132f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower sess.run(training_op) 133f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower assignments = sess.run(assignments) 134840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower end_loss = sess.run(loss_op) 135840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower scores = sess.run(scores) 136840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower self.assertEqual((self.num_examples, 1), scores.shape) 137f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower accuracy = np.mean( 138f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.asarray(self.true_assignments) == np.squeeze(assignments)) 139f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower logging.info('Accuracy: %f', accuracy) 140840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower logging.info('First loss: %f, end loss: %f', first_loss, end_loss) 141840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower self.assertGreater(end_loss, first_loss) 142f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.assertGreater(accuracy, 0.98) 143f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 144f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower def testParams(self): 145f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower """Tests that the params work as intended.""" 146f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower num_classes = 2 147f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower with self.test_session() as sess: 148f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower # Experiment 1. Update weights only. 149e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney data = constant_op.constant(self.data, dtype=dtypes.float32) 150e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes, 151e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [[3.0, 3.0], [0.0, 0.0]], 'w') 152f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower training_ops = gmm_tool.training_ops() 153e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney variables.global_variables_initializer().run() 154a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower sess.run(gmm_tool.init_ops()) 155f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower for _ in xrange(self.iterations): 156f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower sess.run(training_ops) 157f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 158f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower # Only the probability to each class is updated. 159f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower alphas = sess.run(gmm_tool.alphas()) 160f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.assertGreater(alphas[1], 0.6) 161f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower means = sess.run(gmm_tool.clusters()) 162f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_almost_equal( 163f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.expand_dims([[3.0, 3.0], [0.0, 0.0]], 1), means) 164f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower covs = sess.run(gmm_tool.covariances()) 165f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_almost_equal(covs[0], covs[1]) 166f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 167f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower # Experiment 2. Update means and covariances. 168e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes, 169e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [[3.0, 3.0], [0.0, 0.0]], 'mc') 170f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower training_ops = gmm_tool.training_ops() 171e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney variables.global_variables_initializer().run() 172a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower sess.run(gmm_tool.init_ops()) 173f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower for _ in xrange(self.iterations): 174f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower sess.run(training_ops) 175f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower alphas = sess.run(gmm_tool.alphas()) 176f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.assertAlmostEqual(alphas[0], alphas[1]) 177f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower means = sess.run(gmm_tool.clusters()) 178f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_almost_equal( 179f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.expand_dims([[2.0, 2.0], [-1.0, -1.0]], 1), means, decimal=1) 180f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower covs = sess.run(gmm_tool.covariances()) 181f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_almost_equal( 182e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [[0.371111, -0.0050774], [-0.0050774, 0.8651744]], covs[0], decimal=4) 183f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_almost_equal( 184e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [[0.146976, 0.0259463], [0.0259463, 0.2543971]], covs[1], decimal=4) 185f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 186f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower # Experiment 3. Update covariances only. 187e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes, 188e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [[-1.0, -1.0], [1.0, 1.0]], 'c') 189f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower training_ops = gmm_tool.training_ops() 190e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney variables.global_variables_initializer().run() 191a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower sess.run(gmm_tool.init_ops()) 192f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower for _ in xrange(self.iterations): 193f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower sess.run(training_ops) 194f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower alphas = sess.run(gmm_tool.alphas()) 195f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower self.assertAlmostEqual(alphas[0], alphas[1]) 196f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower means = sess.run(gmm_tool.clusters()) 197f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_almost_equal( 198f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.expand_dims([[-1.0, -1.0], [1.0, 1.0]], 1), means) 199f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower covs = sess.run(gmm_tool.covariances()) 200f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_almost_equal( 201e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [[0.1299582, 0.0435872], [0.0435872, 0.2558578]], covs[0], decimal=5) 202f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower np.testing.assert_almost_equal( 203e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney [[3.195385, 2.6989155], [2.6989155, 3.3881593]], covs[1], decimal=5) 204f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 205f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower 206f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerif __name__ == '__main__': 207e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney test.main() 208