1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ 17#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ 18 19#include "tensorflow/compiler/xla/service/buffer_assignment.h" 20#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" 21#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" 22#include "tensorflow/compiler/xla/service/hlo_instruction.h" 23#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" 24 25// Utilities related to emitting LLVM IR for various HLO ops. 26 27namespace xla { 28namespace llvm_ir { 29 30// Checks if we can emit code for the given DynamicUpdateSlice node that updates 31// its input in place. Returns true if the dynamic-update-slice's 32// array-to-be-updated and output share the same BufferAllocation::Slice. 33// 34// dynamic_update_slice must be a DynamicUpdateSlice op. 35bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, 36 const BufferAssignment& assignment); 37 38// Checks if the given fusion node is amenable to being implemented by 39// EmitFusedDynamicUpdateSliceInPlace. 40inline bool CanEmitFusedDynamicUpdateSliceInPlace( 41 HloInstruction* fusion, const BufferAssignment& assignment) { 42 CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); 43 HloInstruction* fused_root = fusion->fused_expression_root(); 44 if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice || 45 fusion->fusion_kind() != HloInstruction::FusionKind::kLoop) { 46 return false; 47 } 48 // Walk DynamicUpdateSlice operand(0) to fused parameter and get its 49 // associated operand. See if it shares an allocation with this operand. 50 HloInstruction* fusion_operand; 51 ShapeIndex index; 52 std::tie(fusion_operand, index) = 53 fused_root->mutable_operand(0)->LatestNonGteAncestorAndIndex(); 54 if (fusion_operand->opcode() != HloOpcode::kParameter) { 55 return false; 56 } 57 auto* operand = fusion->operand(fusion_operand->parameter_number()); 58 return assignment.HasAllocationAt(operand, index) && 59 assignment.HasAllocationAt(fusion, {}) && 60 assignment.SharesSliceAtIndex(fusion, {}, operand, index); 61} 62 63// Emits IR for running the given dynamic-update-slice op in-place -- that is, 64// where the input and output buffers share the same slice, so we can simply 65// modify the input/output buffer without touching any of the other elements. 66Status EmitDynamicUpdateSliceInPlace( 67 tensorflow::gtl::ArraySlice<IrArray> operand_arrays, 68 const IrArray& output_array, tensorflow::StringPiece name, 69 llvm::IRBuilder<>* ir_builder); 70 71// Given a loop-fusion node whose root is a dynamic-update-slice op whose 72// array-to-be-updated and output share the same buffer slice, emits 73// (sequential) code for a fusion node that does the dynamic-update-slice in 74// place. 75Status EmitFusedDynamicUpdateSliceInPlace( 76 HloInstruction* fusion, 77 tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays, 78 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, 79 llvm::IRBuilder<>* ir_builder); 80 81// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with 82// the given launch dimensions. 83Status EmitParallelFusedDynamicUpdateSliceInPlace( 84 HloInstruction* fusion, 85 tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays, 86 const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, 87 const gpu::LaunchDimensions& launch_dimensions, 88 llvm::IRBuilder<>* ir_builder); 89 90} // namespace llvm_ir 91} // namespace xla 92 93#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ 94