1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16package org.tensorflow; 17 18import static org.junit.Assert.assertEquals; 19import static org.junit.Assert.assertTrue; 20import static org.junit.Assert.fail; 21 22import org.junit.Ignore; 23import org.junit.Test; 24import org.junit.runner.RunWith; 25import org.junit.runners.JUnit4; 26 27/** Unit tests for {@link org.tensorflow.OperationBuilder}. */ 28@RunWith(JUnit4.class) 29public class OperationBuilderTest { 30 // TODO(ashankar): Restore this test once the C API gracefully handles mixing graphs and 31 // operations instead of segfaulting. 32 @Test 33 @Ignore 34 public void failWhenMixingOperationsOnDifferentGraphs() { 35 try (Graph g1 = new Graph(); 36 Graph g2 = new Graph()) { 37 Output<Integer> c1 = TestUtil.constant(g1, "C1", 3); 38 Output<Integer> c2 = TestUtil.constant(g2, "C2", 3); 39 TestUtil.addN(g1, c1, c1); 40 try { 41 TestUtil.addN(g2, c1, c2); 42 } catch (Exception e) { 43 fail(e.toString()); 44 } 45 } 46 } 47 48 @Test 49 public void failOnUseAfterBuild() { 50 try (Graph g = new Graph(); 51 Tensor<Integer> t = Tensors.create(1)) { 52 OperationBuilder b = 53 g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); 54 b.build(); 55 try { 56 b.setAttr("dtype", t.dataType()); 57 } catch (IllegalStateException e) { 58 // expected exception. 59 } 60 } 61 } 62 63 @Test 64 public void failOnUseAfterGraphClose() { 65 OperationBuilder b = null; 66 try (Graph g = new Graph(); 67 Tensor<Integer> t = Tensors.create(1)) { 68 b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); 69 } 70 try { 71 b.build(); 72 } catch (IllegalStateException e) { 73 // expected exception. 74 } 75 } 76 77 @Test 78 public void setAttr() { 79 // The effect of setting an attribute may not easily be visible from the other parts of this 80 // package's API. Thus, for now, the test simply executes the various setAttr variants to see 81 // that there are no exceptions. If an attribute is "visible", test for that in a separate test 82 // (like setAttrShape). 83 // 84 // This is a bit of an awkward test since it has to find operations with attributes of specific 85 // types that aren't inferred from the input arguments. 86 try (Graph g = new Graph()) { 87 // dtype, tensor attributes. 88 try (Tensor<Integer> t = Tensors.create(1)) { 89 g.opBuilder("Const", "DataTypeAndTensor") 90 .setAttr("dtype", DataType.INT32) 91 .setAttr("value", t) 92 .build() 93 .output(0); 94 assertTrue(hasNode(g, "DataTypeAndTensor")); 95 } 96 // string, bool attributes. 97 g.opBuilder("Abort", "StringAndBool") 98 .setAttr("error_msg", "SomeErrorMessage") 99 .setAttr("exit_without_error", false) 100 .build(); 101 assertTrue(hasNode(g, "StringAndBool")); 102 // int (TF "int" attributes are 64-bit signed, so a Java long). 103 g.opBuilder("RandomUniform", "Int") 104 .addInput(TestUtil.constant(g, "RandomUniformShape", new int[] {1})) 105 .setAttr("seed", 10) 106 .setAttr("dtype", DataType.FLOAT) 107 .build(); 108 assertTrue(hasNode(g, "Int")); 109 // list(int) 110 g.opBuilder("MaxPool", "IntList") 111 .addInput(TestUtil.constant(g, "MaxPoolInput", new float[2][2][2][2])) 112 .setAttr("ksize", new long[] {1, 1, 1, 1}) 113 .setAttr("strides", new long[] {1, 1, 1, 1}) 114 .setAttr("padding", "SAME") 115 .build(); 116 assertTrue(hasNode(g, "IntList")); 117 // list(float) 118 g.opBuilder("FractionalMaxPool", "FloatList") 119 .addInput(TestUtil.constant(g, "FractionalMaxPoolInput", new float[2][2][2][2])) 120 .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) 121 .build(); 122 assertTrue(hasNode(g, "FloatList")); 123 // Missing tests: float, list(dtype), list(tensor), list(string), list(bool) 124 } 125 } 126 127 @Test 128 public void setAttrShape() { 129 try (Graph g = new Graph()) { 130 Output<?> n = 131 g.opBuilder("Placeholder", "unknown") 132 .setAttr("dtype", DataType.FLOAT) 133 .setAttr("shape", Shape.unknown()) 134 .build() 135 .output(0); 136 assertEquals(-1, n.shape().numDimensions()); 137 assertEquals(DataType.FLOAT, n.dataType()); 138 139 n = g.opBuilder("Placeholder", "batch_of_vectors") 140 .setAttr("dtype", DataType.FLOAT) 141 .setAttr("shape", Shape.make(-1, 784)) 142 .build() 143 .output(0); 144 assertEquals(2, n.shape().numDimensions()); 145 assertEquals(-1, n.shape().size(0)); 146 assertEquals(784, n.shape().size(1)); 147 assertEquals(DataType.FLOAT, n.dataType()); 148 } 149 } 150 151 @Test 152 public void setAttrShapeList() { 153 // Those shapes match tensors ones, so no exception is thrown 154 testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)}); 155 try { 156 // Those shapes do not match tensors ones, exception is thrown 157 testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2, 2)}); 158 fail("Shapes are incompatible and an exception was expected"); 159 } catch (IllegalArgumentException e) { 160 // expected 161 } 162 } 163 164 @Test 165 public void addControlInput() { 166 try (Graph g = new Graph(); 167 Session s = new Session(g); 168 Tensor<Boolean> yes = Tensors.create(true); 169 Tensor<Boolean> no = Tensors.create(false)) { 170 Output<Boolean> placeholder = TestUtil.placeholder(g, "boolean", Boolean.class); 171 Operation check = 172 g.opBuilder("Assert", "assert") 173 .addInput(placeholder) 174 .addInputList(new Output<?>[] {placeholder}) 175 .build(); 176 Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build(); 177 178 // No problems when the Assert check succeeds 179 s.runner().feed(placeholder, yes).addTarget(noop).run(); 180 181 // Exception thrown by the execution of the Assert node 182 try { 183 s.runner().feed(placeholder, no).addTarget(noop).run(); 184 fail("Did not run control operation."); 185 } catch (IllegalArgumentException e) { 186 // expected 187 } 188 } 189 } 190 191 private static void testSetAttrShapeList(Shape[] shapes) { 192 try (Graph g = new Graph(); 193 Session s = new Session(g)) { 194 int[][] matrix = new int[][] {{0, 0}, {0, 0}}; 195 Output<?> queue = 196 g.opBuilder("FIFOQueue", "queue") 197 .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32}) 198 .setAttr("shapes", shapes) 199 .build() 200 .output(0); 201 assertTrue(hasNode(g, "queue")); 202 Output<Integer> c1 = TestUtil.constant(g, "const1", matrix); 203 Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] {matrix, matrix}); 204 Operation enqueue = 205 g.opBuilder("QueueEnqueue", "enqueue") 206 .addInput(queue) 207 .addInputList(new Output<?>[] {c1, c2}) 208 .build(); 209 assertTrue(hasNode(g, "enqueue")); 210 211 s.runner().addTarget(enqueue).run(); 212 } 213 } 214 215 private static boolean hasNode(Graph g, String name) { 216 return g.operation(name) != null; 217 } 218} 219