1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
17
18#include "llvm/Support/raw_ostream.h"
19#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
20#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
21
22namespace xla {
23namespace cpu {
24VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type,
25                                           int64 vector_size,
26                                           llvm::IRBuilder<>* ir_builder,
27                                           std::string name)
28    : vector_size_(vector_size),
29      primitive_type_(primitive_type),
30      ir_builder_(ir_builder),
31      name_(std::move(name)) {
32  scalar_type_ = llvm_ir::PrimitiveTypeToIrType(
33      primitive_type, ir_builder_->GetInsertBlock()->getModule());
34  scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_);
35  vector_type_ = llvm::VectorType::get(scalar_type_, vector_size);
36  vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_);
37}
38
39static string TypeToString(llvm::Type* type) {
40  std::string o;
41  llvm::raw_string_ostream ostream(o);
42  type->print(ostream);
43  return ostream.str();
44}
45
46void VectorSupportLibrary::AssertCorrectTypes(
47    std::initializer_list<llvm::Value*> values) {
48  for (llvm::Value* v : values) {
49    llvm::Type* type = v->getType();
50    if (type != scalar_type() && type != vector_type()) {
51      LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or "
52                 << TypeToString(vector_type()) << " but got "
53                 << TypeToString(type);
54    }
55  }
56}
57
58llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) {
59  AssertCorrectTypes({lhs, rhs});
60  return MulInternal(lhs, rhs);
61}
62
63llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs,
64                                               llvm::Value* rhs) {
65  if (scalar_type_->isFloatingPointTy()) {
66    return ir_builder()->CreateFMul(lhs, rhs, name());
67  } else {
68    return ir_builder()->CreateMul(lhs, rhs, name());
69  }
70}
71
72llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) {
73  AssertCorrectTypes({lhs, rhs});
74  return AddInternal(lhs, rhs);
75}
76
77llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) {
78  AssertCorrectTypes({lhs, rhs});
79  return ir_builder()->CreateFSub(lhs, rhs);
80}
81
82llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) {
83  AssertCorrectTypes({lhs, rhs});
84  if (scalar_type_->isFloatingPointTy()) {
85    return llvm_ir::EmitFloatMax(lhs, rhs, ir_builder_);
86  } else {
87    LOG(FATAL) << "Max for integers is unimplemented";
88  }
89}
90
91llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) {
92  AssertCorrectTypes({a});
93  return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a},
94                                      {a->getType()}, ir_builder());
95}
96
97llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
98  AssertCorrectTypes({lhs, rhs});
99  if (scalar_type_->isFloatingPointTy()) {
100    return ir_builder()->CreateFDiv(lhs, rhs, name());
101  } else {
102    LOG(FATAL) << "Division for integers is unimplemented";
103  }
104}
105
106llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a,
107                                         const llvm::APFloat& low,
108                                         const llvm::APFloat& high) {
109  AssertCorrectTypes({a});
110  llvm::Type* type = a->getType();
111  CHECK(low.compare(high) == llvm::APFloat::cmpLessThan);
112  CHECK(scalar_type_->isFloatingPointTy());
113  return llvm_ir::EmitFloatMin(
114      llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_),
115      GetConstantFloat(type, high), ir_builder_);
116}
117
118llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
119                                              llvm::Value* rhs) {
120  AssertCorrectTypes({lhs, rhs});
121  return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name()));
122}
123
124llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
125                                               llvm::Value* rhs) {
126  AssertCorrectTypes({lhs, rhs});
127  return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name()));
128}
129
130llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
131                                               llvm::Value* rhs) {
132  AssertCorrectTypes({lhs, rhs});
133  return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name()));
134}
135
136llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
137  bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
138  llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
139  return ir_builder()->CreateBitCast(
140      ir_builder()->CreateSExt(i1, integer_type, name()),
141      is_vector ? vector_type() : scalar_type(), name());
142}
143
144llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
145  CHECK(scalar_type()->isFloatingPointTy());
146  const llvm::DataLayout& data_layout =
147      ir_builder()->GetInsertBlock()->getModule()->getDataLayout();
148  int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
149  llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits);
150  if (vector) {
151    return llvm::VectorType::get(scalar_int_type, vector_size());
152  } else {
153    return scalar_int_type;
154  }
155}
156
157llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
158  CHECK_EQ(x->getType(), scalar_type());
159  return ir_builder()->CreateVectorSplat(vector_size(), x, name());
160}
161
162llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
163                                            llvm::Value* rhs) {
164  AssertCorrectTypes({lhs, rhs});
165  llvm::Type* int_type =
166      IntegerTypeForFloatSize(lhs->getType() == vector_type());
167  return ir_builder()->CreateBitCast(
168      ir_builder()->CreateAnd(
169          ir_builder()->CreateBitCast(lhs, int_type, name()),
170          ir_builder()->CreateBitCast(rhs, int_type, name()), name()),
171      vector_type());
172}
173
174llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
175  AssertCorrectTypes({lhs});
176  llvm::Type* int_type =
177      IntegerTypeForFloatSize(lhs->getType() == vector_type());
178  return ir_builder()->CreateBitCast(
179      ir_builder()->CreateNot(
180          ir_builder()->CreateBitCast(lhs, int_type, name()), name()),
181      vector_type());
182}
183
184llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
185  AssertCorrectTypes({lhs, rhs});
186  llvm::Type* int_type =
187      IntegerTypeForFloatSize(lhs->getType() == vector_type());
188  return ir_builder()->CreateBitCast(
189      ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()),
190                             ir_builder()->CreateBitCast(rhs, int_type, name()),
191                             name()),
192      vector_type(), name());
193}
194
195llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
196                                               llvm::Value* rhs) {
197  if (scalar_type_->isFloatingPointTy()) {
198    return ir_builder()->CreateFAdd(lhs, rhs, name());
199  } else {
200    return ir_builder()->CreateAdd(lhs, rhs, name());
201  }
202}
203
204llvm::Value* VectorSupportLibrary::ComputeOffsetPointer(
205    llvm::Value* base_pointer, llvm::Value* offset_elements) {
206  if (base_pointer->getType() != scalar_pointer_type()) {
207    base_pointer = ir_builder()->CreateBitCast(base_pointer,
208                                               scalar_pointer_type(), name());
209  }
210  return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements},
211                                         name());
212}
213
214llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) {
215  if (pointer->getType() != vector_pointer_type()) {
216    pointer =
217        ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name());
218  }
219  return ir_builder()->CreateAlignedLoad(
220      pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
221}
222
223llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) {
224  if (pointer->getType() != scalar_pointer_type()) {
225    pointer =
226        ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
227  }
228  return ir_builder()->CreateAlignedLoad(
229      pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name());
230}
231
232void VectorSupportLibrary::StoreVector(llvm::Value* value,
233                                       llvm::Value* pointer) {
234  AssertCorrectTypes({value});
235  if (pointer->getType() != vector_pointer_type()) {
236    pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type());
237  }
238  ir_builder()->CreateAlignedStore(
239      value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
240}
241
242void VectorSupportLibrary::StoreScalar(llvm::Value* value,
243                                       llvm::Value* pointer) {
244  AssertCorrectTypes({value});
245  if (pointer->getType() != scalar_pointer_type()) {
246    pointer =
247        ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
248  }
249  ir_builder()->CreateAlignedStore(
250      value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_));
251}
252
253llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) {
254  if (pointer->getType() != scalar_pointer_type()) {
255    pointer =
256        ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name());
257  }
258  return ir_builder()->CreateVectorSplat(
259      vector_size(), ir_builder()->CreateLoad(pointer), name());
260}
261
262llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) {
263  llvm::SmallVector<llvm::Constant*, 32> mask(vector_size(), nullptr);
264  for (unsigned i = vector_size(); i != 1; i >>= 1) {
265    // On every iteration, we shuffle half of the remaining lanes to the top
266    // half of shuffle, and add two old and the new vector.
267
268    for (unsigned j = 0; j < vector_size(); ++j) {
269      if (j < (i / 2)) {
270        mask[j] = ir_builder()->getInt32(i / 2 + j);
271      } else {
272        mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty());
273      }
274    }
275
276    llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector(
277        vector, llvm::UndefValue::get(vector_type()),
278        llvm::ConstantVector::get(mask), "");
279    vector = Add(vector, half_remaining_lanes);
280  }
281
282  return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0),
283                                            name());
284}
285
286llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs,
287                                                         llvm::Value* rhs) {
288  CHECK_EQ(lhs->getType(), vector_type());
289  CHECK_EQ(rhs->getType(), vector_type());
290  CHECK_EQ(vector_size() % 2, 0);
291
292  llvm::SmallVector<llvm::Constant*, 32> mask_a, mask_b;
293
294  // Adding the values shuffled using mask_a and mask_b gives us the
295  // AVX-style horizontal add we want.  The masks work as documented
296  // in https://llvm.org/docs/LangRef.html#shufflevector-instruction
297  //
298  // Here are the masks for vector_width() == 8:
299  //
300  //    index: |0 |1 |2 | 3 |4 |5 | 6 | 7
301  //   --------+--+--+--+---+--+--+---+---
302  //   mask_a: |0 |2 |8 |10 |4 |6 |12 |14
303  //   mask_b: |1 |3 |9 |11 |5 |7 |13 |16
304  //
305  // So, as an example, the value at lane 3 of the result vector is
306  // the result of adding lane 10 and lane 11 in the combined lhs++rhs
307  // vector, which are the lanes 2 and 3 in the rhs vector.
308  for (int i = 0; i < vector_size(); i += 2) {
309    int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2);
310    mask_a.push_back(ir_builder()->getInt32(increment + i));
311    mask_b.push_back(ir_builder()->getInt32(increment + i + 1));
312  }
313  for (int i = 0; i < vector_size(); i += 2) {
314    int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size();
315    mask_a.push_back(ir_builder()->getInt32(increment + i));
316    mask_b.push_back(ir_builder()->getInt32(increment + i + 1));
317  }
318
319  llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector(
320      lhs, rhs, llvm::ConstantVector::get(mask_a));
321  llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector(
322      lhs, rhs, llvm::ConstantVector::get(mask_b));
323
324  return Add(shuffle_0, shuffle_1);
325}
326
327llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) {
328  llvm::SmallVector<llvm::Constant*, 32> mask;
329  for (int i = 0; i < vector_size() / 2; i++) {
330    mask.push_back(ir_builder()->getInt32(i));
331  }
332
333  return ir_builder()->CreateShuffleVector(vector,
334                                           llvm::UndefValue::get(vector_type()),
335                                           llvm::ConstantVector::get(mask));
336}
337
338llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) {
339  llvm::SmallVector<llvm::Constant*, 32> mask;
340  for (int i = 0; i < vector_size() / 2; i++) {
341    mask.push_back(ir_builder()->getInt32(i + vector_size() / 2));
342  }
343
344  return ir_builder()->CreateShuffleVector(vector,
345                                           llvm::UndefValue::get(vector_type()),
346                                           llvm::ConstantVector::get(mask));
347}
348
349std::vector<llvm::Value*> VectorSupportLibrary::ComputeHorizontalSums(
350    std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
351  const int x86_avx_vector_elements =
352      TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size();
353  if (vector_size() == x86_avx_vector_elements &&
354      vectors.size() == x86_avx_vector_elements) {
355    return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values);
356  }
357
358  std::vector<llvm::Value*> result;
359  std::transform(vectors.begin(), vectors.end(), std::back_inserter(result),
360                 [this](llvm::Value* vector) { return AddReduce(vector); });
361  if (init_values) {
362    for (int64 i = 0, e = result.size(); i < e; i++) {
363      result[i] = Add(result[i], ir_builder()->CreateExtractElement(
364                                     init_values, ir_builder()->getInt32(i)));
365    }
366  }
367  return result;
368}
369
370std::vector<llvm::Value*>
371VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums(
372    std::vector<llvm::Value*> vectors, llvm::Value* init_values) {
373  while (vectors.size() != 2) {
374    std::vector<llvm::Value*> new_vectors;
375    for (int i = 0; i < vectors.size(); i += 2) {
376      new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1]));
377    }
378
379    vectors = std::move(new_vectors);
380  }
381
382  llvm::Value* low =
383      AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0]));
384  if (init_values) {
385    low = AddInternal(ExtractLowHalf(init_values), low);
386  }
387  llvm::Value* high =
388      AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1]));
389  if (init_values) {
390    high = AddInternal(ExtractHighHalf(init_values), high);
391  }
392
393  std::vector<llvm::Value*> results;
394  for (int i = 0; i < 8; i++) {
395    llvm::Value* scalar_result = ir_builder()->CreateExtractElement(
396        i < 4 ? low : high, ir_builder()->getInt32(i % 4), name());
397    results.push_back(scalar_result);
398  }
399
400  return results;
401}
402
403llvm::Value* VectorSupportLibrary::GetZeroVector() {
404  return llvm::Constant::getNullValue(vector_type());
405}
406
407llvm::Value* VectorSupportLibrary::GetZeroScalar() {
408  return llvm::Constant::getNullValue(scalar_type());
409}
410
411LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder)
412    : ir_builder_(ir_builder) {
413  alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_);
414}
415
416llvm::Value* LlvmVariable::Get() const {
417  return ir_builder_->CreateLoad(alloca_);
418}
419
420void LlvmVariable::Set(llvm::Value* new_value) {
421  ir_builder_->CreateStore(new_value, alloca_);
422}
423}  // namespace cpu
424}  // namespace xla
425