RSScriptGroupFusion.cpp revision 8c12d615b4ed4b1d782722a125dd1d43bc44a71b
1/* 2 * Copyright 2015, The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include "bcc/Renderscript/RSScriptGroupFusion.h" 18 19#include "bcc/Assert.h" 20#include "bcc/BCCContext.h" 21#include "bcc/Source.h" 22#include "bcc/Support/Log.h" 23#include "bcinfo/MetadataExtractor.h" 24#include "llvm/ADT/StringExtras.h" 25#include "llvm/IR/DataLayout.h" 26#include "llvm/IR/IRBuilder.h" 27#include "llvm/IR/Module.h" 28 29using llvm::Function; 30using llvm::Module; 31 32using std::string; 33 34namespace bcc { 35 36namespace { 37 38const Function* getInvokeFunction(const Source& source, const int slot, 39 Module* newModule) { 40 Module* module = const_cast<Module*>(&source.getModule()); 41 bcinfo::MetadataExtractor metadata(module); 42 if (!metadata.extract()) { 43 return nullptr; 44 } 45 const char* functionName = metadata.getExportFuncNameList()[slot]; 46 Function* func = newModule->getFunction(functionName); 47 // Materialize the function so that later the caller can inspect its argument 48 // and return types. 49 newModule->materialize(func); 50 return func; 51} 52 53const Function* 54getFunction(Module* mergedModule, const Source* source, const int slot, 55 uint32_t* signature) { 56 bcinfo::MetadataExtractor metadata(&source->getModule()); 57 metadata.extract(); 58 59 const char* functionName = metadata.getExportForEachNameList()[slot]; 60 if (functionName == nullptr || !functionName[0]) { 61 return nullptr; 62 } 63 64 if (metadata.getExportForEachInputCountList()[slot] > 1) { 65 // TODO: Handle multiple inputs. 66 ALOGW("Kernel %s has multiple inputs", functionName); 67 return nullptr; 68 } 69 70 if (signature != nullptr) { 71 *signature = metadata.getExportForEachSignatureList()[slot]; 72 } 73 74 const Function* function = mergedModule->getFunction(functionName); 75 76 return function; 77} 78 79// The whitelist of supported signature bits. Context or user data arguments are 80// not currently supported in kernel fusion. To support them or any new kinds of 81// arguments in the future, it requires not only listing the signature bits here, 82// but also implementing additional necessary fusion logic in the getFusedFuncSig(), 83// getFusedFuncType(), and fuseKernels() functions below. 84constexpr uint32_t ExpectedSignatureBits = 85 bcinfo::MD_SIG_In | 86 bcinfo::MD_SIG_Out | 87 bcinfo::MD_SIG_X | 88 bcinfo::MD_SIG_Y | 89 bcinfo::MD_SIG_Z | 90 bcinfo::MD_SIG_Kernel; 91 92int getFusedFuncSig(const std::vector<Source*>& sources, 93 const std::vector<int>& slots, 94 uint32_t* retSig) { 95 *retSig = 0; 96 uint32_t firstSignature = 0; 97 uint32_t signature = 0; 98 auto slotIter = slots.begin(); 99 for (const Source* source : sources) { 100 const int slot = *slotIter++; 101 bcinfo::MetadataExtractor metadata(&source->getModule()); 102 metadata.extract(); 103 104 if (metadata.getExportForEachInputCountList()[slot] > 1) { 105 // TODO: Handle multiple inputs in kernel fusion. 106 ALOGW("Kernel %d in source %p has multiple inputs", slot, source); 107 return -1; 108 } 109 110 signature = metadata.getExportForEachSignatureList()[slot]; 111 if (signature & ~ExpectedSignatureBits) { 112 ALOGW("Unexpected signature %x seen while fusing kernels", signature); 113 return -1; 114 } 115 116 if (firstSignature == 0) { 117 firstSignature = signature; 118 } 119 120 *retSig |= signature; 121 } 122 123 if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) { 124 *retSig &= ~bcinfo::MD_SIG_In; 125 } 126 127 if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) { 128 *retSig &= ~bcinfo::MD_SIG_Out; 129 } 130 131 return 0; 132} 133 134llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context, 135 const std::vector<Source*>& sources, 136 const std::vector<int>& slots, 137 Module* M, 138 uint32_t* signature) { 139 int error = getFusedFuncSig(sources, slots, signature); 140 141 if (error < 0) { 142 return nullptr; 143 } 144 145 const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr); 146 147 bccAssert (firstF != nullptr); 148 149 llvm::SmallVector<llvm::Type*, 8> ArgTys; 150 151 if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) { 152 ArgTys.push_back(firstF->arg_begin()->getType()); 153 } 154 155 llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32); 156 if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) { 157 ArgTys.push_back(I32Ty); 158 } 159 if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) { 160 ArgTys.push_back(I32Ty); 161 } 162 if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) { 163 ArgTys.push_back(I32Ty); 164 } 165 166 const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr); 167 168 bccAssert (lastF != nullptr); 169 170 llvm::Type* retTy = lastF->getReturnType(); 171 172 return llvm::FunctionType::get(retTy, ArgTys, false); 173} 174 175} // anonymous namespace 176 177bool fuseKernels(bcc::BCCContext& Context, 178 const std::vector<Source *>& sources, 179 const std::vector<int>& slots, 180 const std::string& fusedName, 181 Module* mergedModule) { 182 bccAssert(sources.size() == slots.size() && "sources and slots differ in size"); 183 184 uint32_t fusedFunctionSignature; 185 186 llvm::FunctionType* fusedType = 187 getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature); 188 189 if (fusedType == nullptr) { 190 return false; 191 } 192 193 Function* fusedKernel = 194 (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType)); 195 196 llvm::LLVMContext& ctxt = Context.getLLVMContext(); 197 198 llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel); 199 llvm::IRBuilder<> builder(block); 200 201 Function::arg_iterator argIter = fusedKernel->arg_begin(); 202 203 llvm::Value* dataElement = nullptr; 204 if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) { 205 dataElement = argIter++; 206 dataElement->setName("DataIn"); 207 } 208 209 llvm::Value* X = nullptr; 210 if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) { 211 X = argIter++; 212 X->setName("x"); 213 } 214 215 llvm::Value* Y = nullptr; 216 if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) { 217 Y = argIter++; 218 Y->setName("y"); 219 } 220 221 llvm::Value* Z = nullptr; 222 if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) { 223 Z = argIter++; 224 Z->setName("z"); 225 } 226 227 auto slotIter = slots.begin(); 228 for (const Source* source : sources) { 229 int slot = *slotIter++; 230 231 uint32_t inputFunctionSignature; 232 const Function* inputFunction = 233 getFunction(mergedModule, source, slot, &inputFunctionSignature); 234 if (inputFunction == nullptr) { 235 return false; 236 } 237 238 // Don't try to fuse a non-kernel 239 if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) { 240 return false; 241 } 242 243 std::vector<llvm::Value*> args; 244 245 if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) { 246 if (dataElement == nullptr) { 247 return false; 248 } 249 250 const llvm::FunctionType* funcTy = inputFunction->getFunctionType(); 251 llvm::Type* firstArgType = funcTy->getParamType(0); 252 253 if (!dataElement->getType()->canLosslesslyBitCastTo(firstArgType)) { 254 return false; 255 } 256 257 args.push_back(dataElement); 258 } else { 259 // Only the first kernel in a batch is allowed to have no input 260 if (slotIter != slots.begin()) { 261 return false; 262 } 263 } 264 265 if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) { 266 args.push_back(X); 267 } 268 269 if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) { 270 args.push_back(Y); 271 } 272 273 if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) { 274 args.push_back(Z); 275 } 276 277 dataElement = builder.CreateCall((llvm::Value*)inputFunction, args); 278 } 279 280 if (fusedKernel->getReturnType()->isVoidTy()) { 281 builder.CreateRetVoid(); 282 } else { 283 builder.CreateRet(dataElement); 284 } 285 286 llvm::NamedMDNode* ExportForEachNameMD = 287 mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name"); 288 289 llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName); 290 llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr); 291 ExportForEachNameMD->addOperand(nameMDNode); 292 293 llvm::NamedMDNode* ExportForEachMD = 294 mergedModule->getOrInsertNamedMetadata("#rs_export_foreach"); 295 llvm::MDString* sigMDStr = llvm::MDString::get(ctxt, 296 llvm::utostr_32(fusedFunctionSignature)); 297 llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr); 298 ExportForEachMD->addOperand(sigMDNode); 299 300 return true; 301} 302 303bool renameInvoke(BCCContext& Context, const Source* source, const int slot, 304 const std::string& newName, Module* module) { 305 const llvm::Function* F = getInvokeFunction(*source, slot, module); 306 std::vector<llvm::Type*> params; 307 for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) { 308 params.push_back(I->getType()); 309 } 310 llvm::Type* returnTy = F->getReturnType(); 311 312 llvm::FunctionType* batchFuncTy = 313 llvm::FunctionType::get(returnTy, params, false); 314 315 llvm::Function* newF = 316 llvm::Function::Create(batchFuncTy, 317 llvm::GlobalValue::ExternalLinkage, newName, 318 module); 319 320 llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(), 321 "entry", newF); 322 llvm::IRBuilder<> builder(block); 323 324 llvm::Function::arg_iterator argIter = newF->arg_begin(); 325 llvm::Value* arg1 = argIter++; 326 builder.CreateCall((llvm::Value*)F, arg1); 327 328 builder.CreateRetVoid(); 329 330 llvm::NamedMDNode* ExportFuncNameMD = 331 module->getOrInsertNamedMetadata("#rs_export_func"); 332 llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName); 333 llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD); 334 ExportFuncNameMD->addOperand(nodeMD); 335 336 return true; 337} 338 339} // namespace bcc 340