check_ops_test.py revision 8043a27ed77f59bb68409070f2bfa01df0e04b89
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.check_ops.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import check_ops 29from tensorflow.python.platform import test 30 31 32class AssertProperIterableTest(test.TestCase): 33 34 def test_single_tensor_raises(self): 35 tensor = constant_op.constant(1) 36 with self.assertRaisesRegexp(TypeError, "proper"): 37 check_ops.assert_proper_iterable(tensor) 38 39 def test_single_sparse_tensor_raises(self): 40 ten = sparse_tensor.SparseTensor( 41 indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 42 with self.assertRaisesRegexp(TypeError, "proper"): 43 check_ops.assert_proper_iterable(ten) 44 45 def test_single_ndarray_raises(self): 46 array = np.array([1, 2, 3]) 47 with self.assertRaisesRegexp(TypeError, "proper"): 48 check_ops.assert_proper_iterable(array) 49 50 def test_single_string_raises(self): 51 mystr = "hello" 52 with self.assertRaisesRegexp(TypeError, "proper"): 53 check_ops.assert_proper_iterable(mystr) 54 55 def test_non_iterable_object_raises(self): 56 non_iterable = 1234 57 with self.assertRaisesRegexp(TypeError, "to be iterable"): 58 check_ops.assert_proper_iterable(non_iterable) 59 60 def test_list_does_not_raise(self): 61 list_of_stuff = [ 62 constant_op.constant([11, 22]), constant_op.constant([1, 2]) 63 ] 64 check_ops.assert_proper_iterable(list_of_stuff) 65 66 def test_generator_does_not_raise(self): 67 generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant( 68 [1, 2])) 69 check_ops.assert_proper_iterable(generator_of_stuff) 70 71 72class AssertEqualTest(test.TestCase): 73 74 def test_doesnt_raise_when_equal(self): 75 with self.test_session(): 76 small = constant_op.constant([1, 2], name="small") 77 with ops.control_dependencies([check_ops.assert_equal(small, small)]): 78 out = array_ops.identity(small) 79 out.eval() 80 81 def test_raises_when_greater(self): 82 with self.test_session(): 83 small = constant_op.constant([1, 2], name="small") 84 big = constant_op.constant([3, 4], name="big") 85 with ops.control_dependencies( 86 [check_ops.assert_equal( 87 big, small, message="fail")]): 88 out = array_ops.identity(small) 89 with self.assertRaisesOpError("fail.*big.*small"): 90 out.eval() 91 92 def test_raises_when_less(self): 93 with self.test_session(): 94 small = constant_op.constant([3, 1], name="small") 95 big = constant_op.constant([4, 2], name="big") 96 with ops.control_dependencies([check_ops.assert_equal(small, big)]): 97 out = array_ops.identity(small) 98 with self.assertRaisesOpError("small.*big"): 99 out.eval() 100 101 def test_doesnt_raise_when_equal_and_broadcastable_shapes(self): 102 with self.test_session(): 103 small = constant_op.constant([1, 2], name="small") 104 small_2 = constant_op.constant([1, 2], name="small_2") 105 with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): 106 out = array_ops.identity(small) 107 out.eval() 108 109 def test_raises_when_equal_but_non_broadcastable_shapes(self): 110 with self.test_session(): 111 small = constant_op.constant([1, 1, 1], name="small") 112 small_2 = constant_op.constant([1, 1], name="small_2") 113 with self.assertRaisesRegexp(ValueError, "must be"): 114 with ops.control_dependencies([check_ops.assert_equal(small, small_2)]): 115 out = array_ops.identity(small) 116 out.eval() 117 118 def test_doesnt_raise_when_both_empty(self): 119 with self.test_session(): 120 larry = constant_op.constant([]) 121 curly = constant_op.constant([]) 122 with ops.control_dependencies([check_ops.assert_equal(larry, curly)]): 123 out = array_ops.identity(larry) 124 out.eval() 125 126 127class AssertNoneEqualTest(test.TestCase): 128 129 def test_doesnt_raise_when_not_equal(self): 130 with self.test_session(): 131 small = constant_op.constant([1, 2], name="small") 132 big = constant_op.constant([10, 20], name="small") 133 with ops.control_dependencies( 134 [check_ops.assert_none_equal(big, small)]): 135 out = array_ops.identity(small) 136 out.eval() 137 138 def test_raises_when_equal(self): 139 with self.test_session(): 140 small = constant_op.constant([3, 1], name="small") 141 with ops.control_dependencies( 142 [check_ops.assert_none_equal(small, small)]): 143 out = array_ops.identity(small) 144 with self.assertRaisesOpError("x != y did not hold"): 145 out.eval() 146 147 def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self): 148 with self.test_session(): 149 small = constant_op.constant([1, 2], name="small") 150 big = constant_op.constant([3], name="big") 151 with ops.control_dependencies( 152 [check_ops.assert_none_equal(small, big)]): 153 out = array_ops.identity(small) 154 out.eval() 155 156 def test_raises_when_not_equal_but_non_broadcastable_shapes(self): 157 with self.test_session(): 158 small = constant_op.constant([1, 1, 1], name="small") 159 big = constant_op.constant([10, 10], name="big") 160 with self.assertRaisesRegexp(ValueError, "must be"): 161 with ops.control_dependencies( 162 [check_ops.assert_none_equal(small, big)]): 163 out = array_ops.identity(small) 164 out.eval() 165 166 def test_doesnt_raise_when_both_empty(self): 167 with self.test_session(): 168 larry = constant_op.constant([]) 169 curly = constant_op.constant([]) 170 with ops.control_dependencies( 171 [check_ops.assert_none_equal(larry, curly)]): 172 out = array_ops.identity(larry) 173 out.eval() 174 175 176class AssertLessTest(test.TestCase): 177 178 def test_raises_when_equal(self): 179 with self.test_session(): 180 small = constant_op.constant([1, 2], name="small") 181 with ops.control_dependencies( 182 [check_ops.assert_less( 183 small, small, message="fail")]): 184 out = array_ops.identity(small) 185 with self.assertRaisesOpError("fail.*small.*small"): 186 out.eval() 187 188 def test_raises_when_greater(self): 189 with self.test_session(): 190 small = constant_op.constant([1, 2], name="small") 191 big = constant_op.constant([3, 4], name="big") 192 with ops.control_dependencies([check_ops.assert_less(big, small)]): 193 out = array_ops.identity(small) 194 with self.assertRaisesOpError("big.*small"): 195 out.eval() 196 197 def test_doesnt_raise_when_less(self): 198 with self.test_session(): 199 small = constant_op.constant([3, 1], name="small") 200 big = constant_op.constant([4, 2], name="big") 201 with ops.control_dependencies([check_ops.assert_less(small, big)]): 202 out = array_ops.identity(small) 203 out.eval() 204 205 def test_doesnt_raise_when_less_and_broadcastable_shapes(self): 206 with self.test_session(): 207 small = constant_op.constant([1], name="small") 208 big = constant_op.constant([3, 2], name="big") 209 with ops.control_dependencies([check_ops.assert_less(small, big)]): 210 out = array_ops.identity(small) 211 out.eval() 212 213 def test_raises_when_less_but_non_broadcastable_shapes(self): 214 with self.test_session(): 215 small = constant_op.constant([1, 1, 1], name="small") 216 big = constant_op.constant([3, 2], name="big") 217 with self.assertRaisesRegexp(ValueError, "must be"): 218 with ops.control_dependencies([check_ops.assert_less(small, big)]): 219 out = array_ops.identity(small) 220 out.eval() 221 222 def test_doesnt_raise_when_both_empty(self): 223 with self.test_session(): 224 larry = constant_op.constant([]) 225 curly = constant_op.constant([]) 226 with ops.control_dependencies([check_ops.assert_less(larry, curly)]): 227 out = array_ops.identity(larry) 228 out.eval() 229 230 231class AssertLessEqualTest(test.TestCase): 232 233 def test_doesnt_raise_when_equal(self): 234 with self.test_session(): 235 small = constant_op.constant([1, 2], name="small") 236 with ops.control_dependencies( 237 [check_ops.assert_less_equal(small, small)]): 238 out = array_ops.identity(small) 239 out.eval() 240 241 def test_raises_when_greater(self): 242 with self.test_session(): 243 small = constant_op.constant([1, 2], name="small") 244 big = constant_op.constant([3, 4], name="big") 245 with ops.control_dependencies( 246 [check_ops.assert_less_equal( 247 big, small, message="fail")]): 248 out = array_ops.identity(small) 249 with self.assertRaisesOpError("fail.*big.*small"): 250 out.eval() 251 252 def test_doesnt_raise_when_less_equal(self): 253 with self.test_session(): 254 small = constant_op.constant([1, 2], name="small") 255 big = constant_op.constant([3, 2], name="big") 256 with ops.control_dependencies([check_ops.assert_less_equal(small, big)]): 257 out = array_ops.identity(small) 258 out.eval() 259 260 def test_doesnt_raise_when_less_equal_and_broadcastable_shapes(self): 261 with self.test_session(): 262 small = constant_op.constant([1], name="small") 263 big = constant_op.constant([3, 1], name="big") 264 with ops.control_dependencies([check_ops.assert_less_equal(small, big)]): 265 out = array_ops.identity(small) 266 out.eval() 267 268 def test_raises_when_less_equal_but_non_broadcastable_shapes(self): 269 with self.test_session(): 270 small = constant_op.constant([1, 1, 1], name="small") 271 big = constant_op.constant([3, 1], name="big") 272 with self.assertRaisesRegexp(ValueError, "must be"): 273 with ops.control_dependencies( 274 [check_ops.assert_less_equal(small, big)]): 275 out = array_ops.identity(small) 276 out.eval() 277 278 def test_doesnt_raise_when_both_empty(self): 279 with self.test_session(): 280 larry = constant_op.constant([]) 281 curly = constant_op.constant([]) 282 with ops.control_dependencies( 283 [check_ops.assert_less_equal(larry, curly)]): 284 out = array_ops.identity(larry) 285 out.eval() 286 287 288class AssertGreaterTest(test.TestCase): 289 290 def test_raises_when_equal(self): 291 with self.test_session(): 292 small = constant_op.constant([1, 2], name="small") 293 with ops.control_dependencies( 294 [check_ops.assert_greater( 295 small, small, message="fail")]): 296 out = array_ops.identity(small) 297 with self.assertRaisesOpError("fail.*small.*small"): 298 out.eval() 299 300 def test_raises_when_less(self): 301 with self.test_session(): 302 small = constant_op.constant([1, 2], name="small") 303 big = constant_op.constant([3, 4], name="big") 304 with ops.control_dependencies([check_ops.assert_greater(small, big)]): 305 out = array_ops.identity(big) 306 with self.assertRaisesOpError("small.*big"): 307 out.eval() 308 309 def test_doesnt_raise_when_greater(self): 310 with self.test_session(): 311 small = constant_op.constant([3, 1], name="small") 312 big = constant_op.constant([4, 2], name="big") 313 with ops.control_dependencies([check_ops.assert_greater(big, small)]): 314 out = array_ops.identity(small) 315 out.eval() 316 317 def test_doesnt_raise_when_greater_and_broadcastable_shapes(self): 318 with self.test_session(): 319 small = constant_op.constant([1], name="small") 320 big = constant_op.constant([3, 2], name="big") 321 with ops.control_dependencies([check_ops.assert_greater(big, small)]): 322 out = array_ops.identity(small) 323 out.eval() 324 325 def test_raises_when_greater_but_non_broadcastable_shapes(self): 326 with self.test_session(): 327 small = constant_op.constant([1, 1, 1], name="small") 328 big = constant_op.constant([3, 2], name="big") 329 with self.assertRaisesRegexp(ValueError, "must be"): 330 with ops.control_dependencies([check_ops.assert_greater(big, small)]): 331 out = array_ops.identity(small) 332 out.eval() 333 334 def test_doesnt_raise_when_both_empty(self): 335 with self.test_session(): 336 larry = constant_op.constant([]) 337 curly = constant_op.constant([]) 338 with ops.control_dependencies([check_ops.assert_greater(larry, curly)]): 339 out = array_ops.identity(larry) 340 out.eval() 341 342 343class AssertGreaterEqualTest(test.TestCase): 344 345 def test_doesnt_raise_when_equal(self): 346 with self.test_session(): 347 small = constant_op.constant([1, 2], name="small") 348 with ops.control_dependencies( 349 [check_ops.assert_greater_equal(small, small)]): 350 out = array_ops.identity(small) 351 out.eval() 352 353 def test_raises_when_less(self): 354 with self.test_session(): 355 small = constant_op.constant([1, 2], name="small") 356 big = constant_op.constant([3, 4], name="big") 357 with ops.control_dependencies( 358 [check_ops.assert_greater_equal( 359 small, big, message="fail")]): 360 out = array_ops.identity(small) 361 with self.assertRaisesOpError("fail.*small.*big"): 362 out.eval() 363 364 def test_doesnt_raise_when_greater_equal(self): 365 with self.test_session(): 366 small = constant_op.constant([1, 2], name="small") 367 big = constant_op.constant([3, 2], name="big") 368 with ops.control_dependencies( 369 [check_ops.assert_greater_equal(big, small)]): 370 out = array_ops.identity(small) 371 out.eval() 372 373 def test_doesnt_raise_when_greater_equal_and_broadcastable_shapes(self): 374 with self.test_session(): 375 small = constant_op.constant([1], name="small") 376 big = constant_op.constant([3, 1], name="big") 377 with ops.control_dependencies( 378 [check_ops.assert_greater_equal(big, small)]): 379 out = array_ops.identity(small) 380 out.eval() 381 382 def test_raises_when_less_equal_but_non_broadcastable_shapes(self): 383 with self.test_session(): 384 small = constant_op.constant([1, 1, 1], name="big") 385 big = constant_op.constant([3, 1], name="small") 386 with self.assertRaisesRegexp(ValueError, "Dimensions must be equal"): 387 with ops.control_dependencies( 388 [check_ops.assert_greater_equal(big, small)]): 389 out = array_ops.identity(small) 390 out.eval() 391 392 def test_doesnt_raise_when_both_empty(self): 393 with self.test_session(): 394 larry = constant_op.constant([]) 395 curly = constant_op.constant([]) 396 with ops.control_dependencies( 397 [check_ops.assert_greater_equal(larry, curly)]): 398 out = array_ops.identity(larry) 399 out.eval() 400 401 402class AssertNegativeTest(test.TestCase): 403 404 def test_doesnt_raise_when_negative(self): 405 with self.test_session(): 406 frank = constant_op.constant([-1, -2], name="frank") 407 with ops.control_dependencies([check_ops.assert_negative(frank)]): 408 out = array_ops.identity(frank) 409 out.eval() 410 411 def test_raises_when_positive(self): 412 with self.test_session(): 413 doug = constant_op.constant([1, 2], name="doug") 414 with ops.control_dependencies( 415 [check_ops.assert_negative( 416 doug, message="fail")]): 417 out = array_ops.identity(doug) 418 with self.assertRaisesOpError("fail.*doug"): 419 out.eval() 420 421 def test_raises_when_zero(self): 422 with self.test_session(): 423 claire = constant_op.constant([0], name="claire") 424 with ops.control_dependencies([check_ops.assert_negative(claire)]): 425 out = array_ops.identity(claire) 426 with self.assertRaisesOpError("claire"): 427 out.eval() 428 429 def test_empty_tensor_doesnt_raise(self): 430 # A tensor is negative when it satisfies: 431 # For every element x_i in x, x_i < 0 432 # and an empty tensor has no elements, so this is trivially satisfied. 433 # This is standard set theory. 434 with self.test_session(): 435 empty = constant_op.constant([], name="empty") 436 with ops.control_dependencies([check_ops.assert_negative(empty)]): 437 out = array_ops.identity(empty) 438 out.eval() 439 440 441class AssertPositiveTest(test.TestCase): 442 443 def test_raises_when_negative(self): 444 with self.test_session(): 445 freddie = constant_op.constant([-1, -2], name="freddie") 446 with ops.control_dependencies( 447 [check_ops.assert_positive( 448 freddie, message="fail")]): 449 out = array_ops.identity(freddie) 450 with self.assertRaisesOpError("fail.*freddie"): 451 out.eval() 452 453 def test_doesnt_raise_when_positive(self): 454 with self.test_session(): 455 remmy = constant_op.constant([1, 2], name="remmy") 456 with ops.control_dependencies([check_ops.assert_positive(remmy)]): 457 out = array_ops.identity(remmy) 458 out.eval() 459 460 def test_raises_when_zero(self): 461 with self.test_session(): 462 meechum = constant_op.constant([0], name="meechum") 463 with ops.control_dependencies([check_ops.assert_positive(meechum)]): 464 out = array_ops.identity(meechum) 465 with self.assertRaisesOpError("meechum"): 466 out.eval() 467 468 def test_empty_tensor_doesnt_raise(self): 469 # A tensor is positive when it satisfies: 470 # For every element x_i in x, x_i > 0 471 # and an empty tensor has no elements, so this is trivially satisfied. 472 # This is standard set theory. 473 with self.test_session(): 474 empty = constant_op.constant([], name="empty") 475 with ops.control_dependencies([check_ops.assert_positive(empty)]): 476 out = array_ops.identity(empty) 477 out.eval() 478 479 480class AssertRankTest(test.TestCase): 481 482 def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): 483 with self.test_session(): 484 tensor = constant_op.constant(1, name="my_tensor") 485 desired_rank = 1 486 with self.assertRaisesRegexp(ValueError, 487 "fail.*my_tensor.*must have rank 1"): 488 with ops.control_dependencies( 489 [check_ops.assert_rank( 490 tensor, desired_rank, message="fail")]): 491 array_ops.identity(tensor).eval() 492 493 def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self): 494 with self.test_session(): 495 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 496 desired_rank = 1 497 with ops.control_dependencies( 498 [check_ops.assert_rank( 499 tensor, desired_rank, message="fail")]): 500 with self.assertRaisesOpError("fail.*my_tensor.*rank"): 501 array_ops.identity(tensor).eval(feed_dict={tensor: 0}) 502 503 def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): 504 with self.test_session(): 505 tensor = constant_op.constant(1, name="my_tensor") 506 desired_rank = 0 507 with ops.control_dependencies( 508 [check_ops.assert_rank(tensor, desired_rank)]): 509 array_ops.identity(tensor).eval() 510 511 def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): 512 with self.test_session(): 513 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 514 desired_rank = 0 515 with ops.control_dependencies( 516 [check_ops.assert_rank(tensor, desired_rank)]): 517 array_ops.identity(tensor).eval(feed_dict={tensor: 0}) 518 519 def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self): 520 with self.test_session(): 521 tensor = constant_op.constant([1, 2], name="my_tensor") 522 desired_rank = 0 523 with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"): 524 with ops.control_dependencies( 525 [check_ops.assert_rank(tensor, desired_rank)]): 526 array_ops.identity(tensor).eval() 527 528 def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self): 529 with self.test_session(): 530 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 531 desired_rank = 0 532 with ops.control_dependencies( 533 [check_ops.assert_rank(tensor, desired_rank)]): 534 with self.assertRaisesOpError("my_tensor.*rank"): 535 array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) 536 537 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): 538 with self.test_session(): 539 tensor = constant_op.constant([1, 2], name="my_tensor") 540 desired_rank = 1 541 with ops.control_dependencies( 542 [check_ops.assert_rank(tensor, desired_rank)]): 543 array_ops.identity(tensor).eval() 544 545 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): 546 with self.test_session(): 547 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 548 desired_rank = 1 549 with ops.control_dependencies( 550 [check_ops.assert_rank(tensor, desired_rank)]): 551 array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) 552 553 def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): 554 with self.test_session(): 555 tensor = constant_op.constant([1, 2], name="my_tensor") 556 desired_rank = 2 557 with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"): 558 with ops.control_dependencies( 559 [check_ops.assert_rank(tensor, desired_rank)]): 560 array_ops.identity(tensor).eval() 561 562 def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self): 563 with self.test_session(): 564 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 565 desired_rank = 2 566 with ops.control_dependencies( 567 [check_ops.assert_rank(tensor, desired_rank)]): 568 with self.assertRaisesOpError("my_tensor.*rank"): 569 array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) 570 571 def test_raises_if_rank_is_not_scalar_static(self): 572 with self.test_session(): 573 tensor = constant_op.constant([1, 2], name="my_tensor") 574 with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"): 575 check_ops.assert_rank(tensor, np.array([], dtype=np.int32)) 576 577 def test_raises_if_rank_is_not_scalar_dynamic(self): 578 with self.test_session(): 579 tensor = constant_op.constant( 580 [1, 2], dtype=dtypes.float32, name="my_tensor") 581 rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor") 582 with self.assertRaisesOpError("Rank must be a scalar"): 583 with ops.control_dependencies( 584 [check_ops.assert_rank(tensor, rank_tensor)]): 585 array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]}) 586 587 def test_raises_if_rank_is_not_integer_static(self): 588 with self.test_session(): 589 tensor = constant_op.constant([1, 2], name="my_tensor") 590 with self.assertRaisesRegexp(TypeError, 591 "must be of type <dtype: 'int32'>"): 592 check_ops.assert_rank(tensor, .5) 593 594 def test_raises_if_rank_is_not_integer_dynamic(self): 595 with self.test_session(): 596 tensor = constant_op.constant( 597 [1, 2], dtype=dtypes.float32, name="my_tensor") 598 rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor") 599 with self.assertRaisesRegexp(TypeError, 600 "must be of type <dtype: 'int32'>"): 601 with ops.control_dependencies( 602 [check_ops.assert_rank(tensor, rank_tensor)]): 603 array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5}) 604 605 606class AssertRankInTest(test.TestCase): 607 608 def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self): 609 with self.test_session(): 610 tensor_rank0 = constant_op.constant(42, name="my_tensor") 611 with self.assertRaisesRegexp( 612 ValueError, "fail.*my_tensor.*must have rank.*in.*1.*2"): 613 with ops.control_dependencies([ 614 check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]): 615 array_ops.identity(tensor_rank0).eval() 616 617 def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self): 618 with self.test_session(): 619 tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") 620 with ops.control_dependencies([ 621 check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]): 622 with self.assertRaisesOpError("fail.*my_tensor.*rank"): 623 array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) 624 625 def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self): 626 with self.test_session(): 627 tensor_rank0 = constant_op.constant(42, name="my_tensor") 628 for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): 629 with ops.control_dependencies([ 630 check_ops.assert_rank_in(tensor_rank0, desired_ranks)]): 631 array_ops.identity(tensor_rank0).eval() 632 633 def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): 634 with self.test_session(): 635 tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") 636 for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): 637 with ops.control_dependencies([ 638 check_ops.assert_rank_in(tensor_rank0, desired_ranks)]): 639 array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) 640 641 def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self): 642 with self.test_session(): 643 tensor_rank1 = constant_op.constant([42, 43], name="my_tensor") 644 for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): 645 with ops.control_dependencies([ 646 check_ops.assert_rank_in(tensor_rank1, desired_ranks)]): 647 array_ops.identity(tensor_rank1).eval() 648 649 def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): 650 with self.test_session(): 651 tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") 652 for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): 653 with ops.control_dependencies([ 654 check_ops.assert_rank_in(tensor_rank1, desired_ranks)]): 655 array_ops.identity(tensor_rank1).eval(feed_dict={ 656 tensor_rank1: (42.0, 43.0) 657 }) 658 659 def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self): 660 with self.test_session(): 661 tensor_rank1 = constant_op.constant((42, 43), name="my_tensor") 662 with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"): 663 with ops.control_dependencies([ 664 check_ops.assert_rank_in(tensor_rank1, (0, 2))]): 665 array_ops.identity(tensor_rank1).eval() 666 667 def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self): 668 with self.test_session(): 669 tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") 670 with ops.control_dependencies([ 671 check_ops.assert_rank_in(tensor_rank1, (0, 2))]): 672 with self.assertRaisesOpError("my_tensor.*rank"): 673 array_ops.identity(tensor_rank1).eval(feed_dict={ 674 tensor_rank1: (42.0, 43.0) 675 }) 676 677 def test_raises_if_rank_is_not_scalar_static(self): 678 with self.test_session(): 679 tensor = constant_op.constant((42, 43), name="my_tensor") 680 desired_ranks = ( 681 np.array(1, dtype=np.int32), 682 np.array((2, 1), dtype=np.int32)) 683 with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"): 684 check_ops.assert_rank_in(tensor, desired_ranks) 685 686 def test_raises_if_rank_is_not_scalar_dynamic(self): 687 with self.test_session(): 688 tensor = constant_op.constant( 689 (42, 43), dtype=dtypes.float32, name="my_tensor") 690 desired_ranks = ( 691 array_ops.placeholder(dtypes.int32, name="rank0_tensor"), 692 array_ops.placeholder(dtypes.int32, name="rank1_tensor")) 693 with self.assertRaisesOpError("Rank must be a scalar"): 694 with ops.control_dependencies( 695 (check_ops.assert_rank_in(tensor, desired_ranks),)): 696 array_ops.identity(tensor).eval(feed_dict={ 697 desired_ranks[0]: 1, 698 desired_ranks[1]: [2, 1], 699 }) 700 701 def test_raises_if_rank_is_not_integer_static(self): 702 with self.test_session(): 703 tensor = constant_op.constant((42, 43), name="my_tensor") 704 with self.assertRaisesRegexp(TypeError, 705 "must be of type <dtype: 'int32'>"): 706 check_ops.assert_rank_in(tensor, (1, .5,)) 707 708 def test_raises_if_rank_is_not_integer_dynamic(self): 709 with self.test_session(): 710 tensor = constant_op.constant( 711 (42, 43), dtype=dtypes.float32, name="my_tensor") 712 rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor") 713 with self.assertRaisesRegexp(TypeError, 714 "must be of type <dtype: 'int32'>"): 715 with ops.control_dependencies( 716 [check_ops.assert_rank_in(tensor, (1, rank_tensor))]): 717 array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5}) 718 719 720class AssertRankAtLeastTest(test.TestCase): 721 722 def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): 723 with self.test_session(): 724 tensor = constant_op.constant(1, name="my_tensor") 725 desired_rank = 1 726 with self.assertRaisesRegexp(ValueError, "my_tensor.*rank at least 1"): 727 with ops.control_dependencies( 728 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 729 array_ops.identity(tensor).eval() 730 731 def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self): 732 with self.test_session(): 733 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 734 desired_rank = 1 735 with ops.control_dependencies( 736 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 737 with self.assertRaisesOpError("my_tensor.*rank"): 738 array_ops.identity(tensor).eval(feed_dict={tensor: 0}) 739 740 def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self): 741 with self.test_session(): 742 tensor = constant_op.constant(1, name="my_tensor") 743 desired_rank = 0 744 with ops.control_dependencies( 745 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 746 array_ops.identity(tensor).eval() 747 748 def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): 749 with self.test_session(): 750 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 751 desired_rank = 0 752 with ops.control_dependencies( 753 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 754 array_ops.identity(tensor).eval(feed_dict={tensor: 0}) 755 756 def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(self): 757 with self.test_session(): 758 tensor = constant_op.constant([1, 2], name="my_tensor") 759 desired_rank = 0 760 with ops.control_dependencies( 761 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 762 array_ops.identity(tensor).eval() 763 764 def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self): 765 with self.test_session(): 766 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 767 desired_rank = 0 768 with ops.control_dependencies( 769 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 770 array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) 771 772 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self): 773 with self.test_session(): 774 tensor = constant_op.constant([1, 2], name="my_tensor") 775 desired_rank = 1 776 with ops.control_dependencies( 777 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 778 array_ops.identity(tensor).eval() 779 780 def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self): 781 with self.test_session(): 782 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 783 desired_rank = 1 784 with ops.control_dependencies( 785 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 786 array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) 787 788 def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self): 789 with self.test_session(): 790 tensor = constant_op.constant([1, 2], name="my_tensor") 791 desired_rank = 2 792 with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"): 793 with ops.control_dependencies( 794 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 795 array_ops.identity(tensor).eval() 796 797 def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self): 798 with self.test_session(): 799 tensor = array_ops.placeholder(dtypes.float32, name="my_tensor") 800 desired_rank = 2 801 with ops.control_dependencies( 802 [check_ops.assert_rank_at_least(tensor, desired_rank)]): 803 with self.assertRaisesOpError("my_tensor.*rank"): 804 array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]}) 805 806 807class AssertNonNegativeTest(test.TestCase): 808 809 def test_raises_when_negative(self): 810 with self.test_session(): 811 zoe = constant_op.constant([-1, -2], name="zoe") 812 with ops.control_dependencies([check_ops.assert_non_negative(zoe)]): 813 out = array_ops.identity(zoe) 814 with self.assertRaisesOpError("zoe"): 815 out.eval() 816 817 def test_doesnt_raise_when_zero_and_positive(self): 818 with self.test_session(): 819 lucas = constant_op.constant([0, 2], name="lucas") 820 with ops.control_dependencies([check_ops.assert_non_negative(lucas)]): 821 out = array_ops.identity(lucas) 822 out.eval() 823 824 def test_empty_tensor_doesnt_raise(self): 825 # A tensor is non-negative when it satisfies: 826 # For every element x_i in x, x_i >= 0 827 # and an empty tensor has no elements, so this is trivially satisfied. 828 # This is standard set theory. 829 with self.test_session(): 830 empty = constant_op.constant([], name="empty") 831 with ops.control_dependencies([check_ops.assert_non_negative(empty)]): 832 out = array_ops.identity(empty) 833 out.eval() 834 835 836class AssertNonPositiveTest(test.TestCase): 837 838 def test_doesnt_raise_when_zero_and_negative(self): 839 with self.test_session(): 840 tom = constant_op.constant([0, -2], name="tom") 841 with ops.control_dependencies([check_ops.assert_non_positive(tom)]): 842 out = array_ops.identity(tom) 843 out.eval() 844 845 def test_raises_when_positive(self): 846 with self.test_session(): 847 rachel = constant_op.constant([0, 2], name="rachel") 848 with ops.control_dependencies([check_ops.assert_non_positive(rachel)]): 849 out = array_ops.identity(rachel) 850 with self.assertRaisesOpError("rachel"): 851 out.eval() 852 853 def test_empty_tensor_doesnt_raise(self): 854 # A tensor is non-positive when it satisfies: 855 # For every element x_i in x, x_i <= 0 856 # and an empty tensor has no elements, so this is trivially satisfied. 857 # This is standard set theory. 858 with self.test_session(): 859 empty = constant_op.constant([], name="empty") 860 with ops.control_dependencies([check_ops.assert_non_positive(empty)]): 861 out = array_ops.identity(empty) 862 out.eval() 863 864 865class AssertIntegerTest(test.TestCase): 866 867 def test_doesnt_raise_when_integer(self): 868 with self.test_session(): 869 integers = constant_op.constant([1, 2], name="integers") 870 with ops.control_dependencies([check_ops.assert_integer(integers)]): 871 out = array_ops.identity(integers) 872 out.eval() 873 874 def test_raises_when_float(self): 875 with self.test_session(): 876 floats = constant_op.constant([1.0, 2.0], name="floats") 877 with self.assertRaisesRegexp(TypeError, "Expected.*integer"): 878 check_ops.assert_integer(floats) 879 880 881class IsStrictlyIncreasingTest(test.TestCase): 882 883 def test_constant_tensor_is_not_strictly_increasing(self): 884 with self.test_session(): 885 self.assertFalse(check_ops.is_strictly_increasing([1, 1, 1]).eval()) 886 887 def test_decreasing_tensor_is_not_strictly_increasing(self): 888 with self.test_session(): 889 self.assertFalse(check_ops.is_strictly_increasing([1, 0, -1]).eval()) 890 891 def test_2d_decreasing_tensor_is_not_strictly_increasing(self): 892 with self.test_session(): 893 self.assertFalse( 894 check_ops.is_strictly_increasing([[1, 3], [2, 4]]).eval()) 895 896 def test_increasing_tensor_is_increasing(self): 897 with self.test_session(): 898 self.assertTrue(check_ops.is_strictly_increasing([1, 2, 3]).eval()) 899 900 def test_increasing_rank_two_tensor(self): 901 with self.test_session(): 902 self.assertTrue( 903 check_ops.is_strictly_increasing([[-1, 2], [3, 4]]).eval()) 904 905 def test_tensor_with_one_element_is_strictly_increasing(self): 906 with self.test_session(): 907 self.assertTrue(check_ops.is_strictly_increasing([1]).eval()) 908 909 def test_empty_tensor_is_strictly_increasing(self): 910 with self.test_session(): 911 self.assertTrue(check_ops.is_strictly_increasing([]).eval()) 912 913 914class IsNonDecreasingTest(test.TestCase): 915 916 def test_constant_tensor_is_non_decreasing(self): 917 with self.test_session(): 918 self.assertTrue(check_ops.is_non_decreasing([1, 1, 1]).eval()) 919 920 def test_decreasing_tensor_is_not_non_decreasing(self): 921 with self.test_session(): 922 self.assertFalse(check_ops.is_non_decreasing([3, 2, 1]).eval()) 923 924 def test_2d_decreasing_tensor_is_not_non_decreasing(self): 925 with self.test_session(): 926 self.assertFalse(check_ops.is_non_decreasing([[1, 3], [2, 4]]).eval()) 927 928 def test_increasing_rank_one_tensor_is_non_decreasing(self): 929 with self.test_session(): 930 self.assertTrue(check_ops.is_non_decreasing([1, 2, 3]).eval()) 931 932 def test_increasing_rank_two_tensor(self): 933 with self.test_session(): 934 self.assertTrue(check_ops.is_non_decreasing([[-1, 2], [3, 3]]).eval()) 935 936 def test_tensor_with_one_element_is_non_decreasing(self): 937 with self.test_session(): 938 self.assertTrue(check_ops.is_non_decreasing([1]).eval()) 939 940 def test_empty_tensor_is_non_decreasing(self): 941 with self.test_session(): 942 self.assertTrue(check_ops.is_non_decreasing([]).eval()) 943 944 945if __name__ == "__main__": 946 test.main() 947