RSScriptGroupFusion.cpp revision 8c12d615b4ed4b1d782722a125dd1d43bc44a71b
1/*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "bcc/Renderscript/RSScriptGroupFusion.h"
18
19#include "bcc/Assert.h"
20#include "bcc/BCCContext.h"
21#include "bcc/Source.h"
22#include "bcc/Support/Log.h"
23#include "bcinfo/MetadataExtractor.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/IR/DataLayout.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/Module.h"
28
29using llvm::Function;
30using llvm::Module;
31
32using std::string;
33
34namespace bcc {
35
36namespace {
37
38const Function* getInvokeFunction(const Source& source, const int slot,
39                                  Module* newModule) {
40  Module* module = const_cast<Module*>(&source.getModule());
41  bcinfo::MetadataExtractor metadata(module);
42  if (!metadata.extract()) {
43    return nullptr;
44  }
45  const char* functionName = metadata.getExportFuncNameList()[slot];
46  Function* func = newModule->getFunction(functionName);
47  // Materialize the function so that later the caller can inspect its argument
48  // and return types.
49  newModule->materialize(func);
50  return func;
51}
52
53const Function*
54getFunction(Module* mergedModule, const Source* source, const int slot,
55            uint32_t* signature) {
56  bcinfo::MetadataExtractor metadata(&source->getModule());
57  metadata.extract();
58
59  const char* functionName = metadata.getExportForEachNameList()[slot];
60  if (functionName == nullptr || !functionName[0]) {
61    return nullptr;
62  }
63
64  if (metadata.getExportForEachInputCountList()[slot] > 1) {
65    // TODO: Handle multiple inputs.
66    ALOGW("Kernel %s has multiple inputs", functionName);
67    return nullptr;
68  }
69
70  if (signature != nullptr) {
71    *signature = metadata.getExportForEachSignatureList()[slot];
72  }
73
74  const Function* function = mergedModule->getFunction(functionName);
75
76  return function;
77}
78
79// The whitelist of supported signature bits. Context or user data arguments are
80// not currently supported in kernel fusion. To support them or any new kinds of
81// arguments in the future, it requires not only listing the signature bits here,
82// but also implementing additional necessary fusion logic in the getFusedFuncSig(),
83// getFusedFuncType(), and fuseKernels() functions below.
84constexpr uint32_t ExpectedSignatureBits =
85        bcinfo::MD_SIG_In |
86        bcinfo::MD_SIG_Out |
87        bcinfo::MD_SIG_X |
88        bcinfo::MD_SIG_Y |
89        bcinfo::MD_SIG_Z |
90        bcinfo::MD_SIG_Kernel;
91
92int getFusedFuncSig(const std::vector<Source*>& sources,
93                    const std::vector<int>& slots,
94                    uint32_t* retSig) {
95  *retSig = 0;
96  uint32_t firstSignature = 0;
97  uint32_t signature = 0;
98  auto slotIter = slots.begin();
99  for (const Source* source : sources) {
100    const int slot = *slotIter++;
101    bcinfo::MetadataExtractor metadata(&source->getModule());
102    metadata.extract();
103
104    if (metadata.getExportForEachInputCountList()[slot] > 1) {
105      // TODO: Handle multiple inputs in kernel fusion.
106      ALOGW("Kernel %d in source %p has multiple inputs", slot, source);
107      return -1;
108    }
109
110    signature = metadata.getExportForEachSignatureList()[slot];
111    if (signature & ~ExpectedSignatureBits) {
112      ALOGW("Unexpected signature %x seen while fusing kernels", signature);
113      return -1;
114    }
115
116    if (firstSignature == 0) {
117      firstSignature = signature;
118    }
119
120    *retSig |= signature;
121  }
122
123  if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) {
124    *retSig &= ~bcinfo::MD_SIG_In;
125  }
126
127  if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) {
128    *retSig &= ~bcinfo::MD_SIG_Out;
129  }
130
131  return 0;
132}
133
134llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context,
135                                     const std::vector<Source*>& sources,
136                                     const std::vector<int>& slots,
137                                     Module* M,
138                                     uint32_t* signature) {
139  int error = getFusedFuncSig(sources, slots, signature);
140
141  if (error < 0) {
142    return nullptr;
143  }
144
145  const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr);
146
147  bccAssert (firstF != nullptr);
148
149  llvm::SmallVector<llvm::Type*, 8> ArgTys;
150
151  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) {
152    ArgTys.push_back(firstF->arg_begin()->getType());
153  }
154
155  llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32);
156  if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) {
157    ArgTys.push_back(I32Ty);
158  }
159  if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) {
160    ArgTys.push_back(I32Ty);
161  }
162  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) {
163    ArgTys.push_back(I32Ty);
164  }
165
166  const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr);
167
168  bccAssert (lastF != nullptr);
169
170  llvm::Type* retTy = lastF->getReturnType();
171
172  return llvm::FunctionType::get(retTy, ArgTys, false);
173}
174
175}  // anonymous namespace
176
177bool fuseKernels(bcc::BCCContext& Context,
178                 const std::vector<Source *>& sources,
179                 const std::vector<int>& slots,
180                 const std::string& fusedName,
181                 Module* mergedModule) {
182  bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
183
184  uint32_t fusedFunctionSignature;
185
186  llvm::FunctionType* fusedType =
187          getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);
188
189  if (fusedType == nullptr) {
190    return false;
191  }
192
193  Function* fusedKernel =
194          (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType));
195
196  llvm::LLVMContext& ctxt = Context.getLLVMContext();
197
198  llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel);
199  llvm::IRBuilder<> builder(block);
200
201  Function::arg_iterator argIter = fusedKernel->arg_begin();
202
203  llvm::Value* dataElement = nullptr;
204  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
205    dataElement = argIter++;
206    dataElement->setName("DataIn");
207  }
208
209  llvm::Value* X = nullptr;
210  if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
211    X = argIter++;
212    X->setName("x");
213  }
214
215  llvm::Value* Y = nullptr;
216  if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
217    Y = argIter++;
218    Y->setName("y");
219  }
220
221  llvm::Value* Z = nullptr;
222  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
223    Z = argIter++;
224    Z->setName("z");
225  }
226
227  auto slotIter = slots.begin();
228  for (const Source* source : sources) {
229    int slot = *slotIter++;
230
231    uint32_t inputFunctionSignature;
232    const Function* inputFunction =
233            getFunction(mergedModule, source, slot, &inputFunctionSignature);
234    if (inputFunction == nullptr) {
235      return false;
236    }
237
238    // Don't try to fuse a non-kernel
239    if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
240      return false;
241    }
242
243    std::vector<llvm::Value*> args;
244
245    if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
246      if (dataElement == nullptr) {
247        return false;
248      }
249
250      const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
251      llvm::Type* firstArgType = funcTy->getParamType(0);
252
253      if (!dataElement->getType()->canLosslesslyBitCastTo(firstArgType)) {
254        return false;
255      }
256
257      args.push_back(dataElement);
258    } else {
259      // Only the first kernel in a batch is allowed to have no input
260      if (slotIter != slots.begin()) {
261        return false;
262      }
263    }
264
265    if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
266      args.push_back(X);
267    }
268
269    if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
270      args.push_back(Y);
271    }
272
273    if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
274      args.push_back(Z);
275    }
276
277    dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);
278  }
279
280  if (fusedKernel->getReturnType()->isVoidTy()) {
281    builder.CreateRetVoid();
282  } else {
283    builder.CreateRet(dataElement);
284  }
285
286  llvm::NamedMDNode* ExportForEachNameMD =
287    mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name");
288
289  llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName);
290  llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr);
291  ExportForEachNameMD->addOperand(nameMDNode);
292
293  llvm::NamedMDNode* ExportForEachMD =
294    mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
295  llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
296                                                 llvm::utostr_32(fusedFunctionSignature));
297  llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
298  ExportForEachMD->addOperand(sigMDNode);
299
300  return true;
301}
302
303bool renameInvoke(BCCContext& Context, const Source* source, const int slot,
304                  const std::string& newName, Module* module) {
305  const llvm::Function* F = getInvokeFunction(*source, slot, module);
306  std::vector<llvm::Type*> params;
307  for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
308    params.push_back(I->getType());
309  }
310  llvm::Type* returnTy = F->getReturnType();
311
312  llvm::FunctionType* batchFuncTy =
313          llvm::FunctionType::get(returnTy, params, false);
314
315  llvm::Function* newF =
316          llvm::Function::Create(batchFuncTy,
317                                 llvm::GlobalValue::ExternalLinkage, newName,
318                                 module);
319
320  llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(),
321                                                     "entry", newF);
322  llvm::IRBuilder<> builder(block);
323
324  llvm::Function::arg_iterator argIter = newF->arg_begin();
325  llvm::Value* arg1 = argIter++;
326  builder.CreateCall((llvm::Value*)F, arg1);
327
328  builder.CreateRetVoid();
329
330  llvm::NamedMDNode* ExportFuncNameMD =
331          module->getOrInsertNamedMetadata("#rs_export_func");
332  llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName);
333  llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD);
334  ExportFuncNameMD->addOperand(nodeMD);
335
336  return true;
337}
338
339}  // namespace bcc
340