1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ 17#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ 18 19#include <string> 20 21#include "llvm/IR/BasicBlock.h" 22#include "llvm/IR/IRBuilder.h" 23#include "llvm/IR/Value.h" 24#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" 25#include "tensorflow/core/lib/core/stringpiece.h" 26 27namespace xla { 28// A thin wrapper around llvm_loop.h to make code generating structured control 29// flow more readable. 30class KernelSupportLibrary { 31 public: 32 // `ir_builder` is the llvm::IRBuilder instance used to generate LLVM IR. 33 // If `prevent_unrolling` is true then unrolling is explicitly disabled on 34 // every loop generated by this instance of KernelSupportLibrary. 35 explicit KernelSupportLibrary(llvm::IRBuilder<>* ir_builder, 36 bool prevent_unrolling = true, 37 bool prevent_vectorization = true) 38 : ir_builder_(ir_builder), 39 prevent_unrolling_(prevent_unrolling), 40 prevent_vectorization_(prevent_vectorization) {} 41 42 // Generates the following control flow structure: 43 // 44 // if (`start` < `end`) { 45 // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/true)`; 46 // for (i64 i = `start` + `step`; i s< `end`; i += `step`) 47 // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/false)`; 48 // } 49 void For( 50 tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, 51 llvm::Value* step, 52 const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& 53 for_body_generator); 54 55 void For( 56 tensorflow::StringPiece name, int64 start, int64 end, int64 step, 57 const std::function<void(llvm::Value* ind_var, bool is_first_iteration)>& 58 for_body_generator) { 59 For(name, /*start=*/ir_builder_->getInt64(start), 60 /*end=*/ir_builder_->getInt64(end), 61 /*step=*/ir_builder_->getInt64(step), for_body_generator); 62 } 63 64 // Generates the following control flow structure if `peel_first_iteration` is 65 // true: 66 // 67 // if (`start` < `end`) { 68 // `for_body_generator(/*ind_var=*/start, /*is_first_iteration=*/,true)`; 69 // for (i64 i = `start` + `step`; i s< `end`; i += `step`) 70 // `for_body_generator(/*ind_var=*/,i, /*is_first_iteration=*/,false)`; 71 // } 72 // 73 // and the following if `peel_first_iteration` is false: 74 // 75 // for (i64 i = `start`; i s< `end`; i += `step`) 76 // `for_body_generator(/*ind_var=*/,i, 77 // /*is_first_iteration=*/,(i != `start`))`; 78 void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, 79 llvm::Value* step, bool peel_first_iteration, 80 const std::function<void(llvm::Value* ind_var, 81 llvm::Value* is_first_iteration)>& 82 for_body_generator); 83 84 void For(tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, 85 int64 step, bool peel_first_iteration, 86 const std::function<void(llvm::Value* ind_var, 87 llvm::Value* is_first_iteration)>& 88 for_body_generator) { 89 For(name, /*start=*/start, /*end=*/end, 90 /*step=*/ir_builder_->getInt64(step), peel_first_iteration, 91 for_body_generator); 92 } 93 94 void For( 95 tensorflow::StringPiece name, llvm::Value* start, llvm::Value* end, 96 llvm::Value* step, 97 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 98 For(name, start, end, step, 99 /*peel_first_iteration=*/false, 100 [&](llvm::Value* indvar, llvm::Value*) { for_body_generator(indvar); }); 101 } 102 103 void For( 104 tensorflow::StringPiece name, int64 start, int64 end, int64 step, 105 const std::function<void(llvm::Value* ind_var)>& for_body_generator) { 106 For(name, /*start=*/ir_builder_->getInt64(start), 107 /*end=*/ir_builder_->getInt64(end), 108 /*step=*/ir_builder_->getInt64(step), for_body_generator); 109 } 110 111 // Generates the following control flow structure: 112 // 113 // if (`condition`) 114 // `true_block_generator()`; 115 // else 116 // `false_block_generator()`; 117 void If(llvm::Value* condition, 118 const std::function<void()>& true_block_generator, 119 const std::function<void()>& false_block_generator = []() {}); 120 121 using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>; 122 123 // Generates the following control flow structure: 124 // 125 // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { 126 // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); 127 // } 128 // 129 // ... 130 // call @`kernel_name`(arguments[0], arguments[1] ...) 131 // ... 132 // 133 // If a function called `kernel_name` is already present in the module then 134 // that function is re-used. In that sense we're using the llvm::Module as a 135 // cache of outlined kernels, keyed by function name. 136 // 137 // If any of the values in `arguments` is nullptr (i.e. a nullptr 138 // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass 139 // in a nullptr llvm::Value* in its position to `kernel_body_generator`. 140 // Currently we only support at most one nullptr value in `arguments`. 141 static void EmitAndCallOutlinedKernel( 142 bool enable_fast_math, bool optimize_for_size, 143 llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, 144 ArgumentVector arguments, 145 const std::function<void(ArgumentVector)>& kernel_body_generator); 146 147 // Thin wrappers around the more general EmitAndCallOutlinedKernel above. 148 static void EmitAndCallOutlinedKernel( 149 bool enable_fast_math, bool optimize_for_size, 150 llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, 151 llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, 152 const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>& 153 kernel_body_generator) { 154 EmitAndCallOutlinedKernel( 155 enable_fast_math, optimize_for_size, ir_builder, kernel_name, 156 {arg0, arg1, arg2}, [&](ArgumentVector args) { 157 kernel_body_generator(args[0], args[1], args[2]); 158 }); 159 } 160 161 static void EmitAndCallOutlinedKernel( 162 bool enable_fast_math, bool optimize_for_size, 163 llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, 164 llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, 165 llvm::Value* arg3, 166 const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*, 167 llvm::Value*)>& kernel_body_generator) { 168 EmitAndCallOutlinedKernel( 169 enable_fast_math, optimize_for_size, ir_builder, kernel_name, 170 {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) { 171 kernel_body_generator(args[0], args[1], args[2], args[3]); 172 }); 173 } 174 175 private: 176 llvm::IRBuilder<>* ir_builder_; 177 bool prevent_unrolling_; 178 bool prevent_vectorization_; 179}; 180} // namespace xla 181 182#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ 183