1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Reader ops from io_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import gzip
23import os
24import shutil
25import threading
26import zlib
27
28import six
29
30from tensorflow.core.protobuf import config_pb2
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors_impl
33from tensorflow.python.lib.io import tf_record
34from tensorflow.python.ops import data_flow_ops
35from tensorflow.python.ops import io_ops
36from tensorflow.python.ops import variables
37from tensorflow.python.platform import test
38from tensorflow.python.training import coordinator
39from tensorflow.python.training import input as input_lib
40from tensorflow.python.training import queue_runner_impl
41from tensorflow.python.util import compat
42
43prefix_path = "tensorflow/core/lib"
44
45# pylint: disable=invalid-name
46TFRecordCompressionType = tf_record.TFRecordCompressionType
47# pylint: enable=invalid-name
48
49# Edgar Allan Poe's 'Eldorado'
50_TEXT = b"""Gaily bedight,
51    A gallant knight,
52    In sunshine and in shadow,
53    Had journeyed long,
54    Singing a song,
55    In search of Eldorado.
56
57    But he grew old
58    This knight so bold
59    And o'er his heart a shadow
60    Fell as he found
61    No spot of ground
62    That looked like Eldorado.
63
64   And, as his strength
65   Failed him at length,
66   He met a pilgrim shadow
67   'Shadow,' said he,
68   'Where can it be
69   This land of Eldorado?'
70
71   'Over the Mountains
72    Of the Moon'
73    Down the Valley of the Shadow,
74    Ride, boldly ride,'
75    The shade replied,
76    'If you seek for Eldorado!'
77    """
78
79
80class IdentityReaderTest(test.TestCase):
81
82  def _ExpectRead(self, sess, key, value, expected):
83    k, v = sess.run([key, value])
84    self.assertAllEqual(expected, k)
85    self.assertAllEqual(expected, v)
86
87  def testOneEpoch(self):
88    with self.test_session() as sess:
89      reader = io_ops.IdentityReader("test_reader")
90      work_completed = reader.num_work_units_completed()
91      produced = reader.num_records_produced()
92      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
93      queued_length = queue.size()
94      key, value = reader.read(queue)
95
96      self.assertAllEqual(0, work_completed.eval())
97      self.assertAllEqual(0, produced.eval())
98      self.assertAllEqual(0, queued_length.eval())
99
100      queue.enqueue_many([["A", "B", "C"]]).run()
101      queue.close().run()
102      self.assertAllEqual(3, queued_length.eval())
103
104      self._ExpectRead(sess, key, value, b"A")
105      self.assertAllEqual(1, produced.eval())
106
107      self._ExpectRead(sess, key, value, b"B")
108
109      self._ExpectRead(sess, key, value, b"C")
110      self.assertAllEqual(3, produced.eval())
111      self.assertAllEqual(0, queued_length.eval())
112
113      with self.assertRaisesOpError("is closed and has insufficient elements "
114                                    "\\(requested 1, current size 0\\)"):
115        sess.run([key, value])
116
117      self.assertAllEqual(3, work_completed.eval())
118      self.assertAllEqual(3, produced.eval())
119      self.assertAllEqual(0, queued_length.eval())
120
121  def testMultipleEpochs(self):
122    with self.test_session() as sess:
123      reader = io_ops.IdentityReader("test_reader")
124      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
125      enqueue = queue.enqueue_many([["DD", "EE"]])
126      key, value = reader.read(queue)
127
128      enqueue.run()
129      self._ExpectRead(sess, key, value, b"DD")
130      self._ExpectRead(sess, key, value, b"EE")
131      enqueue.run()
132      self._ExpectRead(sess, key, value, b"DD")
133      self._ExpectRead(sess, key, value, b"EE")
134      enqueue.run()
135      self._ExpectRead(sess, key, value, b"DD")
136      self._ExpectRead(sess, key, value, b"EE")
137      queue.close().run()
138      with self.assertRaisesOpError("is closed and has insufficient elements "
139                                    "\\(requested 1, current size 0\\)"):
140        sess.run([key, value])
141
142  def testSerializeRestore(self):
143    with self.test_session() as sess:
144      reader = io_ops.IdentityReader("test_reader")
145      produced = reader.num_records_produced()
146      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
147      queue.enqueue_many([["X", "Y", "Z"]]).run()
148      key, value = reader.read(queue)
149
150      self._ExpectRead(sess, key, value, b"X")
151      self.assertAllEqual(1, produced.eval())
152      state = reader.serialize_state().eval()
153
154      self._ExpectRead(sess, key, value, b"Y")
155      self._ExpectRead(sess, key, value, b"Z")
156      self.assertAllEqual(3, produced.eval())
157
158      queue.enqueue_many([["Y", "Z"]]).run()
159      queue.close().run()
160      reader.restore_state(state).run()
161      self.assertAllEqual(1, produced.eval())
162      self._ExpectRead(sess, key, value, b"Y")
163      self._ExpectRead(sess, key, value, b"Z")
164      with self.assertRaisesOpError("is closed and has insufficient elements "
165                                    "\\(requested 1, current size 0\\)"):
166        sess.run([key, value])
167      self.assertAllEqual(3, produced.eval())
168
169      self.assertEqual(bytes, type(state))
170
171      with self.assertRaises(ValueError):
172        reader.restore_state([])
173
174      with self.assertRaises(ValueError):
175        reader.restore_state([state, state])
176
177      with self.assertRaisesOpError(
178          "Could not parse state for IdentityReader 'test_reader'"):
179        reader.restore_state(state[1:]).run()
180
181      with self.assertRaisesOpError(
182          "Could not parse state for IdentityReader 'test_reader'"):
183        reader.restore_state(state[:-1]).run()
184
185      with self.assertRaisesOpError(
186          "Could not parse state for IdentityReader 'test_reader'"):
187        reader.restore_state(state + b"ExtraJunk").run()
188
189      with self.assertRaisesOpError(
190          "Could not parse state for IdentityReader 'test_reader'"):
191        reader.restore_state(b"PREFIX" + state).run()
192
193      with self.assertRaisesOpError(
194          "Could not parse state for IdentityReader 'test_reader'"):
195        reader.restore_state(b"BOGUS" + state[5:]).run()
196
197  def testReset(self):
198    with self.test_session() as sess:
199      reader = io_ops.IdentityReader("test_reader")
200      work_completed = reader.num_work_units_completed()
201      produced = reader.num_records_produced()
202      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
203      queued_length = queue.size()
204      key, value = reader.read(queue)
205
206      queue.enqueue_many([["X", "Y", "Z"]]).run()
207      self._ExpectRead(sess, key, value, b"X")
208      self.assertLess(0, queued_length.eval())
209      self.assertAllEqual(1, produced.eval())
210
211      self._ExpectRead(sess, key, value, b"Y")
212      self.assertLess(0, work_completed.eval())
213      self.assertAllEqual(2, produced.eval())
214
215      reader.reset().run()
216      self.assertAllEqual(0, work_completed.eval())
217      self.assertAllEqual(0, produced.eval())
218      self.assertAllEqual(1, queued_length.eval())
219      self._ExpectRead(sess, key, value, b"Z")
220
221      queue.enqueue_many([["K", "L"]]).run()
222      self._ExpectRead(sess, key, value, b"K")
223
224
225class WholeFileReaderTest(test.TestCase):
226
227  def setUp(self):
228    super(WholeFileReaderTest, self).setUp()
229    self._filenames = [
230        os.path.join(self.get_temp_dir(), "whole_file.%d.txt" % i)
231        for i in range(3)
232    ]
233    self._content = [b"One\na\nb\n", b"Two\nC\nD", b"Three x, y, z"]
234    for fn, c in zip(self._filenames, self._content):
235      with open(fn, "wb") as h:
236        h.write(c)
237
238  def tearDown(self):
239    for fn in self._filenames:
240      os.remove(fn)
241    super(WholeFileReaderTest, self).tearDown()
242
243  def _ExpectRead(self, sess, key, value, index):
244    k, v = sess.run([key, value])
245    self.assertAllEqual(compat.as_bytes(self._filenames[index]), k)
246    self.assertAllEqual(self._content[index], v)
247
248  def testOneEpoch(self):
249    with self.test_session() as sess:
250      reader = io_ops.WholeFileReader("test_reader")
251      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
252      queue.enqueue_many([self._filenames]).run()
253      queue.close().run()
254      key, value = reader.read(queue)
255
256      self._ExpectRead(sess, key, value, 0)
257      self._ExpectRead(sess, key, value, 1)
258      self._ExpectRead(sess, key, value, 2)
259
260      with self.assertRaisesOpError("is closed and has insufficient elements "
261                                    "\\(requested 1, current size 0\\)"):
262        sess.run([key, value])
263
264  def testInfiniteEpochs(self):
265    with self.test_session() as sess:
266      reader = io_ops.WholeFileReader("test_reader")
267      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
268      enqueue = queue.enqueue_many([self._filenames])
269      key, value = reader.read(queue)
270
271      enqueue.run()
272      self._ExpectRead(sess, key, value, 0)
273      self._ExpectRead(sess, key, value, 1)
274      enqueue.run()
275      self._ExpectRead(sess, key, value, 2)
276      self._ExpectRead(sess, key, value, 0)
277      self._ExpectRead(sess, key, value, 1)
278      enqueue.run()
279      self._ExpectRead(sess, key, value, 2)
280      self._ExpectRead(sess, key, value, 0)
281
282
283class TextLineReaderTest(test.TestCase):
284
285  def setUp(self):
286    super(TextLineReaderTest, self).setUp()
287    self._num_files = 2
288    self._num_lines = 5
289
290  def _LineText(self, f, l):
291    return compat.as_bytes("%d: %d" % (f, l))
292
293  def _CreateFiles(self, crlf=False):
294    filenames = []
295    for i in range(self._num_files):
296      fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
297      filenames.append(fn)
298      with open(fn, "wb") as f:
299        for j in range(self._num_lines):
300          f.write(self._LineText(i, j))
301          # Always include a newline after the record unless it is
302          # at the end of the file, in which case we include it sometimes.
303          if j + 1 != self._num_lines or i == 0:
304            f.write(b"\r\n" if crlf else b"\n")
305    return filenames
306
307  def _testOneEpoch(self, files):
308    with self.test_session() as sess:
309      reader = io_ops.TextLineReader(name="test_reader")
310      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
311      key, value = reader.read(queue)
312
313      queue.enqueue_many([files]).run()
314      queue.close().run()
315      for i in range(self._num_files):
316        for j in range(self._num_lines):
317          k, v = sess.run([key, value])
318          self.assertAllEqual("%s:%d" % (files[i], j + 1), compat.as_text(k))
319          self.assertAllEqual(self._LineText(i, j), v)
320
321      with self.assertRaisesOpError("is closed and has insufficient elements "
322                                    "\\(requested 1, current size 0\\)"):
323        k, v = sess.run([key, value])
324
325  def testOneEpochLF(self):
326    self._testOneEpoch(self._CreateFiles(crlf=False))
327
328  def testOneEpochCRLF(self):
329    self._testOneEpoch(self._CreateFiles(crlf=True))
330
331  def testSkipHeaderLines(self):
332    files = self._CreateFiles()
333    with self.test_session() as sess:
334      reader = io_ops.TextLineReader(skip_header_lines=1, name="test_reader")
335      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
336      key, value = reader.read(queue)
337
338      queue.enqueue_many([files]).run()
339      queue.close().run()
340      for i in range(self._num_files):
341        for j in range(self._num_lines - 1):
342          k, v = sess.run([key, value])
343          self.assertAllEqual("%s:%d" % (files[i], j + 2), compat.as_text(k))
344          self.assertAllEqual(self._LineText(i, j + 1), v)
345
346      with self.assertRaisesOpError("is closed and has insufficient elements "
347                                    "\\(requested 1, current size 0\\)"):
348        k, v = sess.run([key, value])
349
350
351class FixedLengthRecordReaderTest(test.TestCase):
352
353  def setUp(self):
354    super(FixedLengthRecordReaderTest, self).setUp()
355    self._num_files = 2
356    self._header_bytes = 5
357    self._record_bytes = 3
358    self._footer_bytes = 2
359
360    self._hop_bytes = 2
361
362  def _Record(self, f, r):
363    return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
364
365  def _OverlappedRecord(self, f, r):
366    record_str = "".join([
367        str(i)[0]
368        for i in range(r * self._hop_bytes,
369                       r * self._hop_bytes + self._record_bytes)
370    ])
371    return compat.as_bytes(record_str)
372
373  # gap_bytes=hop_bytes-record_bytes
374  def _CreateFiles(self, num_records, gap_bytes):
375    filenames = []
376    for i in range(self._num_files):
377      fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
378      filenames.append(fn)
379      with open(fn, "wb") as f:
380        f.write(b"H" * self._header_bytes)
381        if num_records > 0:
382          f.write(self._Record(i, 0))
383        for j in range(1, num_records):
384          if gap_bytes > 0:
385            f.write(b"G" * gap_bytes)
386          f.write(self._Record(i, j))
387        f.write(b"F" * self._footer_bytes)
388    return filenames
389
390  def _CreateOverlappedRecordFiles(self, num_overlapped_records):
391    filenames = []
392    for i in range(self._num_files):
393      fn = os.path.join(self.get_temp_dir(),
394                        "fixed_length_overlapped_record.%d.txt" % i)
395      filenames.append(fn)
396      with open(fn, "wb") as f:
397        f.write(b"H" * self._header_bytes)
398        if num_overlapped_records > 0:
399          all_records_str = "".join([
400              str(i)[0]
401              for i in range(self._record_bytes + self._hop_bytes *
402                             (num_overlapped_records - 1))
403          ])
404          f.write(compat.as_bytes(all_records_str))
405        f.write(b"F" * self._footer_bytes)
406    return filenames
407
408  # gap_bytes=hop_bytes-record_bytes
409  def _CreateGzipFiles(self, num_records, gap_bytes):
410    filenames = []
411    for i in range(self._num_files):
412      fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
413      filenames.append(fn)
414      with gzip.GzipFile(fn, "wb") as f:
415        f.write(b"H" * self._header_bytes)
416        if num_records > 0:
417          f.write(self._Record(i, 0))
418        for j in range(1, num_records):
419          if gap_bytes > 0:
420            f.write(b"G" * gap_bytes)
421          f.write(self._Record(i, j))
422        f.write(b"F" * self._footer_bytes)
423    return filenames
424
425  # gap_bytes=hop_bytes-record_bytes
426  def _CreateZlibFiles(self, num_records, gap_bytes):
427    filenames = []
428    for i in range(self._num_files):
429      fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
430      filenames.append(fn)
431      with open(fn + ".tmp", "wb") as f:
432        f.write(b"H" * self._header_bytes)
433        if num_records > 0:
434          f.write(self._Record(i, 0))
435        for j in range(1, num_records):
436          if gap_bytes > 0:
437            f.write(b"G" * gap_bytes)
438          f.write(self._Record(i, j))
439        f.write(b"F" * self._footer_bytes)
440      with open(fn + ".tmp", "rb") as f:
441        cdata = zlib.compress(f.read())
442        with open(fn, "wb") as zf:
443          zf.write(cdata)
444    return filenames
445
446  def _CreateGzipOverlappedRecordFiles(self, num_overlapped_records):
447    filenames = []
448    for i in range(self._num_files):
449      fn = os.path.join(self.get_temp_dir(),
450                        "fixed_length_overlapped_record.%d.txt" % i)
451      filenames.append(fn)
452      with gzip.GzipFile(fn, "wb") as f:
453        f.write(b"H" * self._header_bytes)
454        if num_overlapped_records > 0:
455          all_records_str = "".join([
456              str(i)[0]
457              for i in range(self._record_bytes + self._hop_bytes *
458                             (num_overlapped_records - 1))
459          ])
460          f.write(compat.as_bytes(all_records_str))
461        f.write(b"F" * self._footer_bytes)
462    return filenames
463
464  def _CreateZlibOverlappedRecordFiles(self, num_overlapped_records):
465    filenames = []
466    for i in range(self._num_files):
467      fn = os.path.join(self.get_temp_dir(),
468                        "fixed_length_overlapped_record.%d.txt" % i)
469      filenames.append(fn)
470      with open(fn + ".tmp", "wb") as f:
471        f.write(b"H" * self._header_bytes)
472        if num_overlapped_records > 0:
473          all_records_str = "".join([
474              str(i)[0]
475              for i in range(self._record_bytes + self._hop_bytes *
476                             (num_overlapped_records - 1))
477          ])
478          f.write(compat.as_bytes(all_records_str))
479        f.write(b"F" * self._footer_bytes)
480      with open(fn + ".tmp", "rb") as f:
481        cdata = zlib.compress(f.read())
482        with open(fn, "wb") as zf:
483          zf.write(cdata)
484    return filenames
485
486  # gap_bytes=hop_bytes-record_bytes
487  def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None):
488    hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes
489    with self.test_session() as sess:
490      reader = io_ops.FixedLengthRecordReader(
491          header_bytes=self._header_bytes,
492          record_bytes=self._record_bytes,
493          footer_bytes=self._footer_bytes,
494          hop_bytes=hop_bytes,
495          encoding=encoding,
496          name="test_reader")
497      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
498      key, value = reader.read(queue)
499
500      queue.enqueue_many([files]).run()
501      queue.close().run()
502      for i in range(self._num_files):
503        for j in range(num_records):
504          k, v = sess.run([key, value])
505          self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
506          self.assertAllEqual(self._Record(i, j), v)
507
508      with self.assertRaisesOpError("is closed and has insufficient elements "
509                                    "\\(requested 1, current size 0\\)"):
510        k, v = sess.run([key, value])
511
512  def _TestOneEpochWithHopBytes(self,
513                                files,
514                                num_overlapped_records,
515                                encoding=None):
516    with self.test_session() as sess:
517      reader = io_ops.FixedLengthRecordReader(
518          header_bytes=self._header_bytes,
519          record_bytes=self._record_bytes,
520          footer_bytes=self._footer_bytes,
521          hop_bytes=self._hop_bytes,
522          encoding=encoding,
523          name="test_reader")
524      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
525      key, value = reader.read(queue)
526
527      queue.enqueue_many([files]).run()
528      queue.close().run()
529      for i in range(self._num_files):
530        for j in range(num_overlapped_records):
531          k, v = sess.run([key, value])
532          print(v)
533          self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
534          self.assertAllEqual(self._OverlappedRecord(i, j), v)
535
536      with self.assertRaisesOpError("is closed and has insufficient elements "
537                                    "\\(requested 1, current size 0\\)"):
538        k, v = sess.run([key, value])
539
540  def testOneEpoch(self):
541    for num_records in [0, 7]:
542      # gap_bytes=0: hop_bytes=0
543      # gap_bytes=1: hop_bytes=record_bytes+1
544      for gap_bytes in [0, 1]:
545        files = self._CreateFiles(num_records, gap_bytes)
546        self._TestOneEpoch(files, num_records, gap_bytes)
547
548  def testGzipOneEpoch(self):
549    for num_records in [0, 7]:
550      # gap_bytes=0: hop_bytes=0
551      # gap_bytes=1: hop_bytes=record_bytes+1
552      for gap_bytes in [0, 1]:
553        files = self._CreateGzipFiles(num_records, gap_bytes)
554        self._TestOneEpoch(files, num_records, gap_bytes, encoding="GZIP")
555
556  def testZlibOneEpoch(self):
557    for num_records in [0, 7]:
558      # gap_bytes=0: hop_bytes=0
559      # gap_bytes=1: hop_bytes=record_bytes+1
560      for gap_bytes in [0, 1]:
561        files = self._CreateZlibFiles(num_records, gap_bytes)
562        self._TestOneEpoch(files, num_records, gap_bytes, encoding="ZLIB")
563
564  def testOneEpochWithHopBytes(self):
565    for num_overlapped_records in [0, 2]:
566      files = self._CreateOverlappedRecordFiles(num_overlapped_records)
567      self._TestOneEpochWithHopBytes(files, num_overlapped_records)
568
569  def testGzipOneEpochWithHopBytes(self):
570    for num_overlapped_records in [0, 2]:
571      files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records,)
572      self._TestOneEpochWithHopBytes(
573          files, num_overlapped_records, encoding="GZIP")
574
575  def testZlibOneEpochWithHopBytes(self):
576    for num_overlapped_records in [0, 2]:
577      files = self._CreateZlibOverlappedRecordFiles(num_overlapped_records)
578      self._TestOneEpochWithHopBytes(
579          files, num_overlapped_records, encoding="ZLIB")
580
581
582class TFRecordReaderTest(test.TestCase):
583
584  def setUp(self):
585    super(TFRecordReaderTest, self).setUp()
586    self._num_files = 2
587    self._num_records = 7
588
589  def _Record(self, f, r):
590    return compat.as_bytes("Record %d of file %d" % (r, f))
591
592  def _CreateFiles(self):
593    filenames = []
594    for i in range(self._num_files):
595      fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
596      filenames.append(fn)
597      writer = tf_record.TFRecordWriter(fn)
598      for j in range(self._num_records):
599        writer.write(self._Record(i, j))
600    return filenames
601
602  def testOneEpoch(self):
603    files = self._CreateFiles()
604    with self.test_session() as sess:
605      reader = io_ops.TFRecordReader(name="test_reader")
606      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
607      key, value = reader.read(queue)
608
609      queue.enqueue_many([files]).run()
610      queue.close().run()
611      for i in range(self._num_files):
612        for j in range(self._num_records):
613          k, v = sess.run([key, value])
614          self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
615          self.assertAllEqual(self._Record(i, j), v)
616
617      with self.assertRaisesOpError("is closed and has insufficient elements "
618                                    "\\(requested 1, current size 0\\)"):
619        k, v = sess.run([key, value])
620
621  def testReadUpTo(self):
622    files = self._CreateFiles()
623    with self.test_session() as sess:
624      reader = io_ops.TFRecordReader(name="test_reader")
625      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
626      batch_size = 3
627      key, value = reader.read_up_to(queue, batch_size)
628
629      queue.enqueue_many([files]).run()
630      queue.close().run()
631      num_k = 0
632      num_v = 0
633
634      while True:
635        try:
636          k, v = sess.run([key, value])
637          # Test reading *up to* batch_size records
638          self.assertLessEqual(len(k), batch_size)
639          self.assertLessEqual(len(v), batch_size)
640          num_k += len(k)
641          num_v += len(v)
642        except errors_impl.OutOfRangeError:
643          break
644
645      # Test that we have read everything
646      self.assertEqual(self._num_files * self._num_records, num_k)
647      self.assertEqual(self._num_files * self._num_records, num_v)
648
649  def testReadZlibFiles(self):
650    files = self._CreateFiles()
651    zlib_files = []
652    for i, fn in enumerate(files):
653      with open(fn, "rb") as f:
654        cdata = zlib.compress(f.read())
655
656        zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
657        with open(zfn, "wb") as f:
658          f.write(cdata)
659        zlib_files.append(zfn)
660
661    with self.test_session() as sess:
662      options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB)
663      reader = io_ops.TFRecordReader(name="test_reader", options=options)
664      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
665      key, value = reader.read(queue)
666
667      queue.enqueue_many([zlib_files]).run()
668      queue.close().run()
669      for i in range(self._num_files):
670        for j in range(self._num_records):
671          k, v = sess.run([key, value])
672          self.assertTrue(compat.as_text(k).startswith("%s:" % zlib_files[i]))
673          self.assertAllEqual(self._Record(i, j), v)
674
675  def testReadGzipFiles(self):
676    files = self._CreateFiles()
677    gzip_files = []
678    for i, fn in enumerate(files):
679      with open(fn, "rb") as f:
680        cdata = f.read()
681
682        zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
683        with gzip.GzipFile(zfn, "wb") as f:
684          f.write(cdata)
685        gzip_files.append(zfn)
686
687    with self.test_session() as sess:
688      options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)
689      reader = io_ops.TFRecordReader(name="test_reader", options=options)
690      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
691      key, value = reader.read(queue)
692
693      queue.enqueue_many([gzip_files]).run()
694      queue.close().run()
695      for i in range(self._num_files):
696        for j in range(self._num_records):
697          k, v = sess.run([key, value])
698          self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i]))
699          self.assertAllEqual(self._Record(i, j), v)
700
701
702class TFRecordWriterZlibTest(test.TestCase):
703
704  def setUp(self):
705    super(TFRecordWriterZlibTest, self).setUp()
706    self._num_files = 2
707    self._num_records = 7
708
709  def _Record(self, f, r):
710    return compat.as_bytes("Record %d of file %d" % (r, f))
711
712  def _CreateFiles(self):
713    filenames = []
714    for i in range(self._num_files):
715      fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
716      filenames.append(fn)
717      options = tf_record.TFRecordOptions(
718          compression_type=TFRecordCompressionType.ZLIB)
719      writer = tf_record.TFRecordWriter(fn, options=options)
720      for j in range(self._num_records):
721        writer.write(self._Record(i, j))
722      writer.close()
723      del writer
724
725    return filenames
726
727  def _WriteRecordsToFile(self, records, name="tf_record"):
728    fn = os.path.join(self.get_temp_dir(), name)
729    writer = tf_record.TFRecordWriter(fn, options=None)
730    for r in records:
731      writer.write(r)
732    writer.close()
733    del writer
734    return fn
735
736  def _ZlibCompressFile(self, infile, name="tfrecord.z"):
737    # zlib compress the file and write compressed contents to file.
738    with open(infile, "rb") as f:
739      cdata = zlib.compress(f.read())
740
741    zfn = os.path.join(self.get_temp_dir(), name)
742    with open(zfn, "wb") as f:
743      f.write(cdata)
744    return zfn
745
746  def testOneEpoch(self):
747    files = self._CreateFiles()
748    with self.test_session() as sess:
749      options = tf_record.TFRecordOptions(
750          compression_type=TFRecordCompressionType.ZLIB)
751      reader = io_ops.TFRecordReader(name="test_reader", options=options)
752      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
753      key, value = reader.read(queue)
754
755      queue.enqueue_many([files]).run()
756      queue.close().run()
757      for i in range(self._num_files):
758        for j in range(self._num_records):
759          k, v = sess.run([key, value])
760          self.assertTrue(compat.as_text(k).startswith("%s:" % files[i]))
761          self.assertAllEqual(self._Record(i, j), v)
762
763      with self.assertRaisesOpError("is closed and has insufficient elements "
764                                    "\\(requested 1, current size 0\\)"):
765        k, v = sess.run([key, value])
766
767  def testZLibFlushRecord(self):
768    fn = self._WriteRecordsToFile([b"small record"], "small_record")
769    with open(fn, "rb") as h:
770      buff = h.read()
771
772    # creating more blocks and trailing blocks shouldn't break reads
773    compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS)
774
775    output = b""
776    for c in buff:
777      if isinstance(c, int):
778        c = six.int2byte(c)
779      output += compressor.compress(c)
780      output += compressor.flush(zlib.Z_FULL_FLUSH)
781
782    output += compressor.flush(zlib.Z_FULL_FLUSH)
783    output += compressor.flush(zlib.Z_FULL_FLUSH)
784    output += compressor.flush(zlib.Z_FINISH)
785
786    # overwrite the original file with the compressed data
787    with open(fn, "wb") as h:
788      h.write(output)
789
790    with self.test_session() as sess:
791      options = tf_record.TFRecordOptions(
792          compression_type=TFRecordCompressionType.ZLIB)
793      reader = io_ops.TFRecordReader(name="test_reader", options=options)
794      queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=())
795      key, value = reader.read(queue)
796      queue.enqueue(fn).run()
797      queue.close().run()
798      k, v = sess.run([key, value])
799      self.assertTrue(compat.as_text(k).startswith("%s:" % fn))
800      self.assertAllEqual(b"small record", v)
801
802  def testZlibReadWrite(self):
803    """Verify that files produced are zlib compatible."""
804    original = [b"foo", b"bar"]
805    fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord")
806    zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z")
807
808    # read the compressed contents and verify.
809    actual = []
810    for r in tf_record.tf_record_iterator(
811        zfn,
812        options=tf_record.TFRecordOptions(
813            tf_record.TFRecordCompressionType.ZLIB)):
814      actual.append(r)
815    self.assertEqual(actual, original)
816
817  def testZlibReadWriteLarge(self):
818    """Verify that writing large contents also works."""
819
820    # Make it large (about 5MB)
821    original = [_TEXT * 10240]
822    fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord")
823    zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z")
824
825    # read the compressed contents and verify.
826    actual = []
827    for r in tf_record.tf_record_iterator(
828        zfn,
829        options=tf_record.TFRecordOptions(
830            tf_record.TFRecordCompressionType.ZLIB)):
831      actual.append(r)
832    self.assertEqual(actual, original)
833
834  def testGzipReadWrite(self):
835    """Verify that files produced are gzip compatible."""
836    original = [b"foo", b"bar"]
837    fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord")
838
839    # gzip compress the file and write compressed contents to file.
840    with open(fn, "rb") as f:
841      cdata = f.read()
842    gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz")
843    with gzip.GzipFile(gzfn, "wb") as f:
844      f.write(cdata)
845
846    actual = []
847    for r in tf_record.tf_record_iterator(
848        gzfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)):
849      actual.append(r)
850    self.assertEqual(actual, original)
851
852
853class TFRecordIteratorTest(test.TestCase):
854
855  def setUp(self):
856    super(TFRecordIteratorTest, self).setUp()
857    self._num_records = 7
858
859  def _Record(self, r):
860    return compat.as_bytes("Record %d" % r)
861
862  def _WriteCompressedRecordsToFile(
863      self,
864      records,
865      name="tfrecord.z",
866      compression_type=tf_record.TFRecordCompressionType.ZLIB):
867    fn = os.path.join(self.get_temp_dir(), name)
868    options = tf_record.TFRecordOptions(compression_type=compression_type)
869    writer = tf_record.TFRecordWriter(fn, options=options)
870    for r in records:
871      writer.write(r)
872    writer.close()
873    del writer
874    return fn
875
876  def _ZlibDecompressFile(self, infile, name="tfrecord", wbits=zlib.MAX_WBITS):
877    with open(infile, "rb") as f:
878      cdata = zlib.decompress(f.read(), wbits)
879    zfn = os.path.join(self.get_temp_dir(), name)
880    with open(zfn, "wb") as f:
881      f.write(cdata)
882    return zfn
883
884  def testIterator(self):
885    fn = self._WriteCompressedRecordsToFile(
886        [self._Record(i) for i in range(self._num_records)],
887        "compressed_records")
888    options = tf_record.TFRecordOptions(
889        compression_type=TFRecordCompressionType.ZLIB)
890    reader = tf_record.tf_record_iterator(fn, options)
891    for i in range(self._num_records):
892      record = next(reader)
893      self.assertAllEqual(self._Record(i), record)
894    with self.assertRaises(StopIteration):
895      record = next(reader)
896
897  def testWriteZlibRead(self):
898    """Verify compression with TFRecordWriter is zlib library compatible."""
899    original = [b"foo", b"bar"]
900    fn = self._WriteCompressedRecordsToFile(original,
901                                            "write_zlib_read.tfrecord.z")
902    zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord")
903    actual = []
904    for r in tf_record.tf_record_iterator(zfn):
905      actual.append(r)
906    self.assertEqual(actual, original)
907
908  def testWriteZlibReadLarge(self):
909    """Verify compression for large records is zlib library compatible."""
910    # Make it large (about 5MB)
911    original = [_TEXT * 10240]
912    fn = self._WriteCompressedRecordsToFile(original,
913                                            "write_zlib_read_large.tfrecord.z")
914    zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tf_record")
915    actual = []
916    for r in tf_record.tf_record_iterator(zfn):
917      actual.append(r)
918    self.assertEqual(actual, original)
919
920  def testWriteGzipRead(self):
921    original = [b"foo", b"bar"]
922    fn = self._WriteCompressedRecordsToFile(
923        original,
924        "write_gzip_read.tfrecord.gz",
925        compression_type=TFRecordCompressionType.GZIP)
926
927    with gzip.GzipFile(fn, "rb") as f:
928      cdata = f.read()
929    zfn = os.path.join(self.get_temp_dir(), "tf_record")
930    with open(zfn, "wb") as f:
931      f.write(cdata)
932
933    actual = []
934    for r in tf_record.tf_record_iterator(zfn):
935      actual.append(r)
936    self.assertEqual(actual, original)
937
938  def testBadFile(self):
939    """Verify that tf_record_iterator throws an exception on bad TFRecords."""
940    fn = os.path.join(self.get_temp_dir(), "bad_file")
941    with tf_record.TFRecordWriter(fn) as writer:
942      writer.write(b"123")
943    fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated")
944    with open(fn, "rb") as f:
945      with open(fn_truncated, "wb") as f2:
946        # DataLossError requires that we've written the header, so this must
947        # be at least 12 bytes.
948        f2.write(f.read(14))
949    with self.assertRaises(errors_impl.DataLossError):
950      for _ in tf_record.tf_record_iterator(fn_truncated):
951        pass
952
953
954class AsyncReaderTest(test.TestCase):
955
956  def testNoDeadlockFromQueue(self):
957    """Tests that reading does not block main execution threads."""
958    config = config_pb2.ConfigProto(
959        inter_op_parallelism_threads=1, intra_op_parallelism_threads=1)
960    with self.test_session(config=config) as sess:
961      thread_data_t = collections.namedtuple("thread_data_t",
962                                             ["thread", "queue", "output"])
963      thread_data = []
964
965      # Create different readers, each with its own queue.
966      for i in range(3):
967        queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
968        reader = io_ops.TextLineReader()
969        _, line = reader.read(queue)
970        output = []
971        t = threading.Thread(
972            target=AsyncReaderTest._RunSessionAndSave,
973            args=(sess, [line], output))
974        thread_data.append(thread_data_t(t, queue, output))
975
976      # Start all readers. They are all blocked waiting for queue entries.
977      sess.run(variables.global_variables_initializer())
978      for d in thread_data:
979        d.thread.start()
980
981      # Unblock the readers.
982      for i, d in enumerate(reversed(thread_data)):
983        fname = os.path.join(self.get_temp_dir(), "deadlock.%s.txt" % i)
984        with open(fname, "wb") as f:
985          f.write(("file-%s" % i).encode())
986        d.queue.enqueue_many([[fname]]).run()
987        d.thread.join()
988        self.assertEqual([[("file-%s" % i).encode()]], d.output)
989
990  @staticmethod
991  def _RunSessionAndSave(sess, args, output):
992    output.append(sess.run(args))
993
994
995class LMDBReaderTest(test.TestCase):
996
997  def setUp(self):
998    super(LMDBReaderTest, self).setUp()
999    # Copy database out because we need the path to be writable to use locks.
1000    path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb")
1001    self.db_path = os.path.join(self.get_temp_dir(), "data.mdb")
1002    shutil.copy(path, self.db_path)
1003
1004  def testReadFromFile(self):
1005    with self.test_session() as sess:
1006      reader = io_ops.LMDBReader(name="test_read_from_file")
1007      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
1008      key, value = reader.read(queue)
1009
1010      queue.enqueue([self.db_path]).run()
1011      queue.close().run()
1012      for i in range(10):
1013        k, v = sess.run([key, value])
1014        self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
1015        self.assertAllEqual(
1016            compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
1017
1018      with self.assertRaisesOpError("is closed and has insufficient elements "
1019                                    "\\(requested 1, current size 0\\)"):
1020        k, v = sess.run([key, value])
1021
1022  def testReadFromSameFile(self):
1023    with self.test_session() as sess:
1024      reader1 = io_ops.LMDBReader(name="test_read_from_same_file1")
1025      reader2 = io_ops.LMDBReader(name="test_read_from_same_file2")
1026      filename_queue = input_lib.string_input_producer(
1027          [self.db_path], num_epochs=None)
1028      key1, value1 = reader1.read(filename_queue)
1029      key2, value2 = reader2.read(filename_queue)
1030
1031      coord = coordinator.Coordinator()
1032      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
1033      for _ in range(3):
1034        for _ in range(10):
1035          k1, v1, k2, v2 = sess.run([key1, value1, key2, value2])
1036          self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2))
1037          self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2))
1038      coord.request_stop()
1039      coord.join(threads)
1040
1041  def testReadFromFolder(self):
1042    with self.test_session() as sess:
1043      reader = io_ops.LMDBReader(name="test_read_from_folder")
1044      queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
1045      key, value = reader.read(queue)
1046
1047      queue.enqueue([self.db_path]).run()
1048      queue.close().run()
1049      for i in range(10):
1050        k, v = sess.run([key, value])
1051        self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i)))
1052        self.assertAllEqual(
1053            compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i))))
1054
1055      with self.assertRaisesOpError("is closed and has insufficient elements "
1056                                    "\\(requested 1, current size 0\\)"):
1057        k, v = sess.run([key, value])
1058
1059  def testReadFromFileRepeatedly(self):
1060    with self.test_session() as sess:
1061      reader = io_ops.LMDBReader(name="test_read_from_file_repeated")
1062      filename_queue = input_lib.string_input_producer(
1063          [self.db_path], num_epochs=None)
1064      key, value = reader.read(filename_queue)
1065
1066      coord = coordinator.Coordinator()
1067      threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
1068      # Iterate over the lmdb 3 times.
1069      for _ in range(3):
1070        # Go over all 10 records each time.
1071        for j in range(10):
1072          k, v = sess.run([key, value])
1073          self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j)))
1074          self.assertAllEqual(
1075              compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j))))
1076      coord.request_stop()
1077      coord.join(threads)
1078
1079
1080if __name__ == "__main__":
1081  test.main()
1082