1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16package org.tensorflow;
17
18import static java.nio.charset.StandardCharsets.UTF_8;
19import static org.junit.Assert.assertArrayEquals;
20import static org.junit.Assert.assertEquals;
21import static org.junit.Assert.assertTrue;
22import static org.junit.Assert.fail;
23
24import java.nio.ByteBuffer;
25import java.nio.ByteOrder;
26import java.nio.DoubleBuffer;
27import java.nio.FloatBuffer;
28import java.nio.IntBuffer;
29import java.nio.LongBuffer;
30import org.junit.Test;
31import org.junit.runner.RunWith;
32import org.junit.runners.JUnit4;
33import org.tensorflow.types.UInt8;
34
35/** Unit tests for {@link org.tensorflow.Tensor}. */
36@RunWith(JUnit4.class)
37public class TensorTest {
38  private static final double EPSILON = 1e-7;
39  private static final float EPSILON_F = 1e-7f;
40
41  @Test
42  public void createWithByteBuffer() {
43    double[] doubles = {1d, 2d, 3d, 4d};
44    long[] doubles_shape = {4};
45    boolean[] bools = {true, false, true, false};
46    long[] bools_shape = {4};
47    byte[] bools_ = TestUtil.bool2byte(bools);
48    byte[] strings = "test".getBytes(UTF_8);
49    long[] strings_shape = {};
50    byte[] strings_; // raw TF_STRING
51    try (Tensor<String> t = Tensors.create(strings)) {
52      ByteBuffer to = ByteBuffer.allocate(t.numBytes());
53      t.writeTo(to);
54      strings_ = to.array();
55    }
56
57    // validate creating a tensor using a byte buffer
58    {
59      try (Tensor<Boolean> t = Tensor.create(Boolean.class, bools_shape, ByteBuffer.wrap(bools_))) {
60        boolean[] actual = t.copyTo(new boolean[bools_.length]);
61        for (int i = 0; i < bools.length; ++i) {
62          assertEquals("" + i, bools[i], actual[i]);
63        }
64      }
65
66      // note: the buffer is expected to contain raw TF_STRING (as per C API)
67      try (Tensor<String> t =
68          Tensor.create(String.class, strings_shape, ByteBuffer.wrap(strings_))) {
69        assertArrayEquals(strings, t.bytesValue());
70      }
71    }
72
73    // validate creating a tensor using a direct byte buffer (in host order)
74    {
75      ByteBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder());
76      buf.asDoubleBuffer().put(doubles);
77      try (Tensor<Double> t = Tensor.create(Double.class, doubles_shape, buf)) {
78        double[] actual = new double[doubles.length];
79        assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
80      }
81    }
82
83    // validate shape checking
84    try (Tensor<Boolean> t =
85        Tensor.create(Boolean.class, new long[bools_.length * 2], ByteBuffer.wrap(bools_))) {
86      fail("should have failed on incompatible buffer");
87    } catch (IllegalArgumentException e) {
88      // expected
89    }
90  }
91
92  @Test
93  public void createFromBufferWithNonNativeByteOrder() {
94    double[] doubles = {1d, 2d, 3d, 4d};
95    DoubleBuffer buf =
96        ByteBuffer.allocate(8 * doubles.length)
97            .order(
98                ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN
99                    ? ByteOrder.BIG_ENDIAN
100                    : ByteOrder.LITTLE_ENDIAN)
101            .asDoubleBuffer()
102            .put(doubles);
103    buf.flip();
104    try (Tensor<Double> t = Tensor.create(new long[] {doubles.length}, buf)) {
105      double[] actual = new double[doubles.length];
106      assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
107    }
108  }
109
110  @Test
111  public void createWithTypedBuffer() {
112    int[] ints = {1, 2, 3, 4};
113    float[] floats = {1f, 2f, 3f, 4f};
114    double[] doubles = {1d, 2d, 3d, 4d};
115    long[] longs = {1L, 2L, 3L, 4L};
116    long[] shape = {4};
117
118    // validate creating a tensor using a typed buffer
119    {
120      try (Tensor<Double> t = Tensor.create(shape, DoubleBuffer.wrap(doubles))) {
121        double[] actual = new double[doubles.length];
122        assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
123      }
124      try (Tensor<Float> t = Tensor.create(shape, FloatBuffer.wrap(floats))) {
125        float[] actual = new float[floats.length];
126        assertArrayEquals(floats, t.copyTo(actual), EPSILON_F);
127      }
128      try (Tensor<Integer> t = Tensor.create(shape, IntBuffer.wrap(ints))) {
129        int[] actual = new int[ints.length];
130        assertArrayEquals(ints, t.copyTo(actual));
131      }
132      try (Tensor<Long> t = Tensor.create(shape, LongBuffer.wrap(longs))) {
133        long[] actual = new long[longs.length];
134        assertArrayEquals(longs, t.copyTo(actual));
135      }
136    }
137
138    // validate shape-checking
139    {
140      try (Tensor<Double> t =
141          Tensor.create(new long[doubles.length + 1], DoubleBuffer.wrap(doubles))) {
142        fail("should have failed on incompatible buffer");
143      } catch (IllegalArgumentException e) {
144        // expected
145      }
146      try (Tensor<Float> t = Tensor.create(new long[floats.length + 1], FloatBuffer.wrap(floats))) {
147        fail("should have failed on incompatible buffer");
148      } catch (IllegalArgumentException e) {
149        // expected
150      }
151      try (Tensor<Integer> t = Tensor.create(new long[ints.length + 1], IntBuffer.wrap(ints))) {
152        fail("should have failed on incompatible buffer");
153      } catch (IllegalArgumentException e) {
154        // expected
155      }
156      try (Tensor<Long> t = Tensor.create(new long[longs.length + 1], LongBuffer.wrap(longs))) {
157        fail("should have failed on incompatible buffer");
158      } catch (IllegalArgumentException e) {
159        // expected
160      }
161    }
162  }
163
164  @Test
165  public void writeTo() {
166    int[] ints = {1, 2, 3};
167    float[] floats = {1f, 2f, 3f};
168    double[] doubles = {1d, 2d, 3d};
169    long[] longs = {1L, 2L, 3L};
170    boolean[] bools = {true, false, true};
171
172    try (Tensor<Integer> tints = Tensors.create(ints);
173        Tensor<Float> tfloats = Tensors.create(floats);
174        Tensor<Double> tdoubles = Tensors.create(doubles);
175        Tensor<Long> tlongs = Tensors.create(longs);
176        Tensor<Boolean> tbools = Tensors.create(bools)) {
177
178      // validate that any datatype is readable with ByteBuffer (content, position)
179      {
180        ByteBuffer bbuf = ByteBuffer.allocate(1024).order(ByteOrder.nativeOrder());
181
182        bbuf.clear(); // FLOAT
183        tfloats.writeTo(bbuf);
184        assertEquals(tfloats.numBytes(), bbuf.position());
185        bbuf.flip();
186        assertEquals(floats[0], bbuf.asFloatBuffer().get(0), EPSILON);
187        bbuf.clear(); // DOUBLE
188        tdoubles.writeTo(bbuf);
189        assertEquals(tdoubles.numBytes(), bbuf.position());
190        bbuf.flip();
191        assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON);
192        bbuf.clear(); // INT32
193        tints.writeTo(bbuf);
194        assertEquals(tints.numBytes(), bbuf.position());
195        bbuf.flip();
196        assertEquals(ints[0], bbuf.asIntBuffer().get(0));
197        bbuf.clear(); // INT64
198        tlongs.writeTo(bbuf);
199        assertEquals(tlongs.numBytes(), bbuf.position());
200        bbuf.flip();
201        assertEquals(longs[0], bbuf.asLongBuffer().get(0));
202        bbuf.clear(); // BOOL
203        tbools.writeTo(bbuf);
204        assertEquals(tbools.numBytes(), bbuf.position());
205        bbuf.flip();
206        assertEquals(bools[0], bbuf.get(0) != 0);
207      }
208
209      // validate the use of direct buffers
210      {
211        DoubleBuffer buf =
212            ByteBuffer.allocateDirect(tdoubles.numBytes())
213                .order(ByteOrder.nativeOrder())
214                .asDoubleBuffer();
215        tdoubles.writeTo(buf);
216        assertTrue(buf.isDirect());
217        assertEquals(tdoubles.numElements(), buf.position());
218        assertEquals(doubles[0], buf.get(0), EPSILON);
219      }
220
221      // validate typed buffers (content, position)
222      {
223        FloatBuffer buf = FloatBuffer.allocate(tfloats.numElements());
224        tfloats.writeTo(buf);
225        assertEquals(tfloats.numElements(), buf.position());
226        assertEquals(floats[0], buf.get(0), EPSILON);
227      }
228      {
229        DoubleBuffer buf = DoubleBuffer.allocate(tdoubles.numElements());
230        tdoubles.writeTo(buf);
231        assertEquals(tdoubles.numElements(), buf.position());
232        assertEquals(doubles[0], buf.get(0), EPSILON);
233      }
234      {
235        IntBuffer buf = IntBuffer.allocate(tints.numElements());
236        tints.writeTo(buf);
237        assertEquals(tints.numElements(), buf.position());
238        assertEquals(ints[0], buf.get(0));
239      }
240      {
241        LongBuffer buf = LongBuffer.allocate(tlongs.numElements());
242        tlongs.writeTo(buf);
243        assertEquals(tlongs.numElements(), buf.position());
244        assertEquals(longs[0], buf.get(0));
245      }
246
247      // validate byte order conversion
248      {
249        DoubleBuffer foreignBuf =
250            ByteBuffer.allocate(tdoubles.numBytes())
251                .order(
252                    ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN
253                        ? ByteOrder.BIG_ENDIAN
254                        : ByteOrder.LITTLE_ENDIAN)
255                .asDoubleBuffer();
256        tdoubles.writeTo(foreignBuf);
257        foreignBuf.flip();
258        double[] actual = new double[foreignBuf.remaining()];
259        foreignBuf.get(actual);
260        assertArrayEquals(doubles, actual, EPSILON);
261      }
262
263      // validate that incompatible buffers are rejected
264      {
265        IntBuffer badbuf1 = IntBuffer.allocate(128);
266        try {
267          tbools.writeTo(badbuf1);
268          fail("should have failed on incompatible buffer");
269        } catch (IllegalArgumentException e) {
270          // expected
271        }
272        FloatBuffer badbuf2 = FloatBuffer.allocate(128);
273        try {
274          tbools.writeTo(badbuf2);
275          fail("should have failed on incompatible buffer");
276        } catch (IllegalArgumentException e) {
277          // expected
278        }
279        DoubleBuffer badbuf3 = DoubleBuffer.allocate(128);
280        try {
281          tbools.writeTo(badbuf3);
282          fail("should have failed on incompatible buffer");
283        } catch (IllegalArgumentException e) {
284          // expected
285        }
286        LongBuffer badbuf4 = LongBuffer.allocate(128);
287        try {
288          tbools.writeTo(badbuf4);
289          fail("should have failed on incompatible buffer");
290        } catch (IllegalArgumentException e) {
291          // expected
292        }
293      }
294    }
295  }
296
297  @Test
298  public void scalars() {
299    try (Tensor<Float> t = Tensors.create(2.718f)) {
300      assertEquals(DataType.FLOAT, t.dataType());
301      assertEquals(0, t.numDimensions());
302      assertEquals(0, t.shape().length);
303      assertEquals(2.718f, t.floatValue(), EPSILON_F);
304    }
305
306    try (Tensor<Double> t = Tensors.create(3.1415)) {
307      assertEquals(DataType.DOUBLE, t.dataType());
308      assertEquals(0, t.numDimensions());
309      assertEquals(0, t.shape().length);
310      assertEquals(3.1415, t.doubleValue(), EPSILON);
311    }
312
313    try (Tensor<Integer> t = Tensors.create(-33)) {
314      assertEquals(DataType.INT32, t.dataType());
315      assertEquals(0, t.numDimensions());
316      assertEquals(0, t.shape().length);
317      assertEquals(-33, t.intValue());
318    }
319
320    try (Tensor<Long> t = Tensors.create(8589934592L)) {
321      assertEquals(DataType.INT64, t.dataType());
322      assertEquals(0, t.numDimensions());
323      assertEquals(0, t.shape().length);
324      assertEquals(8589934592L, t.longValue());
325    }
326
327    try (Tensor<Boolean> t = Tensors.create(true)) {
328      assertEquals(DataType.BOOL, t.dataType());
329      assertEquals(0, t.numDimensions());
330      assertEquals(0, t.shape().length);
331      assertTrue(t.booleanValue());
332    }
333
334    final byte[] bytes = {1, 2, 3, 4};
335    try (Tensor<String> t = Tensors.create(bytes)) {
336      assertEquals(DataType.STRING, t.dataType());
337      assertEquals(0, t.numDimensions());
338      assertEquals(0, t.shape().length);
339      assertArrayEquals(bytes, t.bytesValue());
340    }
341  }
342
343  @Test
344  public void nDimensional() {
345    double[] vector = {1.414, 2.718, 3.1415};
346    try (Tensor<Double> t = Tensors.create(vector)) {
347      assertEquals(DataType.DOUBLE, t.dataType());
348      assertEquals(1, t.numDimensions());
349      assertArrayEquals(new long[] {3}, t.shape());
350
351      double[] got = new double[3];
352      assertArrayEquals(vector, t.copyTo(got), EPSILON);
353    }
354
355    int[][] matrix = {{1, 2, 3}, {4, 5, 6}};
356    try (Tensor<Integer> t = Tensors.create(matrix)) {
357      assertEquals(DataType.INT32, t.dataType());
358      assertEquals(2, t.numDimensions());
359      assertArrayEquals(new long[] {2, 3}, t.shape());
360
361      int[][] got = new int[2][3];
362      assertArrayEquals(matrix, t.copyTo(got));
363    }
364
365    long[][][] threeD = {
366      {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}},
367    };
368    try (Tensor<Long> t = Tensors.create(threeD)) {
369      assertEquals(DataType.INT64, t.dataType());
370      assertEquals(3, t.numDimensions());
371      assertArrayEquals(new long[] {2, 5, 1}, t.shape());
372
373      long[][][] got = new long[2][5][1];
374      assertArrayEquals(threeD, t.copyTo(got));
375    }
376
377    boolean[][][][] fourD = {
378      {{{false, false, false, true}, {false, false, true, false}}},
379      {{{false, false, true, true}, {false, true, false, false}}},
380      {{{false, true, false, true}, {false, true, true, false}}},
381    };
382    try (Tensor<Boolean> t = Tensors.create(fourD)) {
383      assertEquals(DataType.BOOL, t.dataType());
384      assertEquals(4, t.numDimensions());
385      assertArrayEquals(new long[] {3, 1, 2, 4}, t.shape());
386
387      boolean[][][][] got = new boolean[3][1][2][4];
388      assertArrayEquals(fourD, t.copyTo(got));
389    }
390  }
391
392  @Test
393  public void testNDimensionalStringTensor() {
394    byte[][][] matrix = new byte[4][3][];
395    for (int i = 0; i < 4; ++i) {
396      for (int j = 0; j < 3; ++j) {
397        matrix[i][j] = String.format("(%d, %d) = %d", i, j, i << j).getBytes(UTF_8);
398      }
399    }
400    try (Tensor<String> t = Tensors.create(matrix)) {
401      assertEquals(DataType.STRING, t.dataType());
402      assertEquals(2, t.numDimensions());
403      assertArrayEquals(new long[] {4, 3}, t.shape());
404
405      byte[][][] got = t.copyTo(new byte[4][3][]);
406      assertEquals(4, got.length);
407      for (int i = 0; i < 4; ++i) {
408        assertEquals(String.format("%d", i), 3, got[i].length);
409        for (int j = 0; j < 3; ++j) {
410          assertArrayEquals(String.format("(%d, %d)", i, j), matrix[i][j], got[i][j]);
411        }
412      }
413    }
414  }
415
416  @Test
417  public void testUInt8Tensor() {
418    byte[] vector = new byte[] {1, 2, 3, 4};
419    try (Tensor<UInt8> t = Tensor.create(vector, UInt8.class)) {
420      assertEquals(DataType.UINT8, t.dataType());
421      assertEquals(1, t.numDimensions());
422      assertArrayEquals(new long[] {4}, t.shape());
423
424      byte[] got = t.copyTo(new byte[4]);
425      assertArrayEquals(vector, got);
426    }
427  }
428
429  @Test
430  public void testCreateFromArrayOfBoxed() {
431    Integer[] vector = new Integer[] {1, 2, 3, 4};
432    try (Tensor<Integer> t = Tensor.create(vector, Integer.class)) {
433      fail("Tensor.create() should fail because it was given an array of boxed values");
434    } catch (IllegalArgumentException e) {
435        // The expected exception
436    }
437  }
438
439  @Test
440  public void failCreateOnMismatchedDimensions() {
441    int[][][] invalid = new int[3][1][];
442    for (int x = 0; x < invalid.length; ++x) {
443      for (int y = 0; y < invalid[x].length; ++y) {
444        invalid[x][y] = new int[x + y + 1];
445      }
446    }
447    try (Tensor<?> t = Tensor.create(invalid)) {
448      fail("Tensor.create() should fail because of differing sizes in the 3rd dimension");
449    } catch (IllegalArgumentException e) {
450      // The expected exception.
451    }
452  }
453
454  @Test
455  public void failCopyToOnIncompatibleDestination() {
456    try (final Tensor<Integer> matrix = Tensors.create(new int[][] {{1, 2}, {3, 4}})) {
457      try {
458        matrix.copyTo(new int[2]);
459        fail("should have failed on dimension mismatch");
460      } catch (IllegalArgumentException e) {
461        // The expected exception.
462      }
463
464      try {
465        matrix.copyTo(new float[2][2]);
466        fail("should have failed on DataType mismatch");
467      } catch (IllegalArgumentException e) {
468        // The expected exception.
469      }
470
471      try {
472        matrix.copyTo(new int[2][3]);
473        fail("should have failed on shape mismatch");
474      } catch (IllegalArgumentException e) {
475        // The expected exception.
476      }
477    }
478  }
479
480  @Test
481  public void failCopyToOnScalar() {
482    try (final Tensor<Integer> scalar = Tensors.create(3)) {
483      try {
484        scalar.copyTo(3);
485        fail("copyTo should fail on scalar tensors, suggesting use of primitive accessors instead");
486      } catch (IllegalArgumentException e) {
487        // The expected exception.
488      }
489    }
490  }
491
492  @Test
493  public void failOnArbitraryObject() {
494    try (Tensor<?> t = Tensor.create(new Object())) {
495      fail("should fail on creating a Tensor with a Java object that has no equivalent DataType");
496    } catch (IllegalArgumentException e) {
497      // The expected exception.
498    }
499  }
500
501  @Test
502  public void failOnZeroDimension() {
503    try (Tensor<Integer> t = Tensors.create(new int[3][0][1])) {
504      fail("should fail on creating a Tensor where one of the dimensions is 0");
505    } catch (IllegalArgumentException e) {
506      // The expected exception.
507    }
508  }
509
510  @Test
511  public void useAfterClose() {
512    int n = 4;
513    Tensor<?> t = Tensor.create(n);
514    t.close();
515    try {
516      t.intValue();
517    } catch (NullPointerException e) {
518      // The expected exception.
519    }
520  }
521
522  @Test
523  public void fromHandle() {
524    // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been
525    // created independently of the Java code. In practice, two Tensor instances MUST NOT have the
526    // same native handle.
527    //
528    // An exception is made for this test, where the pitfalls of this is avoided by not calling
529    // close() on both Tensors.
530    final float[][] matrix = {{1, 2, 3}, {4, 5, 6}};
531    try (Tensor<Float> src = Tensors.create(matrix)) {
532      Tensor<Float> cpy = Tensor.fromHandle(src.getNativeHandle()).expect(Float.class);
533      assertEquals(src.dataType(), cpy.dataType());
534      assertEquals(src.numDimensions(), cpy.numDimensions());
535      assertArrayEquals(src.shape(), cpy.shape());
536      assertArrayEquals(matrix, cpy.copyTo(new float[2][3]));
537    }
538  }
539}
540