1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Train and Eval the MNIST network.
16
17This version is like fully_connected_feed.py but uses data converted
18to a TFRecords file containing tf.train.Example protocol buffers.
19See:
20https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files
21for context.
22
23YOU MUST run convert_to_records before running this (but you only need to
24run it once).
25"""
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30import argparse
31import os.path
32import sys
33import time
34
35import tensorflow as tf
36
37from tensorflow.examples.tutorials.mnist import mnist
38
39# Basic model parameters as external flags.
40FLAGS = None
41
42# Constants used for dealing with the files, matches convert_to_records.
43TRAIN_FILE = 'train.tfrecords'
44VALIDATION_FILE = 'validation.tfrecords'
45
46
47def decode(serialized_example):
48  features = tf.parse_single_example(
49      serialized_example,
50      # Defaults are not specified since both keys are required.
51      features={
52          'image_raw': tf.FixedLenFeature([], tf.string),
53          'label': tf.FixedLenFeature([], tf.int64),
54      })
55
56  # Convert from a scalar string tensor (whose single string has
57  # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
58  # [mnist.IMAGE_PIXELS].
59  image = tf.decode_raw(features['image_raw'], tf.uint8)
60  image.set_shape((mnist.IMAGE_PIXELS))
61
62  # Convert label from a scalar uint8 tensor to an int32 scalar.
63  label = tf.cast(features['label'], tf.int32)
64
65  return image, label
66
67
68def augment(image, label):
69  # OPTIONAL: Could reshape into a 28x28 image and apply distortions
70  # here.  Since we are not applying any distortions in this
71  # example, and the next step expects the image to be flattened
72  # into a vector, we don't bother.
73  return image, label
74
75
76def normalize(image, label):
77  # Convert from [0, 255] -> [-0.5, 0.5] floats.
78  image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
79
80  return image, label
81
82
83def inputs(train, batch_size, num_epochs):
84  """Reads input data num_epochs times.
85
86  Args:
87    train: Selects between the training (True) and validation (False) data.
88    batch_size: Number of examples per returned batch.
89    num_epochs: Number of times to read the input data, or 0/None to
90       train forever.
91
92  Returns:
93    A tuple (images, labels), where:
94    * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
95      in the range [-0.5, 0.5].
96    * labels is an int32 tensor with shape [batch_size] with the true label,
97      a number in the range [0, mnist.NUM_CLASSES).
98
99    This function creates a one_shot_iterator, meaning that it will only iterate
100    over the dataset once. On the other hand there is no special initialization
101    required.
102  """
103  if not num_epochs:
104    num_epochs = None
105  filename = os.path.join(FLAGS.train_dir, TRAIN_FILE
106                          if train else VALIDATION_FILE)
107
108  with tf.name_scope('input'):
109    # TFRecordDataset opens a protobuf and reads entries line by line
110    # could also be [list, of, filenames]
111    dataset = tf.data.TFRecordDataset(filename)
112    dataset = dataset.repeat(num_epochs)
113
114    # map takes a python function and applies it to every sample
115    dataset = dataset.map(decode)
116    dataset = dataset.map(augment)
117    dataset = dataset.map(normalize)
118
119    #the parameter is the queue size
120    dataset = dataset.shuffle(1000 + 3 * batch_size)
121    dataset = dataset.batch(batch_size)
122
123    iterator = dataset.make_one_shot_iterator()
124  return iterator.get_next()
125
126
127def run_training():
128  """Train MNIST for a number of steps."""
129
130  # Tell TensorFlow that the model will be built into the default Graph.
131  with tf.Graph().as_default():
132    # Input images and labels.
133    image_batch, label_batch = inputs(
134        train=True, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs)
135
136    # Build a Graph that computes predictions from the inference model.
137    logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2)
138
139    # Add to the Graph the loss calculation.
140    loss = mnist.loss(logits, label_batch)
141
142    # Add to the Graph operations that train the model.
143    train_op = mnist.training(loss, FLAGS.learning_rate)
144
145    # The op for initializing the variables.
146    init_op = tf.group(tf.global_variables_initializer(),
147                       tf.local_variables_initializer())
148
149    # Create a session for running operations in the Graph.
150    with tf.Session() as sess:
151      # Initialize the variables (the trained variables and the
152      # epoch counter).
153      sess.run(init_op)
154      try:
155        step = 0
156        while True:  #train until OutOfRangeError
157          start_time = time.time()
158
159          # Run one step of the model.  The return values are
160          # the activations from the `train_op` (which is
161          # discarded) and the `loss` op.  To inspect the values
162          # of your ops or variables, you may include them in
163          # the list passed to sess.run() and the value tensors
164          # will be returned in the tuple from the call.
165          _, loss_value = sess.run([train_op, loss])
166
167          duration = time.time() - start_time
168
169          # Print an overview fairly often.
170          if step % 100 == 0:
171            print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
172                                                       duration))
173          step += 1
174      except tf.errors.OutOfRangeError:
175        print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs,
176                                                          step))
177
178
179def main(_):
180  run_training()
181
182
183if __name__ == '__main__':
184  parser = argparse.ArgumentParser()
185  parser.add_argument(
186      '--learning_rate',
187      type=float,
188      default=0.01,
189      help='Initial learning rate.')
190  parser.add_argument(
191      '--num_epochs',
192      type=int,
193      default=2,
194      help='Number of epochs to run trainer.')
195  parser.add_argument(
196      '--hidden1',
197      type=int,
198      default=128,
199      help='Number of units in hidden layer 1.')
200  parser.add_argument(
201      '--hidden2',
202      type=int,
203      default=32,
204      help='Number of units in hidden layer 2.')
205  parser.add_argument('--batch_size', type=int, default=100, help='Batch size.')
206  parser.add_argument(
207      '--train_dir',
208      type=str,
209      default='/tmp/data',
210      help='Directory with the training data.')
211  FLAGS, unparsed = parser.parse_known_args()
212  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
213