1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Converts MNIST data to TFRecords file format with Example protos."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import argparse
22import os
23import sys
24
25import tensorflow as tf
26
27from tensorflow.contrib.learn.python.learn.datasets import mnist
28
29FLAGS = None
30
31
32def _int64_feature(value):
33  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
34
35
36def _bytes_feature(value):
37  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
38
39
40def convert_to(data_set, name):
41  """Converts a dataset to tfrecords."""
42  images = data_set.images
43  labels = data_set.labels
44  num_examples = data_set.num_examples
45
46  if images.shape[0] != num_examples:
47    raise ValueError('Images size %d does not match label size %d.' %
48                     (images.shape[0], num_examples))
49  rows = images.shape[1]
50  cols = images.shape[2]
51  depth = images.shape[3]
52
53  filename = os.path.join(FLAGS.directory, name + '.tfrecords')
54  print('Writing', filename)
55  with tf.python_io.TFRecordWriter(filename) as writer:
56    for index in range(num_examples):
57      image_raw = images[index].tostring()
58      example = tf.train.Example(
59          features=tf.train.Features(
60              feature={
61                  'height': _int64_feature(rows),
62                  'width': _int64_feature(cols),
63                  'depth': _int64_feature(depth),
64                  'label': _int64_feature(int(labels[index])),
65                  'image_raw': _bytes_feature(image_raw)
66              }))
67      writer.write(example.SerializeToString())
68
69
70def main(unused_argv):
71  # Get the data.
72  data_sets = mnist.read_data_sets(FLAGS.directory,
73                                   dtype=tf.uint8,
74                                   reshape=False,
75                                   validation_size=FLAGS.validation_size)
76
77  # Convert to Examples and write the result to TFRecords.
78  convert_to(data_sets.train, 'train')
79  convert_to(data_sets.validation, 'validation')
80  convert_to(data_sets.test, 'test')
81
82
83if __name__ == '__main__':
84  parser = argparse.ArgumentParser()
85  parser.add_argument(
86      '--directory',
87      type=str,
88      default='/tmp/data',
89      help='Directory to download data files and write the converted result'
90  )
91  parser.add_argument(
92      '--validation_size',
93      type=int,
94      default=5000,
95      help="""\
96      Number of examples to separate from the training data for the validation
97      set.\
98      """
99  )
100  FLAGS, unparsed = parser.parse_known_args()
101  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
102