1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16package org.tensorflow; 17 18import java.util.Arrays; 19 20/** The possibly partially known shape of a tensor produced by an operation. */ 21public final class Shape { 22 23 /** Create a Shape representing an unknown number of dimensions. */ 24 public static Shape unknown() { 25 return new Shape(null); 26 } 27 28 /** Create a Shape representing a scalar value. */ 29 public static Shape scalar() { 30 return new Shape(new long[0]); 31 } 32 33 /** 34 * Create a Shape representing an N-dimensional value. 35 * 36 * <p>Creates a Shape representing an N-dimensional value (N being at least 1), with the provided 37 * size for each dimension. A -1 indicates that the size of the corresponding dimension is 38 * unknown. For example: 39 * 40 * <pre>{@code 41 * // A 2-element vector. 42 * Shape vector = Shape.create(2); 43 * 44 * // A 2x3 matrix. 45 * Shape matrix = Shape.create(2, 3); 46 * 47 * // A matrix with 4 columns but an unknown number of rows. 48 * // This is typically used to indicate the shape of tensors that represent 49 * // a variable-sized batch of values. The Shape below might represent a 50 * // variable-sized batch of 4-element vectors. 51 * Shape batch = Shape.create(-1, 4); 52 * }</pre> 53 */ 54 public static Shape make(long firstDimensionSize, long... otherDimensionSizes) { 55 long[] shape = new long[otherDimensionSizes.length + 1]; 56 shape[0] = firstDimensionSize; 57 System.arraycopy(otherDimensionSizes, 0, shape, 1, otherDimensionSizes.length); 58 return new Shape(shape); 59 } 60 61 /** 62 * Number of dimensions represented by this shape. 63 * 64 * @return -1 if the number of dimensions is unknown, 0 if the shape represents a scalar, 1 for a 65 * vector, 2 for a matrix etc. 66 */ 67 public int numDimensions() { 68 return shape == null ? -1 : shape.length; 69 } 70 71 /** 72 * The size of the i-th dimension. 73 * 74 * @return The size of the requested dimension or -1 if it is unknown. 75 */ 76 public long size(int i) { 77 return shape[i]; 78 } 79 80 @Override 81 public int hashCode() { 82 return Arrays.hashCode(shape); 83 } 84 85 @Override 86 public boolean equals(Object obj) { 87 if (this == obj) { 88 return true; 89 } 90 91 if (obj instanceof Shape && Arrays.equals(this.shape, ((Shape) obj).shape)) { 92 return !hasUnknownDimension(); 93 } 94 95 return super.equals(obj); 96 } 97 98 /** Succinct description of the shape meant for debugging. */ 99 @Override 100 public String toString() { 101 if (shape == null) { 102 return "<unknown>"; 103 } 104 return Arrays.toString(shape).replace("-1", "?"); 105 } 106 107 // Package-private constructor. 108 Shape(long[] shape) { 109 this.shape = shape; 110 } 111 112 // Package-private accessor. 113 // The idea is that the public API does not expose the internal array. 114 long[] asArray() { 115 return shape; 116 } 117 118 private long[] shape; 119 120 private boolean hasUnknownDimension() { 121 if (shape == null) { 122 return true; 123 } 124 125 for (long dimension : shape) { 126 if (dimension == -1) { 127 return true; 128 } 129 } 130 131 return false; 132 } 133} 134