1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16package org.tensorflow; 17 18import static java.nio.charset.StandardCharsets.UTF_8; 19import static org.junit.Assert.assertArrayEquals; 20import static org.junit.Assert.assertEquals; 21import static org.junit.Assert.assertTrue; 22import static org.junit.Assert.fail; 23 24import java.nio.ByteBuffer; 25import java.nio.ByteOrder; 26import java.nio.DoubleBuffer; 27import java.nio.FloatBuffer; 28import java.nio.IntBuffer; 29import java.nio.LongBuffer; 30import org.junit.Test; 31import org.junit.runner.RunWith; 32import org.junit.runners.JUnit4; 33import org.tensorflow.types.UInt8; 34 35/** Unit tests for {@link org.tensorflow.Tensor}. */ 36@RunWith(JUnit4.class) 37public class TensorTest { 38 private static final double EPSILON = 1e-7; 39 private static final float EPSILON_F = 1e-7f; 40 41 @Test 42 public void createWithByteBuffer() { 43 double[] doubles = {1d, 2d, 3d, 4d}; 44 long[] doubles_shape = {4}; 45 boolean[] bools = {true, false, true, false}; 46 long[] bools_shape = {4}; 47 byte[] bools_ = TestUtil.bool2byte(bools); 48 byte[] strings = "test".getBytes(UTF_8); 49 long[] strings_shape = {}; 50 byte[] strings_; // raw TF_STRING 51 try (Tensor<String> t = Tensors.create(strings)) { 52 ByteBuffer to = ByteBuffer.allocate(t.numBytes()); 53 t.writeTo(to); 54 strings_ = to.array(); 55 } 56 57 // validate creating a tensor using a byte buffer 58 { 59 try (Tensor<Boolean> t = Tensor.create(Boolean.class, bools_shape, ByteBuffer.wrap(bools_))) { 60 boolean[] actual = t.copyTo(new boolean[bools_.length]); 61 for (int i = 0; i < bools.length; ++i) { 62 assertEquals("" + i, bools[i], actual[i]); 63 } 64 } 65 66 // note: the buffer is expected to contain raw TF_STRING (as per C API) 67 try (Tensor<String> t = 68 Tensor.create(String.class, strings_shape, ByteBuffer.wrap(strings_))) { 69 assertArrayEquals(strings, t.bytesValue()); 70 } 71 } 72 73 // validate creating a tensor using a direct byte buffer (in host order) 74 { 75 ByteBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder()); 76 buf.asDoubleBuffer().put(doubles); 77 try (Tensor<Double> t = Tensor.create(Double.class, doubles_shape, buf)) { 78 double[] actual = new double[doubles.length]; 79 assertArrayEquals(doubles, t.copyTo(actual), EPSILON); 80 } 81 } 82 83 // validate shape checking 84 try (Tensor<Boolean> t = 85 Tensor.create(Boolean.class, new long[bools_.length * 2], ByteBuffer.wrap(bools_))) { 86 fail("should have failed on incompatible buffer"); 87 } catch (IllegalArgumentException e) { 88 // expected 89 } 90 } 91 92 @Test 93 public void createFromBufferWithNonNativeByteOrder() { 94 double[] doubles = {1d, 2d, 3d, 4d}; 95 DoubleBuffer buf = 96 ByteBuffer.allocate(8 * doubles.length) 97 .order( 98 ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN 99 ? ByteOrder.BIG_ENDIAN 100 : ByteOrder.LITTLE_ENDIAN) 101 .asDoubleBuffer() 102 .put(doubles); 103 buf.flip(); 104 try (Tensor<Double> t = Tensor.create(new long[] {doubles.length}, buf)) { 105 double[] actual = new double[doubles.length]; 106 assertArrayEquals(doubles, t.copyTo(actual), EPSILON); 107 } 108 } 109 110 @Test 111 public void createWithTypedBuffer() { 112 int[] ints = {1, 2, 3, 4}; 113 float[] floats = {1f, 2f, 3f, 4f}; 114 double[] doubles = {1d, 2d, 3d, 4d}; 115 long[] longs = {1L, 2L, 3L, 4L}; 116 long[] shape = {4}; 117 118 // validate creating a tensor using a typed buffer 119 { 120 try (Tensor<Double> t = Tensor.create(shape, DoubleBuffer.wrap(doubles))) { 121 double[] actual = new double[doubles.length]; 122 assertArrayEquals(doubles, t.copyTo(actual), EPSILON); 123 } 124 try (Tensor<Float> t = Tensor.create(shape, FloatBuffer.wrap(floats))) { 125 float[] actual = new float[floats.length]; 126 assertArrayEquals(floats, t.copyTo(actual), EPSILON_F); 127 } 128 try (Tensor<Integer> t = Tensor.create(shape, IntBuffer.wrap(ints))) { 129 int[] actual = new int[ints.length]; 130 assertArrayEquals(ints, t.copyTo(actual)); 131 } 132 try (Tensor<Long> t = Tensor.create(shape, LongBuffer.wrap(longs))) { 133 long[] actual = new long[longs.length]; 134 assertArrayEquals(longs, t.copyTo(actual)); 135 } 136 } 137 138 // validate shape-checking 139 { 140 try (Tensor<Double> t = 141 Tensor.create(new long[doubles.length + 1], DoubleBuffer.wrap(doubles))) { 142 fail("should have failed on incompatible buffer"); 143 } catch (IllegalArgumentException e) { 144 // expected 145 } 146 try (Tensor<Float> t = Tensor.create(new long[floats.length + 1], FloatBuffer.wrap(floats))) { 147 fail("should have failed on incompatible buffer"); 148 } catch (IllegalArgumentException e) { 149 // expected 150 } 151 try (Tensor<Integer> t = Tensor.create(new long[ints.length + 1], IntBuffer.wrap(ints))) { 152 fail("should have failed on incompatible buffer"); 153 } catch (IllegalArgumentException e) { 154 // expected 155 } 156 try (Tensor<Long> t = Tensor.create(new long[longs.length + 1], LongBuffer.wrap(longs))) { 157 fail("should have failed on incompatible buffer"); 158 } catch (IllegalArgumentException e) { 159 // expected 160 } 161 } 162 } 163 164 @Test 165 public void writeTo() { 166 int[] ints = {1, 2, 3}; 167 float[] floats = {1f, 2f, 3f}; 168 double[] doubles = {1d, 2d, 3d}; 169 long[] longs = {1L, 2L, 3L}; 170 boolean[] bools = {true, false, true}; 171 172 try (Tensor<Integer> tints = Tensors.create(ints); 173 Tensor<Float> tfloats = Tensors.create(floats); 174 Tensor<Double> tdoubles = Tensors.create(doubles); 175 Tensor<Long> tlongs = Tensors.create(longs); 176 Tensor<Boolean> tbools = Tensors.create(bools)) { 177 178 // validate that any datatype is readable with ByteBuffer (content, position) 179 { 180 ByteBuffer bbuf = ByteBuffer.allocate(1024).order(ByteOrder.nativeOrder()); 181 182 bbuf.clear(); // FLOAT 183 tfloats.writeTo(bbuf); 184 assertEquals(tfloats.numBytes(), bbuf.position()); 185 bbuf.flip(); 186 assertEquals(floats[0], bbuf.asFloatBuffer().get(0), EPSILON); 187 bbuf.clear(); // DOUBLE 188 tdoubles.writeTo(bbuf); 189 assertEquals(tdoubles.numBytes(), bbuf.position()); 190 bbuf.flip(); 191 assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON); 192 bbuf.clear(); // INT32 193 tints.writeTo(bbuf); 194 assertEquals(tints.numBytes(), bbuf.position()); 195 bbuf.flip(); 196 assertEquals(ints[0], bbuf.asIntBuffer().get(0)); 197 bbuf.clear(); // INT64 198 tlongs.writeTo(bbuf); 199 assertEquals(tlongs.numBytes(), bbuf.position()); 200 bbuf.flip(); 201 assertEquals(longs[0], bbuf.asLongBuffer().get(0)); 202 bbuf.clear(); // BOOL 203 tbools.writeTo(bbuf); 204 assertEquals(tbools.numBytes(), bbuf.position()); 205 bbuf.flip(); 206 assertEquals(bools[0], bbuf.get(0) != 0); 207 } 208 209 // validate the use of direct buffers 210 { 211 DoubleBuffer buf = 212 ByteBuffer.allocateDirect(tdoubles.numBytes()) 213 .order(ByteOrder.nativeOrder()) 214 .asDoubleBuffer(); 215 tdoubles.writeTo(buf); 216 assertTrue(buf.isDirect()); 217 assertEquals(tdoubles.numElements(), buf.position()); 218 assertEquals(doubles[0], buf.get(0), EPSILON); 219 } 220 221 // validate typed buffers (content, position) 222 { 223 FloatBuffer buf = FloatBuffer.allocate(tfloats.numElements()); 224 tfloats.writeTo(buf); 225 assertEquals(tfloats.numElements(), buf.position()); 226 assertEquals(floats[0], buf.get(0), EPSILON); 227 } 228 { 229 DoubleBuffer buf = DoubleBuffer.allocate(tdoubles.numElements()); 230 tdoubles.writeTo(buf); 231 assertEquals(tdoubles.numElements(), buf.position()); 232 assertEquals(doubles[0], buf.get(0), EPSILON); 233 } 234 { 235 IntBuffer buf = IntBuffer.allocate(tints.numElements()); 236 tints.writeTo(buf); 237 assertEquals(tints.numElements(), buf.position()); 238 assertEquals(ints[0], buf.get(0)); 239 } 240 { 241 LongBuffer buf = LongBuffer.allocate(tlongs.numElements()); 242 tlongs.writeTo(buf); 243 assertEquals(tlongs.numElements(), buf.position()); 244 assertEquals(longs[0], buf.get(0)); 245 } 246 247 // validate byte order conversion 248 { 249 DoubleBuffer foreignBuf = 250 ByteBuffer.allocate(tdoubles.numBytes()) 251 .order( 252 ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN 253 ? ByteOrder.BIG_ENDIAN 254 : ByteOrder.LITTLE_ENDIAN) 255 .asDoubleBuffer(); 256 tdoubles.writeTo(foreignBuf); 257 foreignBuf.flip(); 258 double[] actual = new double[foreignBuf.remaining()]; 259 foreignBuf.get(actual); 260 assertArrayEquals(doubles, actual, EPSILON); 261 } 262 263 // validate that incompatible buffers are rejected 264 { 265 IntBuffer badbuf1 = IntBuffer.allocate(128); 266 try { 267 tbools.writeTo(badbuf1); 268 fail("should have failed on incompatible buffer"); 269 } catch (IllegalArgumentException e) { 270 // expected 271 } 272 FloatBuffer badbuf2 = FloatBuffer.allocate(128); 273 try { 274 tbools.writeTo(badbuf2); 275 fail("should have failed on incompatible buffer"); 276 } catch (IllegalArgumentException e) { 277 // expected 278 } 279 DoubleBuffer badbuf3 = DoubleBuffer.allocate(128); 280 try { 281 tbools.writeTo(badbuf3); 282 fail("should have failed on incompatible buffer"); 283 } catch (IllegalArgumentException e) { 284 // expected 285 } 286 LongBuffer badbuf4 = LongBuffer.allocate(128); 287 try { 288 tbools.writeTo(badbuf4); 289 fail("should have failed on incompatible buffer"); 290 } catch (IllegalArgumentException e) { 291 // expected 292 } 293 } 294 } 295 } 296 297 @Test 298 public void scalars() { 299 try (Tensor<Float> t = Tensors.create(2.718f)) { 300 assertEquals(DataType.FLOAT, t.dataType()); 301 assertEquals(0, t.numDimensions()); 302 assertEquals(0, t.shape().length); 303 assertEquals(2.718f, t.floatValue(), EPSILON_F); 304 } 305 306 try (Tensor<Double> t = Tensors.create(3.1415)) { 307 assertEquals(DataType.DOUBLE, t.dataType()); 308 assertEquals(0, t.numDimensions()); 309 assertEquals(0, t.shape().length); 310 assertEquals(3.1415, t.doubleValue(), EPSILON); 311 } 312 313 try (Tensor<Integer> t = Tensors.create(-33)) { 314 assertEquals(DataType.INT32, t.dataType()); 315 assertEquals(0, t.numDimensions()); 316 assertEquals(0, t.shape().length); 317 assertEquals(-33, t.intValue()); 318 } 319 320 try (Tensor<Long> t = Tensors.create(8589934592L)) { 321 assertEquals(DataType.INT64, t.dataType()); 322 assertEquals(0, t.numDimensions()); 323 assertEquals(0, t.shape().length); 324 assertEquals(8589934592L, t.longValue()); 325 } 326 327 try (Tensor<Boolean> t = Tensors.create(true)) { 328 assertEquals(DataType.BOOL, t.dataType()); 329 assertEquals(0, t.numDimensions()); 330 assertEquals(0, t.shape().length); 331 assertTrue(t.booleanValue()); 332 } 333 334 final byte[] bytes = {1, 2, 3, 4}; 335 try (Tensor<String> t = Tensors.create(bytes)) { 336 assertEquals(DataType.STRING, t.dataType()); 337 assertEquals(0, t.numDimensions()); 338 assertEquals(0, t.shape().length); 339 assertArrayEquals(bytes, t.bytesValue()); 340 } 341 } 342 343 @Test 344 public void nDimensional() { 345 double[] vector = {1.414, 2.718, 3.1415}; 346 try (Tensor<Double> t = Tensors.create(vector)) { 347 assertEquals(DataType.DOUBLE, t.dataType()); 348 assertEquals(1, t.numDimensions()); 349 assertArrayEquals(new long[] {3}, t.shape()); 350 351 double[] got = new double[3]; 352 assertArrayEquals(vector, t.copyTo(got), EPSILON); 353 } 354 355 int[][] matrix = {{1, 2, 3}, {4, 5, 6}}; 356 try (Tensor<Integer> t = Tensors.create(matrix)) { 357 assertEquals(DataType.INT32, t.dataType()); 358 assertEquals(2, t.numDimensions()); 359 assertArrayEquals(new long[] {2, 3}, t.shape()); 360 361 int[][] got = new int[2][3]; 362 assertArrayEquals(matrix, t.copyTo(got)); 363 } 364 365 long[][][] threeD = { 366 {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, 367 }; 368 try (Tensor<Long> t = Tensors.create(threeD)) { 369 assertEquals(DataType.INT64, t.dataType()); 370 assertEquals(3, t.numDimensions()); 371 assertArrayEquals(new long[] {2, 5, 1}, t.shape()); 372 373 long[][][] got = new long[2][5][1]; 374 assertArrayEquals(threeD, t.copyTo(got)); 375 } 376 377 boolean[][][][] fourD = { 378 {{{false, false, false, true}, {false, false, true, false}}}, 379 {{{false, false, true, true}, {false, true, false, false}}}, 380 {{{false, true, false, true}, {false, true, true, false}}}, 381 }; 382 try (Tensor<Boolean> t = Tensors.create(fourD)) { 383 assertEquals(DataType.BOOL, t.dataType()); 384 assertEquals(4, t.numDimensions()); 385 assertArrayEquals(new long[] {3, 1, 2, 4}, t.shape()); 386 387 boolean[][][][] got = new boolean[3][1][2][4]; 388 assertArrayEquals(fourD, t.copyTo(got)); 389 } 390 } 391 392 @Test 393 public void testNDimensionalStringTensor() { 394 byte[][][] matrix = new byte[4][3][]; 395 for (int i = 0; i < 4; ++i) { 396 for (int j = 0; j < 3; ++j) { 397 matrix[i][j] = String.format("(%d, %d) = %d", i, j, i << j).getBytes(UTF_8); 398 } 399 } 400 try (Tensor<String> t = Tensors.create(matrix)) { 401 assertEquals(DataType.STRING, t.dataType()); 402 assertEquals(2, t.numDimensions()); 403 assertArrayEquals(new long[] {4, 3}, t.shape()); 404 405 byte[][][] got = t.copyTo(new byte[4][3][]); 406 assertEquals(4, got.length); 407 for (int i = 0; i < 4; ++i) { 408 assertEquals(String.format("%d", i), 3, got[i].length); 409 for (int j = 0; j < 3; ++j) { 410 assertArrayEquals(String.format("(%d, %d)", i, j), matrix[i][j], got[i][j]); 411 } 412 } 413 } 414 } 415 416 @Test 417 public void testUInt8Tensor() { 418 byte[] vector = new byte[] {1, 2, 3, 4}; 419 try (Tensor<UInt8> t = Tensor.create(vector, UInt8.class)) { 420 assertEquals(DataType.UINT8, t.dataType()); 421 assertEquals(1, t.numDimensions()); 422 assertArrayEquals(new long[] {4}, t.shape()); 423 424 byte[] got = t.copyTo(new byte[4]); 425 assertArrayEquals(vector, got); 426 } 427 } 428 429 @Test 430 public void testCreateFromArrayOfBoxed() { 431 Integer[] vector = new Integer[] {1, 2, 3, 4}; 432 try (Tensor<Integer> t = Tensor.create(vector, Integer.class)) { 433 fail("Tensor.create() should fail because it was given an array of boxed values"); 434 } catch (IllegalArgumentException e) { 435 // The expected exception 436 } 437 } 438 439 @Test 440 public void failCreateOnMismatchedDimensions() { 441 int[][][] invalid = new int[3][1][]; 442 for (int x = 0; x < invalid.length; ++x) { 443 for (int y = 0; y < invalid[x].length; ++y) { 444 invalid[x][y] = new int[x + y + 1]; 445 } 446 } 447 try (Tensor<?> t = Tensor.create(invalid)) { 448 fail("Tensor.create() should fail because of differing sizes in the 3rd dimension"); 449 } catch (IllegalArgumentException e) { 450 // The expected exception. 451 } 452 } 453 454 @Test 455 public void failCopyToOnIncompatibleDestination() { 456 try (final Tensor<Integer> matrix = Tensors.create(new int[][] {{1, 2}, {3, 4}})) { 457 try { 458 matrix.copyTo(new int[2]); 459 fail("should have failed on dimension mismatch"); 460 } catch (IllegalArgumentException e) { 461 // The expected exception. 462 } 463 464 try { 465 matrix.copyTo(new float[2][2]); 466 fail("should have failed on DataType mismatch"); 467 } catch (IllegalArgumentException e) { 468 // The expected exception. 469 } 470 471 try { 472 matrix.copyTo(new int[2][3]); 473 fail("should have failed on shape mismatch"); 474 } catch (IllegalArgumentException e) { 475 // The expected exception. 476 } 477 } 478 } 479 480 @Test 481 public void failCopyToOnScalar() { 482 try (final Tensor<Integer> scalar = Tensors.create(3)) { 483 try { 484 scalar.copyTo(3); 485 fail("copyTo should fail on scalar tensors, suggesting use of primitive accessors instead"); 486 } catch (IllegalArgumentException e) { 487 // The expected exception. 488 } 489 } 490 } 491 492 @Test 493 public void failOnArbitraryObject() { 494 try (Tensor<?> t = Tensor.create(new Object())) { 495 fail("should fail on creating a Tensor with a Java object that has no equivalent DataType"); 496 } catch (IllegalArgumentException e) { 497 // The expected exception. 498 } 499 } 500 501 @Test 502 public void failOnZeroDimension() { 503 try (Tensor<Integer> t = Tensors.create(new int[3][0][1])) { 504 fail("should fail on creating a Tensor where one of the dimensions is 0"); 505 } catch (IllegalArgumentException e) { 506 // The expected exception. 507 } 508 } 509 510 @Test 511 public void useAfterClose() { 512 int n = 4; 513 Tensor<?> t = Tensor.create(n); 514 t.close(); 515 try { 516 t.intValue(); 517 } catch (NullPointerException e) { 518 // The expected exception. 519 } 520 } 521 522 @Test 523 public void fromHandle() { 524 // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been 525 // created independently of the Java code. In practice, two Tensor instances MUST NOT have the 526 // same native handle. 527 // 528 // An exception is made for this test, where the pitfalls of this is avoided by not calling 529 // close() on both Tensors. 530 final float[][] matrix = {{1, 2, 3}, {4, 5, 6}}; 531 try (Tensor<Float> src = Tensors.create(matrix)) { 532 Tensor<Float> cpy = Tensor.fromHandle(src.getNativeHandle()).expect(Float.class); 533 assertEquals(src.dataType(), cpy.dataType()); 534 assertEquals(src.numDimensions(), cpy.numDimensions()); 535 assertArrayEquals(src.shape(), cpy.shape()); 536 assertArrayEquals(matrix, cpy.copyTo(new float[2][3])); 537 } 538 } 539} 540