1f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower#
3f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
4f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# you may not use this file except in compliance with the License.
5f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# You may obtain a copy of the License at
6f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower#
7f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
8f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower#
9f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
10f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
11f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# See the License for the specific language governing permissions and
13f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# limitations under the License.
14f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower# ==============================================================================
15f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower"""Tests for gmm_ops."""
16f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
17f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom __future__ import absolute_import
18f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom __future__ import division
19f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom __future__ import print_function
20f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
21f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerimport time
22f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
23f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerimport numpy as np
24f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom six.moves import xrange  # pylint: disable=redefined-builtin
25f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
26f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom tensorflow.contrib.factorization.python.ops import gmm_ops
27e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import constant_op
28e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import dtypes
29e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import ops
30e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.framework import random_seed as random_seed_lib
31e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.ops import variables
32e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyfrom tensorflow.python.platform import test
33f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging
34f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
35f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
36e121667dc609de978a223c56ee906368d2c4ceefJustine Tunneyclass GmmOpsTest(test.TestCase):
37f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
38f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower  def setUp(self):
39f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    self.num_examples = 1000
40f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    self.iterations = 40
41f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    self.seed = 4
42e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    random_seed_lib.set_random_seed(self.seed)
43f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    np.random.seed(self.seed * 2)
44f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    self.data, self.true_assignments = self.make_data(self.num_examples)
45f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    # Generate more complicated data.
46f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    self.centers = [[1, 1], [-1, 0.5], [2, 1]]
47f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    self.more_data, self.more_true_assignments = self.make_data_from_centers(
48f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        self.num_examples, self.centers)
49f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
50f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower  @staticmethod
51f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower  def make_data(num_vectors):
52f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    """Generates 2-dimensional data centered on (2,2), (-1,-1).
53f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
54f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    Args:
55f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      num_vectors: number of training examples.
56f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
57f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    Returns:
58f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      A tuple containing the data as a numpy array and the cluster ids.
59f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    """
60f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    vectors = []
61f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    classes = []
62f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    for _ in xrange(num_vectors):
63f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      if np.random.random() > 0.5:
64e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        vectors.append([np.random.normal(2.0, 0.6), np.random.normal(2.0, 0.9)])
65f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        classes.append(0)
66f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      else:
67e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        vectors.append(
68e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney            [np.random.normal(-1.0, 0.4), np.random.normal(-1.0, 0.5)])
69f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        classes.append(1)
70f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    return np.asarray(vectors), classes
71f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
72f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower  @staticmethod
73f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower  def make_data_from_centers(num_vectors, centers):
74f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    """Generates 2-dimensional data with random centers.
75f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
76f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    Args:
77f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      num_vectors: number of training examples.
78f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      centers: a list of random 2-dimensional centers.
79f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
80f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    Returns:
81f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      A tuple containing the data as a numpy array and the cluster ids.
82f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    """
83f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    vectors = []
84f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    classes = []
85f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    for _ in xrange(num_vectors):
86f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      current_class = np.random.random_integers(0, len(centers) - 1)
87e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      vectors.append([
88e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney          np.random.normal(centers[current_class][0],
89e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                           np.random.random_sample()),
90e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney          np.random.normal(centers[current_class][1], np.random.random_sample())
91e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      ])
92f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      classes.append(current_class)
93f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    return np.asarray(vectors), len(centers)
94f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
95f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower  def test_covariance(self):
96f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    start_time = time.time()
97f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    data = self.data.T
98f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    np_cov = np.cov(data)
99f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    logging.info('Numpy took %f', time.time() - start_time)
100f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
101f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    start_time = time.time()
102f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    with self.test_session() as sess:
103f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      op = gmm_ops._covariance(
104e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney          constant_op.constant(
105e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney              data.T, dtype=dtypes.float32), False)
106f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      op_diag = gmm_ops._covariance(
107e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney          constant_op.constant(
108e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney              data.T, dtype=dtypes.float32), True)
109e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      variables.global_variables_initializer().run()
110f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      tf_cov = sess.run(op)
111f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_array_almost_equal(np_cov, tf_cov)
112f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      logging.info('Tensorflow took %f', time.time() - start_time)
113f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      tf_cov = sess.run(op_diag)
114f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_array_almost_equal(
115f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower          np.diag(np_cov), np.ravel(tf_cov), decimal=5)
116f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
117f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower  def test_simple_cluster(self):
118f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    """Tests that the clusters are correct."""
119f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    num_classes = 2
120e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney    graph = ops.Graph()
121f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    with graph.as_default() as g:
122f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      g.seed = 5
123f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      with self.test_session() as sess:
124e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        data = constant_op.constant(self.data, dtype=dtypes.float32)
125840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower        loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm(
126a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower            data, 'random', num_classes, random_seed=self.seed)
127f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
128e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney        variables.global_variables_initializer().run()
129a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower        sess.run(init_op)
130840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower        first_loss = sess.run(loss_op)
131f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        for _ in xrange(self.iterations):
132f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower          sess.run(training_op)
133f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        assignments = sess.run(assignments)
134840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower        end_loss = sess.run(loss_op)
135840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower        scores = sess.run(scores)
136840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower        self.assertEqual((self.num_examples, 1), scores.shape)
137f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        accuracy = np.mean(
138f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower            np.asarray(self.true_assignments) == np.squeeze(assignments))
139f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        logging.info('Accuracy: %f', accuracy)
140840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower        logging.info('First loss: %f, end loss: %f', first_loss, end_loss)
141840c30b0f38bcf0d94fe86e50a619465f935addaA. Unique TensorFlower        self.assertGreater(end_loss, first_loss)
142f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        self.assertGreater(accuracy, 0.98)
143f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
144f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower  def testParams(self):
145f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    """Tests that the params work as intended."""
146f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    num_classes = 2
147f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower    with self.test_session() as sess:
148f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      # Experiment 1. Update weights only.
149e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      data = constant_op.constant(self.data, dtype=dtypes.float32)
150e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
151e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                      [[3.0, 3.0], [0.0, 0.0]], 'w')
152f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      training_ops = gmm_tool.training_ops()
153e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      variables.global_variables_initializer().run()
154a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower      sess.run(gmm_tool.init_ops())
155f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      for _ in xrange(self.iterations):
156f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        sess.run(training_ops)
157f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
158f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      # Only the probability to each class is updated.
159f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      alphas = sess.run(gmm_tool.alphas())
160f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      self.assertGreater(alphas[1], 0.6)
161f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      means = sess.run(gmm_tool.clusters())
162f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_almost_equal(
163f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower          np.expand_dims([[3.0, 3.0], [0.0, 0.0]], 1), means)
164f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      covs = sess.run(gmm_tool.covariances())
165f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_almost_equal(covs[0], covs[1])
166f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
167f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      # Experiment 2. Update means and covariances.
168e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
169e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                      [[3.0, 3.0], [0.0, 0.0]], 'mc')
170f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      training_ops = gmm_tool.training_ops()
171e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      variables.global_variables_initializer().run()
172a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower      sess.run(gmm_tool.init_ops())
173f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      for _ in xrange(self.iterations):
174f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        sess.run(training_ops)
175f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      alphas = sess.run(gmm_tool.alphas())
176f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      self.assertAlmostEqual(alphas[0], alphas[1])
177f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      means = sess.run(gmm_tool.clusters())
178f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_almost_equal(
179f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower          np.expand_dims([[2.0, 2.0], [-1.0, -1.0]], 1), means, decimal=1)
180f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      covs = sess.run(gmm_tool.covariances())
181f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_almost_equal(
182e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney          [[0.371111, -0.0050774], [-0.0050774, 0.8651744]], covs[0], decimal=4)
183f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_almost_equal(
184e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney          [[0.146976, 0.0259463], [0.0259463, 0.2543971]], covs[1], decimal=4)
185f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
186f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      # Experiment 3. Update covariances only.
187e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
188e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney                                      [[-1.0, -1.0], [1.0, 1.0]], 'c')
189f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      training_ops = gmm_tool.training_ops()
190e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney      variables.global_variables_initializer().run()
191a925c14596135b19bedb73e0f6ae3cd170180106A. Unique TensorFlower      sess.run(gmm_tool.init_ops())
192f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      for _ in xrange(self.iterations):
193f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower        sess.run(training_ops)
194f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      alphas = sess.run(gmm_tool.alphas())
195f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      self.assertAlmostEqual(alphas[0], alphas[1])
196f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      means = sess.run(gmm_tool.clusters())
197f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_almost_equal(
198f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower          np.expand_dims([[-1.0, -1.0], [1.0, 1.0]], 1), means)
199f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      covs = sess.run(gmm_tool.covariances())
200f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_almost_equal(
201e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney          [[0.1299582, 0.0435872], [0.0435872, 0.2558578]], covs[0], decimal=5)
202f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower      np.testing.assert_almost_equal(
203e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney          [[3.195385, 2.6989155], [2.6989155, 3.3881593]], covs[1], decimal=5)
204f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
205f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlower
206f73f79b6d6174c8e7edeca2d756ddcfe9ff51e84A. Unique TensorFlowerif __name__ == '__main__':
207e121667dc609de978a223c56ee906368d2c4ceefJustine Tunney  test.main()
208