fully_connected_reader.py revision e4d6e8a3ff414a2688d93d0d2e1b234165a2703e
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Train and Eval the MNIST network. 17 18This version is like fully_connected_feed.py but uses data converted 19to a TFRecords file containing tf.train.Example protocol buffers. 20See: 21https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files 22for context. 23 24YOU MUST run convert_to_records before running this (but you only need to 25run it once). 26""" 27from __future__ import absolute_import 28from __future__ import division 29from __future__ import print_function 30 31import argparse 32import os.path 33import sys 34import time 35 36import tensorflow as tf 37 38from tensorflow.examples.tutorials.mnist import mnist 39 40# Basic model parameters as external flags. 41FLAGS = None 42 43# Constants used for dealing with the files, matches convert_to_records. 44TRAIN_FILE = 'train.tfrecords' 45VALIDATION_FILE = 'validation.tfrecords' 46 47 48def read_and_decode(filename_queue): 49 reader = tf.TFRecordReader() 50 _, serialized_example = reader.read(filename_queue) 51 features = tf.parse_single_example( 52 serialized_example, 53 # Defaults are not specified since both keys are required. 54 features={ 55 'image_raw': tf.FixedLenFeature([], tf.string), 56 'label': tf.FixedLenFeature([], tf.int64), 57 }) 58 59 # Convert from a scalar string tensor (whose single string has 60 # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape 61 # [mnist.IMAGE_PIXELS]. 62 image = tf.decode_raw(features['image_raw'], tf.uint8) 63 image.set_shape([mnist.IMAGE_PIXELS]) 64 65 # OPTIONAL: Could reshape into a 28x28 image and apply distortions 66 # here. Since we are not applying any distortions in this 67 # example, and the next step expects the image to be flattened 68 # into a vector, we don't bother. 69 70 # Convert from [0, 255] -> [-0.5, 0.5] floats. 71 image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 72 73 # Convert label from a scalar uint8 tensor to an int32 scalar. 74 label = tf.cast(features['label'], tf.int32) 75 76 return image, label 77 78 79def inputs(train, batch_size, num_epochs): 80 """Reads input data num_epochs times. 81 82 Args: 83 train: Selects between the training (True) and validation (False) data. 84 batch_size: Number of examples per returned batch. 85 num_epochs: Number of times to read the input data, or 0/None to 86 train forever. 87 88 Returns: 89 A tuple (images, labels), where: 90 * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS] 91 in the range [-0.5, 0.5]. 92 * labels is an int32 tensor with shape [batch_size] with the true label, 93 a number in the range [0, mnist.NUM_CLASSES). 94 Note that an tf.train.QueueRunner is added to the graph, which 95 must be run using e.g. tf.train.start_queue_runners(). 96 """ 97 if not num_epochs: num_epochs = None 98 filename = os.path.join(FLAGS.train_dir, 99 TRAIN_FILE if train else VALIDATION_FILE) 100 101 with tf.name_scope('input'): 102 filename_queue = tf.train.string_input_producer( 103 [filename], num_epochs=num_epochs) 104 105 # Even when reading in multiple threads, share the filename 106 # queue. 107 image, label = read_and_decode(filename_queue) 108 109 # Shuffle the examples and collect them into batch_size batches. 110 # (Internally uses a RandomShuffleQueue.) 111 # We run this in two threads to avoid being a bottleneck. 112 images, sparse_labels = tf.train.shuffle_batch( 113 [image, label], batch_size=batch_size, num_threads=2, 114 capacity=1000 + 3 * batch_size, 115 # Ensures a minimum amount of shuffling of examples. 116 min_after_dequeue=1000) 117 118 return images, sparse_labels 119 120 121def run_training(): 122 """Train MNIST for a number of steps.""" 123 124 # Tell TensorFlow that the model will be built into the default Graph. 125 with tf.Graph().as_default(): 126 # Input images and labels. 127 images, labels = inputs(train=True, batch_size=FLAGS.batch_size, 128 num_epochs=FLAGS.num_epochs) 129 130 # Build a Graph that computes predictions from the inference model. 131 logits = mnist.inference(images, 132 FLAGS.hidden1, 133 FLAGS.hidden2) 134 135 # Add to the Graph the loss calculation. 136 loss = mnist.loss(logits, labels) 137 138 # Add to the Graph operations that train the model. 139 train_op = mnist.training(loss, FLAGS.learning_rate) 140 141 # The op for initializing the variables. 142 init_op = tf.group(tf.global_variables_initializer(), 143 tf.local_variables_initializer()) 144 145 # Create a session for running operations in the Graph. 146 sess = tf.Session() 147 148 # Initialize the variables (the trained variables and the 149 # epoch counter). 150 sess.run(init_op) 151 152 # Start input enqueue threads. 153 coord = tf.train.Coordinator() 154 threads = tf.train.start_queue_runners(sess=sess, coord=coord) 155 156 try: 157 step = 0 158 while not coord.should_stop(): 159 start_time = time.time() 160 161 # Run one step of the model. The return values are 162 # the activations from the `train_op` (which is 163 # discarded) and the `loss` op. To inspect the values 164 # of your ops or variables, you may include them in 165 # the list passed to sess.run() and the value tensors 166 # will be returned in the tuple from the call. 167 _, loss_value = sess.run([train_op, loss]) 168 169 duration = time.time() - start_time 170 171 # Print an overview fairly often. 172 if step % 100 == 0: 173 print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, 174 duration)) 175 step += 1 176 except tf.errors.OutOfRangeError: 177 print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) 178 finally: 179 # When done, ask the threads to stop. 180 coord.request_stop() 181 182 # Wait for threads to finish. 183 coord.join(threads) 184 sess.close() 185 186 187def main(_): 188 run_training() 189 190 191if __name__ == '__main__': 192 parser = argparse.ArgumentParser() 193 parser.add_argument( 194 '--learning_rate', 195 type=float, 196 default=0.01, 197 help='Initial learning rate.' 198 ) 199 parser.add_argument( 200 '--num_epochs', 201 type=int, 202 default=2, 203 help='Number of epochs to run trainer.' 204 ) 205 parser.add_argument( 206 '--hidden1', 207 type=int, 208 default=128, 209 help='Number of units in hidden layer 1.' 210 ) 211 parser.add_argument( 212 '--hidden2', 213 type=int, 214 default=32, 215 help='Number of units in hidden layer 2.' 216 ) 217 parser.add_argument( 218 '--batch_size', 219 type=int, 220 default=100, 221 help='Batch size.' 222 ) 223 parser.add_argument( 224 '--train_dir', 225 type=str, 226 default='/tmp/data', 227 help='Directory with the training data.' 228 ) 229 FLAGS, unparsed = parser.parse_known_args() 230 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 231