1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16package org.tensorflow; 17 18import static org.junit.Assert.assertArrayEquals; 19import static org.junit.Assert.assertEquals; 20import static org.junit.Assert.assertTrue; 21import static org.junit.Assert.fail; 22 23import java.util.ArrayList; 24import java.util.Collection; 25import org.junit.Test; 26import org.junit.runner.RunWith; 27import org.junit.runners.JUnit4; 28 29/** Unit tests for {@link org.tensorflow.Session}. */ 30@RunWith(JUnit4.class) 31public class SessionTest { 32 33 @Test 34 public void runUsingOperationNames() { 35 try (Graph g = new Graph(); 36 Session s = new Session(g)) { 37 TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); 38 try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); 39 AutoCloseableList<Tensor<?>> outputs = 40 new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) { 41 assertEquals(1, outputs.size()); 42 final int[][] expected = {{31}}; 43 assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); 44 } 45 } 46 } 47 48 @Test 49 public void runUsingOperationHandles() { 50 try (Graph g = new Graph(); 51 Session s = new Session(g)) { 52 TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); 53 Output<Integer> feed = g.operation("X").output(0); 54 Output<Integer> fetch = g.operation("Y").output(0); 55 try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); 56 AutoCloseableList<Tensor<?>> outputs = 57 new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) { 58 assertEquals(1, outputs.size()); 59 final int[][] expected = {{31}}; 60 assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); 61 } 62 } 63 } 64 65 @Test 66 public void runUsingColonSeparatedNames() { 67 try (Graph g = new Graph(); 68 Session s = new Session(g)) { 69 Operation split = 70 g.opBuilder("Split", "Split") 71 .addInput(TestUtil.constant(g, "split_dim", 0)) 72 .addInput(TestUtil.constant(g, "value", new int[] {1, 2, 3, 4})) 73 .setAttr("num_split", 2) 74 .build(); 75 g.opBuilder("Add", "Add") 76 .addInput(split.output(0)) 77 .addInput(split.output(1)) 78 .build() 79 .output(0); 80 // Fetch using colon separated names. 81 try (Tensor<Integer> fetched = 82 s.runner().fetch("Split:1").run().get(0).expect(Integer.class)) { 83 final int[] expected = {3, 4}; 84 assertArrayEquals(expected, fetched.copyTo(new int[2])); 85 } 86 // Feed using colon separated names. 87 try (Tensor<Integer> fed = Tensors.create(new int[] {4, 3, 2, 1}); 88 Tensor<Integer> fetched = 89 s.runner() 90 .feed("Split:0", fed) 91 .feed("Split:1", fed) 92 .fetch("Add") 93 .run() 94 .get(0) 95 .expect(Integer.class)) { 96 final int[] expected = {8, 6, 4, 2}; 97 assertArrayEquals(expected, fetched.copyTo(new int[4])); 98 } 99 } 100 } 101 102 @Test 103 public void runWithMetadata() { 104 try (Graph g = new Graph(); 105 Session s = new Session(g)) { 106 TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); 107 try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}})) { 108 Session.Run result = 109 s.runner() 110 .feed("X", x) 111 .fetch("Y") 112 .setOptions(fullTraceRunOptions()) 113 .runAndFetchMetadata(); 114 // Sanity check on outputs. 115 AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs); 116 assertEquals(1, outputs.size()); 117 final int[][] expected = {{31}}; 118 assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); 119 // Sanity check on metadata 120 // See comments in fullTraceRunOptions() for an explanation about 121 // why this check is really silly. Ideally, this would be: 122 /* 123 RunMetadata md = RunMetadata.parseFrom(result.metadata); 124 assertTrue(md.toString(), md.hasStepStats()); 125 */ 126 assertTrue(result.metadata.length > 0); 127 outputs.close(); 128 } 129 } 130 } 131 132 @Test 133 public void runMultipleOutputs() { 134 try (Graph g = new Graph(); 135 Session s = new Session(g)) { 136 TestUtil.constant(g, "c1", 2718); 137 TestUtil.constant(g, "c2", 31415); 138 AutoCloseableList<Tensor<?>> outputs = 139 new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run()); 140 assertEquals(2, outputs.size()); 141 assertEquals(31415, outputs.get(0).intValue()); 142 assertEquals(2718, outputs.get(1).intValue()); 143 outputs.close(); 144 } 145 } 146 147 @Test 148 public void failOnUseAfterClose() { 149 try (Graph g = new Graph()) { 150 Session s = new Session(g); 151 s.close(); 152 try { 153 s.runner().run(); 154 fail("methods on a session should fail after close() is called"); 155 } catch (IllegalStateException e) { 156 // expected exception 157 } 158 } 159 } 160 161 @Test 162 public void createWithConfigProto() { 163 try (Graph g = new Graph(); 164 Session s = new Session(g, singleThreadConfigProto())) {} 165 } 166 167 private static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> 168 implements AutoCloseable { 169 AutoCloseableList(Collection<? extends E> c) { 170 super(c); 171 } 172 173 @Override 174 public void close() { 175 Exception toThrow = null; 176 for (AutoCloseable c : this) { 177 try { 178 c.close(); 179 } catch (Exception e) { 180 toThrow = e; 181 } 182 } 183 if (toThrow != null) { 184 throw new RuntimeException(toThrow); 185 } 186 } 187 } 188 189 private static byte[] fullTraceRunOptions() { 190 // Ideally this would use the generated Java sources for protocol buffers 191 // and end up with something like the snippet below. However, generating 192 // the Java files for the .proto files in tensorflow/core:protos_all is 193 // a bit cumbersome in bazel until the proto_library rule is setup. 194 // 195 // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 196 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 197 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 198 // 199 // For this test, for now, the use of specific bytes suffices. 200 return new byte[] {0x08, 0x03}; 201 /* 202 return org.tensorflow.framework.RunOptions.newBuilder() 203 .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) 204 .build() 205 .toByteArray(); 206 */ 207 } 208 209 public static byte[] singleThreadConfigProto() { 210 // Ideally this would use the generated Java sources for protocol buffers 211 // and end up with something like the snippet below. However, generating 212 // the Java files for the .proto files in tensorflow/core:protos_all is 213 // a bit cumbersome in bazel until the proto_library rule is setup. 214 // 215 // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 216 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 217 // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 218 // 219 // For this test, for now, the use of specific bytes suffices. 220 return new byte[] {0x10, 0x01, 0x28, 0x01}; 221 /* 222 return org.tensorflow.framework.ConfigProto.newBuilder() 223 .setInterOpParallelismThreads(1) 224 .setIntraOpParallelismThreads(1) 225 .build() 226 .toByteArray(); 227 */ 228 } 229} 230