1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Train and Eval the MNIST network. 16 17This version is like fully_connected_feed.py but uses data converted 18to a TFRecords file containing tf.train.Example protocol buffers. 19See: 20https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files 21for context. 22 23YOU MUST run convert_to_records before running this (but you only need to 24run it once). 25""" 26from __future__ import absolute_import 27from __future__ import division 28from __future__ import print_function 29 30import argparse 31import os.path 32import sys 33import time 34 35import tensorflow as tf 36 37from tensorflow.examples.tutorials.mnist import mnist 38 39# Basic model parameters as external flags. 40FLAGS = None 41 42# Constants used for dealing with the files, matches convert_to_records. 43TRAIN_FILE = 'train.tfrecords' 44VALIDATION_FILE = 'validation.tfrecords' 45 46 47def decode(serialized_example): 48 features = tf.parse_single_example( 49 serialized_example, 50 # Defaults are not specified since both keys are required. 51 features={ 52 'image_raw': tf.FixedLenFeature([], tf.string), 53 'label': tf.FixedLenFeature([], tf.int64), 54 }) 55 56 # Convert from a scalar string tensor (whose single string has 57 # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape 58 # [mnist.IMAGE_PIXELS]. 59 image = tf.decode_raw(features['image_raw'], tf.uint8) 60 image.set_shape((mnist.IMAGE_PIXELS)) 61 62 # Convert label from a scalar uint8 tensor to an int32 scalar. 63 label = tf.cast(features['label'], tf.int32) 64 65 return image, label 66 67 68def augment(image, label): 69 # OPTIONAL: Could reshape into a 28x28 image and apply distortions 70 # here. Since we are not applying any distortions in this 71 # example, and the next step expects the image to be flattened 72 # into a vector, we don't bother. 73 return image, label 74 75 76def normalize(image, label): 77 # Convert from [0, 255] -> [-0.5, 0.5] floats. 78 image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 79 80 return image, label 81 82 83def inputs(train, batch_size, num_epochs): 84 """Reads input data num_epochs times. 85 86 Args: 87 train: Selects between the training (True) and validation (False) data. 88 batch_size: Number of examples per returned batch. 89 num_epochs: Number of times to read the input data, or 0/None to 90 train forever. 91 92 Returns: 93 A tuple (images, labels), where: 94 * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS] 95 in the range [-0.5, 0.5]. 96 * labels is an int32 tensor with shape [batch_size] with the true label, 97 a number in the range [0, mnist.NUM_CLASSES). 98 99 This function creates a one_shot_iterator, meaning that it will only iterate 100 over the dataset once. On the other hand there is no special initialization 101 required. 102 """ 103 if not num_epochs: 104 num_epochs = None 105 filename = os.path.join(FLAGS.train_dir, TRAIN_FILE 106 if train else VALIDATION_FILE) 107 108 with tf.name_scope('input'): 109 # TFRecordDataset opens a protobuf and reads entries line by line 110 # could also be [list, of, filenames] 111 dataset = tf.data.TFRecordDataset(filename) 112 dataset = dataset.repeat(num_epochs) 113 114 # map takes a python function and applies it to every sample 115 dataset = dataset.map(decode) 116 dataset = dataset.map(augment) 117 dataset = dataset.map(normalize) 118 119 #the parameter is the queue size 120 dataset = dataset.shuffle(1000 + 3 * batch_size) 121 dataset = dataset.batch(batch_size) 122 123 iterator = dataset.make_one_shot_iterator() 124 return iterator.get_next() 125 126 127def run_training(): 128 """Train MNIST for a number of steps.""" 129 130 # Tell TensorFlow that the model will be built into the default Graph. 131 with tf.Graph().as_default(): 132 # Input images and labels. 133 image_batch, label_batch = inputs( 134 train=True, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs) 135 136 # Build a Graph that computes predictions from the inference model. 137 logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2) 138 139 # Add to the Graph the loss calculation. 140 loss = mnist.loss(logits, label_batch) 141 142 # Add to the Graph operations that train the model. 143 train_op = mnist.training(loss, FLAGS.learning_rate) 144 145 # The op for initializing the variables. 146 init_op = tf.group(tf.global_variables_initializer(), 147 tf.local_variables_initializer()) 148 149 # Create a session for running operations in the Graph. 150 with tf.Session() as sess: 151 # Initialize the variables (the trained variables and the 152 # epoch counter). 153 sess.run(init_op) 154 try: 155 step = 0 156 while True: #train until OutOfRangeError 157 start_time = time.time() 158 159 # Run one step of the model. The return values are 160 # the activations from the `train_op` (which is 161 # discarded) and the `loss` op. To inspect the values 162 # of your ops or variables, you may include them in 163 # the list passed to sess.run() and the value tensors 164 # will be returned in the tuple from the call. 165 _, loss_value = sess.run([train_op, loss]) 166 167 duration = time.time() - start_time 168 169 # Print an overview fairly often. 170 if step % 100 == 0: 171 print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, 172 duration)) 173 step += 1 174 except tf.errors.OutOfRangeError: 175 print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, 176 step)) 177 178 179def main(_): 180 run_training() 181 182 183if __name__ == '__main__': 184 parser = argparse.ArgumentParser() 185 parser.add_argument( 186 '--learning_rate', 187 type=float, 188 default=0.01, 189 help='Initial learning rate.') 190 parser.add_argument( 191 '--num_epochs', 192 type=int, 193 default=2, 194 help='Number of epochs to run trainer.') 195 parser.add_argument( 196 '--hidden1', 197 type=int, 198 default=128, 199 help='Number of units in hidden layer 1.') 200 parser.add_argument( 201 '--hidden2', 202 type=int, 203 default=32, 204 help='Number of units in hidden layer 2.') 205 parser.add_argument('--batch_size', type=int, default=100, help='Batch size.') 206 parser.add_argument( 207 '--train_dir', 208 type=str, 209 default='/tmp/data', 210 help='Directory with the training data.') 211 FLAGS, unparsed = parser.parse_known_args() 212 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 213