1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
17#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
18#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
19#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
20#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
21#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
22
23namespace xla {
24namespace llvm_ir {
25
26bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice,
27                                  const BufferAssignment& assignment) {
28  CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode());
29  const HloInstruction* operand = dynamic_update_slice->operand(0);
30  return assignment.HasTopLevelAllocation(dynamic_update_slice) &&
31         assignment.HasTopLevelAllocation(operand) &&
32         assignment.SharesTopLevelSlice(dynamic_update_slice, operand);
33}
34
35// Shared implementation of EmitDynamicUpdateSliceInPlace and
36// EmitFusedDynamicUpdateSliceInPlace.
37//
38// Emits a sequential loop if launch_dimensions is null.
39static Status EmitDynamicUpdateSliceInPlaceImpl(
40    const Shape& update_shape, const ElementGenerator& start_indices_generator,
41    ElementGenerator update_array_generator, const IrArray& output_array,
42    const gpu::LaunchDimensions* launch_dimensions,
43    tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) {
44  const Shape& output_shape = output_array.GetShape();
45
46  // Read start indices from start_indices_generator.
47  const int64 rank = ShapeUtil::Rank(output_shape);
48  IrArray::Index start_index(rank);
49  for (int64 i = 0; i < rank; ++i) {
50    IrArray::Index dim_index({ir_builder->getInt64(i)});
51    TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index));
52  }
53
54  auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status {
55    // Calculate output_index, where we'll write the value from update.  For
56    // each dimension,
57    //
58    //   output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size.
59    //
60    IrArray::Index output_index(rank);
61    for (int64 i = 0; i < rank; ++i) {
62      llvm::Value* dim_size = llvm::ConstantInt::get(
63          update_index[i]->getType(), output_shape.dimensions(i));
64      llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast(
65          start_index[i], update_index[i]->getType());
66      output_index[i] = ir_builder->CreateURem(
67          ir_builder->CreateAdd(start_index0, update_index[i]), dim_size);
68    }
69
70    // Do output[output_index] = update[update_index].
71    TF_ASSIGN_OR_RETURN(llvm::Value * update_data,
72                        update_array_generator(update_index));
73    output_array.EmitWriteArrayElement(output_index, update_data, ir_builder);
74    return Status::OK();
75  };
76
77  if (launch_dimensions != nullptr) {
78    return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape,
79                                    *launch_dimensions, ir_builder)
80        .EmitLoop(name);
81  }
82  return LoopEmitter(loop_body_emitter, update_shape, ir_builder)
83      .EmitLoop(name);
84}
85
86Status EmitDynamicUpdateSliceInPlace(
87    tensorflow::gtl::ArraySlice<IrArray> operand_arrays,
88    const IrArray& output_array, tensorflow::StringPiece name,
89    llvm::IRBuilder<>* ir_builder) {
90  VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name;
91
92  // No need to use operand_arrays[0], the input array of the
93  // dynamic-update-slice, because we know it aliases the op's output.
94  IrArray update_array = operand_arrays[1];
95  IrArray start_indices_array = operand_arrays[2];
96  Shape output_shape = output_array.GetShape();
97  Shape update_shape = update_array.GetShape();
98
99  ElementGenerator start_indices_generator = [&](const IrArray::Index& index) {
100    return start_indices_array.EmitReadArrayElement(index, ir_builder);
101  };
102  ElementGenerator update_array_generator = [&](const IrArray::Index& index) {
103    return update_array.EmitReadArrayElement(index, ir_builder);
104  };
105
106  return EmitDynamicUpdateSliceInPlaceImpl(
107      update_shape, start_indices_generator, update_array_generator,
108      output_array, /*launch_dimensions=*/nullptr, name, ir_builder);
109}
110
111// Shared implementation for EmitFusedDynamicUpdateSliceInPlace and
112// EmitParallelFusedDynamicUpdateSliceInPlace.
113//
114// Emits a sequential loop if launch_dimensions is null.
115static Status EmitFusedDynamicUpdateSliceInPlaceImpl(
116    HloInstruction* fusion,
117    tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
118    const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
119    const gpu::LaunchDimensions* launch_dimensions,
120    llvm::IRBuilder<>* ir_builder) {
121  CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
122  VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for "
123          << fusion->ToShortString();
124
125  auto* dynamic_update_slice = fusion->fused_expression_root();
126
127  const auto* update = dynamic_update_slice->operand(1);
128  const auto* start_indices = dynamic_update_slice->operand(2);
129  Shape update_shape = update->shape();
130
131  // Our in-place dynamic-update-slice implementation emits a loop over
132  // update_shape.  To emit a cache-friendly loop, we need to know that shape's
133  // layout.
134  //
135  // update_shape is inside a fusion node -- it's never materialized in memory
136  // and thus doesn't have a layout.  In this case we use the layout of the
137  // fusion node for iteration, since that corresponds to the order in memory of
138  // the buffer we'll be writing to.
139  //
140  // (This isn't necessarily optimal; in some cases it might be faster to peek
141  // through the chain of ops that gives us the update operand and use the
142  // layout of its source buffer(s).  But this is no worse than we do with
143  // fusion elsewhere.)
144  TF_RETURN_IF_ERROR(
145      LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape));
146
147  // Create element generators for update and start_indices.
148  FusedIrEmitter fused_emitter(fusion_operand_arrays, elemental_emitter);
149  TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter));
150  ElementGenerator update_array_generator = fused_emitter.GetGenerator(update);
151  ElementGenerator start_indices_generator =
152      fused_emitter.GetGenerator(start_indices);
153
154  return EmitDynamicUpdateSliceInPlaceImpl(
155      update_shape, start_indices_generator, update_array_generator,
156      fusion_output_array, launch_dimensions, IrName(fusion), ir_builder);
157}
158
159Status EmitFusedDynamicUpdateSliceInPlace(
160    HloInstruction* fusion,
161    tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
162    const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
163    llvm::IRBuilder<>* ir_builder) {
164  return EmitFusedDynamicUpdateSliceInPlaceImpl(
165      fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter,
166      /*launch_dimensions=*/nullptr, ir_builder);
167}
168
169Status EmitParallelFusedDynamicUpdateSliceInPlace(
170    HloInstruction* fusion,
171    tensorflow::gtl::ArraySlice<IrArray> fusion_operand_arrays,
172    const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter,
173    const gpu::LaunchDimensions& launch_dimensions,
174    llvm::IRBuilder<>* ir_builder) {
175  return EmitFusedDynamicUpdateSliceInPlaceImpl(
176      fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter,
177      &launch_dimensions, ir_builder);
178}
179
180}  // namespace llvm_ir
181}  // namespace xla
182