1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
17#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
18
19#include <string>
20
21#include "llvm/IR/IRBuilder.h"
22#include "llvm/IR/Value.h"
23#include "tensorflow/compiler/xla/primitive_util.h"
24#include "tensorflow/compiler/xla/types.h"
25#include "tensorflow/compiler/xla/xla_data.pb.h"
26
27namespace xla {
28namespace cpu {
29
30// Simple wrappers around llvm::APFloat::APFloat to make the calling code more
31// obvious.
32
33inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); }
34inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) {
35  return llvm::APFloat(llvm::APFloat::IEEEsingle(),
36                       llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value));
37}
38
39// A thin wrapper around llvm_util.h to make code generating vector math flow
40// more readable.
41class VectorSupportLibrary {
42 public:
43  // This VectorSupportLibrary instance remembers `primitive_type` and
44  // `vector_size`, and these are implicitly used by the methods on this
45  // instance (i.e. LoadVector will load a vector of type <`vector_size` x
46  // `primitive_type`>).
47  VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size,
48                       llvm::IRBuilder<>* ir_builder, std::string name);
49
50  llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs);
51  llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
52    return Mul(ir_builder()->getInt64(lhs), rhs);
53  }
54  llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) {
55    return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
56  }
57
58  // If your call resolved to these then you probably wanted the versions taking
59  // APFloat.
60  llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete;
61  llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete;
62
63  llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
64  llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
65    return Add(ir_builder()->getInt64(lhs), rhs);
66  }
67  llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) {
68    return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
69  }
70
71  // If your call resolved to these then you probably wanted the versions taking
72  // APFloat.
73  llvm::Value* Add(double lhs, llvm::Value* rhs) = delete;
74  llvm::Value* Add(float lhs, llvm::Value* rhs) = delete;
75
76  llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
77  llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) {
78    return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
79  }
80  llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
81  llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) {
82    return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
83  }
84  llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
85
86  llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
87    return Add(c, Mul(a, b));
88  }
89
90  llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) {
91    return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
92  }
93
94  llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b,
95                      const llvm::APFloat& c) {
96    return Add(GetConstantFloat(a->getType(), c),
97               Mul(a, GetConstantFloat(a->getType(), b)));
98  }
99
100  llvm::Value* Floor(llvm::Value* a);
101
102  llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low,
103                     const llvm::APFloat& high);
104  llvm::Value* SplatFloat(const llvm::APFloat& d) {
105    return GetConstantFloat(vector_type(), d);
106  }
107
108  // These compare instructions return a floating point typed mask instead of an
109  // i1.  For instance, on a vector typed input, lanes where the predicate is
110  // true get a float with all ones and other lanes get a float with all zeros.
111  // This is slightly odd from the perspective of LLVM's type system, but it
112  // makes kernel IR generation code written using VectorSupportLibrary (its
113  // raison d'etre) less cluttered.
114
115  llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
116  llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
117  llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
118  llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
119    return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
120  }
121
122  // These boolean operations operate on the bitwise values of the floating
123  // point inputs.  They return a (vector of) float(s) but like in the mask
124  // generating predicates above this type system oddity makes the kernel IR
125  // generation code less cluttered.
126  llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
127  llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) {
128    return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
129  }
130  llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
131  llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) {
132    return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
133  }
134  llvm::Value* FloatNot(llvm::Value* lhs);
135  llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
136    return FloatAnd(FloatNot(lhs), rhs);
137  }
138
139  llvm::Value* BroadcastScalar(llvm::Value* x);
140  llvm::Value* BroadcastScalar(const llvm::APFloat& d) {
141    return BroadcastScalar(GetConstantFloat(scalar_type(), d));
142  }
143
144  llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
145                                    llvm::Value* offset_elements);
146  llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
147                                    int64 offset_elements) {
148    return ComputeOffsetPointer(base_pointer,
149                                ir_builder()->getInt64(offset_elements));
150  }
151
152  llvm::Value* LoadVector(llvm::Value* pointer);
153
154  llvm::Value* LoadVector(llvm::Value* base_pointer,
155                          llvm::Value* offset_elements) {
156    return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements));
157  }
158
159  llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) {
160    return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements));
161  }
162
163  llvm::Value* LoadScalar(llvm::Value* pointer);
164
165  llvm::Value* LoadScalar(llvm::Value* base_pointer,
166                          llvm::Value* offset_elements) {
167    return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements));
168  }
169
170  llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) {
171    return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements));
172  }
173
174  void StoreVector(llvm::Value* value, llvm::Value* pointer);
175
176  void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
177                   llvm::Value* offset_elements) {
178    StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements));
179  }
180
181  void StoreVector(llvm::Value* value, llvm::Value* base_pointer,
182                   int64 offset_elements) {
183    StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements));
184  }
185
186  void StoreScalar(llvm::Value* value, llvm::Value* pointer);
187  void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
188                   llvm::Value* offset_elements) {
189    StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements));
190  }
191
192  void StoreScalar(llvm::Value* value, llvm::Value* base_pointer,
193                   int64 offset_elements) {
194    StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements));
195  }
196
197  llvm::Value* LoadBroadcast(llvm::Value* pointer);
198  llvm::Value* LoadBroadcast(llvm::Value* base_pointer,
199                             llvm::Value* offset_elements) {
200    return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements));
201  }
202  llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) {
203    return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements));
204  }
205
206  // Compute the horizontal sum of each vector in `vectors`.  The i'th element
207  // in the result vector is the (scalar) horizontal sum of the i'th vector in
208  // `vectors`.  If `init_values` is not nullptr then the value in the i'th lane
209  // in `init_values` is added to the i'th horizontal sum.
210  std::vector<llvm::Value*> ComputeHorizontalSums(
211      std::vector<llvm::Value*> vectors, llvm::Value* init_values = nullptr);
212
213  llvm::Value* GetZeroVector();
214  llvm::Value* GetZeroScalar();
215
216  llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
217  int64 vector_size() const { return vector_size_; }
218  llvm::Type* vector_type() const { return vector_type_; }
219  llvm::Type* vector_pointer_type() const { return vector_pointer_type_; }
220  llvm::Type* scalar_type() const { return scalar_type_; }
221  llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; }
222  int64 scalar_byte_size() const {
223    return primitive_util::BitWidth(primitive_type_) / 8;
224  }
225
226  const std::string& name() const { return name_; }
227
228 private:
229  llvm::Value* ExtractLowHalf(llvm::Value*);
230  llvm::Value* ExtractHighHalf(llvm::Value*);
231
232  llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs);
233  llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs);
234
235  llvm::Value* AddReduce(llvm::Value* vector);
236
237  // Checks that each value in `values` is either of type scalar_type() or
238  // vector_type().  This LOG(FATAL)'s so it should only be called in cases
239  // where a mismatching type is a programmer bug.
240  void AssertCorrectTypes(std::initializer_list<llvm::Value*> values);
241
242  // Perform an X86 AVX style horizontal add between `lhs` and `rhs`.  The
243  // resulting IR for an 8-float wide vector is expected to lower to a single
244  // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in
245  // other cases.
246  //
247  // For a vector width of 8, the result vector is computed as:
248  //   Result[0] = Lhs[0] + Lhs[1]
249  //   Result[1] = Lhs[2] + Lhs[3]
250  //   Result[2] = Rhs[0] + Rhs[1]
251  //   Result[3] = Rhs[2] + Rhs[3]
252  //   Result[4] = Lhs[4] + Lhs[5]
253  //   Result[5] = Lhs[6] + Lhs[7]
254  //   Result[6] = Rhs[4] + Rhs[5]
255  //   Result[7] = Rhs[6] + Rhs[7]
256  llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs);
257
258  std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
259      std::vector<llvm::Value*> vectors, llvm::Value* init_values);
260
261  llvm::Type* IntegerTypeForFloatSize(bool vector);
262  llvm::Value* I1ToFloat(llvm::Value* i1);
263  llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) {
264    llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f);
265    if (llvm::isa<llvm::VectorType>(type)) {
266      return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
267    }
268    return scalar_value;
269  }
270
271  int64 vector_size_;
272  PrimitiveType primitive_type_;
273  llvm::IRBuilder<>* ir_builder_;
274  llvm::Type* vector_type_;
275  llvm::Type* vector_pointer_type_;
276  llvm::Type* scalar_type_;
277  llvm::Type* scalar_pointer_type_;
278  std::string name_;
279};
280
281// This wraps an alloca-backed stack variable which LLVM's SSA construction pass
282// can later convert to a SSA value.
283class LlvmVariable {
284 public:
285  LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder);
286
287  llvm::Value* Get() const;
288  void Set(llvm::Value* new_value);
289
290 private:
291  llvm::AllocaInst* alloca_;
292  llvm::IRBuilder<>* ir_builder_;
293};
294
295class VectorVariable : public LlvmVariable {
296 public:
297  VectorVariable(VectorSupportLibrary* vector_support,
298                 llvm::Value* initial_value)
299      : LlvmVariable(vector_support->vector_type(),
300                     vector_support->ir_builder()) {
301    Set(initial_value);
302  }
303};
304
305class ScalarVariable : public LlvmVariable {
306 public:
307  ScalarVariable(VectorSupportLibrary* vector_support,
308                 llvm::Value* initial_value)
309      : LlvmVariable(vector_support->scalar_type(),
310                     vector_support->ir_builder()) {
311    Set(initial_value);
312  }
313};
314}  // namespace cpu
315}  // namespace xla
316
317#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_
318