11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1684f1b9049de86ba5614ce73f91232fd72eefbd1fJustin Lebar#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <stddef.h>
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <string>
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include <vector>
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
2234cbf161d7b1191ad5c1b3bc02fc52d338e8b175Jiri Simsa#include "llvm/IR/Instructions.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/shape_util.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/types.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/strings/stringprintf.h"
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/logging.h"
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace llvm_ir {
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsvoid EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
344198e27be8115585ad6b5b141383fb7dc7856c24A. Unique TensorFlower                     llvm::Value* on_false, llvm::IRBuilder<>* ir_builder,
354198e27be8115585ad6b5b141383fb7dc7856c24A. Unique TensorFlower                     llvm::Module* module) {
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  CHECK(ShapeUtil::IsScalar(pred.GetShape()));
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  llvm::LoadInst* pred_value =
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value");
401e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  llvm::Value* pred_cond = ir_builder->CreateICmpNE(
411e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      pred_value,
424198e27be8115585ad6b5b141383fb7dc7856c24A. Unique TensorFlower      llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0),
431e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      "boolean_predicate");
441e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  VLOG(2) << "HandleSelect for tuple:";
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  VLOG(2) << "  pred_value: " << DumpToString(*pred_value);
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  VLOG(2) << "  pred_cond: " << DumpToString(*pred_cond);
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) {
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    std::vector<llvm::Value*> element_index = {ir_builder->getInt64(0),
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                               ir_builder->getInt64(i)};
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    llvm::Value* on_true_element_address =
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ir_builder->CreateInBoundsGEP(on_true, element_index);
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    llvm::Value* on_true_element = ir_builder->CreateLoad(
551e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        on_true_element_address,
561e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        tensorflow::strings::Printf("on_true_element_%d", i).c_str());
571e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    llvm::Value* on_false_element_address =
581e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ir_builder->CreateInBoundsGEP(on_false, element_index);
591e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    llvm::Value* on_false_element = ir_builder->CreateLoad(
601e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        on_false_element_address,
611e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        tensorflow::strings::Printf("on_false_element_%d", i).c_str());
621e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
631e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    llvm::Value* output_element_address =
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index);
651e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    ir_builder->CreateStore(
661e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ir_builder->CreateSelect(
671e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            pred_cond, on_true_element, on_false_element,
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            tensorflow::strings::Printf("select_output_element_%d", i).c_str()),
691e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        output_element_address);
701e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
711e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
731e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsvoid EmitTuple(IrArray tuple,
741e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins               tensorflow::gtl::ArraySlice<llvm::Value*> operands,
754198e27be8115585ad6b5b141383fb7dc7856c24A. Unique TensorFlower               llvm::IRBuilder<>* ir_builder, llvm::Module* module) {
761e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  for (size_t i = 0; i < operands.size(); ++i) {
778fcbef3428ce69de9cedafd0d4c0f141c79d418cJustin Lebar    auto* store = ir_builder->CreateStore(
781e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ir_builder->CreatePointerCast(operands[i],
794198e27be8115585ad6b5b141383fb7dc7856c24A. Unique TensorFlower                                      PrimitiveTypeToIrType(TUPLE, module)),
801e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins        ir_builder->CreateInBoundsGEP(
811e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            tuple.GetBasePointer(),
821e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            {ir_builder->getInt64(0), ir_builder->getInt64(i)}));
838fcbef3428ce69de9cedafd0d4c0f141c79d418cJustin Lebar    tuple.AnnotateLoadStoreInstructionWithMetadata(store);
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  }
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsllvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins                                 int alignment, llvm::Value* operand,
894198e27be8115585ad6b5b141383fb7dc7856c24A. Unique TensorFlower                                 llvm::IRBuilder<>* ir_builder,
904198e27be8115585ad6b5b141383fb7dc7856c24A. Unique TensorFlower                                 llvm::Module* module) {
911e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP(
921e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)});
931e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr);
942a90713ef70f01392ac59899ca92376549c57126Justin Lebar
952a90713ef70f01392ac59899ca92376549c57126Justin Lebar  // Mark the loaded pointer as dereferenceable if we know its shape.
962a90713ef70f01392ac59899ca92376549c57126Justin Lebar  if (!ShapeUtil::IsOpaque(target_shape)) {
972a90713ef70f01392ac59899ca92376549c57126Justin Lebar    SetDereferenceableMetadataForLoad(
982a90713ef70f01392ac59899ca92376549c57126Justin Lebar        src_buffer,
992a90713ef70f01392ac59899ca92376549c57126Justin Lebar        ByteSizeOf(target_shape, src_buffer->getModule()->getDataLayout()));
1002a90713ef70f01392ac59899ca92376549c57126Justin Lebar  }
1011e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  SetAlignmentMetadataForLoad(src_buffer, alignment);
1022a90713ef70f01392ac59899ca92376549c57126Justin Lebar
1034198e27be8115585ad6b5b141383fb7dc7856c24A. Unique TensorFlower  llvm::Type* element_type = ShapeToIrType(target_shape, module);
1041e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  llvm::Value* ret_val =
1051e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo());
1061e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  return ret_val;
1071e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}
1081e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
1091e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace llvm_ir
1101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
111