tensor_slice_test.cc revision c8eaac926c929e07ac8db69f67803a2223ff2d93
1/* Copyright 2015 Google Inc. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/tensor_slice.h" 17 18#include "tensorflow/core/lib/core/status_test_util.h" 19#include "tensorflow/core/platform/logging.h" 20#include "tensorflow/core/platform/protobuf.h" 21#include "tensorflow/core/platform/test.h" 22 23namespace tensorflow { 24namespace { 25 26// Basic tests 27TEST(TensorSliceTest, Basic) { 28 { 29 // Repeatedly setting FullSlice should work. 30 TensorSlice s(3); 31 EXPECT_EQ("-:-:-", s.DebugString()); 32 33 s.SetFullSlice(4); 34 EXPECT_EQ("-:-:-:-", s.DebugString()); 35 } 36} 37 38// Testing for serialization and parsing for the string format of slices. 39TEST(TensorSliceTest, Serialization) { 40 // Serialization 41 { 42 TensorSlice s({{0, -1}, {0, 10}, {14, 1}, {0, -1}}); 43 EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); 44 } 45 46 { 47 TensorSliceProto proto; 48 // Define ptxt outside ASSERT_TRUE call to work around bug in some 49 // versions of gcc that breaks when you use raw string literals 50 // inside macro expansions. 51 // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 52 const char* ptxt = R"PROTO( 53 extent { } 54 extent { start: 0 length: 10 } 55 extent { start: 14 length: 1 } 56 extent { } 57 )PROTO"; 58 ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); 59 TensorSlice s(proto); 60 EXPECT_EQ("-:0,10:14,1:-", s.DebugString()); 61 } 62 63 // Parsing 64 { 65 TensorSlice s = TensorSlice::ParseOrDie("-:-:1,3:4,5"); 66 TensorSliceProto proto; 67 s.AsProto(&proto); 68 EXPECT_EQ( 69 "extent { } " 70 "extent { } " 71 "extent { start: 1 length: 3 } " 72 "extent { start: 4 length: 5 }", 73 proto.ShortDebugString()); 74 } 75 76 // Failed parsing 77 { 78 TensorSlice slice; 79 Status s = TensorSlice::Parse("-:-:1,3:4:5", &slice); 80 EXPECT_EQ( 81 "Invalid argument: " 82 "Expected a pair of numbers or '-' but got '4': " 83 "string = -:-:1,3:4:5", 84 s.ToString()); 85 } 86 { 87 TensorSlice slice; 88 Status s = TensorSlice::Parse("-:-1,3", &slice); 89 EXPECT_EQ( 90 "Invalid argument: " 91 "Expected non-negative start and positive length but got " 92 "start = -1, length = 3: string = -:-1,3", 93 s.ToString()); 94 } 95} 96 97// Testing the slice intersection 98TEST(TensorSliceTest, Intersection) { 99 // "EVERYTHING" intersects with everything 100 { 101 TensorSlice a = TensorSlice::ParseOrDie("-:-"); 102 TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); 103 TensorSlice c; 104 EXPECT_TRUE(a.Intersect(b, &c)); 105 EXPECT_EQ("1,2:3,4", c.DebugString()); 106 } 107 108 { 109 TensorSlice a = TensorSlice::ParseOrDie("-:-"); 110 TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4"); 111 TensorSlice c; 112 EXPECT_TRUE(b.Intersect(a, &c)); 113 EXPECT_EQ("1,2:3,4", c.DebugString()); 114 } 115 116 // Overlap at all dimensions 117 { 118 TensorSlice a = TensorSlice::ParseOrDie("1,5:2,6:3,7:5,10"); 119 TensorSlice b = TensorSlice::ParseOrDie("1,2:3,4:9,10:12,1"); 120 TensorSlice c; 121 EXPECT_TRUE(a.Intersect(b, &c)); 122 EXPECT_EQ("1,2:3,4:9,1:12,1", c.DebugString()); 123 } 124 125 // A mixture of everything and non-trivial slices 126 { 127 TensorSlice a = TensorSlice::ParseOrDie("-:1,1"); 128 TensorSlice b = TensorSlice::ParseOrDie("-:0,2"); 129 TensorSlice c; 130 EXPECT_TRUE(a.Intersect(b, &c)); 131 EXPECT_EQ("-:1,1", c.DebugString()); 132 } 133 134 // No overlap on dimension 3: "3,1" and "4,5" don't intersect 135 { 136 TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:5,6"); 137 TensorSlice b = TensorSlice::ParseOrDie("1,3:4,5:1,6"); 138 TensorSlice c; 139 EXPECT_FALSE(a.Intersect(b, &c)); 140 EXPECT_EQ("", c.DebugString()); 141 } 142 // No intersection when there are different dimensions 143 { 144 TensorSlice a = TensorSlice::ParseOrDie("1,2:3,1:-"); 145 TensorSlice b = TensorSlice::ParseOrDie("-:-"); 146 TensorSlice c; 147 EXPECT_FALSE(a.Intersect(b, &c)); 148 EXPECT_EQ("", c.DebugString()); 149 } 150} 151 152// Testing applying a slice to a tensor shape 153TEST(TensorSliceTest, SliceTensorShape) { 154 // A proper application 155 { 156 TensorSlice a = TensorSlice::ParseOrDie("1,1:-:4,1:2,6"); 157 TensorShape x({2, 4, 5, 8}); 158 TensorShape y; 159 EXPECT_OK(a.SliceTensorShape(x, &y)); 160 EXPECT_EQ( 161 "dim { size: 1 } " 162 "dim { size: 4 } " 163 "dim { size: 1 } " 164 "dim { size: 6 }", 165 y.DebugString()); 166 } 167 168 // An invalid application -- dimension 2 is out of bound 169 { 170 TensorSlice a = TensorSlice::ParseOrDie("1,1:1,4:-:-"); 171 TensorShape x({2, 4, 5, 8}); 172 TensorShape y; 173 EXPECT_EQ( 174 "Internal: " 175 "Extent in dimension 1 out of bounds: " 176 "shape = dim { size: 2 } " 177 "dim { size: 4 } " 178 "dim { size: 5 } " 179 "dim { size: 8 }, " 180 "slice = 1,1:1,4:-:-", 181 a.SliceTensorShape(x, &y).ToString()); 182 EXPECT_EQ("", y.DebugString()); 183 } 184} 185 186// Testing the computation of relative slices. 187TEST(TensorSliceTest, ComputeRelative) { 188 // Easy case: base is "everything" 189 { 190 TensorSlice base = TensorSlice::ParseOrDie("-:-:-:-"); 191 TensorSlice sub = TensorSlice::ParseOrDie("-:1,2:-:3,4"); 192 TensorSlice relative; 193 base.ComputeRelative(sub, &relative); 194 EXPECT_EQ("-:1,2:-:3,4", relative.DebugString()); 195 } 196 197 // A slightly more complicated case 198 { 199 TensorSlice base = TensorSlice::ParseOrDie("1,2:3,4:-:5,1"); 200 TensorSlice sub = TensorSlice::ParseOrDie("1,1:4,2:3,3:5,1"); 201 TensorSlice relative; 202 base.ComputeRelative(sub, &relative); 203 EXPECT_EQ("0,1:1,2:3,3:0,1", relative.DebugString()); 204 } 205} 206 207TEST(TensorSliceTest, ExtentLength) { 208 TensorSliceProto proto; 209 // Define ptxt outside ASSERT_TRUE call to work around bug in some 210 // versions of gcc that breaks when you use raw string literals 211 // inside macro expansions. 212 // See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=55971 213 const char* ptxt = R"PROTO( 214 extent { } 215 extent { start: 0 length: 10 } 216 extent { start: 14 length: 1 } 217 extent { } 218 )PROTO"; 219 ASSERT_TRUE(protobuf::TextFormat::ParseFromString(ptxt, &proto)); 220 EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(0))); 221 EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(1))); 222 EXPECT_TRUE(TensorSlice::HasExtentLength(proto.extent(2))); 223 EXPECT_FALSE(TensorSlice::HasExtentLength(proto.extent(3))); 224 EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(0))); 225 EXPECT_EQ(10, TensorSlice::GetExtentLength(proto.extent(1))); 226 EXPECT_EQ(1, TensorSlice::GetExtentLength(proto.extent(2))); 227 EXPECT_EQ(-1, TensorSlice::GetExtentLength(proto.extent(3))); 228} 229 230TEST(TensorSliceTest, Deserialization) { 231 // Serialization of 232 // extent { length: 5 } 233 // extent { start: 0 length: 10 } 234 // extent { start: 14 length: 1 } 235 // extent { start: 1 } 236 // extent { } 237 // in proto2 and proto3: 238 const char pb2[] = 239 "\x0A\x02\x10\x05\x0A\x04\x08\x00" 240 "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; 241 const char pb3[] = 242 "\x0A\x02\x10\x05\x0A\x02" 243 "\x10\x0A\x0A\x04\x08\x0E\x10\x01\x0A\x02\x08\x01\x0A\x00"; 244 // (The difference is that in the proto3 version, "start: 0" isn't included 245 // since 0 is start's default value.) 246 247 TensorSliceProto proto2; 248 ASSERT_TRUE(proto2.ParseFromArray(pb2, sizeof(pb2) - 1)); 249 TensorSlice ts2(proto2); 250 251 TensorSliceProto proto3; 252 ASSERT_TRUE(proto3.ParseFromArray(pb3, sizeof(pb3) - 1)); 253 TensorSlice ts3(proto3); 254 255 // Both serializations should be interpreted the same. 256 EXPECT_EQ("0,5:0,10:14,1:-:-", ts2.DebugString()); 257 EXPECT_EQ("0,5:0,10:14,1:-:-", ts3.DebugString()); 258} 259 260} // namespace 261} // namespace tensorflow 262