1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16package org.tensorflow;
17
18import static org.junit.Assert.assertArrayEquals;
19import static org.junit.Assert.assertEquals;
20import static org.junit.Assert.assertTrue;
21import static org.junit.Assert.fail;
22
23import java.util.ArrayList;
24import java.util.Collection;
25import org.junit.Test;
26import org.junit.runner.RunWith;
27import org.junit.runners.JUnit4;
28
29/** Unit tests for {@link org.tensorflow.Session}. */
30@RunWith(JUnit4.class)
31public class SessionTest {
32
33  @Test
34  public void runUsingOperationNames() {
35    try (Graph g = new Graph();
36        Session s = new Session(g)) {
37      TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
38      try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
39          AutoCloseableList<Tensor<?>> outputs =
40              new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
41        assertEquals(1, outputs.size());
42        final int[][] expected = {{31}};
43        assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
44      }
45    }
46  }
47
48  @Test
49  public void runUsingOperationHandles() {
50    try (Graph g = new Graph();
51        Session s = new Session(g)) {
52      TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
53      Output<Integer> feed = g.operation("X").output(0);
54      Output<Integer> fetch = g.operation("Y").output(0);
55      try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
56          AutoCloseableList<Tensor<?>> outputs =
57              new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
58        assertEquals(1, outputs.size());
59        final int[][] expected = {{31}};
60        assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
61      }
62    }
63  }
64
65  @Test
66  public void runUsingColonSeparatedNames() {
67    try (Graph g = new Graph();
68        Session s = new Session(g)) {
69      Operation split =
70          g.opBuilder("Split", "Split")
71              .addInput(TestUtil.constant(g, "split_dim", 0))
72              .addInput(TestUtil.constant(g, "value", new int[] {1, 2, 3, 4}))
73              .setAttr("num_split", 2)
74              .build();
75      g.opBuilder("Add", "Add")
76          .addInput(split.output(0))
77          .addInput(split.output(1))
78          .build()
79          .output(0);
80      // Fetch using colon separated names.
81      try (Tensor<Integer> fetched =
82          s.runner().fetch("Split:1").run().get(0).expect(Integer.class)) {
83        final int[] expected = {3, 4};
84        assertArrayEquals(expected, fetched.copyTo(new int[2]));
85      }
86      // Feed using colon separated names.
87      try (Tensor<Integer> fed = Tensors.create(new int[] {4, 3, 2, 1});
88          Tensor<Integer> fetched =
89              s.runner()
90                  .feed("Split:0", fed)
91                  .feed("Split:1", fed)
92                  .fetch("Add")
93                  .run()
94                  .get(0)
95                  .expect(Integer.class)) {
96        final int[] expected = {8, 6, 4, 2};
97        assertArrayEquals(expected, fetched.copyTo(new int[4]));
98      }
99    }
100  }
101
102  @Test
103  public void runWithMetadata() {
104    try (Graph g = new Graph();
105        Session s = new Session(g)) {
106      TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
107      try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}})) {
108        Session.Run result =
109            s.runner()
110                .feed("X", x)
111                .fetch("Y")
112                .setOptions(fullTraceRunOptions())
113                .runAndFetchMetadata();
114        // Sanity check on outputs.
115        AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs);
116        assertEquals(1, outputs.size());
117        final int[][] expected = {{31}};
118        assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
119        // Sanity check on metadata
120        // See comments in fullTraceRunOptions() for an explanation about
121        // why this check is really silly. Ideally, this would be:
122        /*
123            RunMetadata md = RunMetadata.parseFrom(result.metadata);
124            assertTrue(md.toString(), md.hasStepStats());
125        */
126        assertTrue(result.metadata.length > 0);
127        outputs.close();
128      }
129    }
130  }
131
132  @Test
133  public void runMultipleOutputs() {
134    try (Graph g = new Graph();
135        Session s = new Session(g)) {
136      TestUtil.constant(g, "c1", 2718);
137      TestUtil.constant(g, "c2", 31415);
138      AutoCloseableList<Tensor<?>> outputs =
139          new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
140      assertEquals(2, outputs.size());
141      assertEquals(31415, outputs.get(0).intValue());
142      assertEquals(2718, outputs.get(1).intValue());
143      outputs.close();
144    }
145  }
146
147  @Test
148  public void failOnUseAfterClose() {
149    try (Graph g = new Graph()) {
150      Session s = new Session(g);
151      s.close();
152      try {
153        s.runner().run();
154        fail("methods on a session should fail after close() is called");
155      } catch (IllegalStateException e) {
156        // expected exception
157      }
158    }
159  }
160
161  @Test
162  public void createWithConfigProto() {
163    try (Graph g = new Graph();
164        Session s = new Session(g, singleThreadConfigProto())) {}
165  }
166
167  private static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
168      implements AutoCloseable {
169    AutoCloseableList(Collection<? extends E> c) {
170      super(c);
171    }
172
173    @Override
174    public void close() {
175      Exception toThrow = null;
176      for (AutoCloseable c : this) {
177        try {
178          c.close();
179        } catch (Exception e) {
180          toThrow = e;
181        }
182      }
183      if (toThrow != null) {
184        throw new RuntimeException(toThrow);
185      }
186    }
187  }
188
189  private static byte[] fullTraceRunOptions() {
190    // Ideally this would use the generated Java sources for protocol buffers
191    // and end up with something like the snippet below. However, generating
192    // the Java files for the .proto files in tensorflow/core:protos_all is
193    // a bit cumbersome in bazel until the proto_library rule is setup.
194    //
195    // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
196    // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
197    // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
198    //
199    // For this test, for now, the use of specific bytes suffices.
200    return new byte[] {0x08, 0x03};
201    /*
202    return org.tensorflow.framework.RunOptions.newBuilder()
203        .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
204        .build()
205        .toByteArray();
206    */
207  }
208
209  public static byte[] singleThreadConfigProto() {
210    // Ideally this would use the generated Java sources for protocol buffers
211    // and end up with something like the snippet below. However, generating
212    // the Java files for the .proto files in tensorflow/core:protos_all is
213    // a bit cumbersome in bazel until the proto_library rule is setup.
214    //
215    // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
216    // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
217    // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
218    //
219    // For this test, for now, the use of specific bytes suffices.
220    return new byte[] {0x10, 0x01, 0x28, 0x01};
221    /*
222    return org.tensorflow.framework.ConfigProto.newBuilder()
223        .setInterOpParallelismThreads(1)
224        .setIntraOpParallelismThreads(1)
225        .build()
226        .toByteArray();
227     */
228  }
229}
230