1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
17
18#include <algorithm>
19#include <memory>
20#include <string>
21#include <vector>
22
23// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
24#include "llvm/IR/BasicBlock.h"
25#include "llvm/IR/Instructions.h"
26#include "llvm/IR/Intrinsics.h"
27#include "llvm/Transforms/Utils/BasicBlockUtils.h"
28#include "tensorflow/compiler/xla/primitive_util.h"
29#include "tensorflow/compiler/xla/service/hlo_module.h"
30#include "tensorflow/compiler/xla/service/hlo_opcode.h"
31#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
32#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
33#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
34#include "tensorflow/compiler/xla/shape_util.h"
35#include "tensorflow/compiler/xla/status_macros.h"
36#include "tensorflow/compiler/xla/statusor.h"
37#include "tensorflow/compiler/xla/types.h"
38#include "tensorflow/compiler/xla/util.h"
39#include "tensorflow/compiler/xla/xla_data.pb.h"
40#include "tensorflow/core/lib/random/random.h"
41#include "tensorflow/core/lib/strings/strcat.h"
42#include "tensorflow/core/platform/logging.h"
43#include "tensorflow/core/platform/types.h"
44
45namespace xla {
46
47using llvm_ir::AsStringRef;
48using llvm_ir::IrArray;
49using llvm_ir::IrName;
50using llvm_ir::SetToFirstInsertPoint;
51using tensorflow::strings::StrCat;
52
53namespace {
54
55llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
56                                      int64 mantissa_bits,
57                                      llvm::IRBuilder<>* ir_builder) {
58  // Integer and float types for casting and constant generation.
59  llvm::Type* float_type = x->getType();
60  llvm::IntegerType* int_type = ir_builder->getInt32Ty();
61
62  // Cast the input value to an integer for bitwise manipulation.
63  llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type);
64
65  if (mantissa_bits < 23) {
66    // Last remaining mantissa bit.
67    const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits);
68
69    // Compute rounding bias for round-to-nearest with ties to even.  This is
70    // equal to a base value of 0111... plus one bit if the last remaining
71    // mantissa bit is 1.
72    const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
73    llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr(
74        ir_builder->CreateAnd(
75            x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
76        (23 - mantissa_bits));
77    llvm::Value* x_rounding_bias = ir_builder->CreateAdd(
78        x_last_mantissa_bit,
79        llvm::ConstantInt::get(int_type, base_rounding_bias));
80
81    // Add rounding bias, and mask out truncated bits.  Note that the case
82    // where adding the rounding bias overflows into the exponent bits is
83    // correct; the non-masked mantissa bits will all be zero, and the
84    // exponent will be incremented by one.
85    const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
86    x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias);
87    x_as_int = ir_builder->CreateAnd(
88        x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
89  }
90
91  if (exponent_bits < 8) {
92    // Masks for f32 values.
93    const uint32_t f32_sign_bit_mask = 1u << 31;
94    const uint32_t f32_exp_bits_mask = 0xffu << 23;
95
96    // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most-
97    // significant bit -- is equal to 1.0f for all exponent sizes.  Adding
98    // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit-
99    // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest'
100    // exponent (corresponding to 0.0f).
101    //
102    // Thus, the f32 exponent corresponding to the highest non-infinite
103    // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32
104    // exponent corresponding to the lowest exponent for a bit size of n is
105    // (2^7-1) - 2^(n-1)-1.
106    //
107    // Note that we have already checked that exponents_bits >= 1.
108    const uint32_t f32_exponent_bias = (1 << 7) - 1;
109    const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1;
110    const uint32_t reduced_max_exponent =
111        f32_exponent_bias + reduced_exponent_bias;
112    const uint32_t reduced_min_exponent =
113        f32_exponent_bias - reduced_exponent_bias;
114
115    // Do we overflow or underflow?
116    llvm::Value* x_exponent = ir_builder->CreateAnd(
117        x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
118    llvm::Value* x_overflows = ir_builder->CreateICmpUGT(
119        x_exponent,
120        llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
121    llvm::Value* x_underflows = ir_builder->CreateICmpULE(
122        x_exponent,
123        llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
124
125    // Compute appropriately-signed values of zero and infinity.
126    llvm::Value* x_signed_zero = ir_builder->CreateAnd(
127        x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
128    llvm::Value* x_signed_inf = ir_builder->CreateOr(
129        x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
130
131    // Force to zero or infinity if overflow or underflow.  (Note that this
132    // truncates all denormal values to zero, rather than rounding them.)
133    x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int);
134    x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int);
135  }
136
137  // Cast the result back to a floating-point type.
138  llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type);
139
140  // Correct result for NaN inputs.
141  //
142  // The exponent handling will "normalize" NaN values to infinities, which is
143  // undesirable (except in the case with no mantissa bits, in which case it
144  // is mandatory).  This logic also handles cases where mantissa-rounding
145  // causes a NaN's mantissa to overflow into the exponent bits, which would
146  // otherwise create an erroneous zero value.
147  //
148  // If the fast-math flags are set to assume no NaNs, the comparison is likely
149  // to be optimized away, so there's no point in even emitting it.
150  if (!ir_builder->getFastMathFlags().noNaNs()) {
151    llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x);
152
153    if (mantissa_bits > 0) {
154      result = ir_builder->CreateSelect(x_is_nan, x, result);
155    } else {
156      result = ir_builder->CreateSelect(
157          x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
158    }
159  }
160  return result;
161}
162
163llvm::Value* EmitF32ToBF16(llvm::Value* f32_value,
164                           llvm::IRBuilder<>* ir_builder) {
165  auto reduced_precision = EmitReducePrecisionFloat(
166      f32_value,
167      /*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
168      /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder);
169  auto as_int32 =
170      ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty());
171  auto shifted = ir_builder->CreateLShr(as_int32, 16);
172  auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty());
173  return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty());
174}
175
176llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value,
177                           llvm::IRBuilder<>* ir_builder) {
178  auto as_int16 =
179      ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty());
180  auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty());
181  auto shifted = ir_builder->CreateShl(as_int32, 16);
182  return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy());
183}
184
185llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
186                                    PrimitiveType from_type,
187                                    PrimitiveType to_type, llvm::Module* module,
188                                    llvm::IRBuilder<>* ir_builder) {
189  if (primitive_util::IsSignedIntegralType(from_type)) {
190    return ir_builder->CreateSIToFP(
191        integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
192  } else {
193    CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
194          from_type == PRED);
195    return ir_builder->CreateUIToFP(
196        integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
197  }
198}
199
200}  // namespace
201
202StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
203    const HloInstruction* op, llvm::Value* operand_value) const {
204  if (op->opcode() == HloOpcode::kCopy) {
205    return operand_value;
206  } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
207             op->operand(0)->shape().element_type() == PRED) {
208    return EmitIntegerUnaryOp(op, operand_value);
209  } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) {
210    return EmitComplexUnaryOp(op, operand_value);
211  } else {
212    return EmitFloatUnaryOp(op, operand_value);
213  }
214}
215
216StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
217    const HloInstruction* op, llvm::Value* operand_value) const {
218  switch (op->opcode()) {
219    case HloOpcode::kConvert: {
220      PrimitiveType from_type = op->operand(0)->shape().element_type();
221      PrimitiveType to_type = op->shape().element_type();
222      CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED);
223      if (from_type == to_type) {
224        return operand_value;
225      }
226      if (primitive_util::IsIntegralType(to_type)) {
227        return ir_builder_->CreateIntCast(
228            operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
229            primitive_util::IsSignedIntegralType(to_type));
230      }
231      if (primitive_util::IsFloatingPointType(to_type)) {
232        if (to_type == BF16) {
233          return EmitF32ToBF16(
234              EmitIntegralToFloating(operand_value, from_type, F32, module_,
235                                     ir_builder_),
236              ir_builder_);
237        }
238        return EmitIntegralToFloating(operand_value, from_type, to_type,
239                                      module_, ir_builder_);
240      }
241      if (primitive_util::IsComplexType(to_type)) {
242        auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
243            primitive_util::ComplexComponentType(to_type), module_);
244        if (primitive_util::IsSignedIntegralType(from_type)) {
245          return EmitComposeComplex(
246              op,
247              ir_builder_->CreateSIToFP(operand_value, to_ir_component_type),
248              nullptr);
249        }
250        if (primitive_util::IsUnsignedIntegralType(from_type) ||
251            from_type == PRED) {
252          return EmitComposeComplex(
253              op,
254              ir_builder_->CreateUIToFP(operand_value, to_ir_component_type),
255              nullptr);
256        }
257      }
258      return Unimplemented("conversion from primitive type %s to %s",
259                           PrimitiveType_Name(from_type).c_str(),
260                           PrimitiveType_Name(to_type).c_str());
261    }
262    case HloOpcode::kBitcastConvert: {
263      PrimitiveType from_type = op->operand(0)->shape().element_type();
264      PrimitiveType to_type = op->shape().element_type();
265      CHECK(primitive_util::IsIntegralType(from_type));
266      if (from_type == to_type) {
267        return operand_value;
268      }
269      if (primitive_util::BitWidth(from_type) ==
270          primitive_util::BitWidth(to_type)) {
271        return ir_builder_->CreateBitCast(
272            operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
273      }
274      return InvalidArgument(
275          "bitcast conversion from primitive type %s to %s with unequal "
276          "bit-widths (%u versus %u) ",
277          PrimitiveType_Name(from_type).c_str(),
278          PrimitiveType_Name(to_type).c_str(),
279          primitive_util::BitWidth(from_type),
280          primitive_util::BitWidth(to_type));
281    }
282    case HloOpcode::kAbs: {
283      bool is_signed =
284          primitive_util::IsSignedIntegralType(op->shape().element_type());
285      if (is_signed) {
286        auto type =
287            llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
288        auto zero = llvm::ConstantInt::get(type, 0);
289        auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
290        return ir_builder_->CreateSelect(cmp, operand_value,
291                                         ir_builder_->CreateNeg(operand_value));
292      } else {
293        return operand_value;
294      }
295    }
296    case HloOpcode::kSign: {
297      bool is_signed =
298          primitive_util::IsSignedIntegralType(op->shape().element_type());
299      auto type =
300          llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
301      auto zero = llvm::ConstantInt::get(type, 0);
302      auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
303      if (is_signed) {
304        auto ashr = ir_builder_->CreateAShr(operand_value,
305                                            type->getIntegerBitWidth() - 1);
306        return ir_builder_->CreateSelect(cmp, zero,
307                                         ir_builder_->CreateOr(ashr, 1));
308      } else {
309        return ir_builder_->CreateSelect(cmp, zero,
310                                         llvm::ConstantInt::get(type, 1));
311      }
312    }
313    case HloOpcode::kNegate:
314      return ir_builder_->CreateNeg(operand_value);
315    case HloOpcode::kNot: {
316      auto type = op->shape().element_type();
317      if (type == PRED) {
318        // It is not sufficient to just call CreateNot() here because a PRED
319        // is represented as an i8 and the truth value is stored only in the
320        // bottom bit.
321        return ir_builder_->CreateZExt(
322            ir_builder_->CreateNot(ir_builder_->CreateTrunc(
323                operand_value, ir_builder_->getInt1Ty())),
324            llvm_ir::PrimitiveTypeToIrType(PRED, module_));
325      } else if (primitive_util::IsIntegralType(type)) {
326        return ir_builder_->CreateNot(operand_value);
327      }
328      return Unimplemented("unary op Not is not defined for type '%d'", type);
329    }
330    default:
331      return Unimplemented("unary integer op '%s'",
332                           HloOpcodeString(op->opcode()).c_str());
333  }
334}
335
336StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
337    const HloInstruction* op, llvm::Value* operand_value) const {
338  switch (op->opcode()) {
339    case HloOpcode::kConvert: {
340      PrimitiveType from_type = op->operand(0)->shape().element_type();
341      PrimitiveType to_type = op->shape().element_type();
342      CHECK(primitive_util::IsFloatingPointType(from_type));
343      if (from_type == to_type) {
344        return operand_value;
345      }
346      if (primitive_util::IsComplexType(to_type)) {
347        PrimitiveType to_component_type =
348            primitive_util::ComplexComponentType(to_type);
349        if (from_type == to_component_type) {
350          return EmitComposeComplex(op, operand_value, nullptr);
351        }
352        return EmitComposeComplex(
353            op,
354            ir_builder_->CreateFPCast(
355                operand_value,
356                llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
357            nullptr);
358      }
359      if (from_type == BF16) {
360        TF_RET_CHECK(to_type != BF16);
361        operand_value = EmitBF16ToF32(operand_value, ir_builder_);
362        from_type = F32;
363        if (from_type == to_type) {
364          return operand_value;
365        }
366      }
367      if (from_type == F32 && to_type == BF16) {
368        return EmitF32ToBF16(operand_value, ir_builder_);
369      }
370      if (primitive_util::IsFloatingPointType(to_type)) {
371        return ir_builder_->CreateFPCast(
372            operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
373      }
374      if (primitive_util::IsSignedIntegralType(to_type)) {
375        return ir_builder_->CreateFPToSI(
376            operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
377      }
378      if (primitive_util::IsUnsignedIntegralType(to_type)) {
379        return ir_builder_->CreateFPToUI(
380            operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
381      }
382      return Unimplemented("unhandled conversion operation: %s => %s",
383                           PrimitiveType_Name(from_type).c_str(),
384                           PrimitiveType_Name(to_type).c_str());
385    }
386    case HloOpcode::kBitcastConvert: {
387      PrimitiveType from_type = op->operand(0)->shape().element_type();
388      PrimitiveType to_type = op->shape().element_type();
389      CHECK(primitive_util::IsFloatingPointType(from_type));
390      if (from_type == to_type) {
391        return operand_value;
392      }
393      if (primitive_util::BitWidth(from_type) ==
394          primitive_util::BitWidth(to_type)) {
395        return ir_builder_->CreateBitCast(
396            operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
397      }
398      return InvalidArgument(
399          "bitcast conversion from primitive type %s to %s with unequal "
400          "bit-widths (%u versus %u) ",
401          PrimitiveType_Name(from_type).c_str(),
402          PrimitiveType_Name(to_type).c_str(),
403          primitive_util::BitWidth(from_type),
404          primitive_util::BitWidth(to_type));
405    }
406    case HloOpcode::kExp:
407      return EmitExp(op->shape().element_type(), operand_value);
408    case HloOpcode::kLog:
409      return EmitLog(op->shape().element_type(), operand_value);
410    case HloOpcode::kCos:
411      return EmitCos(op->shape().element_type(), operand_value);
412    case HloOpcode::kSin:
413      return EmitSin(op->shape().element_type(), operand_value);
414    case HloOpcode::kFloor:
415      return llvm_ir::EmitCallToIntrinsic(
416          llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()},
417          ir_builder_);
418    case HloOpcode::kCeil:
419      return llvm_ir::EmitCallToIntrinsic(
420          llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()},
421          ir_builder_);
422    case HloOpcode::kAbs:
423      return llvm_ir::EmitCallToIntrinsic(
424          llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()},
425          ir_builder_);
426    case HloOpcode::kRoundNearestAfz:
427      return llvm_ir::EmitCallToIntrinsic(
428          llvm::Intrinsic::round, {operand_value}, {operand_value->getType()},
429          ir_builder_);
430    case HloOpcode::kSign: {
431      // TODO(b/32151903): Ensure consistent sign behavior for -0.0.
432      auto type = operand_value->getType();
433      auto zero = llvm::ConstantFP::get(type, 0.0);
434      auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
435      auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero);
436      return ir_builder_->CreateSelect(
437          oeq, zero,
438          ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
439                                    llvm::ConstantFP::get(type, 1.0)));
440    }
441    case HloOpcode::kIsFinite: {
442      // (x == x) && abs(x) != inf
443      auto type = operand_value->getType();
444      auto equal_self =
445          ir_builder_->CreateFCmpOEQ(operand_value, operand_value);
446      auto abs_value = llvm_ir::EmitCallToIntrinsic(
447          llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_);
448      auto infinity = llvm::ConstantFP::getInfinity(type);
449      auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
450      auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite);
451      return ir_builder_->CreateZExt(
452          result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
453    }
454    case HloOpcode::kNegate:
455      return ir_builder_->CreateFNeg(operand_value);
456    default:
457      return Unimplemented("unary floating-point op '%s'",
458                           HloOpcodeString(op->opcode()).c_str());
459  }
460}
461
462StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
463    const HloInstruction* op, llvm::Value* operand_value) const {
464  PrimitiveType input_type = op->operand(0)->shape().element_type();
465  PrimitiveType component_type =
466      primitive_util::IsComplexType(input_type)
467          ? primitive_util::ComplexComponentType(input_type)
468          : input_type;
469  switch (op->opcode()) {
470    case HloOpcode::kLog: {
471      // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a)
472      auto a = EmitExtractReal(operand_value);
473      auto b = EmitExtractImag(operand_value);
474      llvm::Type* llvm_ty = a->getType();
475      auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
476                                            ir_builder_->CreateFMul(b, b));
477      TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
478      TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
479      auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
480      return EmitComposeComplex(
481          op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
482    }
483    case HloOpcode::kConvert: {
484      PrimitiveType from_type = op->operand(0)->shape().element_type();
485      TF_RET_CHECK(primitive_util::IsComplexType(from_type));
486      PrimitiveType to_type = op->shape().element_type();
487      TF_RET_CHECK(primitive_util::IsComplexType(to_type));
488      if (from_type == to_type) {
489        return operand_value;
490      }
491      PrimitiveType to_component_type =
492          primitive_util::ComplexComponentType(to_type);
493      auto to_ir_component_type =
494          llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
495      return EmitComposeComplex(
496          op,
497          ir_builder_->CreateFPCast(EmitExtractReal(operand_value),
498                                    to_ir_component_type),
499          ir_builder_->CreateFPCast(EmitExtractImag(operand_value),
500                                    to_ir_component_type));
501    }
502    case HloOpcode::kExp: {
503      // e^(a+bi) = e^a*(cos(b)+sin(b)i)
504      TF_ASSIGN_OR_RETURN(
505          auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value)));
506      TF_ASSIGN_OR_RETURN(
507          auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
508      TF_ASSIGN_OR_RETURN(
509          auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
510      return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
511                                ir_builder_->CreateFMul(exp_a, sin_b));
512    }
513    case HloOpcode::kCos: {
514      // cos(z) = .5(e^(iz) + e^(-iz))
515      // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai))
516      // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
517      // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i))
518      // cos(-x) = cos(x) and sin(-x) = -sin(x), so
519      // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i))
520      //           = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b))
521      auto a = EmitExtractReal(operand_value);
522      auto b = EmitExtractImag(operand_value);
523      auto type = a->getType();
524      TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
525      auto half_exp_b =
526          ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
527      auto half_exp_neg_b =
528          ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
529      TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
530      TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
531      return EmitComposeComplex(
532          op,
533          ir_builder_->CreateFMul(
534              cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
535          ir_builder_->CreateFMul(
536              sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
537    }
538    case HloOpcode::kSin: {
539      // sin(z) = .5i(e^(-iz) - e^(iz))
540      // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi)))
541      //           = .5i(e^(b-ai) - e^(-b+ai))
542      // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have
543      // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i))
544      //           = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a)))
545      // cos(-x) = cos(x) and sin(-x) = -sin(x), so
546      //           = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a)))
547      //           = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b)
548      auto a = EmitExtractReal(operand_value);
549      auto b = EmitExtractImag(operand_value);
550      auto type = a->getType();
551      TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
552      auto half_exp_b =
553          ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
554      auto half_exp_neg_b =
555          ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
556      TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
557      TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
558      return EmitComposeComplex(
559          op,
560          ir_builder_->CreateFMul(
561              sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
562          ir_builder_->CreateFMul(
563              cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
564    }
565    case HloOpcode::kTanh: {
566      /*
567      tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x))
568      e^(a+bi) = e^a*(cos(b)+sin(b)i)
569      so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) /
570              (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a))
571      cos(b)=cos(-b), sin(-b)=-sin(b)
572      so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) /
573              (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a))
574             =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) /
575              (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a))
576             =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) /
577              (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a))
578      This is a complex division, so we can multiply by denom_conj/denom_conj
579             =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) *
580              (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) /
581              ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
582             =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) +
583               i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) /
584              ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2)
585      */
586      auto a = EmitExtractReal(operand_value);
587      auto b = EmitExtractImag(operand_value);
588      TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
589      TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
590      TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
591      auto exp_neg_a = ir_builder_->CreateFDiv(
592          llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
593      auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub(
594          ir_builder_->CreateFMul(exp_a, exp_a),
595          ir_builder_->CreateFMul(exp_neg_a, exp_neg_a));
596      auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b);
597      auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b);
598      auto real_num = ir_builder_->CreateFAdd(
599          ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
600          ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
601      auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b);
602      auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a);
603      auto exp_a_plus_exp_neg_a_sq =
604          ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
605      auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a);
606      auto exp_a_minus_exp_neg_a_sq =
607          ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
608      auto imag_num = ir_builder_->CreateFMul(
609          cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq,
610                                               exp_a_minus_exp_neg_a_sq));
611      auto denom = ir_builder_->CreateFAdd(
612          ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
613          ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
614      return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom),
615                                ir_builder_->CreateFDiv(imag_num, denom));
616    }
617    case HloOpcode::kAbs: {
618      auto sum_sq = ir_builder_->CreateFAdd(
619          ir_builder_->CreateFMul(EmitExtractReal(operand_value),
620                                  EmitExtractReal(operand_value)),
621          ir_builder_->CreateFMul(EmitExtractImag(operand_value),
622                                  EmitExtractImag(operand_value)));
623      return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
624                                          {sum_sq->getType()}, ir_builder_);
625    }
626    case HloOpcode::kSign: {  // Sign(c) = c / |c|
627      auto sum_sq = ir_builder_->CreateFAdd(
628          ir_builder_->CreateFMul(EmitExtractReal(operand_value),
629                                  EmitExtractReal(operand_value)),
630          ir_builder_->CreateFMul(EmitExtractImag(operand_value),
631                                  EmitExtractImag(operand_value)));
632      auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
633          llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_);
634      auto type = cplx_abs->getType();
635      auto zero = llvm::ConstantFP::get(type, 0.0);
636      auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero);
637      return ir_builder_->CreateSelect(
638          oeq, EmitComposeComplex(op, zero, zero),
639          EmitComposeComplex(
640              op,
641              ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs),
642              ir_builder_->CreateFDiv(EmitExtractImag(operand_value),
643                                      cplx_abs)));
644    }
645    case HloOpcode::kNegate:
646      return EmitComposeComplex(
647          op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)),
648          ir_builder_->CreateFNeg(EmitExtractImag(operand_value)));
649    case HloOpcode::kReal:
650      return EmitExtractReal(operand_value);
651    case HloOpcode::kImag:
652      return EmitExtractImag(operand_value);
653    default:
654      return Unimplemented("unary complex op '%s'",
655                           HloOpcodeString(op->opcode()).c_str());
656  }
657}
658
659StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
660    const HloInstruction* op, llvm::Value* lhs_value,
661    llvm::Value* rhs_value) const {
662  PrimitiveType operand_type = op->operand(0)->shape().element_type();
663  if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) ||
664      operand_type == PRED) {
665    return EmitIntegerBinaryOp(
666        op, lhs_value, rhs_value,
667        primitive_util::IsSignedIntegralType(operand_type));
668  } else if (primitive_util::IsComplexType(operand_type)) {
669    return EmitComplexBinaryOp(op, lhs_value, rhs_value);
670  } else {
671    return EmitFloatBinaryOp(op, lhs_value, rhs_value);
672  }
673}
674
675StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
676    const HloInstruction* op, llvm::Value* lhs_value,
677    llvm::Value* rhs_value) const {
678  switch (op->opcode()) {
679    case HloOpcode::kComplex:
680      return EmitComposeComplex(op, lhs_value, rhs_value);
681    case HloOpcode::kAdd:
682      return ir_builder_->CreateFAdd(lhs_value, rhs_value);
683    case HloOpcode::kSubtract:
684      return ir_builder_->CreateFSub(lhs_value, rhs_value);
685    case HloOpcode::kMultiply:
686      return ir_builder_->CreateFMul(lhs_value, rhs_value);
687    case HloOpcode::kDivide:
688      return ir_builder_->CreateFDiv(lhs_value, rhs_value);
689    case HloOpcode::kRemainder:
690      return ir_builder_->CreateFRem(lhs_value, rhs_value);
691    // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
692    // comparisons always return false when one of the operands is NaN, whereas
693    // unordered comparisons return true.
694    //
695    // We use ordered comparisons for everything except kNe, where we use an
696    // unordered comparison.  This makes x != y equivalent to !(x == y), and
697    // matches C++'s semantics.
698    case HloOpcode::kEq:
699      return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
700                                     rhs_value, ir_builder_);
701    case HloOpcode::kNe:
702      return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
703                                     rhs_value, ir_builder_);
704    case HloOpcode::kLt:
705      return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
706                                     rhs_value, ir_builder_);
707    case HloOpcode::kGt:
708      return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
709                                     rhs_value, ir_builder_);
710    case HloOpcode::kLe:
711      return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
712                                     rhs_value, ir_builder_);
713    case HloOpcode::kGe:
714      return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
715                                     rhs_value, ir_builder_);
716
717    case HloOpcode::kMaximum:
718      return EmitFloatMax(lhs_value, rhs_value);
719    case HloOpcode::kMinimum:
720      return EmitFloatMin(lhs_value, rhs_value);
721    case HloOpcode::kPower:
722      return EmitPow(op->shape().element_type(), lhs_value, rhs_value);
723    case HloOpcode::kAtan2:
724      return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value);
725    default:
726      return Unimplemented("binary floating point op '%s'",
727                           HloOpcodeString(op->opcode()).c_str());
728  }
729}
730
731StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
732    const HloInstruction* op, llvm::Value* lhs_value,
733    llvm::Value* rhs_value) const {
734  switch (op->opcode()) {
735    case HloOpcode::kAdd:
736      return EmitComposeComplex(
737          op,
738          ir_builder_->CreateFAdd(EmitExtractReal(lhs_value),
739                                  EmitExtractReal(rhs_value)),
740          ir_builder_->CreateFAdd(EmitExtractImag(lhs_value),
741                                  EmitExtractImag(rhs_value)));
742    case HloOpcode::kSubtract:
743      return EmitComposeComplex(
744          op,
745          ir_builder_->CreateFSub(EmitExtractReal(lhs_value),
746                                  EmitExtractReal(rhs_value)),
747          ir_builder_->CreateFSub(EmitExtractImag(lhs_value),
748                                  EmitExtractImag(rhs_value)));
749    case HloOpcode::kMultiply:
750      return EmitComposeComplex(
751          op,
752          ir_builder_->CreateFSub(
753              ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
754                                      EmitExtractReal(rhs_value)),
755              ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
756                                      EmitExtractImag(rhs_value))),
757          ir_builder_->CreateFAdd(
758              ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
759                                      EmitExtractImag(rhs_value)),
760              ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
761                                      EmitExtractReal(rhs_value))));
762    case HloOpcode::kDivide: {
763      // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
764      // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
765      auto rhs_sum_sq = ir_builder_->CreateFAdd(
766          ir_builder_->CreateFMul(EmitExtractReal(rhs_value),
767                                  EmitExtractReal(rhs_value)),
768          ir_builder_->CreateFMul(EmitExtractImag(rhs_value),
769                                  EmitExtractImag(rhs_value)));
770      auto type = rhs_sum_sq->getType();
771      auto zero = llvm::ConstantFP::get(type, 0.0);
772      auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero);
773      auto real_inf_or_nan =
774          ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero);
775      auto imag_inf_or_nan =
776          ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero);
777      return ir_builder_->CreateSelect(
778          oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
779          EmitComposeComplex(
780              op,
781              ir_builder_->CreateFDiv(
782                  ir_builder_->CreateFAdd(
783                      ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
784                                              EmitExtractReal(rhs_value)),
785                      ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
786                                              EmitExtractImag(rhs_value))),
787                  rhs_sum_sq),
788              ir_builder_->CreateFDiv(
789                  ir_builder_->CreateFSub(
790                      ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
791                                              EmitExtractReal(rhs_value)),
792                      ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
793                                              EmitExtractImag(rhs_value))),
794                  rhs_sum_sq)));
795    }
796    // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
797    // comparisons always return false when one of the operands is NaN, whereas
798    // unordered comparisons return true.
799    //
800    // We use ordered comparisons for everything except kNe, where we use an
801    // unordered comparison.  This makes x != y equivalent to !(x == y), and
802    // matches C++'s semantics.
803    case HloOpcode::kEq:
804      return ir_builder_->CreateAnd(
805          llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
806                                  EmitExtractReal(lhs_value),
807                                  EmitExtractReal(rhs_value), ir_builder_),
808          llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
809                                  EmitExtractImag(lhs_value),
810                                  EmitExtractImag(rhs_value), ir_builder_));
811    case HloOpcode::kNe:
812      return ir_builder_->CreateOr(
813          llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
814                                  EmitExtractReal(lhs_value),
815                                  EmitExtractReal(rhs_value), ir_builder_),
816          llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
817                                  EmitExtractImag(lhs_value),
818                                  EmitExtractImag(rhs_value), ir_builder_));
819
820    case HloOpcode::kPower: {
821      // (a+bi)^(c+di) =
822      //    (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)),
823      //    where q = c*atan2(b,a)+0.5d*ln(a*a+b*b)
824      PrimitiveType component_type =
825          primitive_util::ComplexComponentType(op->shape().element_type());
826      auto a = EmitExtractReal(lhs_value);
827      auto b = EmitExtractImag(lhs_value);
828      auto c = EmitExtractReal(rhs_value);
829      auto d = EmitExtractImag(rhs_value);
830      auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
831                                             ir_builder_->CreateFMul(b, b));
832      auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
833      auto half_c = ir_builder_->CreateFMul(one_half, c);
834
835      TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
836                          EmitPow(component_type, aa_p_bb, half_c));
837      auto neg_d = ir_builder_->CreateFNeg(d);
838      TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
839      auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs);
840      TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
841                          EmitExp(component_type, neg_d_arg_lhs));
842      auto coeff =
843          ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
844      TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
845      auto half_d = ir_builder_->CreateFMul(one_half, d);
846      auto q =
847          ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs),
848                                  ir_builder_->CreateFMul(half_d, ln_aa_p_bb));
849      TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
850      TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
851      return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q),
852                                ir_builder_->CreateFMul(coeff, sin_q));
853    }
854    default:
855      return Unimplemented("binary complex op '%s'",
856                           HloOpcodeString(op->opcode()).c_str());
857  }
858}
859
860llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
861                                              llvm::Value* rhs_value) const {
862  return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_);
863}
864
865llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
866                                              llvm::Value* rhs_value) const {
867  return llvm_ir::EmitFloatMin(lhs_value, rhs_value, ir_builder_);
868}
869
870StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
871                                                      llvm::Value* x) const {
872  if (prim_type != F32) {
873    // TODO(b/34339814): Implement inverse erf for F64.
874    return Unimplemented(
875        "Inverse erf is only implemented for element "
876        "type F32.");
877  }
878  auto getFloat = [&](const float f) {
879    return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
880  };
881  auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
882                          llvm::Value* w) {
883    llvm::Value* p = getFloat(coefficients.front());
884    coefficients.pop_front();
885    for (float coefficient : coefficients) {
886      p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w),
887                                  getFloat(coefficient));
888    }
889    return p;
890  };
891
892  // Approximation for inverse error function from
893  //   Giles, M., "Approximating the erfinv function".
894  // The approximation has the form:
895  //   w = log((1-x)*(1+x))
896  //   if ( w < 5 ) {
897  //     w = w - 2.5
898  //     p = sum_{i=1}^n lq[i]*w^i
899  //   } else {
900  //     w = sqrt(w) - 3
901  //     p = sum_{i=1}^n gq[i]*w^i
902  //   }
903  //   return p*x
904  llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
905      module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()});
906
907  llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall(
908      logf_fn,
909      {ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x),
910                               ir_builder_->CreateFAdd(getFloat(1.0f), x))}));
911
912  llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry(
913      ir_builder_->getFloatTy(), "p.addr", ir_builder_);
914
915  llvm_ir::LlvmIfData if_data =
916      llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)),
917                              "w_less_than_five", ir_builder_);
918  // Handle true BB.
919  SetToFirstInsertPoint(if_data.true_block, ir_builder_);
920  {
921    llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f));
922    tensorflow::gtl::ArraySlice<float> lq{
923        2.81022636e-08f,  3.43273939e-07f, -3.5233877e-06f,
924        -4.39150654e-06f, 0.00021858087f,  -0.00125372503f,
925        -0.00417768164f,  0.246640727f,    1.50140941f};
926    llvm::Value* p = multiply_add(lq, lw);
927    ir_builder_->CreateStore(p, p_addr);
928  }
929
930  // Handle false BB.
931  SetToFirstInsertPoint(if_data.false_block, ir_builder_);
932  {
933    llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
934        module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()});
935
936    llvm::Value* gw = ir_builder_->CreateFSub(
937        ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
938    tensorflow::gtl::ArraySlice<float> gq{
939        -0.000200214257f, 0.000100950558f, 0.00134934322f,
940        -0.00367342844f,  0.00573950773f,  -0.0076224613f,
941        0.00943887047f,   1.00167406f,     2.83297682f};
942    llvm::Value* p = multiply_add(gq, gw);
943    ir_builder_->CreateStore(p, p_addr);
944  }
945
946  SetToFirstInsertPoint(if_data.after_block, ir_builder_);
947  llvm::Value* p = ir_builder_->CreateLoad(p_addr);
948  return ir_builder_->CreateFMul(p, x);
949}
950
951StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
952    PrimitiveType prim_type, llvm::Value* value) const {
953  // Compute erfcinv(value) by calculating erfinv(1.0 - value).
954  auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
955  auto one = llvm::ConstantFP::get(type, 1.0);
956  return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
957}
958
959StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
960                                                   llvm::Value* value) const {
961  return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
962                                      {value->getType()}, ir_builder_);
963}
964
965StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
966                                                   llvm::Value* value) const {
967  return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
968                                      {value->getType()}, ir_builder_);
969}
970
971StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
972                                                   llvm::Value* value) const {
973  return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
974                                      {value->getType()}, ir_builder_);
975}
976
977StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
978                                                   llvm::Value* value) const {
979  return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
980                                      {value->getType()}, ir_builder_);
981}
982
983StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
984                                                   llvm::Value* lhs,
985                                                   llvm::Value* rhs) const {
986  return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
987                                      {lhs->getType()}, ir_builder_);
988}
989
990StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
991                                                     llvm::Value* lhs,
992                                                     llvm::Value* rhs) const {
993  return Unimplemented("atan2");
994}
995
996StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
997    const HloInstruction* hlo, llvm::Value* x) const {
998  if (hlo->operand(0)->shape().element_type() != F32) {
999    return Unimplemented("reduce-precision only implemented for F32");
1000  }
1001  return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
1002                                  /*mantissa_bits=*/hlo->mantissa_bits(),
1003                                  ir_builder_);
1004}
1005
1006StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
1007    const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
1008    bool is_signed) const {
1009  switch (op->opcode()) {
1010    // TODO(jingyue): add the "nsw" attribute for signed types.
1011    case HloOpcode::kAdd:
1012      return ir_builder_->CreateAdd(lhs_value, rhs_value);
1013    case HloOpcode::kSubtract:
1014      return ir_builder_->CreateSub(lhs_value, rhs_value);
1015    case HloOpcode::kMultiply:
1016      return ir_builder_->CreateMul(lhs_value, rhs_value);
1017    case HloOpcode::kDivide:
1018      return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value)
1019                       : ir_builder_->CreateUDiv(lhs_value, rhs_value);
1020    case HloOpcode::kRemainder:
1021      return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value)
1022                       : ir_builder_->CreateURem(lhs_value, rhs_value);
1023    case HloOpcode::kEq:
1024      return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
1025                                     rhs_value, ir_builder_);
1026    case HloOpcode::kNe:
1027      return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
1028                                     rhs_value, ir_builder_);
1029    case HloOpcode::kLt:
1030      return llvm_ir::EmitComparison(
1031          is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
1032          lhs_value, rhs_value, ir_builder_);
1033    case HloOpcode::kGt:
1034      return llvm_ir::EmitComparison(
1035          is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
1036          lhs_value, rhs_value, ir_builder_);
1037    case HloOpcode::kLe:
1038      return llvm_ir::EmitComparison(
1039          is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
1040          lhs_value, rhs_value, ir_builder_);
1041    case HloOpcode::kGe:
1042      return llvm_ir::EmitComparison(
1043          is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
1044          lhs_value, rhs_value, ir_builder_);
1045    case HloOpcode::kMinimum:
1046      return EmitIntegralMin(lhs_value, rhs_value, is_signed);
1047    case HloOpcode::kMaximum:
1048      return EmitIntegralMax(lhs_value, rhs_value, is_signed);
1049    case HloOpcode::kAnd:
1050      return ir_builder_->CreateAnd(lhs_value, rhs_value);
1051    case HloOpcode::kOr:
1052      return ir_builder_->CreateOr(lhs_value, rhs_value);
1053    case HloOpcode::kShiftLeft:
1054      return ir_builder_->CreateShl(lhs_value, rhs_value);
1055    case HloOpcode::kShiftRightArithmetic:
1056      return ir_builder_->CreateAShr(lhs_value, rhs_value);
1057    case HloOpcode::kShiftRightLogical:
1058      return ir_builder_->CreateLShr(lhs_value, rhs_value);
1059    default:
1060      return Unimplemented("binary integer op '%s'",
1061                           HloOpcodeString(op->opcode()).c_str());
1062  }
1063}
1064
1065llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
1066                                                 llvm::Value* rhs_value,
1067                                                 bool is_signed) const {
1068  return ir_builder_->CreateSelect(
1069      ir_builder_->CreateICmp(
1070          is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
1071          lhs_value, rhs_value),
1072      lhs_value, rhs_value);
1073}
1074
1075llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
1076                                                 llvm::Value* rhs_value,
1077                                                 bool is_signed) const {
1078  return ir_builder_->CreateSelect(
1079      ir_builder_->CreateICmp(
1080          is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
1081          lhs_value, rhs_value),
1082      lhs_value, rhs_value);
1083}
1084
1085llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
1086    const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
1087    int64 operand_no) const {
1088  CHECK(hlo.IsElementwise())
1089      << "HLO " << hlo.ToString() << " is not elementwise.";
1090
1091  const Shape& operand_shape = hlo.operand(operand_no)->shape();
1092  // If the operand is scalar, the source index is always {}.
1093  if (ShapeUtil::IsScalar(operand_shape)) {
1094    return llvm_ir::IrArray::Index();
1095  }
1096
1097  // If no implicit broadcast is needed for this operand, returns the target
1098  // index as the source index.
1099  if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) {
1100    return target_index;
1101  }
1102
1103  // If implicit broadcast is needed, the source dimensions that are broadcast
1104  // have index 0.
1105  CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape()));
1106  llvm_ir::IrArray::Index source_index;
1107  for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) {
1108    if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) {
1109      source_index.push_back(target_index[i]);
1110    } else {
1111      CHECK_EQ(1, operand_shape.dimensions(i));
1112      source_index.push_back(ir_builder_->getInt64(0));
1113    }
1114  }
1115  return source_index;
1116}
1117
1118llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
1119    const HloInstruction* hlo,
1120    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
1121    const {
1122  PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type();
1123  llvm::Type* param_ir_type =
1124      llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_);
1125
1126  // Same values as PCG library
1127  // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
1128  llvm::Value* multiplier = ir_builder_->getInt(
1129      llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4}));
1130  llvm::Value* increment = ir_builder_->getInt(
1131      llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D}));
1132
1133  auto random_value = [hlo]() {
1134    const HloModule* module =
1135        hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent()
1136                       : hlo->parent()->parent();
1137    return module->RandomNew64();
1138  };
1139
1140  // Seed each RNG emitter with a new 64-bit seed from the HloModule. If the
1141  // compilation order is deterministic (i.e., RandomNew64 invocation order is
1142  // deterministic), then the order of RNG is deterministic for a given seed and
1143  // hence tests will be deterministic.
1144  // If the user provides a global seed instruction then we only use 64-bits of
1145  // the host's random number generator to seed the 128 bit value with the other
1146  // 64-bits is due to a user specified global seed instruction.
1147  // Create a GlobalVariable to maintain state between invocations. There is a
1148  // bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit
1149  // values.
1150  llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable(
1151      /*M=*/*module_,
1152      /*Ty=*/ir_builder_->getInt64Ty(),
1153      /*isConstant=*/false,
1154      /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
1155      /*Initializer=*/ir_builder_->getInt64(random_value()),
1156      /*Name=*/"state_ptr0");
1157  uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed()
1158                                                     : random_value();
1159  llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable(
1160      /*M=*/*module_,
1161      /*Ty=*/ir_builder_->getInt64Ty(),
1162      /*isConstant=*/false,
1163      /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
1164      /*Initializer=*/ir_builder_->getInt64(graph_seed),
1165      /*Name=*/"state_ptr1");
1166
1167  // We want each thread to use its own stream, so we modify the increment per
1168  // thread. We want the increment to remain odd, so we shift the thread id left
1169  // 1 and add it to the increment.
1170  increment = ir_builder_->CreateAdd(increment,
1171                                     ir_builder_->CreateShl(EmitThreadId(), 1));
1172
1173  // PCG-XSL-RR algorithm
1174  // http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf
1175  //   state = multiplier * state + increment
1176  //   return uint64_t(state ^ (state >> 64))) >>> (state >> 122)
1177  // where ">>>" is bitwise rotation
1178  auto get_next_i64 = [=]() {
1179    llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc(
1180        ir_builder_->CreateLoad(state_ptr0, "state0"),
1181        ir_builder_->getInt128Ty());
1182    llvm::Value* state1 = ir_builder_->CreateShl(
1183        ir_builder_->CreateZExtOrTrunc(
1184            ir_builder_->CreateLoad(state_ptr1, "state1"),
1185            ir_builder_->getInt128Ty()),
1186        64);
1187    llvm::Value* state = ir_builder_->CreateOr(state0, state1);
1188    llvm::Value* updated = ir_builder_->CreateAdd(
1189        ir_builder_->CreateMul(state, multiplier), increment);
1190    ir_builder_->CreateStore(
1191        ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()),
1192        state_ptr0);
1193    ir_builder_->CreateStore(
1194        ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64),
1195                                 ir_builder_->getInt64Ty()),
1196        state_ptr1);
1197
1198    return llvm_ir::CreateRor(
1199        ir_builder_->CreateTrunc(
1200            ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)),
1201            ir_builder_->getInt64Ty()),
1202        ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122),
1203                                 ir_builder_->getInt64Ty()),
1204        ir_builder_);
1205  };
1206
1207  auto get_next_uniform_float = [=]() {
1208    return ir_builder_->CreateFDiv(
1209        ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type),
1210        llvm::ConstantFP::get(param_ir_type, 0x1p64));
1211  };
1212
1213  return [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
1214    switch (hlo->random_distribution()) {
1215      case RNG_UNIFORM: {
1216        TF_ASSIGN_OR_RETURN(llvm::Value * p,
1217                            operand_to_generator.at(hlo->operand(0))(index));
1218        TF_ASSIGN_OR_RETURN(llvm::Value * q,
1219                            operand_to_generator.at(hlo->operand(1))(index));
1220        if (primitive_util::IsFloatingPointType(param_prim_type)) {
1221          return ir_builder_->CreateFAdd(
1222              ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p),
1223                                      get_next_uniform_float()),
1224              p);
1225        } else {
1226          auto r = ir_builder_->CreateSub(q, p);
1227          auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
1228              llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)},
1229              {param_ir_type}, ir_builder_);
1230          auto in_block = ir_builder_->GetInsertBlock();
1231
1232          // A terminator should be present iff we're emitting code
1233          // into the middle (as opposed to the end) of a basic block.
1234          CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(),
1235                   in_block->getTerminator() == nullptr);
1236
1237          llvm::BasicBlock* body_block;
1238          llvm::BasicBlock* out_block;
1239
1240          if (ir_builder_->GetInsertPoint() == in_block->end()) {
1241            body_block = llvm_ir::CreateBasicBlock(
1242                nullptr, IrName(hlo, "rng_body"), ir_builder_);
1243            out_block = llvm_ir::CreateBasicBlock(
1244                nullptr, IrName(hlo, "rng_out"), ir_builder_);
1245            llvm::BranchInst::Create(body_block, in_block);
1246          } else {
1247            body_block = in_block->splitBasicBlock(
1248                ir_builder_->GetInsertPoint(), "rng_body");
1249            out_block = body_block->splitBasicBlock(
1250                ir_builder_->GetInsertPoint(), "rng_out");
1251            body_block->getTerminator()->eraseFromParent();
1252          }
1253
1254          SetToFirstInsertPoint(body_block, ir_builder_);
1255          auto random = ir_builder_->CreateAnd(
1256              ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type),
1257              ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0),
1258                                      leading_zeros));
1259          llvm::BranchInst::Create(out_block, body_block,
1260                                   ir_builder_->CreateICmpULT(random, r),
1261                                   body_block);
1262          SetToFirstInsertPoint(out_block, ir_builder_);
1263          return ir_builder_->CreateAdd(
1264              p, ir_builder_->CreateSelect(
1265                     ir_builder_->CreateICmpEQ(p, q),
1266                     llvm::ConstantInt::get(param_ir_type, 0), random));
1267        }
1268      }
1269      case RNG_NORMAL: {
1270        TF_ASSIGN_OR_RETURN(llvm::Value * m,
1271                            operand_to_generator.at(hlo->operand(0))(index));
1272        TF_ASSIGN_OR_RETURN(llvm::Value * s,
1273                            operand_to_generator.at(hlo->operand(1))(index));
1274        TF_ASSIGN_OR_RETURN(
1275            llvm::Value * r,
1276            EmitErfcInv(param_prim_type,
1277                        ir_builder_->CreateFMul(
1278                            llvm::ConstantFP::get(param_ir_type, 2.0),
1279                            get_next_uniform_float())));
1280        return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m);
1281      }
1282      default:
1283        return InvalidArgument(
1284            "unhandled distribution %s",
1285            RandomDistribution_Name(hlo->random_distribution()).c_str());
1286    }
1287  };
1288}
1289
1290llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
1291    const HloInstruction* hlo,
1292    const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
1293    const {
1294  switch (hlo->opcode()) {
1295    case HloOpcode::kAbs:
1296    case HloOpcode::kRoundNearestAfz:
1297    case HloOpcode::kCeil:
1298    case HloOpcode::kConvert:
1299    case HloOpcode::kBitcastConvert:
1300    case HloOpcode::kCopy:
1301    case HloOpcode::kCos:
1302    case HloOpcode::kExp:
1303    case HloOpcode::kFloor:
1304    case HloOpcode::kImag:
1305    case HloOpcode::kIsFinite:
1306    case HloOpcode::kLog:
1307    case HloOpcode::kNegate:
1308    case HloOpcode::kNot:
1309    case HloOpcode::kReal:
1310    case HloOpcode::kSign:
1311    case HloOpcode::kSin:
1312    case HloOpcode::kTanh:
1313      return [this, hlo, &operand_to_generator](
1314                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
1315        TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
1316                            operand_to_generator.at(hlo->operand(0))(
1317                                ElementwiseSourceIndex(index, *hlo, 0)));
1318        return EmitUnaryOp(hlo, operand_value);
1319      };
1320    case HloOpcode::kAdd:
1321    case HloOpcode::kAnd:
1322    case HloOpcode::kAtan2:
1323    case HloOpcode::kComplex:
1324    case HloOpcode::kDivide:
1325    case HloOpcode::kEq:
1326    case HloOpcode::kGe:
1327    case HloOpcode::kGt:
1328    case HloOpcode::kLe:
1329    case HloOpcode::kLt:
1330    case HloOpcode::kMaximum:
1331    case HloOpcode::kMinimum:
1332    case HloOpcode::kMultiply:
1333    case HloOpcode::kNe:
1334    case HloOpcode::kOr:
1335    case HloOpcode::kPower:
1336    case HloOpcode::kRemainder:
1337    case HloOpcode::kShiftLeft:
1338    case HloOpcode::kShiftRightArithmetic:
1339    case HloOpcode::kShiftRightLogical:
1340    case HloOpcode::kSubtract:
1341      return [this, hlo, &operand_to_generator](
1342                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
1343        const HloInstruction* lhs = hlo->operand(0);
1344        const HloInstruction* rhs = hlo->operand(1);
1345        TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
1346                            operand_to_generator.at(lhs)(
1347                                ElementwiseSourceIndex(index, *hlo, 0)));
1348        TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
1349                            operand_to_generator.at(rhs)(
1350                                ElementwiseSourceIndex(index, *hlo, 1)));
1351        return EmitBinaryOp(hlo, lhs_value, rhs_value);
1352      };
1353    case HloOpcode::kSelect:
1354      return [this, hlo, &operand_to_generator](
1355                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
1356        TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
1357                            operand_to_generator.at(hlo->operand(0))(
1358                                ElementwiseSourceIndex(index, *hlo, 0)));
1359        TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
1360                            operand_to_generator.at(hlo->operand(1))(
1361                                ElementwiseSourceIndex(index, *hlo, 1)));
1362        TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
1363                            operand_to_generator.at(hlo->operand(2))(
1364                                ElementwiseSourceIndex(index, *hlo, 2)));
1365        return ir_builder_->CreateSelect(
1366            ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
1367            on_true_value, on_false_value);
1368      };
1369    case HloOpcode::kClamp:
1370      return [this, hlo, &operand_to_generator](
1371                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
1372        TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
1373                            operand_to_generator.at(hlo->operand(0))(
1374                                ElementwiseSourceIndex(index, *hlo, 0)));
1375        TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
1376                            operand_to_generator.at(hlo->operand(1))(
1377                                ElementwiseSourceIndex(index, *hlo, 1)));
1378        TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
1379                            operand_to_generator.at(hlo->operand(2))(
1380                                ElementwiseSourceIndex(index, *hlo, 2)));
1381        PrimitiveType prim_type = hlo->shape().element_type();
1382        if (primitive_util::IsFloatingPointType(prim_type)) {
1383          return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
1384        } else if (primitive_util::IsIntegralType(prim_type)) {
1385          bool is_signed = primitive_util::IsSignedIntegralType(prim_type);
1386          return EmitIntegralMin(
1387              max_value, EmitIntegralMax(min_value, arg_value, is_signed),
1388              is_signed);
1389        } else {
1390          return Unimplemented("Clamp unimplemented for %s",
1391                               PrimitiveType_Name(prim_type).c_str());
1392        }
1393      };
1394    case HloOpcode::kReducePrecision:
1395      return [this, hlo, &operand_to_generator](
1396                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
1397        TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
1398                            operand_to_generator.at(hlo->operand(0))(
1399                                ElementwiseSourceIndex(index, *hlo, 0)));
1400        return EmitReducePrecision(hlo, operand_value);
1401      };
1402    case HloOpcode::kConcatenate:
1403      return [this, hlo, &operand_to_generator](
1404                 const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
1405        const int64 concat_dim = hlo->dimensions(0);
1406        auto source_index = target_index;
1407
1408        llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
1409
1410        // A terminator should be present iff we're emitting code
1411        // into the middle (as opposed to the end) of a basic block.
1412        CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(),
1413                 init_block->getTerminator() == nullptr);
1414
1415        llvm::BasicBlock* exit_block;
1416        if (ir_builder_->GetInsertPoint() == init_block->end()) {
1417          exit_block = llvm_ir::CreateBasicBlock(
1418              /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_);
1419        } else {
1420          exit_block = init_block->splitBasicBlock(
1421              ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge")));
1422          init_block->getTerminator()->eraseFromParent();
1423        }
1424
1425        llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
1426        llvm::PHINode* output =
1427            ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType(
1428                                       hlo->shape().element_type(), module_),
1429                                   hlo->operands().size());
1430        auto prior_insert_point = ir_builder_->GetInsertPoint();
1431
1432        ir_builder_->SetInsertPoint(init_block);
1433
1434        for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
1435             ++operand_idx) {
1436          const HloInstruction* operand = hlo->operand(operand_idx);
1437          auto true_block = llvm_ir::CreateBasicBlock(
1438              exit_block, StrCat("concat_index_from_operand", operand_idx),
1439              ir_builder_);
1440          auto false_block = llvm_ir::CreateBasicBlock(
1441              exit_block, StrCat("concat_index_not_from_operand", operand_idx),
1442              ir_builder_);
1443          auto concat_dim_size =
1444              llvm::ConstantInt::get(source_index[concat_dim]->getType(),
1445                                     operand->shape().dimensions(concat_dim));
1446          ir_builder_->CreateCondBr(
1447              ir_builder_->CreateICmpULT(source_index[concat_dim],
1448                                         concat_dim_size),
1449              true_block, false_block);
1450
1451          // Create the terminator of the true block before calling operand
1452          // generators, because they require non-degenerate basic blocks.
1453          ir_builder_->SetInsertPoint(
1454              llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
1455          TF_ASSIGN_OR_RETURN(llvm::Value * value,
1456                              operand_to_generator.at(operand)(source_index));
1457          output->addIncoming(value, ir_builder_->GetInsertBlock());
1458
1459          // Subtract the size of the concat dimension of the current operand
1460          // from the source index.
1461          ir_builder_->SetInsertPoint(false_block);
1462          source_index[concat_dim] =
1463              ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
1464        }
1465
1466        ir_builder_->CreateUnreachable();
1467        ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
1468        return output;
1469      };
1470    case HloOpcode::kReverse:
1471      return [this, hlo, &operand_to_generator](
1472                 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
1473        const HloInstruction* operand = hlo->operand(0);
1474        auto source_index = target_index;
1475        for (int64 dim : hlo->dimensions()) {
1476          source_index[dim] = ir_builder_->CreateSub(
1477              llvm::ConstantInt::get(target_index[dim]->getType(),
1478                                     hlo->shape().dimensions(dim) - 1),
1479              target_index[dim]);
1480        }
1481        return operand_to_generator.at(operand)(source_index);
1482      };
1483    case HloOpcode::kBroadcast:
1484      return [this, hlo, &operand_to_generator](
1485                 const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
1486        // The `dimensions` member of the broadcast instruction maps from
1487        // input dimensions to output dimensions.
1488        const HloInstruction* operand = hlo->operand(0);
1489        int64 rank = ShapeUtil::Rank(operand->shape());
1490        IrArray::Index source_index(rank);
1491        for (int64 i = 0; i < rank; ++i) {
1492          source_index[i] = target_index[hlo->dimensions(i)];
1493        }
1494        return operand_to_generator.at(operand)(source_index);
1495      };
1496    case HloOpcode::kSlice:
1497      return [this, hlo, &operand_to_generator](
1498                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
1499        IrArray::Index sliced_index = index.SourceIndexOfSlice(
1500            /*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(),
1501            /*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_);
1502        return operand_to_generator.at(hlo->operand(0))(sliced_index);
1503      };
1504    case HloOpcode::kDynamicSlice:
1505      return [this, hlo, &operand_to_generator](
1506                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
1507        // Emit IR to read dynamic start indices from hlo->operand(1).
1508        const HloInstruction* input_hlo = hlo->operand(0);
1509        const int64 rank = ShapeUtil::Rank(input_hlo->shape());
1510        llvm_ir::IrArray::Index slice_start_index(rank);
1511        for (int64 i = 0; i < rank; ++i) {
1512          llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
1513          TF_ASSIGN_OR_RETURN(
1514              llvm::Value * start_index_value,
1515              operand_to_generator.at(hlo->operand(1))(dim_index));
1516          start_index_value->setName(
1517              AsStringRef(IrName(hlo, StrCat("start_idx", i))));
1518          slice_start_index[i] = start_index_value;
1519        }
1520
1521        llvm_ir::IrArray::Index input_index(rank);
1522        for (int64 i = 0; i < rank; ++i) {
1523          // Emit IR which computes:
1524          //   input_index = (start_index + offset_index) % dim_size
1525          // Security note: this is the code that keeps the indices in-bounds.
1526          llvm::Value* dim_size = llvm::ConstantInt::get(
1527              index[i]->getType(), input_hlo->shape().dimensions(i));
1528          llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
1529              slice_start_index[i], index[i]->getType());
1530          input_index[i] = ir_builder_->CreateURem(
1531              ir_builder_->CreateAdd(start_index, index[i]), dim_size);
1532        }
1533        return operand_to_generator.at(input_hlo)(input_index);
1534      };
1535    case HloOpcode::kDynamicUpdateSlice:
1536      return [this, hlo, &operand_to_generator](
1537                 const IrArray::Index& index) -> StatusOr<llvm::Value*> {
1538        const HloInstruction* input_hlo = hlo->operand(0);
1539        const HloInstruction* update_hlo = hlo->operand(1);
1540        const HloInstruction* start_hlo = hlo->operand(2);
1541        // Calculate slice start/end indices.
1542        const int64 rank = ShapeUtil::Rank(input_hlo->shape());
1543        llvm_ir::IrArray::Index slice_start_index(rank);
1544        llvm_ir::IrArray::Index slice_limit_index(rank);
1545        // Slice starts at update[index - slice_start_index_adjusted],
1546        // where adjusted value = slice_start_index when in bounds, and
1547        // adjusted value = slice_start_index - input_dim, when wrapping.
1548        llvm_ir::IrArray::Index slice_start_index_adjusted(rank);
1549
1550        // Slice intersection gathers (ANDs) conditions on all ranks for which
1551        // 'input' is set to 'update'
1552        llvm::Value* slice_intersection = ir_builder_->getTrue();
1553
1554        for (int64 i = 0; i < rank; ++i) {
1555          // Emit IR to read dynamic start indices from 'start_hlo'.
1556          llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
1557          TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
1558                              operand_to_generator.at(start_hlo)(dim_index));
1559          start_index_value->setName(
1560              AsStringRef(IrName(hlo, StrCat("start_idx", i))));
1561          slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
1562              start_index_value, index[i]->getType());
1563
1564          llvm::Value* input_dim_size = llvm::ConstantInt::get(
1565              index[i]->getType(), input_hlo->shape().dimensions(i));
1566          llvm::Value* update_dim_size = llvm::ConstantInt::get(
1567              index[i]->getType(), update_hlo->shape().dimensions(i));
1568
1569          // Generate code to handle wrapping semantics:
1570          // slice_start_index[i] = slice_start_index[i] % input_dim_size;
1571          // slice_limit_index[i] = slice_start_index[i] + update_dim_size.
1572          // slice_start_index[i] is updated in place and it will now be in
1573          // range. slice_limit_index[i] may be out of range, and it's being
1574          // URem-ed below if so.
1575          slice_start_index[i] =
1576              ir_builder_->CreateURem(slice_start_index[i], input_dim_size);
1577          slice_limit_index[i] =
1578              ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
1579
1580          // Test if slice_limit_index[i] is in bounds
1581          llvm::Value* in_bounds =
1582              ir_builder_->CreateICmpULE(slice_limit_index[i], input_dim_size);
1583          llvm_ir::LlvmIfData if_in_bounds =
1584              llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
1585
1586          // Handle true BB (slice_limit_index[i] <= input_dim_size).
1587          SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
1588          // Check that index[i] >= slice_start_index[i] &&
1589          //            index[i] < slice_limit_index[i]
1590          llvm::Value* slice_intersection_in_bounds = ir_builder_->CreateAnd(
1591              slice_intersection,
1592              ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
1593              "slice_intersection_in");
1594          slice_intersection_in_bounds = ir_builder_->CreateAnd(
1595              slice_intersection_in_bounds,
1596              ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
1597              "slice_intersection_in");
1598
1599          // Handle false BB (slice_limit_index[i] > input_dim_size).
1600          SetToFirstInsertPoint(if_in_bounds.false_block, ir_builder_);
1601          // Check that index[i] >= slice_start_index[i] ||
1602          //            index[i] < slice_limit_index[i]%input_dim_size.
1603          llvm::Value* index_wraps = ir_builder_->CreateICmpSLT(
1604              index[i],
1605              ir_builder_->CreateURem(slice_limit_index[i], input_dim_size));
1606          llvm::Value* slice_intersection_or = ir_builder_->CreateOr(
1607              ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
1608              index_wraps, "slice_intersection_out");
1609          llvm::Value* slice_intersection_out_of_bounds =
1610              ir_builder_->CreateAnd(slice_intersection, slice_intersection_or,
1611                                     "slice_intersection_out");
1612          // Create value for slice_start_index_adjusted[i] when out of bounds.
1613          // If within out-of-bounds if.
1614          llvm_ir::LlvmIfData if_start_needs_adjustment =
1615              llvm_ir::EmitIfThenElse(index_wraps, "adjust_start", ir_builder_);
1616          SetToFirstInsertPoint(if_start_needs_adjustment.true_block,
1617                                ir_builder_);
1618          llvm::Value* slice_start_index_adjusted_oob =
1619              ir_builder_->CreateSub(slice_start_index[i], input_dim_size);
1620          SetToFirstInsertPoint(if_start_needs_adjustment.after_block,
1621                                ir_builder_);
1622          llvm::PHINode* slice_start_index_adjusted_phi =
1623              ir_builder_->CreatePHI(slice_start_index_adjusted_oob->getType(),
1624                                     2);
1625          slice_start_index_adjusted_phi->addIncoming(
1626              slice_start_index_adjusted_oob,
1627              if_start_needs_adjustment.true_block);
1628          slice_start_index_adjusted_phi->addIncoming(
1629              slice_start_index[i], if_start_needs_adjustment.false_block);
1630          // End of if within if.
1631
1632          // After checking in/out of bounds.
1633          SetToFirstInsertPoint(if_in_bounds.after_block, ir_builder_);
1634          llvm::PHINode* phi_slice_intersection =
1635              ir_builder_->CreatePHI(slice_intersection->getType(), 2);
1636          phi_slice_intersection->addIncoming(slice_intersection_in_bounds,
1637                                              if_in_bounds.true_block);
1638          phi_slice_intersection->addIncoming(
1639              slice_intersection_out_of_bounds,
1640              if_start_needs_adjustment.after_block);
1641          slice_intersection = phi_slice_intersection;
1642
1643          llvm::PHINode* phi_index =
1644              ir_builder_->CreatePHI(slice_start_index[i]->getType(), 2);
1645          phi_index->addIncoming(slice_start_index[i], if_in_bounds.true_block);
1646          phi_index->addIncoming(slice_start_index_adjusted_phi,
1647                                 if_start_needs_adjustment.after_block);
1648          slice_start_index_adjusted[i] = phi_index;
1649        }
1650
1651        // Emit:
1652        // if (slice_intersection) -> return data from 'update'.
1653        // else                    -> return data from 'input'.
1654        llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
1655            llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
1656                                           module_),
1657            "ret_value_addr", ir_builder_);
1658        llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
1659            slice_intersection, "slice_intersection", ir_builder_);
1660
1661        // Handle true BB (return data from 'update')
1662        SetToFirstInsertPoint(if_data.true_block, ir_builder_);
1663        // Compute update index for intersection case.
1664        llvm_ir::IrArray::Index update_index(rank);
1665        for (int64 i = 0; i < rank; ++i) {
1666          llvm::Value* update_dim_size = llvm::ConstantInt::get(
1667              index[i]->getType(), update_hlo->shape().dimensions(i));
1668          // NOTE: Subtraction will be positive due to bounds checking above.
1669          update_index[i] = ir_builder_->CreateURem(
1670              ir_builder_->CreateSub(index[i], slice_start_index_adjusted[i]),
1671              update_dim_size);
1672        }
1673        TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
1674                            operand_to_generator.at(update_hlo)(update_index));
1675        ir_builder_->CreateStore(true_value, ret_value_addr);
1676
1677        // Handle false BB (return data from 'input')
1678        SetToFirstInsertPoint(if_data.false_block, ir_builder_);
1679        TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
1680                            operand_to_generator.at(input_hlo)(index));
1681        ir_builder_->CreateStore(false_value, ret_value_addr);
1682
1683        SetToFirstInsertPoint(if_data.after_block, ir_builder_);
1684        return ir_builder_->CreateLoad(ret_value_addr);
1685      };
1686    case HloOpcode::kReshape:
1687      CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
1688               ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
1689      return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
1690        const HloInstruction* operand = hlo->operand(0);
1691        return operand_to_generator.at(operand)(index.SourceIndexOfReshape(
1692            hlo->shape(), operand->shape(), ir_builder_));
1693      };
1694    case HloOpcode::kTranspose:
1695      return [this, hlo,
1696              &operand_to_generator](const IrArray::Index& target_index) {
1697        return operand_to_generator.at(hlo->operand(0))(
1698            target_index.SourceIndexOfTranspose(
1699                hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(),
1700                ir_builder_));
1701      };
1702    case HloOpcode::kRng:
1703      return MakeRngElementGenerator(hlo, operand_to_generator);
1704    case HloOpcode::kPad:
1705      return [=, &operand_to_generator](
1706                 const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
1707        auto index = padded_index;
1708        llvm::Value* in_bounds = ir_builder_->getTrue();
1709        for (size_t i = 0; i < index.size(); ++i) {
1710          auto index_typed_const = [=](int64 n) {
1711            return llvm::ConstantInt::get(index[i]->getType(), n);
1712          };
1713          const auto& pad_dim = hlo->padding_config().dimensions(i);
1714          index[i] = ir_builder_->CreateSub(
1715              index[i], index_typed_const(pad_dim.edge_padding_low()));
1716          in_bounds = ir_builder_->CreateAnd(
1717              in_bounds,
1718              ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
1719              "in_bounds");
1720          in_bounds = ir_builder_->CreateAnd(
1721              in_bounds,
1722              ir_builder_->CreateICmpEQ(
1723                  index_typed_const(0),
1724                  ir_builder_->CreateURem(
1725                      index[i],
1726                      index_typed_const(pad_dim.interior_padding() + 1))),
1727              "in_bounds");
1728          index[i] = ir_builder_->CreateSDiv(
1729              index[i], index_typed_const(pad_dim.interior_padding() + 1));
1730          in_bounds = ir_builder_->CreateAnd(
1731              in_bounds,
1732              ir_builder_->CreateICmpSLT(
1733                  index[i],
1734                  index_typed_const(hlo->operand(0)->shape().dimensions(i))),
1735              "in_bounds");
1736        }
1737
1738        // if (in_bounds) {
1739        //   ret_value = operand0[index];  // source
1740        // } else {
1741        //   ret_value = *operand1;        // padding
1742        // }
1743        llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
1744            llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
1745                                           module_),
1746            "pad_result_addr", ir_builder_);
1747        llvm_ir::LlvmIfData if_data =
1748            llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
1749        SetToFirstInsertPoint(if_data.true_block, ir_builder_);
1750        TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
1751                            operand_to_generator.at(hlo->operand(0))(index));
1752        ir_builder_->CreateStore(operand_value, ret_value_addr);
1753
1754        SetToFirstInsertPoint(if_data.false_block, ir_builder_);
1755        TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
1756                            operand_to_generator.at(hlo->operand(1))({}));
1757        ir_builder_->CreateStore(padding_value, ret_value_addr);
1758
1759        SetToFirstInsertPoint(if_data.after_block, ir_builder_);
1760        // Don't create phi(operand_value, padding_value) here, because invoking
1761        // operand_to_generator may create new basic blocks, making the parent
1762        // of operand_value or padding_value no longer a predecessor of
1763        // if_data.after_block.
1764        return ir_builder_->CreateLoad(ret_value_addr);
1765      };
1766
1767    case HloOpcode::kDot:
1768      return [=, &operand_to_generator](const IrArray::Index& dot_result_index)
1769                 -> StatusOr<llvm::Value*> {
1770        auto lhs_generator = operand_to_generator.at(hlo->operand(0));
1771        auto rhs_generator = operand_to_generator.at(hlo->operand(1));
1772        int64 contracted_dim_size = hlo->operand(0)->shape().dimensions(
1773            hlo->operand(0)->shape().dimensions_size() - 1);
1774        int64 lhs_dims = hlo->operand(0)->shape().dimensions_size();
1775        int64 rhs_dims = hlo->operand(1)->shape().dimensions_size();
1776
1777        std::unique_ptr<llvm_ir::ForLoop> inner_loop =
1778            llvm_ir::ForLoop::EmitForLoop(
1779                IrName(hlo, "inner"), ir_builder_->getInt64(0),
1780                ir_builder_->getInt64(contracted_dim_size),
1781                ir_builder_->getInt64(1), ir_builder_);
1782
1783        SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(),
1784                              ir_builder_);
1785        PrimitiveType primitive_type = hlo->shape().element_type();
1786        llvm::Type* primitive_type_llvm =
1787            llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
1788        llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
1789            primitive_type_llvm, "dot_acc", ir_builder_);
1790        ir_builder_->CreateStore(
1791            llvm::Constant::getNullValue(primitive_type_llvm),
1792            accumulator_alloca);
1793
1794        SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_);
1795
1796        // This is the inner reduction loop for a dot operation that produces
1797        // one element in the output.  If the operands to the dot operation have
1798        // shapes [A,B,C,T] and [D,T,E], the result has a shape [A,B,C,D,E].
1799        // Given an output index [a,b,c,d,e] in the result, we compute:
1800        //   sum(lhs[a,b,c,t]*rhs[d,t,e] for t in [0, T))
1801
1802        IrArray::Index lhs_index, rhs_index;
1803
1804        for (int64 i = 0; i < lhs_dims - 1; i++) {
1805          lhs_index.push_back(dot_result_index[i]);
1806        }
1807        lhs_index.push_back(inner_loop->GetIndVarValue());
1808
1809        for (int64 i = 0; i < rhs_dims - 2; i++) {
1810          rhs_index.push_back(dot_result_index[lhs_dims - 1 + i]);
1811        }
1812        rhs_index.push_back(inner_loop->GetIndVarValue());
1813        rhs_index.push_back(dot_result_index.back());
1814
1815        llvm::Value* current_accumulator =
1816            ir_builder_->CreateLoad(accumulator_alloca);
1817        TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
1818        TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
1819        llvm::Value* next_accumulator;
1820        if (primitive_util::IsComplexType(primitive_type)) {
1821          llvm::Value* product_real = ir_builder_->CreateFSub(
1822              ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
1823                                      EmitExtractReal(rhs_value)),
1824              ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
1825                                      EmitExtractImag(rhs_value)));
1826          llvm::Value* product_imag = ir_builder_->CreateFAdd(
1827              ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
1828                                      EmitExtractImag(rhs_value)),
1829              ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
1830                                      EmitExtractReal(rhs_value)));
1831          next_accumulator = ir_builder_->CreateInsertValue(
1832              current_accumulator,
1833              ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator),
1834                                      product_real),
1835              {0});
1836          next_accumulator = ir_builder_->CreateInsertValue(
1837              next_accumulator,
1838              ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator),
1839                                      product_imag),
1840              {1});
1841        } else if (primitive_util::IsFloatingPointType(primitive_type)) {
1842          next_accumulator = ir_builder_->CreateFAdd(
1843              current_accumulator,
1844              ir_builder_->CreateFMul(lhs_value, rhs_value));
1845        } else {
1846          next_accumulator = ir_builder_->CreateAdd(
1847              current_accumulator,
1848              ir_builder_->CreateMul(lhs_value, rhs_value));
1849        }
1850        ir_builder_->CreateStore(next_accumulator, accumulator_alloca);
1851
1852        SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_);
1853        return ir_builder_->CreateLoad(accumulator_alloca);
1854      };
1855    default:
1856      return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
1857        return Unimplemented("Unhandled opcode for elemental IR emission: %s",
1858                             HloOpcodeString(hlo->opcode()).c_str());
1859      };
1860  }
1861}
1862
1863llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const {
1864  return ir_builder_->CreateExtractValue(value, {0});
1865}
1866
1867llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const {
1868  return ir_builder_->CreateExtractValue(value, {1});
1869}
1870
1871llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
1872                                                    llvm::Value* real,
1873                                                    llvm::Value* imag) const {
1874  auto cplx_type =
1875      llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
1876  auto complex = ir_builder_->CreateInsertValue(
1877      llvm::ConstantAggregateZero::get(cplx_type), real, {0});
1878  if (imag != nullptr) {
1879    complex = ir_builder_->CreateInsertValue(complex, imag, {1});
1880  }
1881  return complex;
1882}
1883
1884}  // namespace xla
1885