1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16package org.tensorflow;
17
18import java.util.Arrays;
19
20/** The possibly partially known shape of a tensor produced by an operation. */
21public final class Shape {
22
23  /** Create a Shape representing an unknown number of dimensions. */
24  public static Shape unknown() {
25    return new Shape(null);
26  }
27
28  /** Create a Shape representing a scalar value. */
29  public static Shape scalar() {
30    return new Shape(new long[0]);
31  }
32
33  /**
34   * Create a Shape representing an N-dimensional value.
35   *
36   * <p>Creates a Shape representing an N-dimensional value (N being at least 1), with the provided
37   * size for each dimension. A -1 indicates that the size of the corresponding dimension is
38   * unknown. For example:
39   *
40   * <pre>{@code
41   * // A 2-element vector.
42   * Shape vector = Shape.create(2);
43   *
44   * // A 2x3 matrix.
45   * Shape matrix = Shape.create(2, 3);
46   *
47   * // A matrix with 4 columns but an unknown number of rows.
48   * // This is typically used to indicate the shape of tensors that represent
49   * // a variable-sized batch of values. The Shape below might represent a
50   * // variable-sized batch of 4-element vectors.
51   * Shape batch = Shape.create(-1, 4);
52   * }</pre>
53   */
54  public static Shape make(long firstDimensionSize, long... otherDimensionSizes) {
55    long[] shape = new long[otherDimensionSizes.length + 1];
56    shape[0] = firstDimensionSize;
57    System.arraycopy(otherDimensionSizes, 0, shape, 1, otherDimensionSizes.length);
58    return new Shape(shape);
59  }
60
61  /**
62   * Number of dimensions represented by this shape.
63   *
64   * @return -1 if the number of dimensions is unknown, 0 if the shape represents a scalar, 1 for a
65   *     vector, 2 for a matrix etc.
66   */
67  public int numDimensions() {
68    return shape == null ? -1 : shape.length;
69  }
70
71  /**
72   * The size of the i-th dimension.
73   *
74   * @return The size of the requested dimension or -1 if it is unknown.
75   */
76  public long size(int i) {
77    return shape[i];
78  }
79
80  @Override
81  public int hashCode() {
82    return Arrays.hashCode(shape);
83  }
84
85  @Override
86  public boolean equals(Object obj) {
87    if (this == obj) {
88      return true;
89    }
90
91    if (obj instanceof Shape && Arrays.equals(this.shape, ((Shape) obj).shape)) {
92      return !hasUnknownDimension();
93    }
94
95    return super.equals(obj);
96  }
97
98  /** Succinct description of the shape meant for debugging. */
99  @Override
100  public String toString() {
101    if (shape == null) {
102      return "<unknown>";
103    }
104    return Arrays.toString(shape).replace("-1", "?");
105  }
106
107  // Package-private constructor.
108  Shape(long[] shape) {
109    this.shape = shape;
110  }
111
112  // Package-private accessor.
113  // The idea is that the public API does not expose the internal array.
114  long[] asArray() {
115    return shape;
116  }
117
118  private long[] shape;
119
120  private boolean hasUnknownDimension() {
121    if (shape == null) {
122      return true;
123    }
124
125    for (long dimension : shape) {
126      if (dimension == -1) {
127        return true;
128      }
129    }
130
131    return false;
132  }
133}
134