1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Unit tests for linear regression example under TensorFlow eager execution."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import glob
22import os
23import shutil
24import tempfile
25import time
26
27import tensorflow as tf
28
29import tensorflow.contrib.eager as tfe
30from tensorflow.contrib.eager.python.examples.linear_regression import linear_regression
31
32
33def device():
34  return "/device:GPU:0" if tfe.num_gpus() > 0 else "/device:CPU:0"
35
36
37class LinearRegressionTest(tf.test.TestCase):
38
39  def setUp(self):
40    super(LinearRegressionTest, self).setUp()
41    self._tmp_logdir = tempfile.mkdtemp()
42
43  def tearDown(self):
44    shutil.rmtree(self._tmp_logdir)
45    super(LinearRegressionTest, self).tearDown()
46
47  def testSyntheticDataset(self):
48    true_w = tf.random_uniform([3, 1])
49    true_b = [1.0]
50    batch_size = 10
51    num_batches = 2
52    noise_level = 0.
53    dataset = linear_regression.synthetic_dataset(true_w, true_b, noise_level,
54                                                  batch_size, num_batches)
55
56    it = tfe.Iterator(dataset)
57    for _ in range(2):
58      (xs, ys) = it.next()
59      self.assertEqual((batch_size, 3), xs.shape)
60      self.assertEqual((batch_size, 1), ys.shape)
61      self.assertEqual(tf.float32, xs.dtype)
62      self.assertEqual(tf.float32, ys.dtype)
63    with self.assertRaises(StopIteration):
64      it.next()
65
66  def testLinearRegression(self):
67    true_w = [[1.0], [-0.5], [2.0]]
68    true_b = [1.0]
69
70    model = linear_regression.LinearModel()
71    dataset = linear_regression.synthetic_dataset(
72        true_w, true_b, noise_level=0., batch_size=64, num_batches=40)
73
74    with tf.device(device()):
75      optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
76      linear_regression.fit(model, dataset, optimizer, logdir=self._tmp_logdir)
77
78      self.assertAllClose(true_w, model.variables[0].numpy(), rtol=1e-2)
79      self.assertAllClose(true_b, model.variables[1].numpy(), rtol=1e-2)
80      self.assertTrue(glob.glob(os.path.join(self._tmp_logdir, "events.out.*")))
81
82
83class EagerLinearRegressionBenchmark(tf.test.Benchmark):
84
85  def benchmarkEagerLinearRegression(self):
86    num_epochs = 10
87    num_batches = 200
88    batch_size = 64
89    dataset = linear_regression.synthetic_dataset(
90        w=tf.random_uniform([3, 1]),
91        b=tf.random_uniform([1]),
92        noise_level=0.01,
93        batch_size=batch_size,
94        num_batches=num_batches)
95    burn_in_dataset = dataset.take(10)
96
97    model = linear_regression.LinearModel()
98
99    with tf.device(device()):
100      optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
101
102      # Perform burn-in.
103      linear_regression.fit(model, burn_in_dataset, optimizer)
104
105      start_time = time.time()
106      for _ in range(num_epochs):
107        linear_regression.fit(model, dataset, optimizer)
108      wall_time = time.time() - start_time
109
110      examples_per_sec = num_epochs * num_batches * batch_size / wall_time
111      self.report_benchmark(
112          name="eager_train_%s" %
113          ("gpu" if tfe.num_gpus() > 0 else "cpu"),
114          iters=num_epochs * num_batches,
115          extras={"examples_per_sec": examples_per_sec},
116          wall_time=wall_time)
117
118
119if __name__ == "__main__":
120  tfe.enable_eager_execution()
121  tf.test.main()
122