check_ops_test.py revision 8043a27ed77f59bb68409070f2bfa01df0e04b89
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.check_ops."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import check_ops
29from tensorflow.python.platform import test
30
31
32class AssertProperIterableTest(test.TestCase):
33
34  def test_single_tensor_raises(self):
35    tensor = constant_op.constant(1)
36    with self.assertRaisesRegexp(TypeError, "proper"):
37      check_ops.assert_proper_iterable(tensor)
38
39  def test_single_sparse_tensor_raises(self):
40    ten = sparse_tensor.SparseTensor(
41        indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
42    with self.assertRaisesRegexp(TypeError, "proper"):
43      check_ops.assert_proper_iterable(ten)
44
45  def test_single_ndarray_raises(self):
46    array = np.array([1, 2, 3])
47    with self.assertRaisesRegexp(TypeError, "proper"):
48      check_ops.assert_proper_iterable(array)
49
50  def test_single_string_raises(self):
51    mystr = "hello"
52    with self.assertRaisesRegexp(TypeError, "proper"):
53      check_ops.assert_proper_iterable(mystr)
54
55  def test_non_iterable_object_raises(self):
56    non_iterable = 1234
57    with self.assertRaisesRegexp(TypeError, "to be iterable"):
58      check_ops.assert_proper_iterable(non_iterable)
59
60  def test_list_does_not_raise(self):
61    list_of_stuff = [
62        constant_op.constant([11, 22]), constant_op.constant([1, 2])
63    ]
64    check_ops.assert_proper_iterable(list_of_stuff)
65
66  def test_generator_does_not_raise(self):
67    generator_of_stuff = (constant_op.constant([11, 22]), constant_op.constant(
68        [1, 2]))
69    check_ops.assert_proper_iterable(generator_of_stuff)
70
71
72class AssertEqualTest(test.TestCase):
73
74  def test_doesnt_raise_when_equal(self):
75    with self.test_session():
76      small = constant_op.constant([1, 2], name="small")
77      with ops.control_dependencies([check_ops.assert_equal(small, small)]):
78        out = array_ops.identity(small)
79      out.eval()
80
81  def test_raises_when_greater(self):
82    with self.test_session():
83      small = constant_op.constant([1, 2], name="small")
84      big = constant_op.constant([3, 4], name="big")
85      with ops.control_dependencies(
86          [check_ops.assert_equal(
87              big, small, message="fail")]):
88        out = array_ops.identity(small)
89      with self.assertRaisesOpError("fail.*big.*small"):
90        out.eval()
91
92  def test_raises_when_less(self):
93    with self.test_session():
94      small = constant_op.constant([3, 1], name="small")
95      big = constant_op.constant([4, 2], name="big")
96      with ops.control_dependencies([check_ops.assert_equal(small, big)]):
97        out = array_ops.identity(small)
98      with self.assertRaisesOpError("small.*big"):
99        out.eval()
100
101  def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):
102    with self.test_session():
103      small = constant_op.constant([1, 2], name="small")
104      small_2 = constant_op.constant([1, 2], name="small_2")
105      with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
106        out = array_ops.identity(small)
107      out.eval()
108
109  def test_raises_when_equal_but_non_broadcastable_shapes(self):
110    with self.test_session():
111      small = constant_op.constant([1, 1, 1], name="small")
112      small_2 = constant_op.constant([1, 1], name="small_2")
113      with self.assertRaisesRegexp(ValueError, "must be"):
114        with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
115          out = array_ops.identity(small)
116        out.eval()
117
118  def test_doesnt_raise_when_both_empty(self):
119    with self.test_session():
120      larry = constant_op.constant([])
121      curly = constant_op.constant([])
122      with ops.control_dependencies([check_ops.assert_equal(larry, curly)]):
123        out = array_ops.identity(larry)
124      out.eval()
125
126
127class AssertNoneEqualTest(test.TestCase):
128
129  def test_doesnt_raise_when_not_equal(self):
130    with self.test_session():
131      small = constant_op.constant([1, 2], name="small")
132      big = constant_op.constant([10, 20], name="small")
133      with ops.control_dependencies(
134          [check_ops.assert_none_equal(big, small)]):
135        out = array_ops.identity(small)
136      out.eval()
137
138  def test_raises_when_equal(self):
139    with self.test_session():
140      small = constant_op.constant([3, 1], name="small")
141      with ops.control_dependencies(
142          [check_ops.assert_none_equal(small, small)]):
143        out = array_ops.identity(small)
144      with self.assertRaisesOpError("x != y did not hold"):
145        out.eval()
146
147  def test_doesnt_raise_when_not_equal_and_broadcastable_shapes(self):
148    with self.test_session():
149      small = constant_op.constant([1, 2], name="small")
150      big = constant_op.constant([3], name="big")
151      with ops.control_dependencies(
152          [check_ops.assert_none_equal(small, big)]):
153        out = array_ops.identity(small)
154      out.eval()
155
156  def test_raises_when_not_equal_but_non_broadcastable_shapes(self):
157    with self.test_session():
158      small = constant_op.constant([1, 1, 1], name="small")
159      big = constant_op.constant([10, 10], name="big")
160      with self.assertRaisesRegexp(ValueError, "must be"):
161        with ops.control_dependencies(
162            [check_ops.assert_none_equal(small, big)]):
163          out = array_ops.identity(small)
164        out.eval()
165
166  def test_doesnt_raise_when_both_empty(self):
167    with self.test_session():
168      larry = constant_op.constant([])
169      curly = constant_op.constant([])
170      with ops.control_dependencies(
171          [check_ops.assert_none_equal(larry, curly)]):
172        out = array_ops.identity(larry)
173      out.eval()
174
175
176class AssertLessTest(test.TestCase):
177
178  def test_raises_when_equal(self):
179    with self.test_session():
180      small = constant_op.constant([1, 2], name="small")
181      with ops.control_dependencies(
182          [check_ops.assert_less(
183              small, small, message="fail")]):
184        out = array_ops.identity(small)
185      with self.assertRaisesOpError("fail.*small.*small"):
186        out.eval()
187
188  def test_raises_when_greater(self):
189    with self.test_session():
190      small = constant_op.constant([1, 2], name="small")
191      big = constant_op.constant([3, 4], name="big")
192      with ops.control_dependencies([check_ops.assert_less(big, small)]):
193        out = array_ops.identity(small)
194      with self.assertRaisesOpError("big.*small"):
195        out.eval()
196
197  def test_doesnt_raise_when_less(self):
198    with self.test_session():
199      small = constant_op.constant([3, 1], name="small")
200      big = constant_op.constant([4, 2], name="big")
201      with ops.control_dependencies([check_ops.assert_less(small, big)]):
202        out = array_ops.identity(small)
203      out.eval()
204
205  def test_doesnt_raise_when_less_and_broadcastable_shapes(self):
206    with self.test_session():
207      small = constant_op.constant([1], name="small")
208      big = constant_op.constant([3, 2], name="big")
209      with ops.control_dependencies([check_ops.assert_less(small, big)]):
210        out = array_ops.identity(small)
211      out.eval()
212
213  def test_raises_when_less_but_non_broadcastable_shapes(self):
214    with self.test_session():
215      small = constant_op.constant([1, 1, 1], name="small")
216      big = constant_op.constant([3, 2], name="big")
217      with self.assertRaisesRegexp(ValueError, "must be"):
218        with ops.control_dependencies([check_ops.assert_less(small, big)]):
219          out = array_ops.identity(small)
220        out.eval()
221
222  def test_doesnt_raise_when_both_empty(self):
223    with self.test_session():
224      larry = constant_op.constant([])
225      curly = constant_op.constant([])
226      with ops.control_dependencies([check_ops.assert_less(larry, curly)]):
227        out = array_ops.identity(larry)
228      out.eval()
229
230
231class AssertLessEqualTest(test.TestCase):
232
233  def test_doesnt_raise_when_equal(self):
234    with self.test_session():
235      small = constant_op.constant([1, 2], name="small")
236      with ops.control_dependencies(
237          [check_ops.assert_less_equal(small, small)]):
238        out = array_ops.identity(small)
239      out.eval()
240
241  def test_raises_when_greater(self):
242    with self.test_session():
243      small = constant_op.constant([1, 2], name="small")
244      big = constant_op.constant([3, 4], name="big")
245      with ops.control_dependencies(
246          [check_ops.assert_less_equal(
247              big, small, message="fail")]):
248        out = array_ops.identity(small)
249      with self.assertRaisesOpError("fail.*big.*small"):
250        out.eval()
251
252  def test_doesnt_raise_when_less_equal(self):
253    with self.test_session():
254      small = constant_op.constant([1, 2], name="small")
255      big = constant_op.constant([3, 2], name="big")
256      with ops.control_dependencies([check_ops.assert_less_equal(small, big)]):
257        out = array_ops.identity(small)
258      out.eval()
259
260  def test_doesnt_raise_when_less_equal_and_broadcastable_shapes(self):
261    with self.test_session():
262      small = constant_op.constant([1], name="small")
263      big = constant_op.constant([3, 1], name="big")
264      with ops.control_dependencies([check_ops.assert_less_equal(small, big)]):
265        out = array_ops.identity(small)
266      out.eval()
267
268  def test_raises_when_less_equal_but_non_broadcastable_shapes(self):
269    with self.test_session():
270      small = constant_op.constant([1, 1, 1], name="small")
271      big = constant_op.constant([3, 1], name="big")
272      with self.assertRaisesRegexp(ValueError, "must be"):
273        with ops.control_dependencies(
274            [check_ops.assert_less_equal(small, big)]):
275          out = array_ops.identity(small)
276        out.eval()
277
278  def test_doesnt_raise_when_both_empty(self):
279    with self.test_session():
280      larry = constant_op.constant([])
281      curly = constant_op.constant([])
282      with ops.control_dependencies(
283          [check_ops.assert_less_equal(larry, curly)]):
284        out = array_ops.identity(larry)
285      out.eval()
286
287
288class AssertGreaterTest(test.TestCase):
289
290  def test_raises_when_equal(self):
291    with self.test_session():
292      small = constant_op.constant([1, 2], name="small")
293      with ops.control_dependencies(
294          [check_ops.assert_greater(
295              small, small, message="fail")]):
296        out = array_ops.identity(small)
297      with self.assertRaisesOpError("fail.*small.*small"):
298        out.eval()
299
300  def test_raises_when_less(self):
301    with self.test_session():
302      small = constant_op.constant([1, 2], name="small")
303      big = constant_op.constant([3, 4], name="big")
304      with ops.control_dependencies([check_ops.assert_greater(small, big)]):
305        out = array_ops.identity(big)
306      with self.assertRaisesOpError("small.*big"):
307        out.eval()
308
309  def test_doesnt_raise_when_greater(self):
310    with self.test_session():
311      small = constant_op.constant([3, 1], name="small")
312      big = constant_op.constant([4, 2], name="big")
313      with ops.control_dependencies([check_ops.assert_greater(big, small)]):
314        out = array_ops.identity(small)
315      out.eval()
316
317  def test_doesnt_raise_when_greater_and_broadcastable_shapes(self):
318    with self.test_session():
319      small = constant_op.constant([1], name="small")
320      big = constant_op.constant([3, 2], name="big")
321      with ops.control_dependencies([check_ops.assert_greater(big, small)]):
322        out = array_ops.identity(small)
323      out.eval()
324
325  def test_raises_when_greater_but_non_broadcastable_shapes(self):
326    with self.test_session():
327      small = constant_op.constant([1, 1, 1], name="small")
328      big = constant_op.constant([3, 2], name="big")
329      with self.assertRaisesRegexp(ValueError, "must be"):
330        with ops.control_dependencies([check_ops.assert_greater(big, small)]):
331          out = array_ops.identity(small)
332        out.eval()
333
334  def test_doesnt_raise_when_both_empty(self):
335    with self.test_session():
336      larry = constant_op.constant([])
337      curly = constant_op.constant([])
338      with ops.control_dependencies([check_ops.assert_greater(larry, curly)]):
339        out = array_ops.identity(larry)
340      out.eval()
341
342
343class AssertGreaterEqualTest(test.TestCase):
344
345  def test_doesnt_raise_when_equal(self):
346    with self.test_session():
347      small = constant_op.constant([1, 2], name="small")
348      with ops.control_dependencies(
349          [check_ops.assert_greater_equal(small, small)]):
350        out = array_ops.identity(small)
351      out.eval()
352
353  def test_raises_when_less(self):
354    with self.test_session():
355      small = constant_op.constant([1, 2], name="small")
356      big = constant_op.constant([3, 4], name="big")
357      with ops.control_dependencies(
358          [check_ops.assert_greater_equal(
359              small, big, message="fail")]):
360        out = array_ops.identity(small)
361      with self.assertRaisesOpError("fail.*small.*big"):
362        out.eval()
363
364  def test_doesnt_raise_when_greater_equal(self):
365    with self.test_session():
366      small = constant_op.constant([1, 2], name="small")
367      big = constant_op.constant([3, 2], name="big")
368      with ops.control_dependencies(
369          [check_ops.assert_greater_equal(big, small)]):
370        out = array_ops.identity(small)
371      out.eval()
372
373  def test_doesnt_raise_when_greater_equal_and_broadcastable_shapes(self):
374    with self.test_session():
375      small = constant_op.constant([1], name="small")
376      big = constant_op.constant([3, 1], name="big")
377      with ops.control_dependencies(
378          [check_ops.assert_greater_equal(big, small)]):
379        out = array_ops.identity(small)
380      out.eval()
381
382  def test_raises_when_less_equal_but_non_broadcastable_shapes(self):
383    with self.test_session():
384      small = constant_op.constant([1, 1, 1], name="big")
385      big = constant_op.constant([3, 1], name="small")
386      with self.assertRaisesRegexp(ValueError, "Dimensions must be equal"):
387        with ops.control_dependencies(
388            [check_ops.assert_greater_equal(big, small)]):
389          out = array_ops.identity(small)
390        out.eval()
391
392  def test_doesnt_raise_when_both_empty(self):
393    with self.test_session():
394      larry = constant_op.constant([])
395      curly = constant_op.constant([])
396      with ops.control_dependencies(
397          [check_ops.assert_greater_equal(larry, curly)]):
398        out = array_ops.identity(larry)
399      out.eval()
400
401
402class AssertNegativeTest(test.TestCase):
403
404  def test_doesnt_raise_when_negative(self):
405    with self.test_session():
406      frank = constant_op.constant([-1, -2], name="frank")
407      with ops.control_dependencies([check_ops.assert_negative(frank)]):
408        out = array_ops.identity(frank)
409      out.eval()
410
411  def test_raises_when_positive(self):
412    with self.test_session():
413      doug = constant_op.constant([1, 2], name="doug")
414      with ops.control_dependencies(
415          [check_ops.assert_negative(
416              doug, message="fail")]):
417        out = array_ops.identity(doug)
418      with self.assertRaisesOpError("fail.*doug"):
419        out.eval()
420
421  def test_raises_when_zero(self):
422    with self.test_session():
423      claire = constant_op.constant([0], name="claire")
424      with ops.control_dependencies([check_ops.assert_negative(claire)]):
425        out = array_ops.identity(claire)
426      with self.assertRaisesOpError("claire"):
427        out.eval()
428
429  def test_empty_tensor_doesnt_raise(self):
430    # A tensor is negative when it satisfies:
431    #   For every element x_i in x, x_i < 0
432    # and an empty tensor has no elements, so this is trivially satisfied.
433    # This is standard set theory.
434    with self.test_session():
435      empty = constant_op.constant([], name="empty")
436      with ops.control_dependencies([check_ops.assert_negative(empty)]):
437        out = array_ops.identity(empty)
438      out.eval()
439
440
441class AssertPositiveTest(test.TestCase):
442
443  def test_raises_when_negative(self):
444    with self.test_session():
445      freddie = constant_op.constant([-1, -2], name="freddie")
446      with ops.control_dependencies(
447          [check_ops.assert_positive(
448              freddie, message="fail")]):
449        out = array_ops.identity(freddie)
450      with self.assertRaisesOpError("fail.*freddie"):
451        out.eval()
452
453  def test_doesnt_raise_when_positive(self):
454    with self.test_session():
455      remmy = constant_op.constant([1, 2], name="remmy")
456      with ops.control_dependencies([check_ops.assert_positive(remmy)]):
457        out = array_ops.identity(remmy)
458      out.eval()
459
460  def test_raises_when_zero(self):
461    with self.test_session():
462      meechum = constant_op.constant([0], name="meechum")
463      with ops.control_dependencies([check_ops.assert_positive(meechum)]):
464        out = array_ops.identity(meechum)
465      with self.assertRaisesOpError("meechum"):
466        out.eval()
467
468  def test_empty_tensor_doesnt_raise(self):
469    # A tensor is positive when it satisfies:
470    #   For every element x_i in x, x_i > 0
471    # and an empty tensor has no elements, so this is trivially satisfied.
472    # This is standard set theory.
473    with self.test_session():
474      empty = constant_op.constant([], name="empty")
475      with ops.control_dependencies([check_ops.assert_positive(empty)]):
476        out = array_ops.identity(empty)
477      out.eval()
478
479
480class AssertRankTest(test.TestCase):
481
482  def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
483    with self.test_session():
484      tensor = constant_op.constant(1, name="my_tensor")
485      desired_rank = 1
486      with self.assertRaisesRegexp(ValueError,
487                                   "fail.*my_tensor.*must have rank 1"):
488        with ops.control_dependencies(
489            [check_ops.assert_rank(
490                tensor, desired_rank, message="fail")]):
491          array_ops.identity(tensor).eval()
492
493  def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
494    with self.test_session():
495      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
496      desired_rank = 1
497      with ops.control_dependencies(
498          [check_ops.assert_rank(
499              tensor, desired_rank, message="fail")]):
500        with self.assertRaisesOpError("fail.*my_tensor.*rank"):
501          array_ops.identity(tensor).eval(feed_dict={tensor: 0})
502
503  def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
504    with self.test_session():
505      tensor = constant_op.constant(1, name="my_tensor")
506      desired_rank = 0
507      with ops.control_dependencies(
508          [check_ops.assert_rank(tensor, desired_rank)]):
509        array_ops.identity(tensor).eval()
510
511  def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
512    with self.test_session():
513      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
514      desired_rank = 0
515      with ops.control_dependencies(
516          [check_ops.assert_rank(tensor, desired_rank)]):
517        array_ops.identity(tensor).eval(feed_dict={tensor: 0})
518
519  def test_rank_one_tensor_raises_if_rank_too_large_static_rank(self):
520    with self.test_session():
521      tensor = constant_op.constant([1, 2], name="my_tensor")
522      desired_rank = 0
523      with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
524        with ops.control_dependencies(
525            [check_ops.assert_rank(tensor, desired_rank)]):
526          array_ops.identity(tensor).eval()
527
528  def test_rank_one_tensor_raises_if_rank_too_large_dynamic_rank(self):
529    with self.test_session():
530      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
531      desired_rank = 0
532      with ops.control_dependencies(
533          [check_ops.assert_rank(tensor, desired_rank)]):
534        with self.assertRaisesOpError("my_tensor.*rank"):
535          array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
536
537  def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
538    with self.test_session():
539      tensor = constant_op.constant([1, 2], name="my_tensor")
540      desired_rank = 1
541      with ops.control_dependencies(
542          [check_ops.assert_rank(tensor, desired_rank)]):
543        array_ops.identity(tensor).eval()
544
545  def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
546    with self.test_session():
547      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
548      desired_rank = 1
549      with ops.control_dependencies(
550          [check_ops.assert_rank(tensor, desired_rank)]):
551        array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
552
553  def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
554    with self.test_session():
555      tensor = constant_op.constant([1, 2], name="my_tensor")
556      desired_rank = 2
557      with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
558        with ops.control_dependencies(
559            [check_ops.assert_rank(tensor, desired_rank)]):
560          array_ops.identity(tensor).eval()
561
562  def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
563    with self.test_session():
564      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
565      desired_rank = 2
566      with ops.control_dependencies(
567          [check_ops.assert_rank(tensor, desired_rank)]):
568        with self.assertRaisesOpError("my_tensor.*rank"):
569          array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
570
571  def test_raises_if_rank_is_not_scalar_static(self):
572    with self.test_session():
573      tensor = constant_op.constant([1, 2], name="my_tensor")
574      with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"):
575        check_ops.assert_rank(tensor, np.array([], dtype=np.int32))
576
577  def test_raises_if_rank_is_not_scalar_dynamic(self):
578    with self.test_session():
579      tensor = constant_op.constant(
580          [1, 2], dtype=dtypes.float32, name="my_tensor")
581      rank_tensor = array_ops.placeholder(dtypes.int32, name="rank_tensor")
582      with self.assertRaisesOpError("Rank must be a scalar"):
583        with ops.control_dependencies(
584            [check_ops.assert_rank(tensor, rank_tensor)]):
585          array_ops.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]})
586
587  def test_raises_if_rank_is_not_integer_static(self):
588    with self.test_session():
589      tensor = constant_op.constant([1, 2], name="my_tensor")
590      with self.assertRaisesRegexp(TypeError,
591                                   "must be of type <dtype: 'int32'>"):
592        check_ops.assert_rank(tensor, .5)
593
594  def test_raises_if_rank_is_not_integer_dynamic(self):
595    with self.test_session():
596      tensor = constant_op.constant(
597          [1, 2], dtype=dtypes.float32, name="my_tensor")
598      rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
599      with self.assertRaisesRegexp(TypeError,
600                                   "must be of type <dtype: 'int32'>"):
601        with ops.control_dependencies(
602            [check_ops.assert_rank(tensor, rank_tensor)]):
603          array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})
604
605
606class AssertRankInTest(test.TestCase):
607
608  def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self):
609    with self.test_session():
610      tensor_rank0 = constant_op.constant(42, name="my_tensor")
611      with self.assertRaisesRegexp(
612          ValueError, "fail.*my_tensor.*must have rank.*in.*1.*2"):
613        with ops.control_dependencies([
614            check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
615          array_ops.identity(tensor_rank0).eval()
616
617  def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self):
618    with self.test_session():
619      tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
620      with ops.control_dependencies([
621          check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
622        with self.assertRaisesOpError("fail.*my_tensor.*rank"):
623          array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
624
625  def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self):
626    with self.test_session():
627      tensor_rank0 = constant_op.constant(42, name="my_tensor")
628      for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
629        with ops.control_dependencies([
630            check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
631          array_ops.identity(tensor_rank0).eval()
632
633  def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
634    with self.test_session():
635      tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
636      for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
637        with ops.control_dependencies([
638            check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
639          array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
640
641  def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self):
642    with self.test_session():
643      tensor_rank1 = constant_op.constant([42, 43], name="my_tensor")
644      for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
645        with ops.control_dependencies([
646            check_ops.assert_rank_in(tensor_rank1, desired_ranks)]):
647          array_ops.identity(tensor_rank1).eval()
648
649  def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
650    with self.test_session():
651      tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
652      for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
653        with ops.control_dependencies([
654            check_ops.assert_rank_in(tensor_rank1, desired_ranks)]):
655          array_ops.identity(tensor_rank1).eval(feed_dict={
656              tensor_rank1: (42.0, 43.0)
657          })
658
659  def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self):
660    with self.test_session():
661      tensor_rank1 = constant_op.constant((42, 43), name="my_tensor")
662      with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
663        with ops.control_dependencies([
664            check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
665          array_ops.identity(tensor_rank1).eval()
666
667  def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self):
668    with self.test_session():
669      tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
670      with ops.control_dependencies([
671          check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
672        with self.assertRaisesOpError("my_tensor.*rank"):
673          array_ops.identity(tensor_rank1).eval(feed_dict={
674              tensor_rank1: (42.0, 43.0)
675          })
676
677  def test_raises_if_rank_is_not_scalar_static(self):
678    with self.test_session():
679      tensor = constant_op.constant((42, 43), name="my_tensor")
680      desired_ranks = (
681          np.array(1, dtype=np.int32),
682          np.array((2, 1), dtype=np.int32))
683      with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"):
684        check_ops.assert_rank_in(tensor, desired_ranks)
685
686  def test_raises_if_rank_is_not_scalar_dynamic(self):
687    with self.test_session():
688      tensor = constant_op.constant(
689          (42, 43), dtype=dtypes.float32, name="my_tensor")
690      desired_ranks = (
691          array_ops.placeholder(dtypes.int32, name="rank0_tensor"),
692          array_ops.placeholder(dtypes.int32, name="rank1_tensor"))
693      with self.assertRaisesOpError("Rank must be a scalar"):
694        with ops.control_dependencies(
695            (check_ops.assert_rank_in(tensor, desired_ranks),)):
696          array_ops.identity(tensor).eval(feed_dict={
697              desired_ranks[0]: 1,
698              desired_ranks[1]: [2, 1],
699          })
700
701  def test_raises_if_rank_is_not_integer_static(self):
702    with self.test_session():
703      tensor = constant_op.constant((42, 43), name="my_tensor")
704      with self.assertRaisesRegexp(TypeError,
705                                   "must be of type <dtype: 'int32'>"):
706        check_ops.assert_rank_in(tensor, (1, .5,))
707
708  def test_raises_if_rank_is_not_integer_dynamic(self):
709    with self.test_session():
710      tensor = constant_op.constant(
711          (42, 43), dtype=dtypes.float32, name="my_tensor")
712      rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
713      with self.assertRaisesRegexp(TypeError,
714                                   "must be of type <dtype: 'int32'>"):
715        with ops.control_dependencies(
716            [check_ops.assert_rank_in(tensor, (1, rank_tensor))]):
717          array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})
718
719
720class AssertRankAtLeastTest(test.TestCase):
721
722  def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
723    with self.test_session():
724      tensor = constant_op.constant(1, name="my_tensor")
725      desired_rank = 1
726      with self.assertRaisesRegexp(ValueError, "my_tensor.*rank at least 1"):
727        with ops.control_dependencies(
728            [check_ops.assert_rank_at_least(tensor, desired_rank)]):
729          array_ops.identity(tensor).eval()
730
731  def test_rank_zero_tensor_raises_if_rank_too_small_dynamic_rank(self):
732    with self.test_session():
733      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
734      desired_rank = 1
735      with ops.control_dependencies(
736          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
737        with self.assertRaisesOpError("my_tensor.*rank"):
738          array_ops.identity(tensor).eval(feed_dict={tensor: 0})
739
740  def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
741    with self.test_session():
742      tensor = constant_op.constant(1, name="my_tensor")
743      desired_rank = 0
744      with ops.control_dependencies(
745          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
746        array_ops.identity(tensor).eval()
747
748  def test_rank_zero_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
749    with self.test_session():
750      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
751      desired_rank = 0
752      with ops.control_dependencies(
753          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
754        array_ops.identity(tensor).eval(feed_dict={tensor: 0})
755
756  def test_rank_one_ten_doesnt_raise_raise_if_rank_too_large_static_rank(self):
757    with self.test_session():
758      tensor = constant_op.constant([1, 2], name="my_tensor")
759      desired_rank = 0
760      with ops.control_dependencies(
761          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
762        array_ops.identity(tensor).eval()
763
764  def test_rank_one_ten_doesnt_raise_if_rank_too_large_dynamic_rank(self):
765    with self.test_session():
766      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
767      desired_rank = 0
768      with ops.control_dependencies(
769          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
770        array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
771
772  def test_rank_one_tensor_doesnt_raise_if_rank_just_right_static_rank(self):
773    with self.test_session():
774      tensor = constant_op.constant([1, 2], name="my_tensor")
775      desired_rank = 1
776      with ops.control_dependencies(
777          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
778        array_ops.identity(tensor).eval()
779
780  def test_rank_one_tensor_doesnt_raise_if_rank_just_right_dynamic_rank(self):
781    with self.test_session():
782      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
783      desired_rank = 1
784      with ops.control_dependencies(
785          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
786        array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
787
788  def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
789    with self.test_session():
790      tensor = constant_op.constant([1, 2], name="my_tensor")
791      desired_rank = 2
792      with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
793        with ops.control_dependencies(
794            [check_ops.assert_rank_at_least(tensor, desired_rank)]):
795          array_ops.identity(tensor).eval()
796
797  def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
798    with self.test_session():
799      tensor = array_ops.placeholder(dtypes.float32, name="my_tensor")
800      desired_rank = 2
801      with ops.control_dependencies(
802          [check_ops.assert_rank_at_least(tensor, desired_rank)]):
803        with self.assertRaisesOpError("my_tensor.*rank"):
804          array_ops.identity(tensor).eval(feed_dict={tensor: [1, 2]})
805
806
807class AssertNonNegativeTest(test.TestCase):
808
809  def test_raises_when_negative(self):
810    with self.test_session():
811      zoe = constant_op.constant([-1, -2], name="zoe")
812      with ops.control_dependencies([check_ops.assert_non_negative(zoe)]):
813        out = array_ops.identity(zoe)
814      with self.assertRaisesOpError("zoe"):
815        out.eval()
816
817  def test_doesnt_raise_when_zero_and_positive(self):
818    with self.test_session():
819      lucas = constant_op.constant([0, 2], name="lucas")
820      with ops.control_dependencies([check_ops.assert_non_negative(lucas)]):
821        out = array_ops.identity(lucas)
822      out.eval()
823
824  def test_empty_tensor_doesnt_raise(self):
825    # A tensor is non-negative when it satisfies:
826    #   For every element x_i in x, x_i >= 0
827    # and an empty tensor has no elements, so this is trivially satisfied.
828    # This is standard set theory.
829    with self.test_session():
830      empty = constant_op.constant([], name="empty")
831      with ops.control_dependencies([check_ops.assert_non_negative(empty)]):
832        out = array_ops.identity(empty)
833      out.eval()
834
835
836class AssertNonPositiveTest(test.TestCase):
837
838  def test_doesnt_raise_when_zero_and_negative(self):
839    with self.test_session():
840      tom = constant_op.constant([0, -2], name="tom")
841      with ops.control_dependencies([check_ops.assert_non_positive(tom)]):
842        out = array_ops.identity(tom)
843      out.eval()
844
845  def test_raises_when_positive(self):
846    with self.test_session():
847      rachel = constant_op.constant([0, 2], name="rachel")
848      with ops.control_dependencies([check_ops.assert_non_positive(rachel)]):
849        out = array_ops.identity(rachel)
850      with self.assertRaisesOpError("rachel"):
851        out.eval()
852
853  def test_empty_tensor_doesnt_raise(self):
854    # A tensor is non-positive when it satisfies:
855    #   For every element x_i in x, x_i <= 0
856    # and an empty tensor has no elements, so this is trivially satisfied.
857    # This is standard set theory.
858    with self.test_session():
859      empty = constant_op.constant([], name="empty")
860      with ops.control_dependencies([check_ops.assert_non_positive(empty)]):
861        out = array_ops.identity(empty)
862      out.eval()
863
864
865class AssertIntegerTest(test.TestCase):
866
867  def test_doesnt_raise_when_integer(self):
868    with self.test_session():
869      integers = constant_op.constant([1, 2], name="integers")
870      with ops.control_dependencies([check_ops.assert_integer(integers)]):
871        out = array_ops.identity(integers)
872      out.eval()
873
874  def test_raises_when_float(self):
875    with self.test_session():
876      floats = constant_op.constant([1.0, 2.0], name="floats")
877      with self.assertRaisesRegexp(TypeError, "Expected.*integer"):
878        check_ops.assert_integer(floats)
879
880
881class IsStrictlyIncreasingTest(test.TestCase):
882
883  def test_constant_tensor_is_not_strictly_increasing(self):
884    with self.test_session():
885      self.assertFalse(check_ops.is_strictly_increasing([1, 1, 1]).eval())
886
887  def test_decreasing_tensor_is_not_strictly_increasing(self):
888    with self.test_session():
889      self.assertFalse(check_ops.is_strictly_increasing([1, 0, -1]).eval())
890
891  def test_2d_decreasing_tensor_is_not_strictly_increasing(self):
892    with self.test_session():
893      self.assertFalse(
894          check_ops.is_strictly_increasing([[1, 3], [2, 4]]).eval())
895
896  def test_increasing_tensor_is_increasing(self):
897    with self.test_session():
898      self.assertTrue(check_ops.is_strictly_increasing([1, 2, 3]).eval())
899
900  def test_increasing_rank_two_tensor(self):
901    with self.test_session():
902      self.assertTrue(
903          check_ops.is_strictly_increasing([[-1, 2], [3, 4]]).eval())
904
905  def test_tensor_with_one_element_is_strictly_increasing(self):
906    with self.test_session():
907      self.assertTrue(check_ops.is_strictly_increasing([1]).eval())
908
909  def test_empty_tensor_is_strictly_increasing(self):
910    with self.test_session():
911      self.assertTrue(check_ops.is_strictly_increasing([]).eval())
912
913
914class IsNonDecreasingTest(test.TestCase):
915
916  def test_constant_tensor_is_non_decreasing(self):
917    with self.test_session():
918      self.assertTrue(check_ops.is_non_decreasing([1, 1, 1]).eval())
919
920  def test_decreasing_tensor_is_not_non_decreasing(self):
921    with self.test_session():
922      self.assertFalse(check_ops.is_non_decreasing([3, 2, 1]).eval())
923
924  def test_2d_decreasing_tensor_is_not_non_decreasing(self):
925    with self.test_session():
926      self.assertFalse(check_ops.is_non_decreasing([[1, 3], [2, 4]]).eval())
927
928  def test_increasing_rank_one_tensor_is_non_decreasing(self):
929    with self.test_session():
930      self.assertTrue(check_ops.is_non_decreasing([1, 2, 3]).eval())
931
932  def test_increasing_rank_two_tensor(self):
933    with self.test_session():
934      self.assertTrue(check_ops.is_non_decreasing([[-1, 2], [3, 3]]).eval())
935
936  def test_tensor_with_one_element_is_non_decreasing(self):
937    with self.test_session():
938      self.assertTrue(check_ops.is_non_decreasing([1]).eval())
939
940  def test_empty_tensor_is_non_decreasing(self):
941    with self.test_session():
942      self.assertTrue(check_ops.is_non_decreasing([]).eval())
943
944
945if __name__ == "__main__":
946  test.main()
947