1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Converts MNIST data to TFRecords file format with Example protos.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import argparse 22import os 23import sys 24 25import tensorflow as tf 26 27from tensorflow.contrib.learn.python.learn.datasets import mnist 28 29FLAGS = None 30 31 32def _int64_feature(value): 33 return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 34 35 36def _bytes_feature(value): 37 return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 38 39 40def convert_to(data_set, name): 41 """Converts a dataset to tfrecords.""" 42 images = data_set.images 43 labels = data_set.labels 44 num_examples = data_set.num_examples 45 46 if images.shape[0] != num_examples: 47 raise ValueError('Images size %d does not match label size %d.' % 48 (images.shape[0], num_examples)) 49 rows = images.shape[1] 50 cols = images.shape[2] 51 depth = images.shape[3] 52 53 filename = os.path.join(FLAGS.directory, name + '.tfrecords') 54 print('Writing', filename) 55 with tf.python_io.TFRecordWriter(filename) as writer: 56 for index in range(num_examples): 57 image_raw = images[index].tostring() 58 example = tf.train.Example( 59 features=tf.train.Features( 60 feature={ 61 'height': _int64_feature(rows), 62 'width': _int64_feature(cols), 63 'depth': _int64_feature(depth), 64 'label': _int64_feature(int(labels[index])), 65 'image_raw': _bytes_feature(image_raw) 66 })) 67 writer.write(example.SerializeToString()) 68 69 70def main(unused_argv): 71 # Get the data. 72 data_sets = mnist.read_data_sets(FLAGS.directory, 73 dtype=tf.uint8, 74 reshape=False, 75 validation_size=FLAGS.validation_size) 76 77 # Convert to Examples and write the result to TFRecords. 78 convert_to(data_sets.train, 'train') 79 convert_to(data_sets.validation, 'validation') 80 convert_to(data_sets.test, 'test') 81 82 83if __name__ == '__main__': 84 parser = argparse.ArgumentParser() 85 parser.add_argument( 86 '--directory', 87 type=str, 88 default='/tmp/data', 89 help='Directory to download data files and write the converted result' 90 ) 91 parser.add_argument( 92 '--validation_size', 93 type=int, 94 default=5000, 95 help="""\ 96 Number of examples to separate from the training data for the validation 97 set.\ 98 """ 99 ) 100 FLAGS, unparsed = parser.parse_known_args() 101 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 102