linear_regression.py revision a6a61884396ef1d51b01f8e13df21becb23fd0c8
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""TensorFlow Eager Execution Example: Linear Regression.
16
17This example shows how to use TensorFlow Eager Execution to fit a simple linear
18regression model using some synthesized data. Specifically, it illustrates how
19to define the forward path of the linear model and the loss function, as well
20as how to obtain the gradients of the loss function with respect to the
21variables and update the variables with the gradients.
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import argparse
29import sys
30
31import tensorflow as tf
32
33import tensorflow.contrib.eager as tfe
34
35
36class LinearModel(tfe.Network):
37  """A TensorFlow linear regression model.
38
39  Uses TensorFlow's eager execution.
40
41  For those familiar with TensorFlow graphs, notice the absence of
42  `tf.Session`. The `forward()` method here immediately executes and
43  returns output values. The `loss()` method immediately compares the
44  output of `forward()` with the target adn returns the MSE loss value.
45  The `fit()` performs gradient-descent training on the model's weights
46  and bias.
47  """
48
49  def __init__(self):
50    """Constructs a LinearModel object."""
51    super(LinearModel, self).__init__()
52    self._hidden_layer = self.track_layer(tf.layers.Dense(1))
53
54  def call(self, xs):
55    """Invoke the linear model.
56
57    Args:
58      xs: input features, as a tensor of size [batch_size, ndims].
59
60    Returns:
61      ys: the predictions of the linear mode, as a tensor of size [batch_size]
62    """
63    return self._hidden_layer(xs)
64
65
66def fit(model, dataset, optimizer, verbose=False, logdir=None):
67  """Fit the linear-regression model.
68
69  Args:
70    model: The LinearModel to fit.
71    dataset: The tf.data.Dataset to use for training data.
72    optimizer: The TensorFlow Optimizer object to be used.
73    verbose: If true, will print out loss values at every iteration.
74    logdir: The directory in which summaries will be written for TensorBoard
75      (optional).
76  """
77
78  # The loss function to optimize.
79  def mean_square_loss(xs, ys):
80    return tf.reduce_mean(tf.square(model(xs) - ys))
81
82  loss_and_grads = tfe.implicit_value_and_gradients(mean_square_loss)
83
84  tf.train.get_or_create_global_step()
85  if logdir:
86    # Support for TensorBoard summaries. Once training has started, use:
87    #   tensorboard --logdir=<logdir>
88    summary_writer = tf.contrib.summary.create_summary_file_writer(logdir)
89
90  # Training loop.
91  for i, (xs, ys) in enumerate(tfe.Iterator(dataset)):
92    loss, grads = loss_and_grads(xs, ys)
93    if verbose:
94      print("Iteration %d: loss = %s" % (i, loss.numpy()))
95
96    optimizer.apply_gradients(grads, global_step=tf.train.get_global_step())
97
98    if logdir:
99      with summary_writer.as_default():
100        with tf.contrib.summary.always_record_summaries():
101          tf.contrib.summary.scalar("loss", loss)
102
103
104def synthetic_dataset(w, b, noise_level, batch_size, num_batches):
105  """tf.data.Dataset that yields synthetic data for linear regression."""
106
107  # w is a matrix with shape [N, M]
108  # b is a vector with shape [M]
109  # So:
110  # - Generate x's as vectors with shape [batch_size N]
111  # - y = tf.matmul(x, W) + b + noise
112  def batch(_):
113    x = tf.random_normal([batch_size, tf.shape(w)[0]])
114    y = tf.matmul(x, w) + b + noise_level * tf.random_normal([])
115    return x, y
116
117  with tf.device("/device:CPU:0"):
118    return tf.data.Dataset.range(num_batches).map(batch)
119
120
121def main(_):
122  tfe.enable_eager_execution()
123  # Ground-truth constants.
124  true_w = [[-2.0], [4.0], [1.0]]
125  true_b = [0.5]
126  noise_level = 0.01
127
128  # Training constants.
129  batch_size = 64
130  learning_rate = 0.1
131
132  print("True w: %s" % true_w)
133  print("True b: %s\n" % true_b)
134
135  model = LinearModel()
136  dataset = synthetic_dataset(true_w, true_b, noise_level, batch_size, 20)
137
138  device = "gpu:0" if tfe.num_gpus() else "cpu:0"
139  print("Using device: %s" % device)
140  with tf.device(device):
141    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
142    fit(model, dataset, optimizer, verbose=True, logdir=FLAGS.logdir)
143
144  print("\nAfter training: w = %s" % model.variables[0].numpy())
145  print("\nAfter training: b = %s" % model.variables[1].numpy())
146
147
148if __name__ == "__main__":
149  parser = argparse.ArgumentParser()
150  parser.add_argument(
151      "--logdir",
152      type=str,
153      default=None,
154      help="logdir in which TensorBoard summaries will be written (optional).")
155  FLAGS, unparsed = parser.parse_known_args()
156
157  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
158