1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
17#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
18
19#include <string>
20
21#include "llvm/IR/BasicBlock.h"
22#include "llvm/IR/IRBuilder.h"
23#include "llvm/IR/Value.h"
24#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
25#include "tensorflow/core/lib/core/stringpiece.h"
26
27namespace xla {
28// A thin wrapper around llvm_loop.h to make code generating structured control
29// flow more readable.
30class KernelSupportLibrary {
31 public:
32  // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR.
33  // If `prevent_unrolling` is true then unrolling is explicitly disabled on
34  // every loop generated by this instance of KernelSupportLibrary.
35  explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder,
36                                bool prevent_unrolling = true,
37                                bool prevent_vectorization = true)
38      : ir_builder_(ir_builder),
39        prevent_unrolling_(prevent_unrolling),
40        prevent_vectorization_(prevent_vectorization) {}
41
42  // Generates the following control flow structure:
43  //
44  //   if (`start` < `end`) {
45  //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`;
46  //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
47  //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`;
48  //   }
49  void For(
50      tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
51      llvm::Value* step,
52      const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
53          for_body_generator);
54
55  void For(
56      tensorflow::StringPiece name, int64 start, int64 end, int64 step,
57      const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>&
58          for_body_generator) {
59    For(name, /*start=*/ir_builder_->getInt64(start),
60        /*end=*/ir_builder_->getInt64(end),
61        /*step=*/ir_builder_->getInt64(step), for_body_generator);
62  }
63
64  // Generates the following control flow structure if `peel_first_iteration` is
65  // true:
66  //
67  //   if (`start` < `end`) {
68  //     `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`;
69  //     for (i64 i = `start` + `step`; i s< `end`; i += `step`)
70  //       `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`;
71  //   }
72  //
73  // and the following if `peel_first_iteration` is false:
74  //
75  //   for (i64 i = `start`; i s< `end`; i += `step`)
76  //     `for_body_generator(/*ind_var=*/,i,
77  //                         /*is_first_iteration=*/,(i != `start`))`;
78  void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
79           llvm::Value* step, bool peel_first_iteration,
80           const std::function<void(llvm::Value* ind_var,
81                                    llvm::Value* is_first_iteration)>&
82               for_body_generator);
83
84  void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
85           int64 step, bool peel_first_iteration,
86           const std::function<void(llvm::Value* ind_var,
87                                    llvm::Value* is_first_iteration)>&
88               for_body_generator) {
89    For(name, /*start=*/start, /*end=*/end,
90        /*step=*/ir_builder_->getInt64(step), peel_first_iteration,
91        for_body_generator);
92  }
93
94  void For(
95      tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end,
96      llvm::Value* step,
97      const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
98    For(name, start, end, step,
99        /*peel_first_iteration=*/false,
100        [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); });
101  }
102
103  void For(
104      tensorflow::StringPiece name, int64 start, int64 end, int64 step,
105      const std::function<void(llvm::Value* ind_var)>& for_body_generator) {
106    For(name, /*start=*/ir_builder_->getInt64(start),
107        /*end=*/ir_builder_->getInt64(end),
108        /*step=*/ir_builder_->getInt64(step), for_body_generator);
109  }
110
111  // Generates the following control flow structure:
112  //
113  //   if (`condition`)
114  //     `true_block_generator()`;
115  //   else
116  //      `false_block_generator()`;
117  void If(llvm::Value* condition,
118          const std::function<void()>& true_block_generator,
119          const std::function<void()>& false_block_generator = []() {});
120
121  using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>;
122
123  // Generates the following control flow structure:
124  //
125  //  define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) {
126  //    kernel_body_generator({arg0, arg1, ... arg`arguments.size()`});
127  //  }
128  //
129  //  ...
130  //  call @`kernel_name`(arguments[0], arguments[1] ...)
131  //  ...
132  //
133  // If a function called `kernel_name` is already present in the module then
134  // that function is re-used.  In that sense we're using the llvm::Module as a
135  // cache of outlined kernels, keyed by function name.
136  //
137  // If any of the values in `arguments` is nullptr (i.e. a nullptr
138  // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass
139  // in a nullptr llvm::Value* in its position to `kernel_body_generator`.
140  // Currently we only support at most one nullptr value in `arguments`.
141  static void EmitAndCallOutlinedKernel(
142      bool enable_fast_math, bool optimize_for_size,
143      llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
144      ArgumentVector arguments,
145      const std::function<void(ArgumentVector)>& kernel_body_generator);
146
147  // Thin wrappers around the more general EmitAndCallOutlinedKernel above.
148  static void EmitAndCallOutlinedKernel(
149      bool enable_fast_math, bool optimize_for_size,
150      llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
151      llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2,
152      const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
153          kernel_body_generator) {
154    EmitAndCallOutlinedKernel(
155        enable_fast_math, optimize_for_size, ir_builder, kernel_name,
156        {arg0, arg1, arg2}, [&](ArgumentVector args) {
157          kernel_body_generator(args[0], args[1], args[2]);
158        });
159  }
160
161  static void EmitAndCallOutlinedKernel(
162      bool enable_fast_math, bool optimize_for_size,
163      llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
164      llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2,
165      llvm::Value* arg3,
166      const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*,
167                               llvm::Value*)>& kernel_body_generator) {
168    EmitAndCallOutlinedKernel(
169        enable_fast_math, optimize_for_size, ir_builder, kernel_name,
170        {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) {
171          kernel_body_generator(args[0], args[1], args[2], args[3]);
172        });
173  }
174
175 private:
176  llvm::IRBuilder<>* ir_builder_;
177  bool prevent_unrolling_;
178  bool prevent_vectorization_;
179};
180}  // namespace xla
181
182#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_
183