tensor_slice_test.cc revision c8eaac926c929e07ac8db69f67803a2223ff2d93
1/* Copyright 2015 Google Inc. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/tensor_slice.h"
17
18#include "tensorflow/core/lib/core/status_test_util.h"
19#include "tensorflow/core/platform/logging.h"
20#include "tensorflow/core/platform/protobuf.h"
21#include "tensorflow/core/platform/test.h"
22
23namespace tensorflow {
24namespace {
25
26// Basic tests
27TEST(TensorSliceTest, Basic) {
28  {
29    // Repeatedly setting FullSlice should work.
30    TensorSlice s(3);
31    EXPECT_EQ("-:-:-", s.DebugString());
32
33    s.SetFullSlice(4);
34    EXPECT_EQ("-:-:-:-", s.DebugString());
35  }
36}
37
38// Testing for serialization and parsing for the string format of slices.
39TEST(TensorSliceTest, Serialization) {
40  // Serialization
41  {
42    TensorSlice s({{0, -1}, {0, 10}, {14, 1}, {0, -1}});
43    EXPECT_EQ("-:0,10:14,1:-", s.DebugString());
44  }
45
46  {
47    TensorSliceProto proto;
48    // Define ptxt outside ASSERT_TRUE call to work around bug in some
49    // versions of gcc that breaks when you use raw string literals
50    // inside macro expansions.
51    //   See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971
52    const char* ptxt = R"PROTO(
53      extent { }
54      extent { start: 0 length: 10 }
55      extent { start: 14 length: 1 }
56      extent { }
57    )PROTO";
58    ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto));
59    TensorSlice s(proto);
60    EXPECT_EQ("-:0,10:14,1:-", s.DebugString());
61  }
62
63  // Parsing
64  {
65    TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5");
66    TensorSliceProto proto;
67    s.AsProto(&proto);
68    EXPECT_EQ(
69        "extent { } "
70        "extent { } "
71        "extent { start: 1 length: 3 } "
72        "extent { start: 4 length: 5 }",
73        proto.ShortDebugString());
74  }
75
76  // Failed parsing
77  {
78    TensorSlice slice;
79    Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice);
80    EXPECT_EQ(
81        "Invalid argument: "
82        "Expected a pair of numbers or '-' but got '4': "
83        "string = -:-:1,3:4:5",
84        s.ToString());
85  }
86  {
87    TensorSlice slice;
88    Status s = TensorSlice::Parse("-:-1,3", &slice);
89    EXPECT_EQ(
90        "Invalid argument: "
91        "Expected non-negative start and positive length but got "
92        "start = -1, length = 3: string = -:-1,3",
93        s.ToString());
94  }
95}
96
97// Testing the slice intersection
98TEST(TensorSliceTest, Intersection) {
99  // "EVERYTHING" intersects with everything
100  {
101    TensorSlice a = TensorSlice::ParseOrDie("-:-");
102    TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4");
103    TensorSlice c;
104    EXPECT_TRUE(a.Intersect(b, &c));
105    EXPECT_EQ("1,2:3,4", c.DebugString());
106  }
107
108  {
109    TensorSlice a = TensorSlice::ParseOrDie("-:-");
110    TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4");
111    TensorSlice c;
112    EXPECT_TRUE(b.Intersect(a, &c));
113    EXPECT_EQ("1,2:3,4", c.DebugString());
114  }
115
116  // Overlap at all dimensions
117  {
118    TensorSlice a = TensorSlice::ParseOrDie("1,5:2,6:3,7:5,10");
119    TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4:9,10:12,1");
120    TensorSlice c;
121    EXPECT_TRUE(a.Intersect(b, &c));
122    EXPECT_EQ("1,2:3,4:9,1:12,1", c.DebugString());
123  }
124
125  // A mixture of everything and non-trivial slices
126  {
127    TensorSlice a = TensorSlice::ParseOrDie("-:1,1");
128    TensorSlice b = TensorSlice::ParseOrDie("-:0,2");
129    TensorSlice c;
130    EXPECT_TRUE(a.Intersect(b, &c));
131    EXPECT_EQ("-:1,1", c.DebugString());
132  }
133
134  // No overlap on dimension 3: "3,1" and "4,5" don't intersect
135  {
136    TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:5,6");
137    TensorSlice b = TensorSlice::ParseOrDie("1,3:4,5:1,6");
138    TensorSlice c;
139    EXPECT_FALSE(a.Intersect(b, &c));
140    EXPECT_EQ("", c.DebugString());
141  }
142  // No intersection when there are different dimensions
143  {
144    TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:-");
145    TensorSlice b = TensorSlice::ParseOrDie("-:-");
146    TensorSlice c;
147    EXPECT_FALSE(a.Intersect(b, &c));
148    EXPECT_EQ("", c.DebugString());
149  }
150}
151
152// Testing applying a slice to a tensor shape
153TEST(TensorSliceTest, SliceTensorShape) {
154  // A proper application
155  {
156    TensorSlice a = TensorSlice::ParseOrDie("1,1:-:4,1:2,6");
157    TensorShape x({2, 4, 5, 8});
158    TensorShape y;
159    EXPECT_OK(a.SliceTensorShape(x, &y));
160    EXPECT_EQ(
161        "dim { size: 1 } "
162        "dim { size: 4 } "
163        "dim { size: 1 } "
164        "dim { size: 6 }",
165        y.DebugString());
166  }
167
168  // An invalid application -- dimension 2 is out of bound
169  {
170    TensorSlice a = TensorSlice::ParseOrDie("1,1:1,4:-:-");
171    TensorShape x({2, 4, 5, 8});
172    TensorShape y;
173    EXPECT_EQ(
174        "Internal: "
175        "Extent in dimension 1 out of bounds: "
176        "shape = dim { size: 2 } "
177        "dim { size: 4 } "
178        "dim { size: 5 } "
179        "dim { size: 8 }, "
180        "slice = 1,1:1,4:-:-",
181        a.SliceTensorShape(x, &y).ToString());
182    EXPECT_EQ("", y.DebugString());
183  }
184}
185
186// Testing the computation of relative slices.
187TEST(TensorSliceTest, ComputeRelative) {
188  // Easy case: base is "everything"
189  {
190    TensorSlice base = TensorSlice::ParseOrDie("-:-:-:-");
191    TensorSlice sub = TensorSlice::ParseOrDie("-:1,2:-:3,4");
192    TensorSlice relative;
193    base.ComputeRelative(sub, &relative);
194    EXPECT_EQ("-:1,2:-:3,4", relative.DebugString());
195  }
196
197  // A slightly more complicated case
198  {
199    TensorSlice base = TensorSlice::ParseOrDie("1,2:3,4:-:5,1");
200    TensorSlice sub = TensorSlice::ParseOrDie("1,1:4,2:3,3:5,1");
201    TensorSlice relative;
202    base.ComputeRelative(sub, &relative);
203    EXPECT_EQ("0,1:1,2:3,3:0,1", relative.DebugString());
204  }
205}
206
207TEST(TensorSliceTest, ExtentLength) {
208  TensorSliceProto proto;
209  // Define ptxt outside ASSERT_TRUE call to work around bug in some
210  // versions of gcc that breaks when you use raw string literals
211  // inside macro expansions.
212  //   See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971
213  const char* ptxt = R"PROTO(
214    extent { }
215    extent { start: 0 length: 10 }
216    extent { start: 14 length: 1 }
217    extent { }
218  )PROTO";
219  ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto));
220  EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(0)));
221  EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(1)));
222  EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(2)));
223  EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(3)));
224  EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(0)));
225  EXPECT_EQ(10, TensorSlice::GetExtentLength(proto.extent(1)));
226  EXPECT_EQ(1, TensorSlice::GetExtentLength(proto.extent(2)));
227  EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(3)));
228}
229
230TEST(TensorSliceTest, Deserialization) {
231  // Serialization of
232  //     extent { length: 5 }
233  //     extent { start: 0 length: 10 }
234  //     extent { start: 14 length: 1 }
235  //     extent { start: 1 }
236  //     extent { }
237  // in proto2 and proto3:
238  const char pb2[] =
239      "\x0A\x02\x10\x05\x0A\x04\x08\x00"
240      "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00";
241  const char pb3[] =
242      "\x0A\x02\x10\x05\x0A\x02"
243      "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00";
244  // (The difference is that in the proto3 version, "start: 0" isn't included
245  // since 0 is start's default value.)
246
247  TensorSliceProto proto2;
248  ASSERT_TRUE(proto2.ParseFromArray(pb2, sizeof(pb2) - 1));
249  TensorSlice ts2(proto2);
250
251  TensorSliceProto proto3;
252  ASSERT_TRUE(proto3.ParseFromArray(pb3, sizeof(pb3) - 1));
253  TensorSlice ts3(proto3);
254
255  // Both serializations should be interpreted the same.
256  EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString());
257  EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString());
258}
259
260}  // namespace
261}  // namespace tensorflow
262