fully_connected_reader.py revision 492e55ff2297dc13f5ed540056ad40603f2b1dd0
1# Copyright 2015 Google Inc. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Train and Eval the MNIST network.
17
18This version is like fully_connected_feed.py but uses data converted
19to a TFRecords file containing tf.train.Example protocol buffers.
20See tensorflow/g3doc/how_tos/reading_data.md#reading-from-files
21for context.
22
23YOU MUST run convert_to_records before running this (but you only need to
24run it once).
25"""
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30import os.path
31import time
32
33import tensorflow.python.platform
34import numpy
35import tensorflow as tf
36
37from tensorflow.examples.tutorials.mnist import mnist
38
39
40# Basic model parameters as external flags.
41flags = tf.app.flags
42FLAGS = flags.FLAGS
43flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
44flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.')
45flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
46flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
47flags.DEFINE_integer('batch_size', 100, 'Batch size.')
48flags.DEFINE_string('train_dir', '/tmp/data',
49                    'Directory with the training data.')
50
51# Constants used for dealing with the files, matches convert_to_records.
52TRAIN_FILE = 'train.tfrecords'
53VALIDATION_FILE = 'validation.tfrecords'
54
55
56def read_and_decode(filename_queue):
57  reader = tf.TFRecordReader()
58  _, serialized_example = reader.read(filename_queue)
59  features = tf.parse_single_example(
60      serialized_example,
61      # Defaults are not specified since both keys are required.
62      features={
63          'image_raw': tf.FixedLenFeature([], tf.string),
64          'label': tf.FixedLenFeature([], tf.int64),
65      })
66
67  # Convert from a scalar string tensor (whose single string has
68  # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
69  # [mnist.IMAGE_PIXELS].
70  image = tf.decode_raw(features['image_raw'], tf.uint8)
71  image.set_shape([mnist.IMAGE_PIXELS])
72
73  # OPTIONAL: Could reshape into a 28x28 image and apply distortions
74  # here.  Since we are not applying any distortions in this
75  # example, and the next step expects the image to be flattened
76  # into a vector, we don't bother.
77
78  # Convert from [0, 255] -> [-0.5, 0.5] floats.
79  image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
80
81  # Convert label from a scalar uint8 tensor to an int32 scalar.
82  label = tf.cast(features['label'], tf.int32)
83
84  return image, label
85
86
87def inputs(train, batch_size, num_epochs):
88  """Reads input data num_epochs times.
89
90  Args:
91    train: Selects between the training (True) and validation (False) data.
92    batch_size: Number of examples per returned batch.
93    num_epochs: Number of times to read the input data, or 0/None to
94       train forever.
95
96  Returns:
97    A tuple (images, labels), where:
98    * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
99      in the range [-0.5, 0.5].
100    * labels is an int32 tensor with shape [batch_size] with the true label,
101      a number in the range [0, mnist.NUM_CLASSES).
102    Note that an tf.train.QueueRunner is added to the graph, which
103    must be run using e.g. tf.train.start_queue_runners().
104  """
105  if not num_epochs: num_epochs = None
106  filename = os.path.join(FLAGS.train_dir,
107                          TRAIN_FILE if train else VALIDATION_FILE)
108
109  with tf.name_scope('input'):
110    filename_queue = tf.train.string_input_producer(
111        [filename], num_epochs=num_epochs)
112
113    # Even when reading in multiple threads, share the filename
114    # queue.
115    image, label = read_and_decode(filename_queue)
116
117    # Shuffle the examples and collect them into batch_size batches.
118    # (Internally uses a RandomShuffleQueue.)
119    # We run this in two threads to avoid being a bottleneck.
120    images, sparse_labels = tf.train.shuffle_batch(
121        [image, label], batch_size=batch_size, num_threads=2,
122        capacity=1000 + 3 * batch_size,
123        # Ensures a minimum amount of shuffling of examples.
124        min_after_dequeue=1000)
125
126    return images, sparse_labels
127
128
129def run_training():
130  """Train MNIST for a number of steps."""
131
132  # Tell TensorFlow that the model will be built into the default Graph.
133  with tf.Graph().as_default():
134    # Input images and labels.
135    images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
136                            num_epochs=FLAGS.num_epochs)
137
138    # Build a Graph that computes predictions from the inference model.
139    logits = mnist.inference(images,
140                             FLAGS.hidden1,
141                             FLAGS.hidden2)
142
143    # Add to the Graph the loss calculation.
144    loss = mnist.loss(logits, labels)
145
146    # Add to the Graph operations that train the model.
147    train_op = mnist.training(loss, FLAGS.learning_rate)
148
149    # The op for initializing the variables.
150    init_op = tf.initialize_all_variables()
151
152    # Create a session for running operations in the Graph.
153    sess = tf.Session()
154
155    # Initialize the variables (the trained variables and the
156    # epoch counter).
157    sess.run(init_op)
158
159    # Start input enqueue threads.
160    coord = tf.train.Coordinator()
161    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
162
163    try:
164      step = 0
165      while not coord.should_stop():
166        start_time = time.time()
167
168        # Run one step of the model.  The return values are
169        # the activations from the `train_op` (which is
170        # discarded) and the `loss` op.  To inspect the values
171        # of your ops or variables, you may include them in
172        # the list passed to sess.run() and the value tensors
173        # will be returned in the tuple from the call.
174        _, loss_value = sess.run([train_op, loss])
175
176        duration = time.time() - start_time
177
178        # Print an overview fairly often.
179        if step % 100 == 0:
180          print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
181                                                     duration))
182        step += 1
183    except tf.errors.OutOfRangeError:
184      print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
185    finally:
186      # When done, ask the threads to stop.
187      coord.request_stop()
188
189    # Wait for threads to finish.
190    coord.join(threads)
191    sess.close()
192
193
194def main(_):
195  run_training()
196
197
198if __name__ == '__main__':
199  tf.app.run()
200