fully_connected_reader.py revision e4d6e8a3ff414a2688d93d0d2e1b234165a2703e
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Train and Eval the MNIST network.
17
18This version is like fully_connected_feed.py but uses data converted
19to a TFRecords file containing tf.train.Example protocol buffers.
20See:
21https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files
22for context.
23
24YOU MUST run convert_to_records before running this (but you only need to
25run it once).
26"""
27from __future__ import absolute_import
28from __future__ import division
29from __future__ import print_function
30
31import argparse
32import os.path
33import sys
34import time
35
36import tensorflow as tf
37
38from tensorflow.examples.tutorials.mnist import mnist
39
40# Basic model parameters as external flags.
41FLAGS = None
42
43# Constants used for dealing with the files, matches convert_to_records.
44TRAIN_FILE = 'train.tfrecords'
45VALIDATION_FILE = 'validation.tfrecords'
46
47
48def read_and_decode(filename_queue):
49  reader = tf.TFRecordReader()
50  _, serialized_example = reader.read(filename_queue)
51  features = tf.parse_single_example(
52      serialized_example,
53      # Defaults are not specified since both keys are required.
54      features={
55          'image_raw': tf.FixedLenFeature([], tf.string),
56          'label': tf.FixedLenFeature([], tf.int64),
57      })
58
59  # Convert from a scalar string tensor (whose single string has
60  # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
61  # [mnist.IMAGE_PIXELS].
62  image = tf.decode_raw(features['image_raw'], tf.uint8)
63  image.set_shape([mnist.IMAGE_PIXELS])
64
65  # OPTIONAL: Could reshape into a 28x28 image and apply distortions
66  # here.  Since we are not applying any distortions in this
67  # example, and the next step expects the image to be flattened
68  # into a vector, we don't bother.
69
70  # Convert from [0, 255] -> [-0.5, 0.5] floats.
71  image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
72
73  # Convert label from a scalar uint8 tensor to an int32 scalar.
74  label = tf.cast(features['label'], tf.int32)
75
76  return image, label
77
78
79def inputs(train, batch_size, num_epochs):
80  """Reads input data num_epochs times.
81
82  Args:
83    train: Selects between the training (True) and validation (False) data.
84    batch_size: Number of examples per returned batch.
85    num_epochs: Number of times to read the input data, or 0/None to
86       train forever.
87
88  Returns:
89    A tuple (images, labels), where:
90    * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
91      in the range [-0.5, 0.5].
92    * labels is an int32 tensor with shape [batch_size] with the true label,
93      a number in the range [0, mnist.NUM_CLASSES).
94    Note that an tf.train.QueueRunner is added to the graph, which
95    must be run using e.g. tf.train.start_queue_runners().
96  """
97  if not num_epochs: num_epochs = None
98  filename = os.path.join(FLAGS.train_dir,
99                          TRAIN_FILE if train else VALIDATION_FILE)
100
101  with tf.name_scope('input'):
102    filename_queue = tf.train.string_input_producer(
103        [filename], num_epochs=num_epochs)
104
105    # Even when reading in multiple threads, share the filename
106    # queue.
107    image, label = read_and_decode(filename_queue)
108
109    # Shuffle the examples and collect them into batch_size batches.
110    # (Internally uses a RandomShuffleQueue.)
111    # We run this in two threads to avoid being a bottleneck.
112    images, sparse_labels = tf.train.shuffle_batch(
113        [image, label], batch_size=batch_size, num_threads=2,
114        capacity=1000 + 3 * batch_size,
115        # Ensures a minimum amount of shuffling of examples.
116        min_after_dequeue=1000)
117
118    return images, sparse_labels
119
120
121def run_training():
122  """Train MNIST for a number of steps."""
123
124  # Tell TensorFlow that the model will be built into the default Graph.
125  with tf.Graph().as_default():
126    # Input images and labels.
127    images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
128                            num_epochs=FLAGS.num_epochs)
129
130    # Build a Graph that computes predictions from the inference model.
131    logits = mnist.inference(images,
132                             FLAGS.hidden1,
133                             FLAGS.hidden2)
134
135    # Add to the Graph the loss calculation.
136    loss = mnist.loss(logits, labels)
137
138    # Add to the Graph operations that train the model.
139    train_op = mnist.training(loss, FLAGS.learning_rate)
140
141    # The op for initializing the variables.
142    init_op = tf.group(tf.global_variables_initializer(),
143                       tf.local_variables_initializer())
144
145    # Create a session for running operations in the Graph.
146    sess = tf.Session()
147
148    # Initialize the variables (the trained variables and the
149    # epoch counter).
150    sess.run(init_op)
151
152    # Start input enqueue threads.
153    coord = tf.train.Coordinator()
154    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
155
156    try:
157      step = 0
158      while not coord.should_stop():
159        start_time = time.time()
160
161        # Run one step of the model.  The return values are
162        # the activations from the `train_op` (which is
163        # discarded) and the `loss` op.  To inspect the values
164        # of your ops or variables, you may include them in
165        # the list passed to sess.run() and the value tensors
166        # will be returned in the tuple from the call.
167        _, loss_value = sess.run([train_op, loss])
168
169        duration = time.time() - start_time
170
171        # Print an overview fairly often.
172        if step % 100 == 0:
173          print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
174                                                     duration))
175        step += 1
176    except tf.errors.OutOfRangeError:
177      print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
178    finally:
179      # When done, ask the threads to stop.
180      coord.request_stop()
181
182    # Wait for threads to finish.
183    coord.join(threads)
184    sess.close()
185
186
187def main(_):
188  run_training()
189
190
191if __name__ == '__main__':
192  parser = argparse.ArgumentParser()
193  parser.add_argument(
194      '--learning_rate',
195      type=float,
196      default=0.01,
197      help='Initial learning rate.'
198  )
199  parser.add_argument(
200      '--num_epochs',
201      type=int,
202      default=2,
203      help='Number of epochs to run trainer.'
204  )
205  parser.add_argument(
206      '--hidden1',
207      type=int,
208      default=128,
209      help='Number of units in hidden layer 1.'
210  )
211  parser.add_argument(
212      '--hidden2',
213      type=int,
214      default=32,
215      help='Number of units in hidden layer 2.'
216  )
217  parser.add_argument(
218      '--batch_size',
219      type=int,
220      default=100,
221      help='Batch size.'
222  )
223  parser.add_argument(
224      '--train_dir',
225      type=str,
226      default='/tmp/data',
227      help='Directory with the training data.'
228  )
229  FLAGS, unparsed = parser.parse_known_args()
230  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
231