1/*
2 * Copyright 2015, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "bcc/Renderscript/RSScriptGroupFusion.h"
18
19#include "bcc/Assert.h"
20#include "bcc/BCCContext.h"
21#include "bcc/Source.h"
22#include "bcc/Support/Log.h"
23#include "bcinfo/MetadataExtractor.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/IR/DataLayout.h"
26#include "llvm/IR/IRBuilder.h"
27#include "llvm/IR/Module.h"
28#include "llvm/Support/raw_ostream.h"
29
30using llvm::Function;
31using llvm::Module;
32
33using std::string;
34
35namespace bcc {
36
37namespace {
38
39const Function* getInvokeFunction(const Source& source, const int slot,
40                                  Module* newModule) {
41
42  bcinfo::MetadataExtractor &metadata = *source.getMetadata();
43  const char* functionName = metadata.getExportFuncNameList()[slot];
44  Function* func = newModule->getFunction(functionName);
45  // Materialize the function so that later the caller can inspect its argument
46  // and return types.
47  newModule->materialize(func);
48  return func;
49}
50
51const Function*
52getFunction(Module* mergedModule, const Source* source, const int slot,
53            uint32_t* signature) {
54
55  bcinfo::MetadataExtractor &metadata = *source->getMetadata();
56  const char* functionName = metadata.getExportForEachNameList()[slot];
57  if (functionName == nullptr || !functionName[0]) {
58    ALOGE("Kernel fusion (module %s slot %d): failed to find kernel function",
59          source->getName().c_str(), slot);
60    return nullptr;
61  }
62
63  if (metadata.getExportForEachInputCountList()[slot] > 1) {
64    ALOGE("Kernel fusion (module %s function %s): cannot handle multiple inputs",
65          source->getName().c_str(), functionName);
66    return nullptr;
67  }
68
69  if (signature != nullptr) {
70    *signature = metadata.getExportForEachSignatureList()[slot];
71  }
72
73  const Function* function = mergedModule->getFunction(functionName);
74
75  return function;
76}
77
78// The whitelist of supported signature bits. Context or user data arguments are
79// not currently supported in kernel fusion. To support them or any new kinds of
80// arguments in the future, it requires not only listing the signature bits here,
81// but also implementing additional necessary fusion logic in the getFusedFuncSig(),
82// getFusedFuncType(), and fuseKernels() functions below.
83constexpr uint32_t ExpectedSignatureBits =
84        bcinfo::MD_SIG_In |
85        bcinfo::MD_SIG_Out |
86        bcinfo::MD_SIG_X |
87        bcinfo::MD_SIG_Y |
88        bcinfo::MD_SIG_Z |
89        bcinfo::MD_SIG_Kernel;
90
91int getFusedFuncSig(const std::vector<Source*>& sources,
92                    const std::vector<int>& slots,
93                    uint32_t* retSig) {
94  *retSig = 0;
95  uint32_t firstSignature = 0;
96  uint32_t signature = 0;
97  auto slotIter = slots.begin();
98  for (const Source* source : sources) {
99    const int slot = *slotIter++;
100    bcinfo::MetadataExtractor &metadata = *source->getMetadata();
101
102    if (metadata.getExportForEachInputCountList()[slot] > 1) {
103      ALOGE("Kernel fusion (module %s slot %d): cannot handle multiple inputs",
104            source->getName().c_str(), slot);
105      return -1;
106    }
107
108    signature = metadata.getExportForEachSignatureList()[slot];
109    if (signature & ~ExpectedSignatureBits) {
110      ALOGE("Kernel fusion (module %s slot %d): Unexpected signature %x",
111            source->getName().c_str(), slot, signature);
112      return -1;
113    }
114
115    if (firstSignature == 0) {
116      firstSignature = signature;
117    }
118
119    *retSig |= signature;
120  }
121
122  if (!bcinfo::MetadataExtractor::hasForEachSignatureIn(firstSignature)) {
123    *retSig &= ~bcinfo::MD_SIG_In;
124  }
125
126  if (!bcinfo::MetadataExtractor::hasForEachSignatureOut(signature)) {
127    *retSig &= ~bcinfo::MD_SIG_Out;
128  }
129
130  return 0;
131}
132
133llvm::FunctionType* getFusedFuncType(bcc::BCCContext& Context,
134                                     const std::vector<Source*>& sources,
135                                     const std::vector<int>& slots,
136                                     Module* M,
137                                     uint32_t* signature) {
138  int error = getFusedFuncSig(sources, slots, signature);
139
140  if (error < 0) {
141    return nullptr;
142  }
143
144  const Function* firstF = getFunction(M, sources.front(), slots.front(), nullptr);
145
146  bccAssert (firstF != nullptr);
147
148  llvm::SmallVector<llvm::Type*, 8> ArgTys;
149
150  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(*signature)) {
151    ArgTys.push_back(firstF->arg_begin()->getType());
152  }
153
154  llvm::Type* I32Ty = llvm::IntegerType::get(Context.getLLVMContext(), 32);
155  if (bcinfo::MetadataExtractor::hasForEachSignatureX(*signature)) {
156    ArgTys.push_back(I32Ty);
157  }
158  if (bcinfo::MetadataExtractor::hasForEachSignatureY(*signature)) {
159    ArgTys.push_back(I32Ty);
160  }
161  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(*signature)) {
162    ArgTys.push_back(I32Ty);
163  }
164
165  const Function* lastF = getFunction(M, sources.back(), slots.back(), nullptr);
166
167  bccAssert (lastF != nullptr);
168
169  llvm::Type* retTy = lastF->getReturnType();
170
171  return llvm::FunctionType::get(retTy, ArgTys, false);
172}
173
174}  // anonymous namespace
175
176bool fuseKernels(bcc::BCCContext& Context,
177                 const std::vector<Source *>& sources,
178                 const std::vector<int>& slots,
179                 const std::string& fusedName,
180                 Module* mergedModule) {
181  bccAssert(sources.size() == slots.size() && "sources and slots differ in size");
182
183  uint32_t fusedFunctionSignature;
184
185  llvm::FunctionType* fusedType =
186          getFusedFuncType(Context, sources, slots, mergedModule, &fusedFunctionSignature);
187
188  if (fusedType == nullptr) {
189    return false;
190  }
191
192  Function* fusedKernel =
193          (Function*)(mergedModule->getOrInsertFunction(fusedName, fusedType));
194
195  llvm::LLVMContext& ctxt = Context.getLLVMContext();
196
197  llvm::BasicBlock* block = llvm::BasicBlock::Create(ctxt, "entry", fusedKernel);
198  llvm::IRBuilder<> builder(block);
199
200  Function::arg_iterator argIter = fusedKernel->arg_begin();
201
202  llvm::Value* dataElement = nullptr;
203  if (bcinfo::MetadataExtractor::hasForEachSignatureIn(fusedFunctionSignature)) {
204    dataElement = &*(argIter++);
205    dataElement->setName("DataIn");
206  }
207
208  llvm::Value* X = nullptr;
209  if (bcinfo::MetadataExtractor::hasForEachSignatureX(fusedFunctionSignature)) {
210    X = &*(argIter++);
211    X->setName("x");
212  }
213
214  llvm::Value* Y = nullptr;
215  if (bcinfo::MetadataExtractor::hasForEachSignatureY(fusedFunctionSignature)) {
216    Y = &*(argIter++);
217    Y->setName("y");
218  }
219
220  llvm::Value* Z = nullptr;
221  if (bcinfo::MetadataExtractor::hasForEachSignatureZ(fusedFunctionSignature)) {
222    Z = &*(argIter++);
223    Z->setName("z");
224  }
225
226  auto slotIter = slots.begin();
227  for (const Source* source : sources) {
228    int slot = *slotIter;
229
230    uint32_t inputFunctionSignature;
231    const Function* inputFunction =
232            getFunction(mergedModule, source, slot, &inputFunctionSignature);
233    if (inputFunction == nullptr) {
234      // Either failed to find the kernel function, or the function has multiple inputs.
235      return false;
236    }
237
238    // Don't try to fuse a non-kernel
239    if (!bcinfo::MetadataExtractor::hasForEachSignatureKernel(inputFunctionSignature)) {
240      ALOGE("Kernel fusion (module %s function %s): not a kernel",
241            source->getName().c_str(), inputFunction->getName().str().c_str());
242      return false;
243    }
244
245    std::vector<llvm::Value*> args;
246
247    if (bcinfo::MetadataExtractor::hasForEachSignatureIn(inputFunctionSignature)) {
248      if (dataElement == nullptr) {
249        ALOGE("Kernel fusion (module %s function %s): expected input, but got null",
250              source->getName().c_str(), inputFunction->getName().str().c_str());
251        return false;
252      }
253
254      const llvm::FunctionType* funcTy = inputFunction->getFunctionType();
255      llvm::Type* firstArgType = funcTy->getParamType(0);
256
257      if (dataElement->getType() != firstArgType) {
258        std::string msg;
259        llvm::raw_string_ostream rso(msg);
260        rso << "Mismatching argument type, expected ";
261        firstArgType->print(rso);
262        rso << ", received ";
263        dataElement->getType()->print(rso);
264        ALOGE("Kernel fusion (module %s function %s): %s", source->getName().c_str(),
265              inputFunction->getName().str().c_str(), rso.str().c_str());
266        return false;
267      }
268
269      args.push_back(dataElement);
270    } else {
271      // Only the first kernel in a batch is allowed to have no input
272      if (slotIter != slots.begin()) {
273        ALOGE("Kernel fusion (module %s function %s): function not first in batch takes no input",
274              source->getName().c_str(), inputFunction->getName().str().c_str());
275        return false;
276      }
277    }
278
279    if (bcinfo::MetadataExtractor::hasForEachSignatureX(inputFunctionSignature)) {
280      args.push_back(X);
281    }
282
283    if (bcinfo::MetadataExtractor::hasForEachSignatureY(inputFunctionSignature)) {
284      args.push_back(Y);
285    }
286
287    if (bcinfo::MetadataExtractor::hasForEachSignatureZ(inputFunctionSignature)) {
288      args.push_back(Z);
289    }
290
291    dataElement = builder.CreateCall((llvm::Value*)inputFunction, args);
292
293    slotIter++;
294  }
295
296  if (fusedKernel->getReturnType()->isVoidTy()) {
297    builder.CreateRetVoid();
298  } else {
299    builder.CreateRet(dataElement);
300  }
301
302  llvm::NamedMDNode* ExportForEachNameMD =
303    mergedModule->getOrInsertNamedMetadata("#rs_export_foreach_name");
304
305  llvm::MDString* nameMDStr = llvm::MDString::get(ctxt, fusedName);
306  llvm::MDNode* nameMDNode = llvm::MDNode::get(ctxt, nameMDStr);
307  ExportForEachNameMD->addOperand(nameMDNode);
308
309  llvm::NamedMDNode* ExportForEachMD =
310    mergedModule->getOrInsertNamedMetadata("#rs_export_foreach");
311  llvm::MDString* sigMDStr = llvm::MDString::get(ctxt,
312                                                 llvm::utostr_32(fusedFunctionSignature));
313  llvm::MDNode* sigMDNode = llvm::MDNode::get(ctxt, sigMDStr);
314  ExportForEachMD->addOperand(sigMDNode);
315
316  return true;
317}
318
319bool renameInvoke(BCCContext& Context, const Source* source, const int slot,
320                  const std::string& newName, Module* module) {
321  const llvm::Function* F = getInvokeFunction(*source, slot, module);
322  std::vector<llvm::Type*> params;
323  for (auto I = F->arg_begin(), E = F->arg_end(); I != E; ++I) {
324    params.push_back(I->getType());
325  }
326  llvm::Type* returnTy = F->getReturnType();
327
328  llvm::FunctionType* batchFuncTy =
329          llvm::FunctionType::get(returnTy, params, false);
330
331  llvm::Function* newF =
332          llvm::Function::Create(batchFuncTy,
333                                 llvm::GlobalValue::ExternalLinkage, newName,
334                                 module);
335
336  llvm::BasicBlock* block = llvm::BasicBlock::Create(Context.getLLVMContext(),
337                                                     "entry", newF);
338  llvm::IRBuilder<> builder(block);
339
340  llvm::Function::arg_iterator argIter = newF->arg_begin();
341  llvm::Value* arg1 = &*(argIter++);
342  builder.CreateCall((llvm::Value*)F, arg1);
343
344  builder.CreateRetVoid();
345
346  llvm::NamedMDNode* ExportFuncNameMD =
347          module->getOrInsertNamedMetadata("#rs_export_func");
348  llvm::MDString* strMD = llvm::MDString::get(module->getContext(), newName);
349  llvm::MDNode* nodeMD = llvm::MDNode::get(module->getContext(), strMD);
350  ExportFuncNameMD->addOperand(nodeMD);
351
352  return true;
353}
354
355}  // namespace bcc
356