1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" 17 18#include "llvm/Support/raw_ostream.h" 19#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" 20#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 21 22namespace xla { 23namespace cpu { 24VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, 25 int64 vector_size, 26 llvm::IRBuilder<>* ir_builder, 27 std::string name) 28 : vector_size_(vector_size), 29 primitive_type_(primitive_type), 30 ir_builder_(ir_builder), 31 name_(std::move(name)) { 32 scalar_type_ = llvm_ir::PrimitiveTypeToIrType( 33 primitive_type, ir_builder_->GetInsertBlock()->getModule()); 34 scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); 35 vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); 36 vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); 37} 38 39static string TypeToString(llvm::Type* type) { 40 std::string o; 41 llvm::raw_string_ostream ostream(o); 42 type->print(ostream); 43 return ostream.str(); 44} 45 46void VectorSupportLibrary::AssertCorrectTypes( 47 std::initializer_list<llvm::Value*> values) { 48 for (llvm::Value* v : values) { 49 llvm::Type* type = v->getType(); 50 if (type != scalar_type() && type != vector_type()) { 51 LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or " 52 << TypeToString(vector_type()) << " but got " 53 << TypeToString(type); 54 } 55 } 56} 57 58llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { 59 AssertCorrectTypes({lhs, rhs}); 60 return MulInternal(lhs, rhs); 61} 62 63llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, 64 llvm::Value* rhs) { 65 if (scalar_type_->isFloatingPointTy()) { 66 return ir_builder()->CreateFMul(lhs, rhs, name()); 67 } else { 68 return ir_builder()->CreateMul(lhs, rhs, name()); 69 } 70} 71 72llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { 73 AssertCorrectTypes({lhs, rhs}); 74 return AddInternal(lhs, rhs); 75} 76 77llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) { 78 AssertCorrectTypes({lhs, rhs}); 79 return ir_builder()->CreateFSub(lhs, rhs); 80} 81 82llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { 83 AssertCorrectTypes({lhs, rhs}); 84 if (scalar_type_->isFloatingPointTy()) { 85 return llvm_ir::EmitFloatMax(lhs, rhs, ir_builder_); 86 } else { 87 LOG(FATAL) << "Max for integers is unimplemented"; 88 } 89} 90 91llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) { 92 AssertCorrectTypes({a}); 93 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a}, 94 {a->getType()}, ir_builder()); 95} 96 97llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) { 98 AssertCorrectTypes({lhs, rhs}); 99 if (scalar_type_->isFloatingPointTy()) { 100 return ir_builder()->CreateFDiv(lhs, rhs, name()); 101 } else { 102 LOG(FATAL) << "Division for integers is unimplemented"; 103 } 104} 105 106llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, 107 const llvm::APFloat& low, 108 const llvm::APFloat& high) { 109 AssertCorrectTypes({a}); 110 llvm::Type* type = a->getType(); 111 CHECK(low.compare(high) == llvm::APFloat::cmpLessThan); 112 CHECK(scalar_type_->isFloatingPointTy()); 113 return llvm_ir::EmitFloatMin( 114 llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_), 115 GetConstantFloat(type, high), ir_builder_); 116} 117 118llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs, 119 llvm::Value* rhs) { 120 AssertCorrectTypes({lhs, rhs}); 121 return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name())); 122} 123 124llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs, 125 llvm::Value* rhs) { 126 AssertCorrectTypes({lhs, rhs}); 127 return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name())); 128} 129 130llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs, 131 llvm::Value* rhs) { 132 AssertCorrectTypes({lhs, rhs}); 133 return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name())); 134} 135 136llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) { 137 bool is_vector = llvm::isa<llvm::VectorType>(i1->getType()); 138 llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector); 139 return ir_builder()->CreateBitCast( 140 ir_builder()->CreateSExt(i1, integer_type, name()), 141 is_vector ? vector_type() : scalar_type(), name()); 142} 143 144llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { 145 CHECK(scalar_type()->isFloatingPointTy()); 146 const llvm::DataLayout& data_layout = 147 ir_builder()->GetInsertBlock()->getModule()->getDataLayout(); 148 int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); 149 llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits); 150 if (vector) { 151 return llvm::VectorType::get(scalar_int_type, vector_size()); 152 } else { 153 return scalar_int_type; 154 } 155} 156 157llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) { 158 CHECK_EQ(x->getType(), scalar_type()); 159 return ir_builder()->CreateVectorSplat(vector_size(), x, name()); 160} 161 162llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs, 163 llvm::Value* rhs) { 164 AssertCorrectTypes({lhs, rhs}); 165 llvm::Type* int_type = 166 IntegerTypeForFloatSize(lhs->getType() == vector_type()); 167 return ir_builder()->CreateBitCast( 168 ir_builder()->CreateAnd( 169 ir_builder()->CreateBitCast(lhs, int_type, name()), 170 ir_builder()->CreateBitCast(rhs, int_type, name()), name()), 171 vector_type()); 172} 173 174llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) { 175 AssertCorrectTypes({lhs}); 176 llvm::Type* int_type = 177 IntegerTypeForFloatSize(lhs->getType() == vector_type()); 178 return ir_builder()->CreateBitCast( 179 ir_builder()->CreateNot( 180 ir_builder()->CreateBitCast(lhs, int_type, name()), name()), 181 vector_type()); 182} 183 184llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) { 185 AssertCorrectTypes({lhs, rhs}); 186 llvm::Type* int_type = 187 IntegerTypeForFloatSize(lhs->getType() == vector_type()); 188 return ir_builder()->CreateBitCast( 189 ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()), 190 ir_builder()->CreateBitCast(rhs, int_type, name()), 191 name()), 192 vector_type(), name()); 193} 194 195llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, 196 llvm::Value* rhs) { 197 if (scalar_type_->isFloatingPointTy()) { 198 return ir_builder()->CreateFAdd(lhs, rhs, name()); 199 } else { 200 return ir_builder()->CreateAdd(lhs, rhs, name()); 201 } 202} 203 204llvm::Value* VectorSupportLibrary::ComputeOffsetPointer( 205 llvm::Value* base_pointer, llvm::Value* offset_elements) { 206 if (base_pointer->getType() != scalar_pointer_type()) { 207 base_pointer = ir_builder()->CreateBitCast(base_pointer, 208 scalar_pointer_type(), name()); 209 } 210 return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements}, 211 name()); 212} 213 214llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) { 215 if (pointer->getType() != vector_pointer_type()) { 216 pointer = 217 ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name()); 218 } 219 return ir_builder()->CreateAlignedLoad( 220 pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); 221} 222 223llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { 224 if (pointer->getType() != scalar_pointer_type()) { 225 pointer = 226 ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); 227 } 228 return ir_builder()->CreateAlignedLoad( 229 pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); 230} 231 232void VectorSupportLibrary::StoreVector(llvm::Value* value, 233 llvm::Value* pointer) { 234 AssertCorrectTypes({value}); 235 if (pointer->getType() != vector_pointer_type()) { 236 pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); 237 } 238 ir_builder()->CreateAlignedStore( 239 value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); 240} 241 242void VectorSupportLibrary::StoreScalar(llvm::Value* value, 243 llvm::Value* pointer) { 244 AssertCorrectTypes({value}); 245 if (pointer->getType() != scalar_pointer_type()) { 246 pointer = 247 ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); 248 } 249 ir_builder()->CreateAlignedStore( 250 value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); 251} 252 253llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) { 254 if (pointer->getType() != scalar_pointer_type()) { 255 pointer = 256 ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); 257 } 258 return ir_builder()->CreateVectorSplat( 259 vector_size(), ir_builder()->CreateLoad(pointer), name()); 260} 261 262llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { 263 llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr); 264 for (unsigned i = vector_size(); i != 1; i >>= 1) { 265 // On every iteration, we shuffle half of the remaining lanes to the top 266 // half of shuffle, and add two old and the new vector. 267 268 for (unsigned j = 0; j < vector_size(); ++j) { 269 if (j < (i / 2)) { 270 mask[j] = ir_builder()->getInt32(i / 2 + j); 271 } else { 272 mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty()); 273 } 274 } 275 276 llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector( 277 vector, llvm::UndefValue::get(vector_type()), 278 llvm::ConstantVector::get(mask), ""); 279 vector = Add(vector, half_remaining_lanes); 280 } 281 282 return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0), 283 name()); 284} 285 286llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, 287 llvm::Value* rhs) { 288 CHECK_EQ(lhs->getType(), vector_type()); 289 CHECK_EQ(rhs->getType(), vector_type()); 290 CHECK_EQ(vector_size() % 2, 0); 291 292 llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b; 293 294 // Adding the values shuffled using mask_a and mask_b gives us the 295 // AVX-style horizontal add we want. The masks work as documented 296 // in https://llvm.org/docs/LangRef.html#shufflevector-instruction 297 // 298 // Here are the masks for vector_width() == 8: 299 // 300 // index: |0 |1 |2 | 3 |4 |5 | 6 | 7 301 // --------+--+--+--+---+--+--+---+--- 302 // mask_a: |0 |2 |8 |10 |4 |6 |12 |14 303 // mask_b: |1 |3 |9 |11 |5 |7 |13 |16 304 // 305 // So, as an example, the value at lane 3 of the result vector is 306 // the result of adding lane 10 and lane 11 in the combined lhs++rhs 307 // vector, which are the lanes 2 and 3 in the rhs vector. 308 for (int i = 0; i < vector_size(); i += 2) { 309 int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2); 310 mask_a.push_back(ir_builder()->getInt32(increment + i)); 311 mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); 312 } 313 for (int i = 0; i < vector_size(); i += 2) { 314 int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size(); 315 mask_a.push_back(ir_builder()->getInt32(increment + i)); 316 mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); 317 } 318 319 llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector( 320 lhs, rhs, llvm::ConstantVector::get(mask_a)); 321 llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector( 322 lhs, rhs, llvm::ConstantVector::get(mask_b)); 323 324 return Add(shuffle_0, shuffle_1); 325} 326 327llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) { 328 llvm::SmallVector<llvm::Constant*, 32> mask; 329 for (int i = 0; i < vector_size() / 2; i++) { 330 mask.push_back(ir_builder()->getInt32(i)); 331 } 332 333 return ir_builder()->CreateShuffleVector(vector, 334 llvm::UndefValue::get(vector_type()), 335 llvm::ConstantVector::get(mask)); 336} 337 338llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) { 339 llvm::SmallVector<llvm::Constant*, 32> mask; 340 for (int i = 0; i < vector_size() / 2; i++) { 341 mask.push_back(ir_builder()->getInt32(i + vector_size() / 2)); 342 } 343 344 return ir_builder()->CreateShuffleVector(vector, 345 llvm::UndefValue::get(vector_type()), 346 llvm::ConstantVector::get(mask)); 347} 348 349std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums( 350 std::vector<llvm::Value*> vectors, llvm::Value* init_values) { 351 const int x86_avx_vector_elements = 352 TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size(); 353 if (vector_size() == x86_avx_vector_elements && 354 vectors.size() == x86_avx_vector_elements) { 355 return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values); 356 } 357 358 std::vector<llvm::Value*> result; 359 std::transform(vectors.begin(), vectors.end(), std::back_inserter(result), 360 [this](llvm::Value* vector) { return AddReduce(vector); }); 361 if (init_values) { 362 for (int64 i = 0, e = result.size(); i < e; i++) { 363 result[i] = Add(result[i], ir_builder()->CreateExtractElement( 364 init_values, ir_builder()->getInt32(i))); 365 } 366 } 367 return result; 368} 369 370std::vector<llvm::Value*> 371VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( 372 std::vector<llvm::Value*> vectors, llvm::Value* init_values) { 373 while (vectors.size() != 2) { 374 std::vector<llvm::Value*> new_vectors; 375 for (int i = 0; i < vectors.size(); i += 2) { 376 new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1])); 377 } 378 379 vectors = std::move(new_vectors); 380 } 381 382 llvm::Value* low = 383 AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0])); 384 if (init_values) { 385 low = AddInternal(ExtractLowHalf(init_values), low); 386 } 387 llvm::Value* high = 388 AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1])); 389 if (init_values) { 390 high = AddInternal(ExtractHighHalf(init_values), high); 391 } 392 393 std::vector<llvm::Value*> results; 394 for (int i = 0; i < 8; i++) { 395 llvm::Value* scalar_result = ir_builder()->CreateExtractElement( 396 i < 4 ? low : high, ir_builder()->getInt32(i % 4), name()); 397 results.push_back(scalar_result); 398 } 399 400 return results; 401} 402 403llvm::Value* VectorSupportLibrary::GetZeroVector() { 404 return llvm::Constant::getNullValue(vector_type()); 405} 406 407llvm::Value* VectorSupportLibrary::GetZeroScalar() { 408 return llvm::Constant::getNullValue(scalar_type()); 409} 410 411LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) 412 : ir_builder_(ir_builder) { 413 alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); 414} 415 416llvm::Value* LlvmVariable::Get() const { 417 return ir_builder_->CreateLoad(alloca_); 418} 419 420void LlvmVariable::Set(llvm::Value* new_value) { 421 ir_builder_->CreateStore(new_value, alloca_); 422} 423} // namespace cpu 424} // namespace xla 425