1a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar#
3a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# Licensed under the Apache License, Version 2.0 (the "License");
4a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# you may not use this file except in compliance with the License.
5a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# You may obtain a copy of the License at
6a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar#
7a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar#     http://www.apache.org/licenses/LICENSE-2.0
8a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar#
9a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# Unless required by applicable law or agreed to in writing, software
10a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# distributed under the License is distributed on an "AS IS" BASIS,
11a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# See the License for the specific language governing permissions and
13a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# limitations under the License.
14a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar# ==============================================================================
15a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar"""Unit tests for linear regression example under TensorFlow eager execution."""
16a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
17a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarfrom __future__ import absolute_import
18a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarfrom __future__ import division
19a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarfrom __future__ import print_function
20a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
21a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarimport glob
22a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarimport os
23a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarimport shutil
24a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarimport tempfile
25a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarimport time
26a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
27a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarimport tensorflow as tf
28a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
29a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarimport tensorflow.contrib.eager as tfe
30a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarfrom tensorflow.contrib.eager.python.examples.linear_regression import linear_regression
31a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
32a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
33a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankardef device():
34a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar  return "/device:GPU:0" if tfe.num_gpus() > 0 else "/device:CPU:0"
35a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
36a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
37a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarclass LinearRegressionTest(tf.test.TestCase):
38a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
39a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar  def setUp(self):
40a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    super(LinearRegressionTest, self).setUp()
41a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    self._tmp_logdir = tempfile.mkdtemp()
42a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
43a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar  def tearDown(self):
44a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    shutil.rmtree(self._tmp_logdir)
45a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    super(LinearRegressionTest, self).tearDown()
46a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
47a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar  def testSyntheticDataset(self):
48a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    true_w = tf.random_uniform([3, 1])
49a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    true_b = [1.0]
50a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    batch_size = 10
51a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    num_batches = 2
52a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    noise_level = 0.
53a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    dataset = linear_regression.synthetic_dataset(true_w, true_b, noise_level,
54a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar                                                  batch_size, num_batches)
55a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
56a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    it = tfe.Iterator(dataset)
57a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    for _ in range(2):
58a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      (xs, ys) = it.next()
59a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      self.assertEqual((batch_size, 3), xs.shape)
60a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      self.assertEqual((batch_size, 1), ys.shape)
61a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      self.assertEqual(tf.float32, xs.dtype)
62a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      self.assertEqual(tf.float32, ys.dtype)
63a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    with self.assertRaises(StopIteration):
64a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      it.next()
65a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
66a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar  def testLinearRegression(self):
67a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    true_w = [[1.0], [-0.5], [2.0]]
68a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    true_b = [1.0]
69a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
70a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    model = linear_regression.LinearModel()
71a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    dataset = linear_regression.synthetic_dataset(
72a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar        true_w, true_b, noise_level=0., batch_size=64, num_batches=40)
73a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
74a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    with tf.device(device()):
75a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
76a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      linear_regression.fit(model, dataset, optimizer, logdir=self._tmp_logdir)
77a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
78a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      self.assertAllClose(true_w, model.variables[0].numpy(), rtol=1e-2)
79a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      self.assertAllClose(true_b, model.variables[1].numpy(), rtol=1e-2)
80a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      self.assertTrue(glob.glob(os.path.join(self._tmp_logdir, "events.out.*")))
81a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
82a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
83a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarclass EagerLinearRegressionBenchmark(tf.test.Benchmark):
84a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
85a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar  def benchmarkEagerLinearRegression(self):
8632e3acde7fd75c2a34fd10b6f11cd2df864e6e32Akshay Agrawal    num_epochs = 10
87a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    num_batches = 200
88a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    batch_size = 64
89a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    dataset = linear_regression.synthetic_dataset(
90a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar        w=tf.random_uniform([3, 1]),
91a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar        b=tf.random_uniform([1]),
92a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar        noise_level=0.01,
93a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar        batch_size=batch_size,
94a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar        num_batches=num_batches)
95a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    burn_in_dataset = dataset.take(10)
96a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
97a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    model = linear_regression.LinearModel()
98a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
99a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar    with tf.device(device()):
100a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)
101a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
102a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      # Perform burn-in.
103a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      linear_regression.fit(model, burn_in_dataset, optimizer)
104a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
105a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      start_time = time.time()
10632e3acde7fd75c2a34fd10b6f11cd2df864e6e32Akshay Agrawal      for _ in range(num_epochs):
10732e3acde7fd75c2a34fd10b6f11cd2df864e6e32Akshay Agrawal        linear_regression.fit(model, dataset, optimizer)
108a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      wall_time = time.time() - start_time
109a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
11032e3acde7fd75c2a34fd10b6f11cd2df864e6e32Akshay Agrawal      examples_per_sec = num_epochs * num_batches * batch_size / wall_time
111a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar      self.report_benchmark(
112a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar          name="eager_train_%s" %
113a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar          ("gpu" if tfe.num_gpus() > 0 else "cpu"),
11432e3acde7fd75c2a34fd10b6f11cd2df864e6e32Akshay Agrawal          iters=num_epochs * num_batches,
115a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar          extras={"examples_per_sec": examples_per_sec},
116a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar          wall_time=wall_time)
117a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
118a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar
119a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankarif __name__ == "__main__":
120a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar  tfe.enable_eager_execution()
121a6a61884396ef1d51b01f8e13df21becb23fd0c8Asim Shankar  tf.test.main()
122