1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" 17 18#include <algorithm> 19#include <memory> 20#include <string> 21#include <vector> 22 23// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" 24#include "llvm/IR/BasicBlock.h" 25#include "llvm/IR/Instructions.h" 26#include "llvm/IR/Intrinsics.h" 27#include "llvm/Transforms/Utils/BasicBlockUtils.h" 28#include "tensorflow/compiler/xla/primitive_util.h" 29#include "tensorflow/compiler/xla/service/hlo_module.h" 30#include "tensorflow/compiler/xla/service/hlo_opcode.h" 31#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 32#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" 33#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 34#include "tensorflow/compiler/xla/shape_util.h" 35#include "tensorflow/compiler/xla/status_macros.h" 36#include "tensorflow/compiler/xla/statusor.h" 37#include "tensorflow/compiler/xla/types.h" 38#include "tensorflow/compiler/xla/util.h" 39#include "tensorflow/compiler/xla/xla_data.pb.h" 40#include "tensorflow/core/lib/random/random.h" 41#include "tensorflow/core/lib/strings/strcat.h" 42#include "tensorflow/core/platform/logging.h" 43#include "tensorflow/core/platform/types.h" 44 45namespace xla { 46 47using llvm_ir::AsStringRef; 48using llvm_ir::IrArray; 49using llvm_ir::IrName; 50using llvm_ir::SetToFirstInsertPoint; 51using tensorflow::strings::StrCat; 52 53namespace { 54 55llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, 56 int64 mantissa_bits, 57 llvm::IRBuilder<>* ir_builder) { 58 // Integer and float types for casting and constant generation. 59 llvm::Type* float_type = x->getType(); 60 llvm::IntegerType* int_type = ir_builder->getInt32Ty(); 61 62 // Cast the input value to an integer for bitwise manipulation. 63 llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type); 64 65 if (mantissa_bits < 23) { 66 // Last remaining mantissa bit. 67 const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); 68 69 // Compute rounding bias for round-to-nearest with ties to even. This is 70 // equal to a base value of 0111... plus one bit if the last remaining 71 // mantissa bit is 1. 72 const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; 73 llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr( 74 ir_builder->CreateAnd( 75 x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), 76 (23 - mantissa_bits)); 77 llvm::Value* x_rounding_bias = ir_builder->CreateAdd( 78 x_last_mantissa_bit, 79 llvm::ConstantInt::get(int_type, base_rounding_bias)); 80 81 // Add rounding bias, and mask out truncated bits. Note that the case 82 // where adding the rounding bias overflows into the exponent bits is 83 // correct; the non-masked mantissa bits will all be zero, and the 84 // exponent will be incremented by one. 85 const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); 86 x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias); 87 x_as_int = ir_builder->CreateAnd( 88 x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); 89 } 90 91 if (exponent_bits < 8) { 92 // Masks for f32 values. 93 const uint32_t f32_sign_bit_mask = 1u << 31; 94 const uint32_t f32_exp_bits_mask = 0xffu << 23; 95 96 // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- 97 // significant bit -- is equal to 1.0f for all exponent sizes. Adding 98 // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- 99 // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' 100 // exponent (corresponding to 0.0f). 101 // 102 // Thus, the f32 exponent corresponding to the highest non-infinite 103 // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 104 // exponent corresponding to the lowest exponent for a bit size of n is 105 // (2^7-1) - 2^(n-1)-1. 106 // 107 // Note that we have already checked that exponents_bits >= 1. 108 const uint32_t f32_exponent_bias = (1 << 7) - 1; 109 const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1; 110 const uint32_t reduced_max_exponent = 111 f32_exponent_bias + reduced_exponent_bias; 112 const uint32_t reduced_min_exponent = 113 f32_exponent_bias - reduced_exponent_bias; 114 115 // Do we overflow or underflow? 116 llvm::Value* x_exponent = ir_builder->CreateAnd( 117 x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); 118 llvm::Value* x_overflows = ir_builder->CreateICmpUGT( 119 x_exponent, 120 llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); 121 llvm::Value* x_underflows = ir_builder->CreateICmpULE( 122 x_exponent, 123 llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); 124 125 // Compute appropriately-signed values of zero and infinity. 126 llvm::Value* x_signed_zero = ir_builder->CreateAnd( 127 x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); 128 llvm::Value* x_signed_inf = ir_builder->CreateOr( 129 x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); 130 131 // Force to zero or infinity if overflow or underflow. (Note that this 132 // truncates all denormal values to zero, rather than rounding them.) 133 x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int); 134 x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int); 135 } 136 137 // Cast the result back to a floating-point type. 138 llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type); 139 140 // Correct result for NaN inputs. 141 // 142 // The exponent handling will "normalize" NaN values to infinities, which is 143 // undesirable (except in the case with no mantissa bits, in which case it 144 // is mandatory). This logic also handles cases where mantissa-rounding 145 // causes a NaN's mantissa to overflow into the exponent bits, which would 146 // otherwise create an erroneous zero value. 147 // 148 // If the fast-math flags are set to assume no NaNs, the comparison is likely 149 // to be optimized away, so there's no point in even emitting it. 150 if (!ir_builder->getFastMathFlags().noNaNs()) { 151 llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x); 152 153 if (mantissa_bits > 0) { 154 result = ir_builder->CreateSelect(x_is_nan, x, result); 155 } else { 156 result = ir_builder->CreateSelect( 157 x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); 158 } 159 } 160 return result; 161} 162 163llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, 164 llvm::IRBuilder<>* ir_builder) { 165 auto reduced_precision = EmitReducePrecisionFloat( 166 f32_value, 167 /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, 168 /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder); 169 auto as_int32 = 170 ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty()); 171 auto shifted = ir_builder->CreateLShr(as_int32, 16); 172 auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty()); 173 return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty()); 174} 175 176llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, 177 llvm::IRBuilder<>* ir_builder) { 178 auto as_int16 = 179 ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty()); 180 auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty()); 181 auto shifted = ir_builder->CreateShl(as_int32, 16); 182 return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy()); 183} 184 185llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, 186 PrimitiveType from_type, 187 PrimitiveType to_type, llvm::Module* module, 188 llvm::IRBuilder<>* ir_builder) { 189 if (primitive_util::IsSignedIntegralType(from_type)) { 190 return ir_builder->CreateSIToFP( 191 integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); 192 } else { 193 CHECK(primitive_util::IsUnsignedIntegralType(from_type) || 194 from_type == PRED); 195 return ir_builder->CreateUIToFP( 196 integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); 197 } 198} 199 200} // namespace 201 202StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp( 203 const HloInstruction* op, llvm::Value* operand_value) const { 204 if (op->opcode() == HloOpcode::kCopy) { 205 return operand_value; 206 } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || 207 op->operand(0)->shape().element_type() == PRED) { 208 return EmitIntegerUnaryOp(op, operand_value); 209 } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { 210 return EmitComplexUnaryOp(op, operand_value); 211 } else { 212 return EmitFloatUnaryOp(op, operand_value); 213 } 214} 215 216StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp( 217 const HloInstruction* op, llvm::Value* operand_value) const { 218 switch (op->opcode()) { 219 case HloOpcode::kConvert: { 220 PrimitiveType from_type = op->operand(0)->shape().element_type(); 221 PrimitiveType to_type = op->shape().element_type(); 222 CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED); 223 if (from_type == to_type) { 224 return operand_value; 225 } 226 if (primitive_util::IsIntegralType(to_type)) { 227 return ir_builder_->CreateIntCast( 228 operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), 229 primitive_util::IsSignedIntegralType(to_type)); 230 } 231 if (primitive_util::IsFloatingPointType(to_type)) { 232 if (to_type == BF16) { 233 return EmitF32ToBF16( 234 EmitIntegralToFloating(operand_value, from_type, F32, module_, 235 ir_builder_), 236 ir_builder_); 237 } 238 return EmitIntegralToFloating(operand_value, from_type, to_type, 239 module_, ir_builder_); 240 } 241 if (primitive_util::IsComplexType(to_type)) { 242 auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( 243 primitive_util::ComplexComponentType(to_type), module_); 244 if (primitive_util::IsSignedIntegralType(from_type)) { 245 return EmitComposeComplex( 246 op, 247 ir_builder_->CreateSIToFP(operand_value, to_ir_component_type), 248 nullptr); 249 } 250 if (primitive_util::IsUnsignedIntegralType(from_type) || 251 from_type == PRED) { 252 return EmitComposeComplex( 253 op, 254 ir_builder_->CreateUIToFP(operand_value, to_ir_component_type), 255 nullptr); 256 } 257 } 258 return Unimplemented("conversion from primitive type %s to %s", 259 PrimitiveType_Name(from_type).c_str(), 260 PrimitiveType_Name(to_type).c_str()); 261 } 262 case HloOpcode::kBitcastConvert: { 263 PrimitiveType from_type = op->operand(0)->shape().element_type(); 264 PrimitiveType to_type = op->shape().element_type(); 265 CHECK(primitive_util::IsIntegralType(from_type)); 266 if (from_type == to_type) { 267 return operand_value; 268 } 269 if (primitive_util::BitWidth(from_type) == 270 primitive_util::BitWidth(to_type)) { 271 return ir_builder_->CreateBitCast( 272 operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); 273 } 274 return InvalidArgument( 275 "bitcast conversion from primitive type %s to %s with unequal " 276 "bit-widths (%u versus %u) ", 277 PrimitiveType_Name(from_type).c_str(), 278 PrimitiveType_Name(to_type).c_str(), 279 primitive_util::BitWidth(from_type), 280 primitive_util::BitWidth(to_type)); 281 } 282 case HloOpcode::kAbs: { 283 bool is_signed = 284 primitive_util::IsSignedIntegralType(op->shape().element_type()); 285 if (is_signed) { 286 auto type = 287 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); 288 auto zero = llvm::ConstantInt::get(type, 0); 289 auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero); 290 return ir_builder_->CreateSelect(cmp, operand_value, 291 ir_builder_->CreateNeg(operand_value)); 292 } else { 293 return operand_value; 294 } 295 } 296 case HloOpcode::kSign: { 297 bool is_signed = 298 primitive_util::IsSignedIntegralType(op->shape().element_type()); 299 auto type = 300 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); 301 auto zero = llvm::ConstantInt::get(type, 0); 302 auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero); 303 if (is_signed) { 304 auto ashr = ir_builder_->CreateAShr(operand_value, 305 type->getIntegerBitWidth() - 1); 306 return ir_builder_->CreateSelect(cmp, zero, 307 ir_builder_->CreateOr(ashr, 1)); 308 } else { 309 return ir_builder_->CreateSelect(cmp, zero, 310 llvm::ConstantInt::get(type, 1)); 311 } 312 } 313 case HloOpcode::kNegate: 314 return ir_builder_->CreateNeg(operand_value); 315 case HloOpcode::kNot: { 316 auto type = op->shape().element_type(); 317 if (type == PRED) { 318 // It is not sufficient to just call CreateNot() here because a PRED 319 // is represented as an i8 and the truth value is stored only in the 320 // bottom bit. 321 return ir_builder_->CreateZExt( 322 ir_builder_->CreateNot(ir_builder_->CreateTrunc( 323 operand_value, ir_builder_->getInt1Ty())), 324 llvm_ir::PrimitiveTypeToIrType(PRED, module_)); 325 } else if (primitive_util::IsIntegralType(type)) { 326 return ir_builder_->CreateNot(operand_value); 327 } 328 return Unimplemented("unary op Not is not defined for type '%d'", type); 329 } 330 default: 331 return Unimplemented("unary integer op '%s'", 332 HloOpcodeString(op->opcode()).c_str()); 333 } 334} 335 336StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp( 337 const HloInstruction* op, llvm::Value* operand_value) const { 338 switch (op->opcode()) { 339 case HloOpcode::kConvert: { 340 PrimitiveType from_type = op->operand(0)->shape().element_type(); 341 PrimitiveType to_type = op->shape().element_type(); 342 CHECK(primitive_util::IsFloatingPointType(from_type)); 343 if (from_type == to_type) { 344 return operand_value; 345 } 346 if (primitive_util::IsComplexType(to_type)) { 347 PrimitiveType to_component_type = 348 primitive_util::ComplexComponentType(to_type); 349 if (from_type == to_component_type) { 350 return EmitComposeComplex(op, operand_value, nullptr); 351 } 352 return EmitComposeComplex( 353 op, 354 ir_builder_->CreateFPCast( 355 operand_value, 356 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), 357 nullptr); 358 } 359 if (from_type == BF16) { 360 TF_RET_CHECK(to_type != BF16); 361 operand_value = EmitBF16ToF32(operand_value, ir_builder_); 362 from_type = F32; 363 if (from_type == to_type) { 364 return operand_value; 365 } 366 } 367 if (from_type == F32 && to_type == BF16) { 368 return EmitF32ToBF16(operand_value, ir_builder_); 369 } 370 if (primitive_util::IsFloatingPointType(to_type)) { 371 return ir_builder_->CreateFPCast( 372 operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); 373 } 374 if (primitive_util::IsSignedIntegralType(to_type)) { 375 return ir_builder_->CreateFPToSI( 376 operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); 377 } 378 if (primitive_util::IsUnsignedIntegralType(to_type)) { 379 return ir_builder_->CreateFPToUI( 380 operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); 381 } 382 return Unimplemented("unhandled conversion operation: %s => %s", 383 PrimitiveType_Name(from_type).c_str(), 384 PrimitiveType_Name(to_type).c_str()); 385 } 386 case HloOpcode::kBitcastConvert: { 387 PrimitiveType from_type = op->operand(0)->shape().element_type(); 388 PrimitiveType to_type = op->shape().element_type(); 389 CHECK(primitive_util::IsFloatingPointType(from_type)); 390 if (from_type == to_type) { 391 return operand_value; 392 } 393 if (primitive_util::BitWidth(from_type) == 394 primitive_util::BitWidth(to_type)) { 395 return ir_builder_->CreateBitCast( 396 operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); 397 } 398 return InvalidArgument( 399 "bitcast conversion from primitive type %s to %s with unequal " 400 "bit-widths (%u versus %u) ", 401 PrimitiveType_Name(from_type).c_str(), 402 PrimitiveType_Name(to_type).c_str(), 403 primitive_util::BitWidth(from_type), 404 primitive_util::BitWidth(to_type)); 405 } 406 case HloOpcode::kExp: 407 return EmitExp(op->shape().element_type(), operand_value); 408 case HloOpcode::kLog: 409 return EmitLog(op->shape().element_type(), operand_value); 410 case HloOpcode::kCos: 411 return EmitCos(op->shape().element_type(), operand_value); 412 case HloOpcode::kSin: 413 return EmitSin(op->shape().element_type(), operand_value); 414 case HloOpcode::kFloor: 415 return llvm_ir::EmitCallToIntrinsic( 416 llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, 417 ir_builder_); 418 case HloOpcode::kCeil: 419 return llvm_ir::EmitCallToIntrinsic( 420 llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()}, 421 ir_builder_); 422 case HloOpcode::kAbs: 423 return llvm_ir::EmitCallToIntrinsic( 424 llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()}, 425 ir_builder_); 426 case HloOpcode::kRoundNearestAfz: 427 return llvm_ir::EmitCallToIntrinsic( 428 llvm::Intrinsic::round, {operand_value}, {operand_value->getType()}, 429 ir_builder_); 430 case HloOpcode::kSign: { 431 // TODO(b/32151903): Ensure consistent sign behavior for -0.0. 432 auto type = operand_value->getType(); 433 auto zero = llvm::ConstantFP::get(type, 0.0); 434 auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero); 435 auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero); 436 return ir_builder_->CreateSelect( 437 oeq, zero, 438 ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0), 439 llvm::ConstantFP::get(type, 1.0))); 440 } 441 case HloOpcode::kIsFinite: { 442 // (x == x) && abs(x) != inf 443 auto type = operand_value->getType(); 444 auto equal_self = 445 ir_builder_->CreateFCmpOEQ(operand_value, operand_value); 446 auto abs_value = llvm_ir::EmitCallToIntrinsic( 447 llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_); 448 auto infinity = llvm::ConstantFP::getInfinity(type); 449 auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); 450 auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); 451 return ir_builder_->CreateZExt( 452 result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); 453 } 454 case HloOpcode::kNegate: 455 return ir_builder_->CreateFNeg(operand_value); 456 default: 457 return Unimplemented("unary floating-point op '%s'", 458 HloOpcodeString(op->opcode()).c_str()); 459 } 460} 461 462StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp( 463 const HloInstruction* op, llvm::Value* operand_value) const { 464 PrimitiveType input_type = op->operand(0)->shape().element_type(); 465 PrimitiveType component_type = 466 primitive_util::IsComplexType(input_type) 467 ? primitive_util::ComplexComponentType(input_type) 468 : input_type; 469 switch (op->opcode()) { 470 case HloOpcode::kLog: { 471 // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) 472 auto a = EmitExtractReal(operand_value); 473 auto b = EmitExtractImag(operand_value); 474 llvm::Type* llvm_ty = a->getType(); 475 auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), 476 ir_builder_->CreateFMul(b, b)); 477 TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); 478 TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); 479 auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); 480 return EmitComposeComplex( 481 op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); 482 } 483 case HloOpcode::kConvert: { 484 PrimitiveType from_type = op->operand(0)->shape().element_type(); 485 TF_RET_CHECK(primitive_util::IsComplexType(from_type)); 486 PrimitiveType to_type = op->shape().element_type(); 487 TF_RET_CHECK(primitive_util::IsComplexType(to_type)); 488 if (from_type == to_type) { 489 return operand_value; 490 } 491 PrimitiveType to_component_type = 492 primitive_util::ComplexComponentType(to_type); 493 auto to_ir_component_type = 494 llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); 495 return EmitComposeComplex( 496 op, 497 ir_builder_->CreateFPCast(EmitExtractReal(operand_value), 498 to_ir_component_type), 499 ir_builder_->CreateFPCast(EmitExtractImag(operand_value), 500 to_ir_component_type)); 501 } 502 case HloOpcode::kExp: { 503 // e^(a+bi) = e^a*(cos(b)+sin(b)i) 504 TF_ASSIGN_OR_RETURN( 505 auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); 506 TF_ASSIGN_OR_RETURN( 507 auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); 508 TF_ASSIGN_OR_RETURN( 509 auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); 510 return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), 511 ir_builder_->CreateFMul(exp_a, sin_b)); 512 } 513 case HloOpcode::kCos: { 514 // cos(z) = .5(e^(iz) + e^(-iz)) 515 // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai)) 516 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have 517 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i)) 518 // cos(-x) = cos(x) and sin(-x) = -sin(x), so 519 // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i)) 520 // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) 521 auto a = EmitExtractReal(operand_value); 522 auto b = EmitExtractImag(operand_value); 523 auto type = a->getType(); 524 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); 525 auto half_exp_b = 526 ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); 527 auto half_exp_neg_b = 528 ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); 529 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); 530 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); 531 return EmitComposeComplex( 532 op, 533 ir_builder_->CreateFMul( 534 cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), 535 ir_builder_->CreateFMul( 536 sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); 537 } 538 case HloOpcode::kSin: { 539 // sin(z) = .5i(e^(-iz) - e^(iz)) 540 // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi))) 541 // = .5i(e^(b-ai) - e^(-b+ai)) 542 // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have 543 // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i)) 544 // = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a))) 545 // cos(-x) = cos(x) and sin(-x) = -sin(x), so 546 // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a))) 547 // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) 548 auto a = EmitExtractReal(operand_value); 549 auto b = EmitExtractImag(operand_value); 550 auto type = a->getType(); 551 TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); 552 auto half_exp_b = 553 ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); 554 auto half_exp_neg_b = 555 ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); 556 TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); 557 TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); 558 return EmitComposeComplex( 559 op, 560 ir_builder_->CreateFMul( 561 sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), 562 ir_builder_->CreateFMul( 563 cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); 564 } 565 case HloOpcode::kTanh: { 566 /* 567 tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) 568 e^(a+bi) = e^a*(cos(b)+sin(b)i) 569 so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / 570 (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) 571 cos(b)=cos(-b), sin(-b)=-sin(b) 572 so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / 573 (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) 574 =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / 575 (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) 576 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / 577 (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) 578 This is a complex division, so we can multiply by denom_conj/denom_conj 579 =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * 580 (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / 581 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) 582 =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + 583 i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / 584 ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) 585 */ 586 auto a = EmitExtractReal(operand_value); 587 auto b = EmitExtractImag(operand_value); 588 TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); 589 TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); 590 TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); 591 auto exp_neg_a = ir_builder_->CreateFDiv( 592 llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); 593 auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( 594 ir_builder_->CreateFMul(exp_a, exp_a), 595 ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); 596 auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); 597 auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); 598 auto real_num = ir_builder_->CreateFAdd( 599 ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), 600 ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); 601 auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); 602 auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); 603 auto exp_a_plus_exp_neg_a_sq = 604 ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); 605 auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); 606 auto exp_a_minus_exp_neg_a_sq = 607 ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); 608 auto imag_num = ir_builder_->CreateFMul( 609 cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, 610 exp_a_minus_exp_neg_a_sq)); 611 auto denom = ir_builder_->CreateFAdd( 612 ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), 613 ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); 614 return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), 615 ir_builder_->CreateFDiv(imag_num, denom)); 616 } 617 case HloOpcode::kAbs: { 618 auto sum_sq = ir_builder_->CreateFAdd( 619 ir_builder_->CreateFMul(EmitExtractReal(operand_value), 620 EmitExtractReal(operand_value)), 621 ir_builder_->CreateFMul(EmitExtractImag(operand_value), 622 EmitExtractImag(operand_value))); 623 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, 624 {sum_sq->getType()}, ir_builder_); 625 } 626 case HloOpcode::kSign: { // Sign(c) = c / |c| 627 auto sum_sq = ir_builder_->CreateFAdd( 628 ir_builder_->CreateFMul(EmitExtractReal(operand_value), 629 EmitExtractReal(operand_value)), 630 ir_builder_->CreateFMul(EmitExtractImag(operand_value), 631 EmitExtractImag(operand_value))); 632 auto cplx_abs = llvm_ir::EmitCallToIntrinsic( 633 llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); 634 auto type = cplx_abs->getType(); 635 auto zero = llvm::ConstantFP::get(type, 0.0); 636 auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero); 637 return ir_builder_->CreateSelect( 638 oeq, EmitComposeComplex(op, zero, zero), 639 EmitComposeComplex( 640 op, 641 ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs), 642 ir_builder_->CreateFDiv(EmitExtractImag(operand_value), 643 cplx_abs))); 644 } 645 case HloOpcode::kNegate: 646 return EmitComposeComplex( 647 op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)), 648 ir_builder_->CreateFNeg(EmitExtractImag(operand_value))); 649 case HloOpcode::kReal: 650 return EmitExtractReal(operand_value); 651 case HloOpcode::kImag: 652 return EmitExtractImag(operand_value); 653 default: 654 return Unimplemented("unary complex op '%s'", 655 HloOpcodeString(op->opcode()).c_str()); 656 } 657} 658 659StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp( 660 const HloInstruction* op, llvm::Value* lhs_value, 661 llvm::Value* rhs_value) const { 662 PrimitiveType operand_type = op->operand(0)->shape().element_type(); 663 if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || 664 operand_type == PRED) { 665 return EmitIntegerBinaryOp( 666 op, lhs_value, rhs_value, 667 primitive_util::IsSignedIntegralType(operand_type)); 668 } else if (primitive_util::IsComplexType(operand_type)) { 669 return EmitComplexBinaryOp(op, lhs_value, rhs_value); 670 } else { 671 return EmitFloatBinaryOp(op, lhs_value, rhs_value); 672 } 673} 674 675StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp( 676 const HloInstruction* op, llvm::Value* lhs_value, 677 llvm::Value* rhs_value) const { 678 switch (op->opcode()) { 679 case HloOpcode::kComplex: 680 return EmitComposeComplex(op, lhs_value, rhs_value); 681 case HloOpcode::kAdd: 682 return ir_builder_->CreateFAdd(lhs_value, rhs_value); 683 case HloOpcode::kSubtract: 684 return ir_builder_->CreateFSub(lhs_value, rhs_value); 685 case HloOpcode::kMultiply: 686 return ir_builder_->CreateFMul(lhs_value, rhs_value); 687 case HloOpcode::kDivide: 688 return ir_builder_->CreateFDiv(lhs_value, rhs_value); 689 case HloOpcode::kRemainder: 690 return ir_builder_->CreateFRem(lhs_value, rhs_value); 691 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered 692 // comparisons always return false when one of the operands is NaN, whereas 693 // unordered comparisons return true. 694 // 695 // We use ordered comparisons for everything except kNe, where we use an 696 // unordered comparison. This makes x != y equivalent to !(x == y), and 697 // matches C++'s semantics. 698 case HloOpcode::kEq: 699 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value, 700 rhs_value, ir_builder_); 701 case HloOpcode::kNe: 702 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value, 703 rhs_value, ir_builder_); 704 case HloOpcode::kLt: 705 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value, 706 rhs_value, ir_builder_); 707 case HloOpcode::kGt: 708 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value, 709 rhs_value, ir_builder_); 710 case HloOpcode::kLe: 711 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value, 712 rhs_value, ir_builder_); 713 case HloOpcode::kGe: 714 return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value, 715 rhs_value, ir_builder_); 716 717 case HloOpcode::kMaximum: 718 return EmitFloatMax(lhs_value, rhs_value); 719 case HloOpcode::kMinimum: 720 return EmitFloatMin(lhs_value, rhs_value); 721 case HloOpcode::kPower: 722 return EmitPow(op->shape().element_type(), lhs_value, rhs_value); 723 case HloOpcode::kAtan2: 724 return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); 725 default: 726 return Unimplemented("binary floating point op '%s'", 727 HloOpcodeString(op->opcode()).c_str()); 728 } 729} 730 731StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp( 732 const HloInstruction* op, llvm::Value* lhs_value, 733 llvm::Value* rhs_value) const { 734 switch (op->opcode()) { 735 case HloOpcode::kAdd: 736 return EmitComposeComplex( 737 op, 738 ir_builder_->CreateFAdd(EmitExtractReal(lhs_value), 739 EmitExtractReal(rhs_value)), 740 ir_builder_->CreateFAdd(EmitExtractImag(lhs_value), 741 EmitExtractImag(rhs_value))); 742 case HloOpcode::kSubtract: 743 return EmitComposeComplex( 744 op, 745 ir_builder_->CreateFSub(EmitExtractReal(lhs_value), 746 EmitExtractReal(rhs_value)), 747 ir_builder_->CreateFSub(EmitExtractImag(lhs_value), 748 EmitExtractImag(rhs_value))); 749 case HloOpcode::kMultiply: 750 return EmitComposeComplex( 751 op, 752 ir_builder_->CreateFSub( 753 ir_builder_->CreateFMul(EmitExtractReal(lhs_value), 754 EmitExtractReal(rhs_value)), 755 ir_builder_->CreateFMul(EmitExtractImag(lhs_value), 756 EmitExtractImag(rhs_value))), 757 ir_builder_->CreateFAdd( 758 ir_builder_->CreateFMul(EmitExtractReal(lhs_value), 759 EmitExtractImag(rhs_value)), 760 ir_builder_->CreateFMul(EmitExtractImag(lhs_value), 761 EmitExtractReal(rhs_value)))); 762 case HloOpcode::kDivide: { 763 // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) 764 // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) 765 auto rhs_sum_sq = ir_builder_->CreateFAdd( 766 ir_builder_->CreateFMul(EmitExtractReal(rhs_value), 767 EmitExtractReal(rhs_value)), 768 ir_builder_->CreateFMul(EmitExtractImag(rhs_value), 769 EmitExtractImag(rhs_value))); 770 auto type = rhs_sum_sq->getType(); 771 auto zero = llvm::ConstantFP::get(type, 0.0); 772 auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero); 773 auto real_inf_or_nan = 774 ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero); 775 auto imag_inf_or_nan = 776 ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero); 777 return ir_builder_->CreateSelect( 778 oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan), 779 EmitComposeComplex( 780 op, 781 ir_builder_->CreateFDiv( 782 ir_builder_->CreateFAdd( 783 ir_builder_->CreateFMul(EmitExtractReal(lhs_value), 784 EmitExtractReal(rhs_value)), 785 ir_builder_->CreateFMul(EmitExtractImag(lhs_value), 786 EmitExtractImag(rhs_value))), 787 rhs_sum_sq), 788 ir_builder_->CreateFDiv( 789 ir_builder_->CreateFSub( 790 ir_builder_->CreateFMul(EmitExtractImag(lhs_value), 791 EmitExtractReal(rhs_value)), 792 ir_builder_->CreateFMul(EmitExtractReal(lhs_value), 793 EmitExtractImag(rhs_value))), 794 rhs_sum_sq))); 795 } 796 // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered 797 // comparisons always return false when one of the operands is NaN, whereas 798 // unordered comparisons return true. 799 // 800 // We use ordered comparisons for everything except kNe, where we use an 801 // unordered comparison. This makes x != y equivalent to !(x == y), and 802 // matches C++'s semantics. 803 case HloOpcode::kEq: 804 return ir_builder_->CreateAnd( 805 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, 806 EmitExtractReal(lhs_value), 807 EmitExtractReal(rhs_value), ir_builder_), 808 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, 809 EmitExtractImag(lhs_value), 810 EmitExtractImag(rhs_value), ir_builder_)); 811 case HloOpcode::kNe: 812 return ir_builder_->CreateOr( 813 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, 814 EmitExtractReal(lhs_value), 815 EmitExtractReal(rhs_value), ir_builder_), 816 llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, 817 EmitExtractImag(lhs_value), 818 EmitExtractImag(rhs_value), ir_builder_)); 819 820 case HloOpcode::kPower: { 821 // (a+bi)^(c+di) = 822 // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), 823 // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) 824 PrimitiveType component_type = 825 primitive_util::ComplexComponentType(op->shape().element_type()); 826 auto a = EmitExtractReal(lhs_value); 827 auto b = EmitExtractImag(lhs_value); 828 auto c = EmitExtractReal(rhs_value); 829 auto d = EmitExtractImag(rhs_value); 830 auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), 831 ir_builder_->CreateFMul(b, b)); 832 auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); 833 auto half_c = ir_builder_->CreateFMul(one_half, c); 834 835 TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, 836 EmitPow(component_type, aa_p_bb, half_c)); 837 auto neg_d = ir_builder_->CreateFNeg(d); 838 TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); 839 auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); 840 TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, 841 EmitExp(component_type, neg_d_arg_lhs)); 842 auto coeff = 843 ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); 844 TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); 845 auto half_d = ir_builder_->CreateFMul(one_half, d); 846 auto q = 847 ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), 848 ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); 849 TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); 850 TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); 851 return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), 852 ir_builder_->CreateFMul(coeff, sin_q)); 853 } 854 default: 855 return Unimplemented("binary complex op '%s'", 856 HloOpcodeString(op->opcode()).c_str()); 857 } 858} 859 860llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, 861 llvm::Value* rhs_value) const { 862 return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_); 863} 864 865llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, 866 llvm::Value* rhs_value) const { 867 return llvm_ir::EmitFloatMin(lhs_value, rhs_value, ir_builder_); 868} 869 870StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, 871 llvm::Value* x) const { 872 if (prim_type != F32) { 873 // TODO(b/34339814): Implement inverse erf for F64. 874 return Unimplemented( 875 "Inverse erf is only implemented for element " 876 "type F32."); 877 } 878 auto getFloat = [&](const float f) { 879 return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f); 880 }; 881 auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients, 882 llvm::Value* w) { 883 llvm::Value* p = getFloat(coefficients.front()); 884 coefficients.pop_front(); 885 for (float coefficient : coefficients) { 886 p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w), 887 getFloat(coefficient)); 888 } 889 return p; 890 }; 891 892 // Approximation for inverse error function from 893 // Giles, M., "Approximating the erfinv function". 894 // The approximation has the form: 895 // w = log((1-x)*(1+x)) 896 // if ( w < 5 ) { 897 // w = w - 2.5 898 // p = sum_{i=1}^n lq[i]*w^i 899 // } else { 900 // w = sqrt(w) - 3 901 // p = sum_{i=1}^n gq[i]*w^i 902 // } 903 // return p*x 904 llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration( 905 module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()}); 906 907 llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall( 908 logf_fn, 909 {ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x), 910 ir_builder_->CreateFAdd(getFloat(1.0f), x))})); 911 912 llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry( 913 ir_builder_->getFloatTy(), "p.addr", ir_builder_); 914 915 llvm_ir::LlvmIfData if_data = 916 llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)), 917 "w_less_than_five", ir_builder_); 918 // Handle true BB. 919 SetToFirstInsertPoint(if_data.true_block, ir_builder_); 920 { 921 llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f)); 922 tensorflow::gtl::ArraySlice<float> lq{ 923 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f, 924 -4.39150654e-06f, 0.00021858087f, -0.00125372503f, 925 -0.00417768164f, 0.246640727f, 1.50140941f}; 926 llvm::Value* p = multiply_add(lq, lw); 927 ir_builder_->CreateStore(p, p_addr); 928 } 929 930 // Handle false BB. 931 SetToFirstInsertPoint(if_data.false_block, ir_builder_); 932 { 933 llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration( 934 module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()}); 935 936 llvm::Value* gw = ir_builder_->CreateFSub( 937 ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f)); 938 tensorflow::gtl::ArraySlice<float> gq{ 939 -0.000200214257f, 0.000100950558f, 0.00134934322f, 940 -0.00367342844f, 0.00573950773f, -0.0076224613f, 941 0.00943887047f, 1.00167406f, 2.83297682f}; 942 llvm::Value* p = multiply_add(gq, gw); 943 ir_builder_->CreateStore(p, p_addr); 944 } 945 946 SetToFirstInsertPoint(if_data.after_block, ir_builder_); 947 llvm::Value* p = ir_builder_->CreateLoad(p_addr); 948 return ir_builder_->CreateFMul(p, x); 949} 950 951StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv( 952 PrimitiveType prim_type, llvm::Value* value) const { 953 // Compute erfcinv(value) by calculating erfinv(1.0 - value). 954 auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); 955 auto one = llvm::ConstantFP::get(type, 1.0); 956 return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); 957} 958 959StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type, 960 llvm::Value* value) const { 961 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, 962 {value->getType()}, ir_builder_); 963} 964 965StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type, 966 llvm::Value* value) const { 967 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, 968 {value->getType()}, ir_builder_); 969} 970 971StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type, 972 llvm::Value* value) const { 973 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, 974 {value->getType()}, ir_builder_); 975} 976 977StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type, 978 llvm::Value* value) const { 979 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, 980 {value->getType()}, ir_builder_); 981} 982 983StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type, 984 llvm::Value* lhs, 985 llvm::Value* rhs) const { 986 return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, 987 {lhs->getType()}, ir_builder_); 988} 989 990StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, 991 llvm::Value* lhs, 992 llvm::Value* rhs) const { 993 return Unimplemented("atan2"); 994} 995 996StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision( 997 const HloInstruction* hlo, llvm::Value* x) const { 998 if (hlo->operand(0)->shape().element_type() != F32) { 999 return Unimplemented("reduce-precision only implemented for F32"); 1000 } 1001 return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), 1002 /*mantissa_bits=*/hlo->mantissa_bits(), 1003 ir_builder_); 1004} 1005 1006StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( 1007 const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, 1008 bool is_signed) const { 1009 switch (op->opcode()) { 1010 // TODO(jingyue): add the "nsw" attribute for signed types. 1011 case HloOpcode::kAdd: 1012 return ir_builder_->CreateAdd(lhs_value, rhs_value); 1013 case HloOpcode::kSubtract: 1014 return ir_builder_->CreateSub(lhs_value, rhs_value); 1015 case HloOpcode::kMultiply: 1016 return ir_builder_->CreateMul(lhs_value, rhs_value); 1017 case HloOpcode::kDivide: 1018 return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value) 1019 : ir_builder_->CreateUDiv(lhs_value, rhs_value); 1020 case HloOpcode::kRemainder: 1021 return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value) 1022 : ir_builder_->CreateURem(lhs_value, rhs_value); 1023 case HloOpcode::kEq: 1024 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, 1025 rhs_value, ir_builder_); 1026 case HloOpcode::kNe: 1027 return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value, 1028 rhs_value, ir_builder_); 1029 case HloOpcode::kLt: 1030 return llvm_ir::EmitComparison( 1031 is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT, 1032 lhs_value, rhs_value, ir_builder_); 1033 case HloOpcode::kGt: 1034 return llvm_ir::EmitComparison( 1035 is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT, 1036 lhs_value, rhs_value, ir_builder_); 1037 case HloOpcode::kLe: 1038 return llvm_ir::EmitComparison( 1039 is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE, 1040 lhs_value, rhs_value, ir_builder_); 1041 case HloOpcode::kGe: 1042 return llvm_ir::EmitComparison( 1043 is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, 1044 lhs_value, rhs_value, ir_builder_); 1045 case HloOpcode::kMinimum: 1046 return EmitIntegralMin(lhs_value, rhs_value, is_signed); 1047 case HloOpcode::kMaximum: 1048 return EmitIntegralMax(lhs_value, rhs_value, is_signed); 1049 case HloOpcode::kAnd: 1050 return ir_builder_->CreateAnd(lhs_value, rhs_value); 1051 case HloOpcode::kOr: 1052 return ir_builder_->CreateOr(lhs_value, rhs_value); 1053 case HloOpcode::kShiftLeft: 1054 return ir_builder_->CreateShl(lhs_value, rhs_value); 1055 case HloOpcode::kShiftRightArithmetic: 1056 return ir_builder_->CreateAShr(lhs_value, rhs_value); 1057 case HloOpcode::kShiftRightLogical: 1058 return ir_builder_->CreateLShr(lhs_value, rhs_value); 1059 default: 1060 return Unimplemented("binary integer op '%s'", 1061 HloOpcodeString(op->opcode()).c_str()); 1062 } 1063} 1064 1065llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, 1066 llvm::Value* rhs_value, 1067 bool is_signed) const { 1068 return ir_builder_->CreateSelect( 1069 ir_builder_->CreateICmp( 1070 is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, 1071 lhs_value, rhs_value), 1072 lhs_value, rhs_value); 1073} 1074 1075llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, 1076 llvm::Value* rhs_value, 1077 bool is_signed) const { 1078 return ir_builder_->CreateSelect( 1079 ir_builder_->CreateICmp( 1080 is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, 1081 lhs_value, rhs_value), 1082 lhs_value, rhs_value); 1083} 1084 1085llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( 1086 const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, 1087 int64 operand_no) const { 1088 CHECK(hlo.IsElementwise()) 1089 << "HLO " << hlo.ToString() << " is not elementwise."; 1090 1091 const Shape& operand_shape = hlo.operand(operand_no)->shape(); 1092 // If the operand is scalar, the source index is always {}. 1093 if (ShapeUtil::IsScalar(operand_shape)) { 1094 return llvm_ir::IrArray::Index(); 1095 } 1096 1097 // If no implicit broadcast is needed for this operand, returns the target 1098 // index as the source index. 1099 if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) { 1100 return target_index; 1101 } 1102 1103 // If implicit broadcast is needed, the source dimensions that are broadcast 1104 // have index 0. 1105 CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape())); 1106 llvm_ir::IrArray::Index source_index; 1107 for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) { 1108 if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) { 1109 source_index.push_back(target_index[i]); 1110 } else { 1111 CHECK_EQ(1, operand_shape.dimensions(i)); 1112 source_index.push_back(ir_builder_->getInt64(0)); 1113 } 1114 } 1115 return source_index; 1116} 1117 1118llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( 1119 const HloInstruction* hlo, 1120 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) 1121 const { 1122 PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type(); 1123 llvm::Type* param_ir_type = 1124 llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_); 1125 1126 // Same values as PCG library 1127 // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h 1128 llvm::Value* multiplier = ir_builder_->getInt( 1129 llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4})); 1130 llvm::Value* increment = ir_builder_->getInt( 1131 llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D})); 1132 1133 auto random_value = [hlo]() { 1134 const HloModule* module = 1135 hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent() 1136 : hlo->parent()->parent(); 1137 return module->RandomNew64(); 1138 }; 1139 1140 // Seed each RNG emitter with a new 64-bit seed from the HloModule. If the 1141 // compilation order is deterministic (i.e., RandomNew64 invocation order is 1142 // deterministic), then the order of RNG is deterministic for a given seed and 1143 // hence tests will be deterministic. 1144 // If the user provides a global seed instruction then we only use 64-bits of 1145 // the host's random number generator to seed the 128 bit value with the other 1146 // 64-bits is due to a user specified global seed instruction. 1147 // Create a GlobalVariable to maintain state between invocations. There is a 1148 // bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit 1149 // values. 1150 llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable( 1151 /*M=*/*module_, 1152 /*Ty=*/ir_builder_->getInt64Ty(), 1153 /*isConstant=*/false, 1154 /*Linkage=*/llvm::GlobalValue::PrivateLinkage, 1155 /*Initializer=*/ir_builder_->getInt64(random_value()), 1156 /*Name=*/"state_ptr0"); 1157 uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed() 1158 : random_value(); 1159 llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable( 1160 /*M=*/*module_, 1161 /*Ty=*/ir_builder_->getInt64Ty(), 1162 /*isConstant=*/false, 1163 /*Linkage=*/llvm::GlobalValue::PrivateLinkage, 1164 /*Initializer=*/ir_builder_->getInt64(graph_seed), 1165 /*Name=*/"state_ptr1"); 1166 1167 // We want each thread to use its own stream, so we modify the increment per 1168 // thread. We want the increment to remain odd, so we shift the thread id left 1169 // 1 and add it to the increment. 1170 increment = ir_builder_->CreateAdd(increment, 1171 ir_builder_->CreateShl(EmitThreadId(), 1)); 1172 1173 // PCG-XSL-RR algorithm 1174 // http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf 1175 // state = multiplier * state + increment 1176 // return uint64_t(state ^ (state >> 64))) >>> (state >> 122) 1177 // where ">>>" is bitwise rotation 1178 auto get_next_i64 = [=]() { 1179 llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc( 1180 ir_builder_->CreateLoad(state_ptr0, "state0"), 1181 ir_builder_->getInt128Ty()); 1182 llvm::Value* state1 = ir_builder_->CreateShl( 1183 ir_builder_->CreateZExtOrTrunc( 1184 ir_builder_->CreateLoad(state_ptr1, "state1"), 1185 ir_builder_->getInt128Ty()), 1186 64); 1187 llvm::Value* state = ir_builder_->CreateOr(state0, state1); 1188 llvm::Value* updated = ir_builder_->CreateAdd( 1189 ir_builder_->CreateMul(state, multiplier), increment); 1190 ir_builder_->CreateStore( 1191 ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()), 1192 state_ptr0); 1193 ir_builder_->CreateStore( 1194 ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64), 1195 ir_builder_->getInt64Ty()), 1196 state_ptr1); 1197 1198 return llvm_ir::CreateRor( 1199 ir_builder_->CreateTrunc( 1200 ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)), 1201 ir_builder_->getInt64Ty()), 1202 ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122), 1203 ir_builder_->getInt64Ty()), 1204 ir_builder_); 1205 }; 1206 1207 auto get_next_uniform_float = [=]() { 1208 return ir_builder_->CreateFDiv( 1209 ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type), 1210 llvm::ConstantFP::get(param_ir_type, 0x1p64)); 1211 }; 1212 1213 return [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> { 1214 switch (hlo->random_distribution()) { 1215 case RNG_UNIFORM: { 1216 TF_ASSIGN_OR_RETURN(llvm::Value * p, 1217 operand_to_generator.at(hlo->operand(0))(index)); 1218 TF_ASSIGN_OR_RETURN(llvm::Value * q, 1219 operand_to_generator.at(hlo->operand(1))(index)); 1220 if (primitive_util::IsFloatingPointType(param_prim_type)) { 1221 return ir_builder_->CreateFAdd( 1222 ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p), 1223 get_next_uniform_float()), 1224 p); 1225 } else { 1226 auto r = ir_builder_->CreateSub(q, p); 1227 auto leading_zeros = llvm_ir::EmitCallToIntrinsic( 1228 llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)}, 1229 {param_ir_type}, ir_builder_); 1230 auto in_block = ir_builder_->GetInsertBlock(); 1231 1232 // A terminator should be present iff we're emitting code 1233 // into the middle (as opposed to the end) of a basic block. 1234 CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(), 1235 in_block->getTerminator() == nullptr); 1236 1237 llvm::BasicBlock* body_block; 1238 llvm::BasicBlock* out_block; 1239 1240 if (ir_builder_->GetInsertPoint() == in_block->end()) { 1241 body_block = llvm_ir::CreateBasicBlock( 1242 nullptr, IrName(hlo, "rng_body"), ir_builder_); 1243 out_block = llvm_ir::CreateBasicBlock( 1244 nullptr, IrName(hlo, "rng_out"), ir_builder_); 1245 llvm::BranchInst::Create(body_block, in_block); 1246 } else { 1247 body_block = in_block->splitBasicBlock( 1248 ir_builder_->GetInsertPoint(), "rng_body"); 1249 out_block = body_block->splitBasicBlock( 1250 ir_builder_->GetInsertPoint(), "rng_out"); 1251 body_block->getTerminator()->eraseFromParent(); 1252 } 1253 1254 SetToFirstInsertPoint(body_block, ir_builder_); 1255 auto random = ir_builder_->CreateAnd( 1256 ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type), 1257 ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0), 1258 leading_zeros)); 1259 llvm::BranchInst::Create(out_block, body_block, 1260 ir_builder_->CreateICmpULT(random, r), 1261 body_block); 1262 SetToFirstInsertPoint(out_block, ir_builder_); 1263 return ir_builder_->CreateAdd( 1264 p, ir_builder_->CreateSelect( 1265 ir_builder_->CreateICmpEQ(p, q), 1266 llvm::ConstantInt::get(param_ir_type, 0), random)); 1267 } 1268 } 1269 case RNG_NORMAL: { 1270 TF_ASSIGN_OR_RETURN(llvm::Value * m, 1271 operand_to_generator.at(hlo->operand(0))(index)); 1272 TF_ASSIGN_OR_RETURN(llvm::Value * s, 1273 operand_to_generator.at(hlo->operand(1))(index)); 1274 TF_ASSIGN_OR_RETURN( 1275 llvm::Value * r, 1276 EmitErfcInv(param_prim_type, 1277 ir_builder_->CreateFMul( 1278 llvm::ConstantFP::get(param_ir_type, 2.0), 1279 get_next_uniform_float()))); 1280 return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m); 1281 } 1282 default: 1283 return InvalidArgument( 1284 "unhandled distribution %s", 1285 RandomDistribution_Name(hlo->random_distribution()).c_str()); 1286 } 1287 }; 1288} 1289 1290llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( 1291 const HloInstruction* hlo, 1292 const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator) 1293 const { 1294 switch (hlo->opcode()) { 1295 case HloOpcode::kAbs: 1296 case HloOpcode::kRoundNearestAfz: 1297 case HloOpcode::kCeil: 1298 case HloOpcode::kConvert: 1299 case HloOpcode::kBitcastConvert: 1300 case HloOpcode::kCopy: 1301 case HloOpcode::kCos: 1302 case HloOpcode::kExp: 1303 case HloOpcode::kFloor: 1304 case HloOpcode::kImag: 1305 case HloOpcode::kIsFinite: 1306 case HloOpcode::kLog: 1307 case HloOpcode::kNegate: 1308 case HloOpcode::kNot: 1309 case HloOpcode::kReal: 1310 case HloOpcode::kSign: 1311 case HloOpcode::kSin: 1312 case HloOpcode::kTanh: 1313 return [this, hlo, &operand_to_generator]( 1314 const IrArray::Index& index) -> StatusOr<llvm::Value*> { 1315 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, 1316 operand_to_generator.at(hlo->operand(0))( 1317 ElementwiseSourceIndex(index, *hlo, 0))); 1318 return EmitUnaryOp(hlo, operand_value); 1319 }; 1320 case HloOpcode::kAdd: 1321 case HloOpcode::kAnd: 1322 case HloOpcode::kAtan2: 1323 case HloOpcode::kComplex: 1324 case HloOpcode::kDivide: 1325 case HloOpcode::kEq: 1326 case HloOpcode::kGe: 1327 case HloOpcode::kGt: 1328 case HloOpcode::kLe: 1329 case HloOpcode::kLt: 1330 case HloOpcode::kMaximum: 1331 case HloOpcode::kMinimum: 1332 case HloOpcode::kMultiply: 1333 case HloOpcode::kNe: 1334 case HloOpcode::kOr: 1335 case HloOpcode::kPower: 1336 case HloOpcode::kRemainder: 1337 case HloOpcode::kShiftLeft: 1338 case HloOpcode::kShiftRightArithmetic: 1339 case HloOpcode::kShiftRightLogical: 1340 case HloOpcode::kSubtract: 1341 return [this, hlo, &operand_to_generator]( 1342 const IrArray::Index& index) -> StatusOr<llvm::Value*> { 1343 const HloInstruction* lhs = hlo->operand(0); 1344 const HloInstruction* rhs = hlo->operand(1); 1345 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, 1346 operand_to_generator.at(lhs)( 1347 ElementwiseSourceIndex(index, *hlo, 0))); 1348 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, 1349 operand_to_generator.at(rhs)( 1350 ElementwiseSourceIndex(index, *hlo, 1))); 1351 return EmitBinaryOp(hlo, lhs_value, rhs_value); 1352 }; 1353 case HloOpcode::kSelect: 1354 return [this, hlo, &operand_to_generator]( 1355 const IrArray::Index& index) -> StatusOr<llvm::Value*> { 1356 TF_ASSIGN_OR_RETURN(llvm::Value * pred_value, 1357 operand_to_generator.at(hlo->operand(0))( 1358 ElementwiseSourceIndex(index, *hlo, 0))); 1359 TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value, 1360 operand_to_generator.at(hlo->operand(1))( 1361 ElementwiseSourceIndex(index, *hlo, 1))); 1362 TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value, 1363 operand_to_generator.at(hlo->operand(2))( 1364 ElementwiseSourceIndex(index, *hlo, 2))); 1365 return ir_builder_->CreateSelect( 1366 ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()), 1367 on_true_value, on_false_value); 1368 }; 1369 case HloOpcode::kClamp: 1370 return [this, hlo, &operand_to_generator]( 1371 const IrArray::Index& index) -> StatusOr<llvm::Value*> { 1372 TF_ASSIGN_OR_RETURN(llvm::Value * min_value, 1373 operand_to_generator.at(hlo->operand(0))( 1374 ElementwiseSourceIndex(index, *hlo, 0))); 1375 TF_ASSIGN_OR_RETURN(llvm::Value * arg_value, 1376 operand_to_generator.at(hlo->operand(1))( 1377 ElementwiseSourceIndex(index, *hlo, 1))); 1378 TF_ASSIGN_OR_RETURN(llvm::Value * max_value, 1379 operand_to_generator.at(hlo->operand(2))( 1380 ElementwiseSourceIndex(index, *hlo, 2))); 1381 PrimitiveType prim_type = hlo->shape().element_type(); 1382 if (primitive_util::IsFloatingPointType(prim_type)) { 1383 return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); 1384 } else if (primitive_util::IsIntegralType(prim_type)) { 1385 bool is_signed = primitive_util::IsSignedIntegralType(prim_type); 1386 return EmitIntegralMin( 1387 max_value, EmitIntegralMax(min_value, arg_value, is_signed), 1388 is_signed); 1389 } else { 1390 return Unimplemented("Clamp unimplemented for %s", 1391 PrimitiveType_Name(prim_type).c_str()); 1392 } 1393 }; 1394 case HloOpcode::kReducePrecision: 1395 return [this, hlo, &operand_to_generator]( 1396 const IrArray::Index& index) -> StatusOr<llvm::Value*> { 1397 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, 1398 operand_to_generator.at(hlo->operand(0))( 1399 ElementwiseSourceIndex(index, *hlo, 0))); 1400 return EmitReducePrecision(hlo, operand_value); 1401 }; 1402 case HloOpcode::kConcatenate: 1403 return [this, hlo, &operand_to_generator]( 1404 const IrArray::Index target_index) -> StatusOr<llvm::Value*> { 1405 const int64 concat_dim = hlo->dimensions(0); 1406 auto source_index = target_index; 1407 1408 llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); 1409 1410 // A terminator should be present iff we're emitting code 1411 // into the middle (as opposed to the end) of a basic block. 1412 CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), 1413 init_block->getTerminator() == nullptr); 1414 1415 llvm::BasicBlock* exit_block; 1416 if (ir_builder_->GetInsertPoint() == init_block->end()) { 1417 exit_block = llvm_ir::CreateBasicBlock( 1418 /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); 1419 } else { 1420 exit_block = init_block->splitBasicBlock( 1421 ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); 1422 init_block->getTerminator()->eraseFromParent(); 1423 } 1424 1425 llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); 1426 llvm::PHINode* output = 1427 ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType( 1428 hlo->shape().element_type(), module_), 1429 hlo->operands().size()); 1430 auto prior_insert_point = ir_builder_->GetInsertPoint(); 1431 1432 ir_builder_->SetInsertPoint(init_block); 1433 1434 for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); 1435 ++operand_idx) { 1436 const HloInstruction* operand = hlo->operand(operand_idx); 1437 auto true_block = llvm_ir::CreateBasicBlock( 1438 exit_block, StrCat("concat_index_from_operand", operand_idx), 1439 ir_builder_); 1440 auto false_block = llvm_ir::CreateBasicBlock( 1441 exit_block, StrCat("concat_index_not_from_operand", operand_idx), 1442 ir_builder_); 1443 auto concat_dim_size = 1444 llvm::ConstantInt::get(source_index[concat_dim]->getType(), 1445 operand->shape().dimensions(concat_dim)); 1446 ir_builder_->CreateCondBr( 1447 ir_builder_->CreateICmpULT(source_index[concat_dim], 1448 concat_dim_size), 1449 true_block, false_block); 1450 1451 // Create the terminator of the true block before calling operand 1452 // generators, because they require non-degenerate basic blocks. 1453 ir_builder_->SetInsertPoint( 1454 llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block)); 1455 TF_ASSIGN_OR_RETURN(llvm::Value * value, 1456 operand_to_generator.at(operand)(source_index)); 1457 output->addIncoming(value, ir_builder_->GetInsertBlock()); 1458 1459 // Subtract the size of the concat dimension of the current operand 1460 // from the source index. 1461 ir_builder_->SetInsertPoint(false_block); 1462 source_index[concat_dim] = 1463 ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size); 1464 } 1465 1466 ir_builder_->CreateUnreachable(); 1467 ir_builder_->SetInsertPoint(exit_block, prior_insert_point); 1468 return output; 1469 }; 1470 case HloOpcode::kReverse: 1471 return [this, hlo, &operand_to_generator]( 1472 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> { 1473 const HloInstruction* operand = hlo->operand(0); 1474 auto source_index = target_index; 1475 for (int64 dim : hlo->dimensions()) { 1476 source_index[dim] = ir_builder_->CreateSub( 1477 llvm::ConstantInt::get(target_index[dim]->getType(), 1478 hlo->shape().dimensions(dim) - 1), 1479 target_index[dim]); 1480 } 1481 return operand_to_generator.at(operand)(source_index); 1482 }; 1483 case HloOpcode::kBroadcast: 1484 return [this, hlo, &operand_to_generator]( 1485 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> { 1486 // The `dimensions` member of the broadcast instruction maps from 1487 // input dimensions to output dimensions. 1488 const HloInstruction* operand = hlo->operand(0); 1489 int64 rank = ShapeUtil::Rank(operand->shape()); 1490 IrArray::Index source_index(rank); 1491 for (int64 i = 0; i < rank; ++i) { 1492 source_index[i] = target_index[hlo->dimensions(i)]; 1493 } 1494 return operand_to_generator.at(operand)(source_index); 1495 }; 1496 case HloOpcode::kSlice: 1497 return [this, hlo, &operand_to_generator]( 1498 const IrArray::Index& index) -> StatusOr<llvm::Value*> { 1499 IrArray::Index sliced_index = index.SourceIndexOfSlice( 1500 /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(), 1501 /*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_); 1502 return operand_to_generator.at(hlo->operand(0))(sliced_index); 1503 }; 1504 case HloOpcode::kDynamicSlice: 1505 return [this, hlo, &operand_to_generator]( 1506 const IrArray::Index& index) -> StatusOr<llvm::Value*> { 1507 // Emit IR to read dynamic start indices from hlo->operand(1). 1508 const HloInstruction* input_hlo = hlo->operand(0); 1509 const int64 rank = ShapeUtil::Rank(input_hlo->shape()); 1510 llvm_ir::IrArray::Index slice_start_index(rank); 1511 for (int64 i = 0; i < rank; ++i) { 1512 llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); 1513 TF_ASSIGN_OR_RETURN( 1514 llvm::Value * start_index_value, 1515 operand_to_generator.at(hlo->operand(1))(dim_index)); 1516 start_index_value->setName( 1517 AsStringRef(IrName(hlo, StrCat("start_idx", i)))); 1518 slice_start_index[i] = start_index_value; 1519 } 1520 1521 llvm_ir::IrArray::Index input_index(rank); 1522 for (int64 i = 0; i < rank; ++i) { 1523 // Emit IR which computes: 1524 // input_index = (start_index + offset_index) % dim_size 1525 // Security note: this is the code that keeps the indices in-bounds. 1526 llvm::Value* dim_size = llvm::ConstantInt::get( 1527 index[i]->getType(), input_hlo->shape().dimensions(i)); 1528 llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast( 1529 slice_start_index[i], index[i]->getType()); 1530 input_index[i] = ir_builder_->CreateURem( 1531 ir_builder_->CreateAdd(start_index, index[i]), dim_size); 1532 } 1533 return operand_to_generator.at(input_hlo)(input_index); 1534 }; 1535 case HloOpcode::kDynamicUpdateSlice: 1536 return [this, hlo, &operand_to_generator]( 1537 const IrArray::Index& index) -> StatusOr<llvm::Value*> { 1538 const HloInstruction* input_hlo = hlo->operand(0); 1539 const HloInstruction* update_hlo = hlo->operand(1); 1540 const HloInstruction* start_hlo = hlo->operand(2); 1541 // Calculate slice start/end indices. 1542 const int64 rank = ShapeUtil::Rank(input_hlo->shape()); 1543 llvm_ir::IrArray::Index slice_start_index(rank); 1544 llvm_ir::IrArray::Index slice_limit_index(rank); 1545 // Slice starts at update[index - slice_start_index_adjusted], 1546 // where adjusted value = slice_start_index when in bounds, and 1547 // adjusted value = slice_start_index - input_dim, when wrapping. 1548 llvm_ir::IrArray::Index slice_start_index_adjusted(rank); 1549 1550 // Slice intersection gathers (ANDs) conditions on all ranks for which 1551 // 'input' is set to 'update' 1552 llvm::Value* slice_intersection = ir_builder_->getTrue(); 1553 1554 for (int64 i = 0; i < rank; ++i) { 1555 // Emit IR to read dynamic start indices from 'start_hlo'. 1556 llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); 1557 TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, 1558 operand_to_generator.at(start_hlo)(dim_index)); 1559 start_index_value->setName( 1560 AsStringRef(IrName(hlo, StrCat("start_idx", i)))); 1561 slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( 1562 start_index_value, index[i]->getType()); 1563 1564 llvm::Value* input_dim_size = llvm::ConstantInt::get( 1565 index[i]->getType(), input_hlo->shape().dimensions(i)); 1566 llvm::Value* update_dim_size = llvm::ConstantInt::get( 1567 index[i]->getType(), update_hlo->shape().dimensions(i)); 1568 1569 // Generate code to handle wrapping semantics: 1570 // slice_start_index[i] = slice_start_index[i] % input_dim_size; 1571 // slice_limit_index[i] = slice_start_index[i] + update_dim_size. 1572 // slice_start_index[i] is updated in place and it will now be in 1573 // range. slice_limit_index[i] may be out of range, and it's being 1574 // URem-ed below if so. 1575 slice_start_index[i] = 1576 ir_builder_->CreateURem(slice_start_index[i], input_dim_size); 1577 slice_limit_index[i] = 1578 ir_builder_->CreateAdd(slice_start_index[i], update_dim_size); 1579 1580 // Test if slice_limit_index[i] is in bounds 1581 llvm::Value* in_bounds = 1582 ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size); 1583 llvm_ir::LlvmIfData if_in_bounds = 1584 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); 1585 1586 // Handle true BB (slice_limit_index[i] <= input_dim_size). 1587 SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_); 1588 // Check that index[i] >= slice_start_index[i] && 1589 // index[i] < slice_limit_index[i] 1590 llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd( 1591 slice_intersection, 1592 ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), 1593 "slice_intersection_in"); 1594 slice_intersection_in_bounds = ir_builder_->CreateAnd( 1595 slice_intersection_in_bounds, 1596 ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]), 1597 "slice_intersection_in"); 1598 1599 // Handle false BB (slice_limit_index[i] > input_dim_size). 1600 SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_); 1601 // Check that index[i] >= slice_start_index[i] || 1602 // index[i] < slice_limit_index[i]%input_dim_size. 1603 llvm::Value* index_wraps = ir_builder_->CreateICmpSLT( 1604 index[i], 1605 ir_builder_->CreateURem(slice_limit_index[i], input_dim_size)); 1606 llvm::Value* slice_intersection_or = ir_builder_->CreateOr( 1607 ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]), 1608 index_wraps, "slice_intersection_out"); 1609 llvm::Value* slice_intersection_out_of_bounds = 1610 ir_builder_->CreateAnd(slice_intersection, slice_intersection_or, 1611 "slice_intersection_out"); 1612 // Create value for slice_start_index_adjusted[i] when out of bounds. 1613 // If within out-of-bounds if. 1614 llvm_ir::LlvmIfData if_start_needs_adjustment = 1615 llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_); 1616 SetToFirstInsertPoint(if_start_needs_adjustment.true_block, 1617 ir_builder_); 1618 llvm::Value* slice_start_index_adjusted_oob = 1619 ir_builder_->CreateSub(slice_start_index[i], input_dim_size); 1620 SetToFirstInsertPoint(if_start_needs_adjustment.after_block, 1621 ir_builder_); 1622 llvm::PHINode* slice_start_index_adjusted_phi = 1623 ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(), 1624 2); 1625 slice_start_index_adjusted_phi->addIncoming( 1626 slice_start_index_adjusted_oob, 1627 if_start_needs_adjustment.true_block); 1628 slice_start_index_adjusted_phi->addIncoming( 1629 slice_start_index[i], if_start_needs_adjustment.false_block); 1630 // End of if within if. 1631 1632 // After checking in/out of bounds. 1633 SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_); 1634 llvm::PHINode* phi_slice_intersection = 1635 ir_builder_->CreatePHI(slice_intersection->getType(), 2); 1636 phi_slice_intersection->addIncoming(slice_intersection_in_bounds, 1637 if_in_bounds.true_block); 1638 phi_slice_intersection->addIncoming( 1639 slice_intersection_out_of_bounds, 1640 if_start_needs_adjustment.after_block); 1641 slice_intersection = phi_slice_intersection; 1642 1643 llvm::PHINode* phi_index = 1644 ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2); 1645 phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block); 1646 phi_index->addIncoming(slice_start_index_adjusted_phi, 1647 if_start_needs_adjustment.after_block); 1648 slice_start_index_adjusted[i] = phi_index; 1649 } 1650 1651 // Emit: 1652 // if (slice_intersection) -> return data from 'update'. 1653 // else -> return data from 'input'. 1654 llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( 1655 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), 1656 module_), 1657 "ret_value_addr", ir_builder_); 1658 llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( 1659 slice_intersection, "slice_intersection", ir_builder_); 1660 1661 // Handle true BB (return data from 'update') 1662 SetToFirstInsertPoint(if_data.true_block, ir_builder_); 1663 // Compute update index for intersection case. 1664 llvm_ir::IrArray::Index update_index(rank); 1665 for (int64 i = 0; i < rank; ++i) { 1666 llvm::Value* update_dim_size = llvm::ConstantInt::get( 1667 index[i]->getType(), update_hlo->shape().dimensions(i)); 1668 // NOTE: Subtraction will be positive due to bounds checking above. 1669 update_index[i] = ir_builder_->CreateURem( 1670 ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]), 1671 update_dim_size); 1672 } 1673 TF_ASSIGN_OR_RETURN(llvm::Value * true_value, 1674 operand_to_generator.at(update_hlo)(update_index)); 1675 ir_builder_->CreateStore(true_value, ret_value_addr); 1676 1677 // Handle false BB (return data from 'input') 1678 SetToFirstInsertPoint(if_data.false_block, ir_builder_); 1679 TF_ASSIGN_OR_RETURN(llvm::Value * false_value, 1680 operand_to_generator.at(input_hlo)(index)); 1681 ir_builder_->CreateStore(false_value, ret_value_addr); 1682 1683 SetToFirstInsertPoint(if_data.after_block, ir_builder_); 1684 return ir_builder_->CreateLoad(ret_value_addr); 1685 }; 1686 case HloOpcode::kReshape: 1687 CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()), 1688 ShapeUtil::ElementsIn(hlo->operand(0)->shape())); 1689 return [this, hlo, &operand_to_generator](const IrArray::Index& index) { 1690 const HloInstruction* operand = hlo->operand(0); 1691 return operand_to_generator.at(operand)(index.SourceIndexOfReshape( 1692 hlo->shape(), operand->shape(), ir_builder_)); 1693 }; 1694 case HloOpcode::kTranspose: 1695 return [this, hlo, 1696 &operand_to_generator](const IrArray::Index& target_index) { 1697 return operand_to_generator.at(hlo->operand(0))( 1698 target_index.SourceIndexOfTranspose( 1699 hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), 1700 ir_builder_)); 1701 }; 1702 case HloOpcode::kRng: 1703 return MakeRngElementGenerator(hlo, operand_to_generator); 1704 case HloOpcode::kPad: 1705 return [=, &operand_to_generator]( 1706 const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> { 1707 auto index = padded_index; 1708 llvm::Value* in_bounds = ir_builder_->getTrue(); 1709 for (size_t i = 0; i < index.size(); ++i) { 1710 auto index_typed_const = [=](int64 n) { 1711 return llvm::ConstantInt::get(index[i]->getType(), n); 1712 }; 1713 const auto& pad_dim = hlo->padding_config().dimensions(i); 1714 index[i] = ir_builder_->CreateSub( 1715 index[i], index_typed_const(pad_dim.edge_padding_low())); 1716 in_bounds = ir_builder_->CreateAnd( 1717 in_bounds, 1718 ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)), 1719 "in_bounds"); 1720 in_bounds = ir_builder_->CreateAnd( 1721 in_bounds, 1722 ir_builder_->CreateICmpEQ( 1723 index_typed_const(0), 1724 ir_builder_->CreateURem( 1725 index[i], 1726 index_typed_const(pad_dim.interior_padding() + 1))), 1727 "in_bounds"); 1728 index[i] = ir_builder_->CreateSDiv( 1729 index[i], index_typed_const(pad_dim.interior_padding() + 1)); 1730 in_bounds = ir_builder_->CreateAnd( 1731 in_bounds, 1732 ir_builder_->CreateICmpSLT( 1733 index[i], 1734 index_typed_const(hlo->operand(0)->shape().dimensions(i))), 1735 "in_bounds"); 1736 } 1737 1738 // if (in_bounds) { 1739 // ret_value = operand0[index]; // source 1740 // } else { 1741 // ret_value = *operand1; // padding 1742 // } 1743 llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( 1744 llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), 1745 module_), 1746 "pad_result_addr", ir_builder_); 1747 llvm_ir::LlvmIfData if_data = 1748 llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); 1749 SetToFirstInsertPoint(if_data.true_block, ir_builder_); 1750 TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, 1751 operand_to_generator.at(hlo->operand(0))(index)); 1752 ir_builder_->CreateStore(operand_value, ret_value_addr); 1753 1754 SetToFirstInsertPoint(if_data.false_block, ir_builder_); 1755 TF_ASSIGN_OR_RETURN(llvm::Value * padding_value, 1756 operand_to_generator.at(hlo->operand(1))({})); 1757 ir_builder_->CreateStore(padding_value, ret_value_addr); 1758 1759 SetToFirstInsertPoint(if_data.after_block, ir_builder_); 1760 // Don't create phi(operand_value, padding_value) here, because invoking 1761 // operand_to_generator may create new basic blocks, making the parent 1762 // of operand_value or padding_value no longer a predecessor of 1763 // if_data.after_block. 1764 return ir_builder_->CreateLoad(ret_value_addr); 1765 }; 1766 1767 case HloOpcode::kDot: 1768 return [=, &operand_to_generator](const IrArray::Index& dot_result_index) 1769 -> StatusOr<llvm::Value*> { 1770 auto lhs_generator = operand_to_generator.at(hlo->operand(0)); 1771 auto rhs_generator = operand_to_generator.at(hlo->operand(1)); 1772 int64 contracted_dim_size = hlo->operand(0)->shape().dimensions( 1773 hlo->operand(0)->shape().dimensions_size() - 1); 1774 int64 lhs_dims = hlo->operand(0)->shape().dimensions_size(); 1775 int64 rhs_dims = hlo->operand(1)->shape().dimensions_size(); 1776 1777 std::unique_ptr<llvm_ir::ForLoop> inner_loop = 1778 llvm_ir::ForLoop::EmitForLoop( 1779 IrName(hlo, "inner"), ir_builder_->getInt64(0), 1780 ir_builder_->getInt64(contracted_dim_size), 1781 ir_builder_->getInt64(1), ir_builder_); 1782 1783 SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), 1784 ir_builder_); 1785 PrimitiveType primitive_type = hlo->shape().element_type(); 1786 llvm::Type* primitive_type_llvm = 1787 llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); 1788 llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( 1789 primitive_type_llvm, "dot_acc", ir_builder_); 1790 ir_builder_->CreateStore( 1791 llvm::Constant::getNullValue(primitive_type_llvm), 1792 accumulator_alloca); 1793 1794 SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_); 1795 1796 // This is the inner reduction loop for a dot operation that produces 1797 // one element in the output. If the operands to the dot operation have 1798 // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E]. 1799 // Given an output index [a,b,c,d,e] in the result, we compute: 1800 // sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T)) 1801 1802 IrArray::Index lhs_index, rhs_index; 1803 1804 for (int64 i = 0; i < lhs_dims - 1; i++) { 1805 lhs_index.push_back(dot_result_index[i]); 1806 } 1807 lhs_index.push_back(inner_loop->GetIndVarValue()); 1808 1809 for (int64 i = 0; i < rhs_dims - 2; i++) { 1810 rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]); 1811 } 1812 rhs_index.push_back(inner_loop->GetIndVarValue()); 1813 rhs_index.push_back(dot_result_index.back()); 1814 1815 llvm::Value* current_accumulator = 1816 ir_builder_->CreateLoad(accumulator_alloca); 1817 TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); 1818 TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); 1819 llvm::Value* next_accumulator; 1820 if (primitive_util::IsComplexType(primitive_type)) { 1821 llvm::Value* product_real = ir_builder_->CreateFSub( 1822 ir_builder_->CreateFMul(EmitExtractReal(lhs_value), 1823 EmitExtractReal(rhs_value)), 1824 ir_builder_->CreateFMul(EmitExtractImag(lhs_value), 1825 EmitExtractImag(rhs_value))); 1826 llvm::Value* product_imag = ir_builder_->CreateFAdd( 1827 ir_builder_->CreateFMul(EmitExtractReal(lhs_value), 1828 EmitExtractImag(rhs_value)), 1829 ir_builder_->CreateFMul(EmitExtractImag(lhs_value), 1830 EmitExtractReal(rhs_value))); 1831 next_accumulator = ir_builder_->CreateInsertValue( 1832 current_accumulator, 1833 ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator), 1834 product_real), 1835 {0}); 1836 next_accumulator = ir_builder_->CreateInsertValue( 1837 next_accumulator, 1838 ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator), 1839 product_imag), 1840 {1}); 1841 } else if (primitive_util::IsFloatingPointType(primitive_type)) { 1842 next_accumulator = ir_builder_->CreateFAdd( 1843 current_accumulator, 1844 ir_builder_->CreateFMul(lhs_value, rhs_value)); 1845 } else { 1846 next_accumulator = ir_builder_->CreateAdd( 1847 current_accumulator, 1848 ir_builder_->CreateMul(lhs_value, rhs_value)); 1849 } 1850 ir_builder_->CreateStore(next_accumulator, accumulator_alloca); 1851 1852 SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_); 1853 return ir_builder_->CreateLoad(accumulator_alloca); 1854 }; 1855 default: 1856 return [this, hlo, &operand_to_generator](const IrArray::Index& index) { 1857 return Unimplemented("Unhandled opcode for elemental IR emission: %s", 1858 HloOpcodeString(hlo->opcode()).c_str()); 1859 }; 1860 } 1861} 1862 1863llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const { 1864 return ir_builder_->CreateExtractValue(value, {0}); 1865} 1866 1867llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const { 1868 return ir_builder_->CreateExtractValue(value, {1}); 1869} 1870 1871llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op, 1872 llvm::Value* real, 1873 llvm::Value* imag) const { 1874 auto cplx_type = 1875 llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); 1876 auto complex = ir_builder_->CreateInsertValue( 1877 llvm::ConstantAggregateZero::get(cplx_type), real, {0}); 1878 if (imag != nullptr) { 1879 complex = ir_builder_->CreateInsertValue(complex, imag, {1}); 1880 } 1881 return complex; 1882} 1883 1884} // namespace xla 1885