RSScriptGroupFusion.cpp revision 531d08c85971e47f58aedc093fbe83f1b909703e
1/*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "bcc/Renderscript/RSScriptGroupFusion.h"
18
19#include "bcc/Assert.h"
20#include "bcc/BCCContext.h"
21#include "bcc/Source.h"
22#include "bcc/Support/Log.h"
23#include "bcinfo/MetadataExtractor.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/IR/DataLayout.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/Module.h"
28#include "llvm/Support/raw_ostream.h"
29
30using llvm::Function;
31using llvm::Module;
32
33using std::string;
34
35namespace bcc {
36
37namespace {
38
39const Function* getInvokeFunction(const Source& source, const int slot,
40                                  Module* newModule) {
41  Module* module = const_cast<Module*>(&source.getModule());
42  bcinfo::MetadataExtractor metadata(module);
43  if (!metadata.extract()) {
44    ALOGE("Kernel fusion (module %s slot %d): failed to extract metadata",
45          source.getName().c_str(), slot);
46    return nullptr;
47  }
48  const char* functionName = metadata.getExportFuncNameList()[slot];
49  Function* func = newModule->getFunction(functionName);
50  // Materialize the function so that later the caller can inspect its argument
51  // and return types.
52  newModule->materialize(func);
53  return func;
54}
55
56const Function*
57getFunction(Module* mergedModule, const Source* source, const int slot,
58            uint32_t* signature) {
59  bcinfo::MetadataExtractor metadata(&source->getModule());
60  metadata.extract();
61
62  const char* functionName = metadata.getExportForEachNameList()[slot];
63  if (functionName == nullptr || !functionName[0]) {
64    ALOGE("Kernel fusion (module %s slot %d): failed to find kernel function",
65          source->getName().c_str(), slot);
66    return nullptr;
67  }
68
69  if (metadata.getExportForEachInputCountList()[slot] > 1) {
70    ALOGE("Kernel fusion (module %s function %s): cannot handle multiple inputs",
71          source->getName().c_str(), functionName);
72    return nullptr;
73  }
74
75  if (signature != nullptr) {
76    *signature = metadata.getExportForEachSignatureList()[slot];
77  }
78
79  const Function* function = mergedModule->getFunction(functionName);
80
81  return function;
82}
83
84// The whitelist of supported signature bits. Context or user data arguments are
85// not currently supported in kernel fusion. To support them or any new kinds of
86// arguments in the future, it requires not only listing the signature bits here,
87// but also implementing additional necessary fusion logic in the getFusedFuncSig(),
88// getFusedFuncType(), and fuseKernels() functions below.
89constexpr uint32_t ExpectedSignatureBits =
90        bcinfo::MD_SIG_In |
91        bcinfo::MD_SIG_Out |
92        bcinfo::MD_SIG_X |
93        bcinfo::MD_SIG_Y |
94        bcinfo::MD_SIG_Z |
95        bcinfo::MD_SIG_Kernel;
96
97int getFusedFuncSig(const std::vector<Source*>& sources,
98                    const std::vector<int>& slots,
99                    uint32_t* retSig) {
100  *retSig = 0;
101  uint32_t firstSignature = 0;
102  uint32_t signature = 0;
103  auto slotIter = slots.begin();
104  for (const Source* source : sources) {
105    const int slot = *slotIter++;
106    bcinfo::MetadataExtractor metadata(&source->getModule());
107    metadata.extract();
108
109    if (metadata.getExportForEachInputCountList()[slot] > 1) {
110      ALOGE("Kernel fusion (module %s slot %d): cannot handle multiple inputs",
111            source->getName().c_str(), slot);
112      return -1;
113    }
114
115    signature = metadata.getExportForEachSignatureList()[slot];
116    if (signature & ~ExpectedSignatureBits) {
117      ALOGE("Kernel fusion (module %s slot %d): Unexpected signature %x",
118            source->getName().c_str(), slot, signature);
119      return -1;
120    }
121
122    if (firstSignature == 0) {
123      firstSignature = signature;
124    }
125
126    *retSig |= signature;
127  }
128
129  if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) {
130    *retSig &= ~bcinfo::MD_SIG_In;
131  }
132
133  if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) {
134    *retSig &= ~bcinfo::MD_SIG_Out;
135  }
136
137  return 0;
138}
139
140llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context,
141                                     const std::vector<Source*>& sources,
142                                     const std::vector<int>& slots,
143                                     Module* M,
144                                     uint32_t* signature) {
145  int error = getFusedFuncSig(sources, slots, signature);
146
147  if (error < 0) {
148    return nullptr;
149  }
150
151  const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr);
152
153  bccAssert (firstF != nullptr);
154
155  llvm::SmallVector<llvm::Type*, 8> ArgTys;
156
157  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) {
158    ArgTys.push_back(firstF->arg_begin()->getType());
159  }
160
161  llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32);
162  if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) {
163    ArgTys.push_back(I32Ty);
164  }
165  if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) {
166    ArgTys.push_back(I32Ty);
167  }
168  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) {
169    ArgTys.push_back(I32Ty);
170  }
171
172  const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr);
173
174  bccAssert (lastF != nullptr);
175
176  llvm::Type* retTy = lastF->getReturnType();
177
178  return llvm::FunctionType::get(retTy, ArgTys, false);
179}
180
181}  // anonymous namespace
182
183bool fuseKernels(bcc::BCCContext& Context,
184                 const std::vector<Source *>& sources,
185                 const std::vector<int>& slots,
186                 const std::string& fusedName,
187                 Module* mergedModule) {
188  bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
189
190  uint32_t fusedFunctionSignature;
191
192  llvm::FunctionType* fusedType =
193          getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);
194
195  if (fusedType == nullptr) {
196    return false;
197  }
198
199  Function* fusedKernel =
200          (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType));
201
202  llvm::LLVMContext& ctxt = Context.getLLVMContext();
203
204  llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel);
205  llvm::IRBuilder<> builder(block);
206
207  Function::arg_iterator argIter = fusedKernel->arg_begin();
208
209  llvm::Value* dataElement = nullptr;
210  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
211    dataElement = argIter++;
212    dataElement->setName("DataIn");
213  }
214
215  llvm::Value* X = nullptr;
216  if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
217    X = argIter++;
218    X->setName("x");
219  }
220
221  llvm::Value* Y = nullptr;
222  if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
223    Y = argIter++;
224    Y->setName("y");
225  }
226
227  llvm::Value* Z = nullptr;
228  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
229    Z = argIter++;
230    Z->setName("z");
231  }
232
233  auto slotIter = slots.begin();
234  for (const Source* source : sources) {
235    int slot = *slotIter;
236
237    uint32_t inputFunctionSignature;
238    const Function* inputFunction =
239            getFunction(mergedModule, source, slot, &inputFunctionSignature);
240    if (inputFunction == nullptr) {
241      // Either failed to find the kernel function, or the function has multiple inputs.
242      return false;
243    }
244
245    // Don't try to fuse a non-kernel
246    if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
247      ALOGE("Kernel fusion (module %s function %s): not a kernel",
248            source->getName().c_str(), inputFunction->getName().str().c_str());
249      return false;
250    }
251
252    std::vector<llvm::Value*> args;
253
254    if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
255      if (dataElement == nullptr) {
256        ALOGE("Kernel fusion (module %s function %s): expected input, but got null",
257              source->getName().c_str(), inputFunction->getName().str().c_str());
258        return false;
259      }
260
261      const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
262      llvm::Type* firstArgType = funcTy->getParamType(0);
263
264      if (dataElement->getType() != firstArgType) {
265        std::string msg;
266        llvm::raw_string_ostream rso(msg);
267        rso << "Mismatching argument type, expected ";
268        firstArgType->print(rso);
269        rso << ", received ";
270        dataElement->getType()->print(rso);
271        ALOGE("Kernel fusion (module %s function %s): %s", source->getName().c_str(),
272              inputFunction->getName().str().c_str(), rso.str().c_str());
273        return false;
274      }
275
276      args.push_back(dataElement);
277    } else {
278      // Only the first kernel in a batch is allowed to have no input
279      if (slotIter != slots.begin()) {
280        ALOGE("Kernel fusion (module %s function %s): function not first in batch takes no input",
281              source->getName().c_str(), inputFunction->getName().str().c_str());
282        return false;
283      }
284    }
285
286    if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
287      args.push_back(X);
288    }
289
290    if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
291      args.push_back(Y);
292    }
293
294    if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
295      args.push_back(Z);
296    }
297
298    dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);
299
300    slotIter++;
301  }
302
303  if (fusedKernel->getReturnType()->isVoidTy()) {
304    builder.CreateRetVoid();
305  } else {
306    builder.CreateRet(dataElement);
307  }
308
309  llvm::NamedMDNode* ExportForEachNameMD =
310    mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name");
311
312  llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName);
313  llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr);
314  ExportForEachNameMD->addOperand(nameMDNode);
315
316  llvm::NamedMDNode* ExportForEachMD =
317    mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
318  llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
319                                                 llvm::utostr_32(fusedFunctionSignature));
320  llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
321  ExportForEachMD->addOperand(sigMDNode);
322
323  return true;
324}
325
326bool renameInvoke(BCCContext& Context, const Source* source, const int slot,
327                  const std::string& newName, Module* module) {
328  const llvm::Function* F = getInvokeFunction(*source, slot, module);
329  std::vector<llvm::Type*> params;
330  for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
331    params.push_back(I->getType());
332  }
333  llvm::Type* returnTy = F->getReturnType();
334
335  llvm::FunctionType* batchFuncTy =
336          llvm::FunctionType::get(returnTy, params, false);
337
338  llvm::Function* newF =
339          llvm::Function::Create(batchFuncTy,
340                                 llvm::GlobalValue::ExternalLinkage, newName,
341                                 module);
342
343  llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(),
344                                                     "entry", newF);
345  llvm::IRBuilder<> builder(block);
346
347  llvm::Function::arg_iterator argIter = newF->arg_begin();
348  llvm::Value* arg1 = argIter++;
349  builder.CreateCall((llvm::Value*)F, arg1);
350
351  builder.CreateRetVoid();
352
353  llvm::NamedMDNode* ExportFuncNameMD =
354          module->getOrInsertNamedMetadata("#rs_export_func");
355  llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName);
356  llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD);
357  ExportFuncNameMD->addOperand(nodeMD);
358
359  return true;
360}
361
362}  // namespace bcc
363