1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Reader ops from io_ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import gzip 23import os 24import shutil 25import threading 26import zlib 27 28import six 29 30from tensorflow.core.protobuf import config_pb2 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import errors_impl 33from tensorflow.python.lib.io import tf_record 34from tensorflow.python.ops import data_flow_ops 35from tensorflow.python.ops import io_ops 36from tensorflow.python.ops import variables 37from tensorflow.python.platform import test 38from tensorflow.python.training import coordinator 39from tensorflow.python.training import input as input_lib 40from tensorflow.python.training import queue_runner_impl 41from tensorflow.python.util import compat 42 43prefix_path = "tensorflow/core/lib" 44 45# pylint: disable=invalid-name 46TFRecordCompressionType = tf_record.TFRecordCompressionType 47# pylint: enable=invalid-name 48 49# Edgar Allan Poe's 'Eldorado' 50_TEXT = b"""Gaily bedight, 51 A gallant knight, 52 In sunshine and in shadow, 53 Had journeyed long, 54 Singing a song, 55 In search of Eldorado. 56 57 But he grew old 58 This knight so bold 59 And o'er his heart a shadow 60 Fell as he found 61 No spot of ground 62 That looked like Eldorado. 63 64 And, as his strength 65 Failed him at length, 66 He met a pilgrim shadow 67 'Shadow,' said he, 68 'Where can it be 69 This land of Eldorado?' 70 71 'Over the Mountains 72 Of the Moon' 73 Down the Valley of the Shadow, 74 Ride, boldly ride,' 75 The shade replied, 76 'If you seek for Eldorado!' 77 """ 78 79 80class IdentityReaderTest(test.TestCase): 81 82 def _ExpectRead(self, sess, key, value, expected): 83 k, v = sess.run([key, value]) 84 self.assertAllEqual(expected, k) 85 self.assertAllEqual(expected, v) 86 87 def testOneEpoch(self): 88 with self.test_session() as sess: 89 reader = io_ops.IdentityReader("test_reader") 90 work_completed = reader.num_work_units_completed() 91 produced = reader.num_records_produced() 92 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 93 queued_length = queue.size() 94 key, value = reader.read(queue) 95 96 self.assertAllEqual(0, work_completed.eval()) 97 self.assertAllEqual(0, produced.eval()) 98 self.assertAllEqual(0, queued_length.eval()) 99 100 queue.enqueue_many([["A", "B", "C"]]).run() 101 queue.close().run() 102 self.assertAllEqual(3, queued_length.eval()) 103 104 self._ExpectRead(sess, key, value, b"A") 105 self.assertAllEqual(1, produced.eval()) 106 107 self._ExpectRead(sess, key, value, b"B") 108 109 self._ExpectRead(sess, key, value, b"C") 110 self.assertAllEqual(3, produced.eval()) 111 self.assertAllEqual(0, queued_length.eval()) 112 113 with self.assertRaisesOpError("is closed and has insufficient elements " 114 "\\(requested 1, current size 0\\)"): 115 sess.run([key, value]) 116 117 self.assertAllEqual(3, work_completed.eval()) 118 self.assertAllEqual(3, produced.eval()) 119 self.assertAllEqual(0, queued_length.eval()) 120 121 def testMultipleEpochs(self): 122 with self.test_session() as sess: 123 reader = io_ops.IdentityReader("test_reader") 124 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 125 enqueue = queue.enqueue_many([["DD", "EE"]]) 126 key, value = reader.read(queue) 127 128 enqueue.run() 129 self._ExpectRead(sess, key, value, b"DD") 130 self._ExpectRead(sess, key, value, b"EE") 131 enqueue.run() 132 self._ExpectRead(sess, key, value, b"DD") 133 self._ExpectRead(sess, key, value, b"EE") 134 enqueue.run() 135 self._ExpectRead(sess, key, value, b"DD") 136 self._ExpectRead(sess, key, value, b"EE") 137 queue.close().run() 138 with self.assertRaisesOpError("is closed and has insufficient elements " 139 "\\(requested 1, current size 0\\)"): 140 sess.run([key, value]) 141 142 def testSerializeRestore(self): 143 with self.test_session() as sess: 144 reader = io_ops.IdentityReader("test_reader") 145 produced = reader.num_records_produced() 146 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 147 queue.enqueue_many([["X", "Y", "Z"]]).run() 148 key, value = reader.read(queue) 149 150 self._ExpectRead(sess, key, value, b"X") 151 self.assertAllEqual(1, produced.eval()) 152 state = reader.serialize_state().eval() 153 154 self._ExpectRead(sess, key, value, b"Y") 155 self._ExpectRead(sess, key, value, b"Z") 156 self.assertAllEqual(3, produced.eval()) 157 158 queue.enqueue_many([["Y", "Z"]]).run() 159 queue.close().run() 160 reader.restore_state(state).run() 161 self.assertAllEqual(1, produced.eval()) 162 self._ExpectRead(sess, key, value, b"Y") 163 self._ExpectRead(sess, key, value, b"Z") 164 with self.assertRaisesOpError("is closed and has insufficient elements " 165 "\\(requested 1, current size 0\\)"): 166 sess.run([key, value]) 167 self.assertAllEqual(3, produced.eval()) 168 169 self.assertEqual(bytes, type(state)) 170 171 with self.assertRaises(ValueError): 172 reader.restore_state([]) 173 174 with self.assertRaises(ValueError): 175 reader.restore_state([state, state]) 176 177 with self.assertRaisesOpError( 178 "Could not parse state for IdentityReader 'test_reader'"): 179 reader.restore_state(state[1:]).run() 180 181 with self.assertRaisesOpError( 182 "Could not parse state for IdentityReader 'test_reader'"): 183 reader.restore_state(state[:-1]).run() 184 185 with self.assertRaisesOpError( 186 "Could not parse state for IdentityReader 'test_reader'"): 187 reader.restore_state(state + b"ExtraJunk").run() 188 189 with self.assertRaisesOpError( 190 "Could not parse state for IdentityReader 'test_reader'"): 191 reader.restore_state(b"PREFIX" + state).run() 192 193 with self.assertRaisesOpError( 194 "Could not parse state for IdentityReader 'test_reader'"): 195 reader.restore_state(b"BOGUS" + state[5:]).run() 196 197 def testReset(self): 198 with self.test_session() as sess: 199 reader = io_ops.IdentityReader("test_reader") 200 work_completed = reader.num_work_units_completed() 201 produced = reader.num_records_produced() 202 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 203 queued_length = queue.size() 204 key, value = reader.read(queue) 205 206 queue.enqueue_many([["X", "Y", "Z"]]).run() 207 self._ExpectRead(sess, key, value, b"X") 208 self.assertLess(0, queued_length.eval()) 209 self.assertAllEqual(1, produced.eval()) 210 211 self._ExpectRead(sess, key, value, b"Y") 212 self.assertLess(0, work_completed.eval()) 213 self.assertAllEqual(2, produced.eval()) 214 215 reader.reset().run() 216 self.assertAllEqual(0, work_completed.eval()) 217 self.assertAllEqual(0, produced.eval()) 218 self.assertAllEqual(1, queued_length.eval()) 219 self._ExpectRead(sess, key, value, b"Z") 220 221 queue.enqueue_many([["K", "L"]]).run() 222 self._ExpectRead(sess, key, value, b"K") 223 224 225class WholeFileReaderTest(test.TestCase): 226 227 def setUp(self): 228 super(WholeFileReaderTest, self).setUp() 229 self._filenames = [ 230 os.path.join(self.get_temp_dir(), "whole_file.%d.txt" % i) 231 for i in range(3) 232 ] 233 self._content = [b"One\na\nb\n", b"Two\nC\nD", b"Three x, y, z"] 234 for fn, c in zip(self._filenames, self._content): 235 with open(fn, "wb") as h: 236 h.write(c) 237 238 def tearDown(self): 239 for fn in self._filenames: 240 os.remove(fn) 241 super(WholeFileReaderTest, self).tearDown() 242 243 def _ExpectRead(self, sess, key, value, index): 244 k, v = sess.run([key, value]) 245 self.assertAllEqual(compat.as_bytes(self._filenames[index]), k) 246 self.assertAllEqual(self._content[index], v) 247 248 def testOneEpoch(self): 249 with self.test_session() as sess: 250 reader = io_ops.WholeFileReader("test_reader") 251 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 252 queue.enqueue_many([self._filenames]).run() 253 queue.close().run() 254 key, value = reader.read(queue) 255 256 self._ExpectRead(sess, key, value, 0) 257 self._ExpectRead(sess, key, value, 1) 258 self._ExpectRead(sess, key, value, 2) 259 260 with self.assertRaisesOpError("is closed and has insufficient elements " 261 "\\(requested 1, current size 0\\)"): 262 sess.run([key, value]) 263 264 def testInfiniteEpochs(self): 265 with self.test_session() as sess: 266 reader = io_ops.WholeFileReader("test_reader") 267 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 268 enqueue = queue.enqueue_many([self._filenames]) 269 key, value = reader.read(queue) 270 271 enqueue.run() 272 self._ExpectRead(sess, key, value, 0) 273 self._ExpectRead(sess, key, value, 1) 274 enqueue.run() 275 self._ExpectRead(sess, key, value, 2) 276 self._ExpectRead(sess, key, value, 0) 277 self._ExpectRead(sess, key, value, 1) 278 enqueue.run() 279 self._ExpectRead(sess, key, value, 2) 280 self._ExpectRead(sess, key, value, 0) 281 282 283class TextLineReaderTest(test.TestCase): 284 285 def setUp(self): 286 super(TextLineReaderTest, self).setUp() 287 self._num_files = 2 288 self._num_lines = 5 289 290 def _LineText(self, f, l): 291 return compat.as_bytes("%d: %d" % (f, l)) 292 293 def _CreateFiles(self, crlf=False): 294 filenames = [] 295 for i in range(self._num_files): 296 fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i) 297 filenames.append(fn) 298 with open(fn, "wb") as f: 299 for j in range(self._num_lines): 300 f.write(self._LineText(i, j)) 301 # Always include a newline after the record unless it is 302 # at the end of the file, in which case we include it sometimes. 303 if j + 1 != self._num_lines or i == 0: 304 f.write(b"\r\n" if crlf else b"\n") 305 return filenames 306 307 def _testOneEpoch(self, files): 308 with self.test_session() as sess: 309 reader = io_ops.TextLineReader(name="test_reader") 310 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 311 key, value = reader.read(queue) 312 313 queue.enqueue_many([files]).run() 314 queue.close().run() 315 for i in range(self._num_files): 316 for j in range(self._num_lines): 317 k, v = sess.run([key, value]) 318 self.assertAllEqual("%s:%d" % (files[i], j + 1), compat.as_text(k)) 319 self.assertAllEqual(self._LineText(i, j), v) 320 321 with self.assertRaisesOpError("is closed and has insufficient elements " 322 "\\(requested 1, current size 0\\)"): 323 k, v = sess.run([key, value]) 324 325 def testOneEpochLF(self): 326 self._testOneEpoch(self._CreateFiles(crlf=False)) 327 328 def testOneEpochCRLF(self): 329 self._testOneEpoch(self._CreateFiles(crlf=True)) 330 331 def testSkipHeaderLines(self): 332 files = self._CreateFiles() 333 with self.test_session() as sess: 334 reader = io_ops.TextLineReader(skip_header_lines=1, name="test_reader") 335 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 336 key, value = reader.read(queue) 337 338 queue.enqueue_many([files]).run() 339 queue.close().run() 340 for i in range(self._num_files): 341 for j in range(self._num_lines - 1): 342 k, v = sess.run([key, value]) 343 self.assertAllEqual("%s:%d" % (files[i], j + 2), compat.as_text(k)) 344 self.assertAllEqual(self._LineText(i, j + 1), v) 345 346 with self.assertRaisesOpError("is closed and has insufficient elements " 347 "\\(requested 1, current size 0\\)"): 348 k, v = sess.run([key, value]) 349 350 351class FixedLengthRecordReaderTest(test.TestCase): 352 353 def setUp(self): 354 super(FixedLengthRecordReaderTest, self).setUp() 355 self._num_files = 2 356 self._header_bytes = 5 357 self._record_bytes = 3 358 self._footer_bytes = 2 359 360 self._hop_bytes = 2 361 362 def _Record(self, f, r): 363 return compat.as_bytes(str(f * 2 + r) * self._record_bytes) 364 365 def _OverlappedRecord(self, f, r): 366 record_str = "".join([ 367 str(i)[0] 368 for i in range(r * self._hop_bytes, 369 r * self._hop_bytes + self._record_bytes) 370 ]) 371 return compat.as_bytes(record_str) 372 373 # gap_bytes=hop_bytes-record_bytes 374 def _CreateFiles(self, num_records, gap_bytes): 375 filenames = [] 376 for i in range(self._num_files): 377 fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) 378 filenames.append(fn) 379 with open(fn, "wb") as f: 380 f.write(b"H" * self._header_bytes) 381 if num_records > 0: 382 f.write(self._Record(i, 0)) 383 for j in range(1, num_records): 384 if gap_bytes > 0: 385 f.write(b"G" * gap_bytes) 386 f.write(self._Record(i, j)) 387 f.write(b"F" * self._footer_bytes) 388 return filenames 389 390 def _CreateOverlappedRecordFiles(self, num_overlapped_records): 391 filenames = [] 392 for i in range(self._num_files): 393 fn = os.path.join(self.get_temp_dir(), 394 "fixed_length_overlapped_record.%d.txt" % i) 395 filenames.append(fn) 396 with open(fn, "wb") as f: 397 f.write(b"H" * self._header_bytes) 398 if num_overlapped_records > 0: 399 all_records_str = "".join([ 400 str(i)[0] 401 for i in range(self._record_bytes + self._hop_bytes * 402 (num_overlapped_records - 1)) 403 ]) 404 f.write(compat.as_bytes(all_records_str)) 405 f.write(b"F" * self._footer_bytes) 406 return filenames 407 408 # gap_bytes=hop_bytes-record_bytes 409 def _CreateGzipFiles(self, num_records, gap_bytes): 410 filenames = [] 411 for i in range(self._num_files): 412 fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) 413 filenames.append(fn) 414 with gzip.GzipFile(fn, "wb") as f: 415 f.write(b"H" * self._header_bytes) 416 if num_records > 0: 417 f.write(self._Record(i, 0)) 418 for j in range(1, num_records): 419 if gap_bytes > 0: 420 f.write(b"G" * gap_bytes) 421 f.write(self._Record(i, j)) 422 f.write(b"F" * self._footer_bytes) 423 return filenames 424 425 # gap_bytes=hop_bytes-record_bytes 426 def _CreateZlibFiles(self, num_records, gap_bytes): 427 filenames = [] 428 for i in range(self._num_files): 429 fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i) 430 filenames.append(fn) 431 with open(fn + ".tmp", "wb") as f: 432 f.write(b"H" * self._header_bytes) 433 if num_records > 0: 434 f.write(self._Record(i, 0)) 435 for j in range(1, num_records): 436 if gap_bytes > 0: 437 f.write(b"G" * gap_bytes) 438 f.write(self._Record(i, j)) 439 f.write(b"F" * self._footer_bytes) 440 with open(fn + ".tmp", "rb") as f: 441 cdata = zlib.compress(f.read()) 442 with open(fn, "wb") as zf: 443 zf.write(cdata) 444 return filenames 445 446 def _CreateGzipOverlappedRecordFiles(self, num_overlapped_records): 447 filenames = [] 448 for i in range(self._num_files): 449 fn = os.path.join(self.get_temp_dir(), 450 "fixed_length_overlapped_record.%d.txt" % i) 451 filenames.append(fn) 452 with gzip.GzipFile(fn, "wb") as f: 453 f.write(b"H" * self._header_bytes) 454 if num_overlapped_records > 0: 455 all_records_str = "".join([ 456 str(i)[0] 457 for i in range(self._record_bytes + self._hop_bytes * 458 (num_overlapped_records - 1)) 459 ]) 460 f.write(compat.as_bytes(all_records_str)) 461 f.write(b"F" * self._footer_bytes) 462 return filenames 463 464 def _CreateZlibOverlappedRecordFiles(self, num_overlapped_records): 465 filenames = [] 466 for i in range(self._num_files): 467 fn = os.path.join(self.get_temp_dir(), 468 "fixed_length_overlapped_record.%d.txt" % i) 469 filenames.append(fn) 470 with open(fn + ".tmp", "wb") as f: 471 f.write(b"H" * self._header_bytes) 472 if num_overlapped_records > 0: 473 all_records_str = "".join([ 474 str(i)[0] 475 for i in range(self._record_bytes + self._hop_bytes * 476 (num_overlapped_records - 1)) 477 ]) 478 f.write(compat.as_bytes(all_records_str)) 479 f.write(b"F" * self._footer_bytes) 480 with open(fn + ".tmp", "rb") as f: 481 cdata = zlib.compress(f.read()) 482 with open(fn, "wb") as zf: 483 zf.write(cdata) 484 return filenames 485 486 # gap_bytes=hop_bytes-record_bytes 487 def _TestOneEpoch(self, files, num_records, gap_bytes, encoding=None): 488 hop_bytes = 0 if gap_bytes == 0 else self._record_bytes + gap_bytes 489 with self.test_session() as sess: 490 reader = io_ops.FixedLengthRecordReader( 491 header_bytes=self._header_bytes, 492 record_bytes=self._record_bytes, 493 footer_bytes=self._footer_bytes, 494 hop_bytes=hop_bytes, 495 encoding=encoding, 496 name="test_reader") 497 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 498 key, value = reader.read(queue) 499 500 queue.enqueue_many([files]).run() 501 queue.close().run() 502 for i in range(self._num_files): 503 for j in range(num_records): 504 k, v = sess.run([key, value]) 505 self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k)) 506 self.assertAllEqual(self._Record(i, j), v) 507 508 with self.assertRaisesOpError("is closed and has insufficient elements " 509 "\\(requested 1, current size 0\\)"): 510 k, v = sess.run([key, value]) 511 512 def _TestOneEpochWithHopBytes(self, 513 files, 514 num_overlapped_records, 515 encoding=None): 516 with self.test_session() as sess: 517 reader = io_ops.FixedLengthRecordReader( 518 header_bytes=self._header_bytes, 519 record_bytes=self._record_bytes, 520 footer_bytes=self._footer_bytes, 521 hop_bytes=self._hop_bytes, 522 encoding=encoding, 523 name="test_reader") 524 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 525 key, value = reader.read(queue) 526 527 queue.enqueue_many([files]).run() 528 queue.close().run() 529 for i in range(self._num_files): 530 for j in range(num_overlapped_records): 531 k, v = sess.run([key, value]) 532 print(v) 533 self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k)) 534 self.assertAllEqual(self._OverlappedRecord(i, j), v) 535 536 with self.assertRaisesOpError("is closed and has insufficient elements " 537 "\\(requested 1, current size 0\\)"): 538 k, v = sess.run([key, value]) 539 540 def testOneEpoch(self): 541 for num_records in [0, 7]: 542 # gap_bytes=0: hop_bytes=0 543 # gap_bytes=1: hop_bytes=record_bytes+1 544 for gap_bytes in [0, 1]: 545 files = self._CreateFiles(num_records, gap_bytes) 546 self._TestOneEpoch(files, num_records, gap_bytes) 547 548 def testGzipOneEpoch(self): 549 for num_records in [0, 7]: 550 # gap_bytes=0: hop_bytes=0 551 # gap_bytes=1: hop_bytes=record_bytes+1 552 for gap_bytes in [0, 1]: 553 files = self._CreateGzipFiles(num_records, gap_bytes) 554 self._TestOneEpoch(files, num_records, gap_bytes, encoding="GZIP") 555 556 def testZlibOneEpoch(self): 557 for num_records in [0, 7]: 558 # gap_bytes=0: hop_bytes=0 559 # gap_bytes=1: hop_bytes=record_bytes+1 560 for gap_bytes in [0, 1]: 561 files = self._CreateZlibFiles(num_records, gap_bytes) 562 self._TestOneEpoch(files, num_records, gap_bytes, encoding="ZLIB") 563 564 def testOneEpochWithHopBytes(self): 565 for num_overlapped_records in [0, 2]: 566 files = self._CreateOverlappedRecordFiles(num_overlapped_records) 567 self._TestOneEpochWithHopBytes(files, num_overlapped_records) 568 569 def testGzipOneEpochWithHopBytes(self): 570 for num_overlapped_records in [0, 2]: 571 files = self._CreateGzipOverlappedRecordFiles(num_overlapped_records,) 572 self._TestOneEpochWithHopBytes( 573 files, num_overlapped_records, encoding="GZIP") 574 575 def testZlibOneEpochWithHopBytes(self): 576 for num_overlapped_records in [0, 2]: 577 files = self._CreateZlibOverlappedRecordFiles(num_overlapped_records) 578 self._TestOneEpochWithHopBytes( 579 files, num_overlapped_records, encoding="ZLIB") 580 581 582class TFRecordReaderTest(test.TestCase): 583 584 def setUp(self): 585 super(TFRecordReaderTest, self).setUp() 586 self._num_files = 2 587 self._num_records = 7 588 589 def _Record(self, f, r): 590 return compat.as_bytes("Record %d of file %d" % (r, f)) 591 592 def _CreateFiles(self): 593 filenames = [] 594 for i in range(self._num_files): 595 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 596 filenames.append(fn) 597 writer = tf_record.TFRecordWriter(fn) 598 for j in range(self._num_records): 599 writer.write(self._Record(i, j)) 600 return filenames 601 602 def testOneEpoch(self): 603 files = self._CreateFiles() 604 with self.test_session() as sess: 605 reader = io_ops.TFRecordReader(name="test_reader") 606 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 607 key, value = reader.read(queue) 608 609 queue.enqueue_many([files]).run() 610 queue.close().run() 611 for i in range(self._num_files): 612 for j in range(self._num_records): 613 k, v = sess.run([key, value]) 614 self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) 615 self.assertAllEqual(self._Record(i, j), v) 616 617 with self.assertRaisesOpError("is closed and has insufficient elements " 618 "\\(requested 1, current size 0\\)"): 619 k, v = sess.run([key, value]) 620 621 def testReadUpTo(self): 622 files = self._CreateFiles() 623 with self.test_session() as sess: 624 reader = io_ops.TFRecordReader(name="test_reader") 625 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 626 batch_size = 3 627 key, value = reader.read_up_to(queue, batch_size) 628 629 queue.enqueue_many([files]).run() 630 queue.close().run() 631 num_k = 0 632 num_v = 0 633 634 while True: 635 try: 636 k, v = sess.run([key, value]) 637 # Test reading *up to* batch_size records 638 self.assertLessEqual(len(k), batch_size) 639 self.assertLessEqual(len(v), batch_size) 640 num_k += len(k) 641 num_v += len(v) 642 except errors_impl.OutOfRangeError: 643 break 644 645 # Test that we have read everything 646 self.assertEqual(self._num_files * self._num_records, num_k) 647 self.assertEqual(self._num_files * self._num_records, num_v) 648 649 def testReadZlibFiles(self): 650 files = self._CreateFiles() 651 zlib_files = [] 652 for i, fn in enumerate(files): 653 with open(fn, "rb") as f: 654 cdata = zlib.compress(f.read()) 655 656 zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) 657 with open(zfn, "wb") as f: 658 f.write(cdata) 659 zlib_files.append(zfn) 660 661 with self.test_session() as sess: 662 options = tf_record.TFRecordOptions(TFRecordCompressionType.ZLIB) 663 reader = io_ops.TFRecordReader(name="test_reader", options=options) 664 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 665 key, value = reader.read(queue) 666 667 queue.enqueue_many([zlib_files]).run() 668 queue.close().run() 669 for i in range(self._num_files): 670 for j in range(self._num_records): 671 k, v = sess.run([key, value]) 672 self.assertTrue(compat.as_text(k).startswith("%s:" % zlib_files[i])) 673 self.assertAllEqual(self._Record(i, j), v) 674 675 def testReadGzipFiles(self): 676 files = self._CreateFiles() 677 gzip_files = [] 678 for i, fn in enumerate(files): 679 with open(fn, "rb") as f: 680 cdata = f.read() 681 682 zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) 683 with gzip.GzipFile(zfn, "wb") as f: 684 f.write(cdata) 685 gzip_files.append(zfn) 686 687 with self.test_session() as sess: 688 options = tf_record.TFRecordOptions(TFRecordCompressionType.GZIP) 689 reader = io_ops.TFRecordReader(name="test_reader", options=options) 690 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 691 key, value = reader.read(queue) 692 693 queue.enqueue_many([gzip_files]).run() 694 queue.close().run() 695 for i in range(self._num_files): 696 for j in range(self._num_records): 697 k, v = sess.run([key, value]) 698 self.assertTrue(compat.as_text(k).startswith("%s:" % gzip_files[i])) 699 self.assertAllEqual(self._Record(i, j), v) 700 701 702class TFRecordWriterZlibTest(test.TestCase): 703 704 def setUp(self): 705 super(TFRecordWriterZlibTest, self).setUp() 706 self._num_files = 2 707 self._num_records = 7 708 709 def _Record(self, f, r): 710 return compat.as_bytes("Record %d of file %d" % (r, f)) 711 712 def _CreateFiles(self): 713 filenames = [] 714 for i in range(self._num_files): 715 fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i) 716 filenames.append(fn) 717 options = tf_record.TFRecordOptions( 718 compression_type=TFRecordCompressionType.ZLIB) 719 writer = tf_record.TFRecordWriter(fn, options=options) 720 for j in range(self._num_records): 721 writer.write(self._Record(i, j)) 722 writer.close() 723 del writer 724 725 return filenames 726 727 def _WriteRecordsToFile(self, records, name="tf_record"): 728 fn = os.path.join(self.get_temp_dir(), name) 729 writer = tf_record.TFRecordWriter(fn, options=None) 730 for r in records: 731 writer.write(r) 732 writer.close() 733 del writer 734 return fn 735 736 def _ZlibCompressFile(self, infile, name="tfrecord.z"): 737 # zlib compress the file and write compressed contents to file. 738 with open(infile, "rb") as f: 739 cdata = zlib.compress(f.read()) 740 741 zfn = os.path.join(self.get_temp_dir(), name) 742 with open(zfn, "wb") as f: 743 f.write(cdata) 744 return zfn 745 746 def testOneEpoch(self): 747 files = self._CreateFiles() 748 with self.test_session() as sess: 749 options = tf_record.TFRecordOptions( 750 compression_type=TFRecordCompressionType.ZLIB) 751 reader = io_ops.TFRecordReader(name="test_reader", options=options) 752 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 753 key, value = reader.read(queue) 754 755 queue.enqueue_many([files]).run() 756 queue.close().run() 757 for i in range(self._num_files): 758 for j in range(self._num_records): 759 k, v = sess.run([key, value]) 760 self.assertTrue(compat.as_text(k).startswith("%s:" % files[i])) 761 self.assertAllEqual(self._Record(i, j), v) 762 763 with self.assertRaisesOpError("is closed and has insufficient elements " 764 "\\(requested 1, current size 0\\)"): 765 k, v = sess.run([key, value]) 766 767 def testZLibFlushRecord(self): 768 fn = self._WriteRecordsToFile([b"small record"], "small_record") 769 with open(fn, "rb") as h: 770 buff = h.read() 771 772 # creating more blocks and trailing blocks shouldn't break reads 773 compressor = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS) 774 775 output = b"" 776 for c in buff: 777 if isinstance(c, int): 778 c = six.int2byte(c) 779 output += compressor.compress(c) 780 output += compressor.flush(zlib.Z_FULL_FLUSH) 781 782 output += compressor.flush(zlib.Z_FULL_FLUSH) 783 output += compressor.flush(zlib.Z_FULL_FLUSH) 784 output += compressor.flush(zlib.Z_FINISH) 785 786 # overwrite the original file with the compressed data 787 with open(fn, "wb") as h: 788 h.write(output) 789 790 with self.test_session() as sess: 791 options = tf_record.TFRecordOptions( 792 compression_type=TFRecordCompressionType.ZLIB) 793 reader = io_ops.TFRecordReader(name="test_reader", options=options) 794 queue = data_flow_ops.FIFOQueue(1, [dtypes.string], shapes=()) 795 key, value = reader.read(queue) 796 queue.enqueue(fn).run() 797 queue.close().run() 798 k, v = sess.run([key, value]) 799 self.assertTrue(compat.as_text(k).startswith("%s:" % fn)) 800 self.assertAllEqual(b"small record", v) 801 802 def testZlibReadWrite(self): 803 """Verify that files produced are zlib compatible.""" 804 original = [b"foo", b"bar"] 805 fn = self._WriteRecordsToFile(original, "zlib_read_write.tfrecord") 806 zfn = self._ZlibCompressFile(fn, "zlib_read_write.tfrecord.z") 807 808 # read the compressed contents and verify. 809 actual = [] 810 for r in tf_record.tf_record_iterator( 811 zfn, 812 options=tf_record.TFRecordOptions( 813 tf_record.TFRecordCompressionType.ZLIB)): 814 actual.append(r) 815 self.assertEqual(actual, original) 816 817 def testZlibReadWriteLarge(self): 818 """Verify that writing large contents also works.""" 819 820 # Make it large (about 5MB) 821 original = [_TEXT * 10240] 822 fn = self._WriteRecordsToFile(original, "zlib_read_write_large.tfrecord") 823 zfn = self._ZlibCompressFile(fn, "zlib_read_write_large.tfrecord.z") 824 825 # read the compressed contents and verify. 826 actual = [] 827 for r in tf_record.tf_record_iterator( 828 zfn, 829 options=tf_record.TFRecordOptions( 830 tf_record.TFRecordCompressionType.ZLIB)): 831 actual.append(r) 832 self.assertEqual(actual, original) 833 834 def testGzipReadWrite(self): 835 """Verify that files produced are gzip compatible.""" 836 original = [b"foo", b"bar"] 837 fn = self._WriteRecordsToFile(original, "gzip_read_write.tfrecord") 838 839 # gzip compress the file and write compressed contents to file. 840 with open(fn, "rb") as f: 841 cdata = f.read() 842 gzfn = os.path.join(self.get_temp_dir(), "tf_record.gz") 843 with gzip.GzipFile(gzfn, "wb") as f: 844 f.write(cdata) 845 846 actual = [] 847 for r in tf_record.tf_record_iterator( 848 gzfn, options=tf_record.TFRecordOptions(TFRecordCompressionType.GZIP)): 849 actual.append(r) 850 self.assertEqual(actual, original) 851 852 853class TFRecordIteratorTest(test.TestCase): 854 855 def setUp(self): 856 super(TFRecordIteratorTest, self).setUp() 857 self._num_records = 7 858 859 def _Record(self, r): 860 return compat.as_bytes("Record %d" % r) 861 862 def _WriteCompressedRecordsToFile( 863 self, 864 records, 865 name="tfrecord.z", 866 compression_type=tf_record.TFRecordCompressionType.ZLIB): 867 fn = os.path.join(self.get_temp_dir(), name) 868 options = tf_record.TFRecordOptions(compression_type=compression_type) 869 writer = tf_record.TFRecordWriter(fn, options=options) 870 for r in records: 871 writer.write(r) 872 writer.close() 873 del writer 874 return fn 875 876 def _ZlibDecompressFile(self, infile, name="tfrecord", wbits=zlib.MAX_WBITS): 877 with open(infile, "rb") as f: 878 cdata = zlib.decompress(f.read(), wbits) 879 zfn = os.path.join(self.get_temp_dir(), name) 880 with open(zfn, "wb") as f: 881 f.write(cdata) 882 return zfn 883 884 def testIterator(self): 885 fn = self._WriteCompressedRecordsToFile( 886 [self._Record(i) for i in range(self._num_records)], 887 "compressed_records") 888 options = tf_record.TFRecordOptions( 889 compression_type=TFRecordCompressionType.ZLIB) 890 reader = tf_record.tf_record_iterator(fn, options) 891 for i in range(self._num_records): 892 record = next(reader) 893 self.assertAllEqual(self._Record(i), record) 894 with self.assertRaises(StopIteration): 895 record = next(reader) 896 897 def testWriteZlibRead(self): 898 """Verify compression with TFRecordWriter is zlib library compatible.""" 899 original = [b"foo", b"bar"] 900 fn = self._WriteCompressedRecordsToFile(original, 901 "write_zlib_read.tfrecord.z") 902 zfn = self._ZlibDecompressFile(fn, "write_zlib_read.tfrecord") 903 actual = [] 904 for r in tf_record.tf_record_iterator(zfn): 905 actual.append(r) 906 self.assertEqual(actual, original) 907 908 def testWriteZlibReadLarge(self): 909 """Verify compression for large records is zlib library compatible.""" 910 # Make it large (about 5MB) 911 original = [_TEXT * 10240] 912 fn = self._WriteCompressedRecordsToFile(original, 913 "write_zlib_read_large.tfrecord.z") 914 zfn = self._ZlibDecompressFile(fn, "write_zlib_read_large.tf_record") 915 actual = [] 916 for r in tf_record.tf_record_iterator(zfn): 917 actual.append(r) 918 self.assertEqual(actual, original) 919 920 def testWriteGzipRead(self): 921 original = [b"foo", b"bar"] 922 fn = self._WriteCompressedRecordsToFile( 923 original, 924 "write_gzip_read.tfrecord.gz", 925 compression_type=TFRecordCompressionType.GZIP) 926 927 with gzip.GzipFile(fn, "rb") as f: 928 cdata = f.read() 929 zfn = os.path.join(self.get_temp_dir(), "tf_record") 930 with open(zfn, "wb") as f: 931 f.write(cdata) 932 933 actual = [] 934 for r in tf_record.tf_record_iterator(zfn): 935 actual.append(r) 936 self.assertEqual(actual, original) 937 938 def testBadFile(self): 939 """Verify that tf_record_iterator throws an exception on bad TFRecords.""" 940 fn = os.path.join(self.get_temp_dir(), "bad_file") 941 with tf_record.TFRecordWriter(fn) as writer: 942 writer.write(b"123") 943 fn_truncated = os.path.join(self.get_temp_dir(), "bad_file_truncated") 944 with open(fn, "rb") as f: 945 with open(fn_truncated, "wb") as f2: 946 # DataLossError requires that we've written the header, so this must 947 # be at least 12 bytes. 948 f2.write(f.read(14)) 949 with self.assertRaises(errors_impl.DataLossError): 950 for _ in tf_record.tf_record_iterator(fn_truncated): 951 pass 952 953 954class AsyncReaderTest(test.TestCase): 955 956 def testNoDeadlockFromQueue(self): 957 """Tests that reading does not block main execution threads.""" 958 config = config_pb2.ConfigProto( 959 inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) 960 with self.test_session(config=config) as sess: 961 thread_data_t = collections.namedtuple("thread_data_t", 962 ["thread", "queue", "output"]) 963 thread_data = [] 964 965 # Create different readers, each with its own queue. 966 for i in range(3): 967 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 968 reader = io_ops.TextLineReader() 969 _, line = reader.read(queue) 970 output = [] 971 t = threading.Thread( 972 target=AsyncReaderTest._RunSessionAndSave, 973 args=(sess, [line], output)) 974 thread_data.append(thread_data_t(t, queue, output)) 975 976 # Start all readers. They are all blocked waiting for queue entries. 977 sess.run(variables.global_variables_initializer()) 978 for d in thread_data: 979 d.thread.start() 980 981 # Unblock the readers. 982 for i, d in enumerate(reversed(thread_data)): 983 fname = os.path.join(self.get_temp_dir(), "deadlock.%s.txt" % i) 984 with open(fname, "wb") as f: 985 f.write(("file-%s" % i).encode()) 986 d.queue.enqueue_many([[fname]]).run() 987 d.thread.join() 988 self.assertEqual([[("file-%s" % i).encode()]], d.output) 989 990 @staticmethod 991 def _RunSessionAndSave(sess, args, output): 992 output.append(sess.run(args)) 993 994 995class LMDBReaderTest(test.TestCase): 996 997 def setUp(self): 998 super(LMDBReaderTest, self).setUp() 999 # Copy database out because we need the path to be writable to use locks. 1000 path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb") 1001 self.db_path = os.path.join(self.get_temp_dir(), "data.mdb") 1002 shutil.copy(path, self.db_path) 1003 1004 def testReadFromFile(self): 1005 with self.test_session() as sess: 1006 reader = io_ops.LMDBReader(name="test_read_from_file") 1007 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 1008 key, value = reader.read(queue) 1009 1010 queue.enqueue([self.db_path]).run() 1011 queue.close().run() 1012 for i in range(10): 1013 k, v = sess.run([key, value]) 1014 self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i))) 1015 self.assertAllEqual( 1016 compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i)))) 1017 1018 with self.assertRaisesOpError("is closed and has insufficient elements " 1019 "\\(requested 1, current size 0\\)"): 1020 k, v = sess.run([key, value]) 1021 1022 def testReadFromSameFile(self): 1023 with self.test_session() as sess: 1024 reader1 = io_ops.LMDBReader(name="test_read_from_same_file1") 1025 reader2 = io_ops.LMDBReader(name="test_read_from_same_file2") 1026 filename_queue = input_lib.string_input_producer( 1027 [self.db_path], num_epochs=None) 1028 key1, value1 = reader1.read(filename_queue) 1029 key2, value2 = reader2.read(filename_queue) 1030 1031 coord = coordinator.Coordinator() 1032 threads = queue_runner_impl.start_queue_runners(sess, coord=coord) 1033 for _ in range(3): 1034 for _ in range(10): 1035 k1, v1, k2, v2 = sess.run([key1, value1, key2, value2]) 1036 self.assertAllEqual(compat.as_bytes(k1), compat.as_bytes(k2)) 1037 self.assertAllEqual(compat.as_bytes(v1), compat.as_bytes(v2)) 1038 coord.request_stop() 1039 coord.join(threads) 1040 1041 def testReadFromFolder(self): 1042 with self.test_session() as sess: 1043 reader = io_ops.LMDBReader(name="test_read_from_folder") 1044 queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) 1045 key, value = reader.read(queue) 1046 1047 queue.enqueue([self.db_path]).run() 1048 queue.close().run() 1049 for i in range(10): 1050 k, v = sess.run([key, value]) 1051 self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i))) 1052 self.assertAllEqual( 1053 compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i)))) 1054 1055 with self.assertRaisesOpError("is closed and has insufficient elements " 1056 "\\(requested 1, current size 0\\)"): 1057 k, v = sess.run([key, value]) 1058 1059 def testReadFromFileRepeatedly(self): 1060 with self.test_session() as sess: 1061 reader = io_ops.LMDBReader(name="test_read_from_file_repeated") 1062 filename_queue = input_lib.string_input_producer( 1063 [self.db_path], num_epochs=None) 1064 key, value = reader.read(filename_queue) 1065 1066 coord = coordinator.Coordinator() 1067 threads = queue_runner_impl.start_queue_runners(sess, coord=coord) 1068 # Iterate over the lmdb 3 times. 1069 for _ in range(3): 1070 # Go over all 10 records each time. 1071 for j in range(10): 1072 k, v = sess.run([key, value]) 1073 self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(j))) 1074 self.assertAllEqual( 1075 compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + j)))) 1076 coord.request_stop() 1077 coord.join(threads) 1078 1079 1080if __name__ == "__main__": 1081 test.main() 1082