fully_connected_reader.py revision 492e55ff2297dc13f5ed540056ad40603f2b1dd0
1# Copyright 2015 Google Inc. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Train and Eval the MNIST network. 17 18This version is like fully_connected_feed.py but uses data converted 19to a TFRecords file containing tf.train.Example protocol buffers. 20See tensorflow/g3doc/how_tos/reading_data.md#reading-from-files 21for context. 22 23YOU MUST run convert_to_records before running this (but you only need to 24run it once). 25""" 26from __future__ import absolute_import 27from __future__ import division 28from __future__ import print_function 29 30import os.path 31import time 32 33import tensorflow.python.platform 34import numpy 35import tensorflow as tf 36 37from tensorflow.examples.tutorials.mnist import mnist 38 39 40# Basic model parameters as external flags. 41flags = tf.app.flags 42FLAGS = flags.FLAGS 43flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 44flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.') 45flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.') 46flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.') 47flags.DEFINE_integer('batch_size', 100, 'Batch size.') 48flags.DEFINE_string('train_dir', '/tmp/data', 49 'Directory with the training data.') 50 51# Constants used for dealing with the files, matches convert_to_records. 52TRAIN_FILE = 'train.tfrecords' 53VALIDATION_FILE = 'validation.tfrecords' 54 55 56def read_and_decode(filename_queue): 57 reader = tf.TFRecordReader() 58 _, serialized_example = reader.read(filename_queue) 59 features = tf.parse_single_example( 60 serialized_example, 61 # Defaults are not specified since both keys are required. 62 features={ 63 'image_raw': tf.FixedLenFeature([], tf.string), 64 'label': tf.FixedLenFeature([], tf.int64), 65 }) 66 67 # Convert from a scalar string tensor (whose single string has 68 # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape 69 # [mnist.IMAGE_PIXELS]. 70 image = tf.decode_raw(features['image_raw'], tf.uint8) 71 image.set_shape([mnist.IMAGE_PIXELS]) 72 73 # OPTIONAL: Could reshape into a 28x28 image and apply distortions 74 # here. Since we are not applying any distortions in this 75 # example, and the next step expects the image to be flattened 76 # into a vector, we don't bother. 77 78 # Convert from [0, 255] -> [-0.5, 0.5] floats. 79 image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 80 81 # Convert label from a scalar uint8 tensor to an int32 scalar. 82 label = tf.cast(features['label'], tf.int32) 83 84 return image, label 85 86 87def inputs(train, batch_size, num_epochs): 88 """Reads input data num_epochs times. 89 90 Args: 91 train: Selects between the training (True) and validation (False) data. 92 batch_size: Number of examples per returned batch. 93 num_epochs: Number of times to read the input data, or 0/None to 94 train forever. 95 96 Returns: 97 A tuple (images, labels), where: 98 * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS] 99 in the range [-0.5, 0.5]. 100 * labels is an int32 tensor with shape [batch_size] with the true label, 101 a number in the range [0, mnist.NUM_CLASSES). 102 Note that an tf.train.QueueRunner is added to the graph, which 103 must be run using e.g. tf.train.start_queue_runners(). 104 """ 105 if not num_epochs: num_epochs = None 106 filename = os.path.join(FLAGS.train_dir, 107 TRAIN_FILE if train else VALIDATION_FILE) 108 109 with tf.name_scope('input'): 110 filename_queue = tf.train.string_input_producer( 111 [filename], num_epochs=num_epochs) 112 113 # Even when reading in multiple threads, share the filename 114 # queue. 115 image, label = read_and_decode(filename_queue) 116 117 # Shuffle the examples and collect them into batch_size batches. 118 # (Internally uses a RandomShuffleQueue.) 119 # We run this in two threads to avoid being a bottleneck. 120 images, sparse_labels = tf.train.shuffle_batch( 121 [image, label], batch_size=batch_size, num_threads=2, 122 capacity=1000 + 3 * batch_size, 123 # Ensures a minimum amount of shuffling of examples. 124 min_after_dequeue=1000) 125 126 return images, sparse_labels 127 128 129def run_training(): 130 """Train MNIST for a number of steps.""" 131 132 # Tell TensorFlow that the model will be built into the default Graph. 133 with tf.Graph().as_default(): 134 # Input images and labels. 135 images, labels = inputs(train=True, batch_size=FLAGS.batch_size, 136 num_epochs=FLAGS.num_epochs) 137 138 # Build a Graph that computes predictions from the inference model. 139 logits = mnist.inference(images, 140 FLAGS.hidden1, 141 FLAGS.hidden2) 142 143 # Add to the Graph the loss calculation. 144 loss = mnist.loss(logits, labels) 145 146 # Add to the Graph operations that train the model. 147 train_op = mnist.training(loss, FLAGS.learning_rate) 148 149 # The op for initializing the variables. 150 init_op = tf.initialize_all_variables() 151 152 # Create a session for running operations in the Graph. 153 sess = tf.Session() 154 155 # Initialize the variables (the trained variables and the 156 # epoch counter). 157 sess.run(init_op) 158 159 # Start input enqueue threads. 160 coord = tf.train.Coordinator() 161 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 162 163 try: 164 step = 0 165 while not coord.should_stop(): 166 start_time = time.time() 167 168 # Run one step of the model. The return values are 169 # the activations from the `train_op` (which is 170 # discarded) and the `loss` op. To inspect the values 171 # of your ops or variables, you may include them in 172 # the list passed to sess.run() and the value tensors 173 # will be returned in the tuple from the call. 174 _, loss_value = sess.run([train_op, loss]) 175 176 duration = time.time() - start_time 177 178 # Print an overview fairly often. 179 if step % 100 == 0: 180 print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, 181 duration)) 182 step += 1 183 except tf.errors.OutOfRangeError: 184 print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) 185 finally: 186 # When done, ask the threads to stop. 187 coord.request_stop() 188 189 # Wait for threads to finish. 190 coord.join(threads) 191 sess.close() 192 193 194def main(_): 195 run_training() 196 197 198if __name__ == '__main__': 199 tf.app.run() 200