1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16package org.tensorflow;
17
18import static org.junit.Assert.assertEquals;
19import static org.junit.Assert.assertTrue;
20import static org.junit.Assert.fail;
21
22import org.junit.Ignore;
23import org.junit.Test;
24import org.junit.runner.RunWith;
25import org.junit.runners.JUnit4;
26
27/** Unit tests for {@link org.tensorflow.OperationBuilder}. */
28@RunWith(JUnit4.class)
29public class OperationBuilderTest {
30  // TODO(ashankar): Restore this test once the C API gracefully handles mixing graphs and
31  // operations instead of segfaulting.
32  @Test
33  @Ignore
34  public void failWhenMixingOperationsOnDifferentGraphs() {
35    try (Graph g1 = new Graph();
36        Graph g2 = new Graph()) {
37      Output<Integer> c1 = TestUtil.constant(g1, "C1", 3);
38      Output<Integer> c2 = TestUtil.constant(g2, "C2", 3);
39      TestUtil.addN(g1, c1, c1);
40      try {
41        TestUtil.addN(g2, c1, c2);
42      } catch (Exception e) {
43        fail(e.toString());
44      }
45    }
46  }
47
48  @Test
49  public void failOnUseAfterBuild() {
50    try (Graph g = new Graph();
51        Tensor<Integer> t = Tensors.create(1)) {
52      OperationBuilder b =
53          g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
54      b.build();
55      try {
56        b.setAttr("dtype", t.dataType());
57      } catch (IllegalStateException e) {
58        // expected exception.
59      }
60    }
61  }
62
63  @Test
64  public void failOnUseAfterGraphClose() {
65    OperationBuilder b = null;
66    try (Graph g = new Graph();
67        Tensor<Integer> t = Tensors.create(1)) {
68      b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t);
69    }
70    try {
71      b.build();
72    } catch (IllegalStateException e) {
73      // expected exception.
74    }
75  }
76
77  @Test
78  public void setAttr() {
79    // The effect of setting an attribute may not easily be visible from the other parts of this
80    // package's API. Thus, for now, the test simply executes the various setAttr variants to see
81    // that there are no exceptions. If an attribute is "visible", test for that in a separate test
82    // (like setAttrShape).
83    //
84    // This is a bit of an awkward test since it has to find operations with attributes of specific
85    // types that aren't inferred from the input arguments.
86    try (Graph g = new Graph()) {
87      // dtype, tensor attributes.
88      try (Tensor<Integer> t = Tensors.create(1)) {
89        g.opBuilder("Const", "DataTypeAndTensor")
90            .setAttr("dtype", DataType.INT32)
91            .setAttr("value", t)
92            .build()
93            .output(0);
94        assertTrue(hasNode(g, "DataTypeAndTensor"));
95      }
96      // string, bool attributes.
97      g.opBuilder("Abort", "StringAndBool")
98          .setAttr("error_msg", "SomeErrorMessage")
99          .setAttr("exit_without_error", false)
100          .build();
101      assertTrue(hasNode(g, "StringAndBool"));
102      // int (TF "int" attributes are 64-bit signed, so a Java long).
103      g.opBuilder("RandomUniform", "Int")
104          .addInput(TestUtil.constant(g, "RandomUniformShape", new int[] {1}))
105          .setAttr("seed", 10)
106          .setAttr("dtype", DataType.FLOAT)
107          .build();
108      assertTrue(hasNode(g, "Int"));
109      // list(int)
110      g.opBuilder("MaxPool", "IntList")
111          .addInput(TestUtil.constant(g, "MaxPoolInput", new float[2][2][2][2]))
112          .setAttr("ksize", new long[] {1, 1, 1, 1})
113          .setAttr("strides", new long[] {1, 1, 1, 1})
114          .setAttr("padding", "SAME")
115          .build();
116      assertTrue(hasNode(g, "IntList"));
117      // list(float)
118      g.opBuilder("FractionalMaxPool", "FloatList")
119          .addInput(TestUtil.constant(g, "FractionalMaxPoolInput", new float[2][2][2][2]))
120          .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f})
121          .build();
122      assertTrue(hasNode(g, "FloatList"));
123      // Missing tests: float, list(dtype), list(tensor), list(string), list(bool)
124    }
125  }
126
127  @Test
128  public void setAttrShape() {
129    try (Graph g = new Graph()) {
130      Output<?> n =
131          g.opBuilder("Placeholder", "unknown")
132              .setAttr("dtype", DataType.FLOAT)
133              .setAttr("shape", Shape.unknown())
134              .build()
135              .output(0);
136      assertEquals(-1, n.shape().numDimensions());
137      assertEquals(DataType.FLOAT, n.dataType());
138
139      n = g.opBuilder("Placeholder", "batch_of_vectors")
140              .setAttr("dtype", DataType.FLOAT)
141              .setAttr("shape", Shape.make(-1, 784))
142              .build()
143              .output(0);
144      assertEquals(2, n.shape().numDimensions());
145      assertEquals(-1, n.shape().size(0));
146      assertEquals(784, n.shape().size(1));
147      assertEquals(DataType.FLOAT, n.dataType());
148    }
149  }
150
151  @Test
152  public void setAttrShapeList() {
153    // Those shapes match tensors ones, so no exception is thrown
154    testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)});
155    try {
156      // Those shapes do not match tensors ones, exception is thrown
157      testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2, 2)});
158      fail("Shapes are incompatible and an exception was expected");
159    } catch (IllegalArgumentException e) {
160      // expected
161    }
162  }
163
164  @Test
165  public void addControlInput() {
166    try (Graph g = new Graph();
167        Session s = new Session(g);
168        Tensor<Boolean> yes = Tensors.create(true);
169        Tensor<Boolean> no = Tensors.create(false)) {
170      Output<Boolean> placeholder = TestUtil.placeholder(g, "boolean", Boolean.class);
171      Operation check =
172          g.opBuilder("Assert", "assert")
173              .addInput(placeholder)
174              .addInputList(new Output<?>[] {placeholder})
175              .build();
176      Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build();
177
178      // No problems when the Assert check succeeds
179      s.runner().feed(placeholder, yes).addTarget(noop).run();
180
181      // Exception thrown by the execution of the Assert node
182      try {
183        s.runner().feed(placeholder, no).addTarget(noop).run();
184        fail("Did not run control operation.");
185      } catch (IllegalArgumentException e) {
186        // expected
187      }
188    }
189  }
190
191  private static void testSetAttrShapeList(Shape[] shapes) {
192    try (Graph g = new Graph();
193        Session s = new Session(g)) {
194      int[][] matrix = new int[][] {{0, 0}, {0, 0}};
195      Output<?> queue =
196          g.opBuilder("FIFOQueue", "queue")
197              .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32})
198              .setAttr("shapes", shapes)
199              .build()
200              .output(0);
201      assertTrue(hasNode(g, "queue"));
202      Output<Integer> c1 = TestUtil.constant(g, "const1", matrix);
203      Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] {matrix, matrix});
204      Operation enqueue =
205          g.opBuilder("QueueEnqueue", "enqueue")
206              .addInput(queue)
207              .addInputList(new Output<?>[] {c1, c2})
208              .build();
209      assertTrue(hasNode(g, "enqueue"));
210
211      s.runner().addTarget(enqueue).run();
212    }
213  }
214
215  private static boolean hasNode(Graph g, String name) {
216    return g.operation(name) != null;
217  }
218}
219