RSForEachExpand.cpp revision 2b04086acbef6520ae2c54a868b1271abf053122
1/*
2 * Copyright 2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "bcc/Assert.h"
18#include "bcc/Renderscript/RSTransforms.h"
19
20#include <cstdlib>
21
22#include <llvm/DerivedTypes.h>
23#include <llvm/Function.h>
24#include <llvm/Instructions.h>
25#include <llvm/Module.h>
26#include <llvm/Pass.h>
27#include <llvm/Support/IRBuilder.h>
28#include <llvm/Target/TargetData.h>
29#include <llvm/Type.h>
30
31#include "bcc/Config/Config.h"
32#include "bcc/Renderscript/RSInfo.h"
33#include "bcc/Support/Log.h"
34
35using namespace bcc;
36
37namespace {
38
39/* RSForEachExpandPass - This pass operates on functions that are able to be
40 * called via rsForEach() or "foreach_<NAME>". We create an inner loop for the
41 * ForEach-able function to be invoked over the appropriate data cells of the
42 * input/output allocations (adjusting other relevant parameters as we go). We
43 * support doing this for any ForEach-able compute kernels. The new function
44 * name is the original function name followed by ".expand". Note that we
45 * still generate code for the original function.
46 */
47class RSForEachExpandPass : public llvm::ModulePass {
48private:
49  static char ID;
50
51  llvm::Module *M;
52  llvm::LLVMContext *C;
53
54  const RSInfo::ExportForeachFuncListTy &mFuncs;
55
56  // Turns on optimization of allocation stride values.
57  bool mEnableStepOpt;
58
59  uint32_t getRootSignature(llvm::Function *F) {
60    const llvm::NamedMDNode *ExportForEachMetadata =
61        M->getNamedMetadata("#rs_export_foreach");
62
63    if (!ExportForEachMetadata) {
64      llvm::SmallVector<llvm::Type*, 8> RootArgTys;
65      for (llvm::Function::arg_iterator B = F->arg_begin(),
66                                        E = F->arg_end();
67           B != E;
68           ++B) {
69        RootArgTys.push_back(B->getType());
70      }
71
72      // For pre-ICS bitcode, we may not have signature information. In that
73      // case, we use the size of the RootArgTys to select the number of
74      // arguments.
75      return (1 << RootArgTys.size()) - 1;
76    }
77
78    bccAssert(ExportForEachMetadata->getNumOperands() > 0);
79
80    // We only handle the case for legacy root() functions here, so this is
81    // hard-coded to look at only the first such function.
82    llvm::MDNode *SigNode = ExportForEachMetadata->getOperand(0);
83    if (SigNode != NULL && SigNode->getNumOperands() == 1) {
84      llvm::Value *SigVal = SigNode->getOperand(0);
85      if (SigVal->getValueID() == llvm::Value::MDStringVal) {
86        llvm::StringRef SigString =
87            static_cast<llvm::MDString*>(SigVal)->getString();
88        uint32_t Signature = 0;
89        if (SigString.getAsInteger(10, Signature)) {
90          ALOGE("Non-integer signature value '%s'", SigString.str().c_str());
91          return 0;
92        }
93        return Signature;
94      }
95    }
96
97    return 0;
98  }
99
100  // Get the actual value we should use to step through an allocation.
101  // TD - Target Data size/layout information.
102  // T - Type of allocation (should be a pointer).
103  // OrigStep - Original step increment (root.expand() input from driver).
104  llvm::Value *getStepValue(llvm::TargetData *TD, llvm::Type *T,
105                            llvm::Value *OrigStep) {
106    bccAssert(TD);
107    bccAssert(T);
108    bccAssert(OrigStep);
109    llvm::PointerType *PT = llvm::dyn_cast<llvm::PointerType>(T);
110    llvm::Type *VoidPtrTy = llvm::Type::getInt8PtrTy(*C);
111    if (mEnableStepOpt && T != VoidPtrTy && PT) {
112      llvm::Type *ET = PT->getElementType();
113      uint64_t ETSize = TD->getTypeStoreSize(ET);
114      llvm::Type *Int32Ty = llvm::Type::getInt32Ty(*C);
115      return llvm::ConstantInt::get(Int32Ty, ETSize);
116    } else {
117      return OrigStep;
118    }
119  }
120
121  static bool hasIn(uint32_t Signature) {
122    return Signature & 1;
123  }
124
125  static bool hasOut(uint32_t Signature) {
126    return Signature & 2;
127  }
128
129  static bool hasUsrData(uint32_t Signature) {
130    return Signature & 4;
131  }
132
133  static bool hasX(uint32_t Signature) {
134    return Signature & 8;
135  }
136
137  static bool hasY(uint32_t Signature) {
138    return Signature & 16;
139  }
140
141public:
142  RSForEachExpandPass(const RSInfo::ExportForeachFuncListTy &pForeachFuncs,
143                      bool pEnableStepOpt)
144      : ModulePass(ID), M(NULL), C(NULL), mFuncs(pForeachFuncs),
145        mEnableStepOpt(pEnableStepOpt) {
146  }
147
148  /* Performs the actual optimization on a selected function. On success, the
149   * Module will contain a new function of the name "<NAME>.expand" that
150   * invokes <NAME>() in a loop with the appropriate parameters.
151   */
152  bool ExpandFunction(llvm::Function *F, uint32_t Signature) {
153    ALOGV("Expanding ForEach-able Function %s", F->getName().str().c_str());
154
155    if (!Signature) {
156      Signature = getRootSignature(F);
157      if (!Signature) {
158        // We couldn't determine how to expand this function based on its
159        // function signature.
160        return false;
161      }
162    }
163
164    llvm::TargetData TD(M);
165
166    llvm::Type *VoidPtrTy = llvm::Type::getInt8PtrTy(*C);
167    llvm::Type *Int32Ty = llvm::Type::getInt32Ty(*C);
168    llvm::Type *SizeTy = Int32Ty;
169
170    /* Defined in frameworks/base/libs/rs/rs_hal.h:
171     *
172     * struct RsForEachStubParamStruct {
173     *   const void *in;
174     *   void *out;
175     *   const void *usr;
176     *   size_t usr_len;
177     *   uint32_t x;
178     *   uint32_t y;
179     *   uint32_t z;
180     *   uint32_t lod;
181     *   enum RsAllocationCubemapFace face;
182     *   uint32_t ar[16];
183     * };
184     */
185    llvm::SmallVector<llvm::Type*, 9> StructTys;
186    StructTys.push_back(VoidPtrTy);  // const void *in
187    StructTys.push_back(VoidPtrTy);  // void *out
188    StructTys.push_back(VoidPtrTy);  // const void *usr
189    StructTys.push_back(SizeTy);     // size_t usr_len
190    StructTys.push_back(Int32Ty);    // uint32_t x
191    StructTys.push_back(Int32Ty);    // uint32_t y
192    StructTys.push_back(Int32Ty);    // uint32_t z
193    StructTys.push_back(Int32Ty);    // uint32_t lod
194    StructTys.push_back(Int32Ty);    // enum RsAllocationCubemapFace
195    StructTys.push_back(llvm::ArrayType::get(Int32Ty, 16));  // uint32_t ar[16]
196
197    llvm::Type *ForEachStubPtrTy = llvm::StructType::create(
198        StructTys, "RsForEachStubParamStruct")->getPointerTo();
199
200    /* Create the function signature for our expanded function.
201     * void (const RsForEachStubParamStruct *p, uint32_t x1, uint32_t x2,
202     *       uint32_t instep, uint32_t outstep)
203     */
204    llvm::SmallVector<llvm::Type*, 8> ParamTys;
205    ParamTys.push_back(ForEachStubPtrTy);  // const RsForEachStubParamStruct *p
206    ParamTys.push_back(Int32Ty);           // uint32_t x1
207    ParamTys.push_back(Int32Ty);           // uint32_t x2
208    ParamTys.push_back(Int32Ty);           // uint32_t instep
209    ParamTys.push_back(Int32Ty);           // uint32_t outstep
210
211    llvm::FunctionType *FT =
212        llvm::FunctionType::get(llvm::Type::getVoidTy(*C), ParamTys, false);
213    llvm::Function *ExpandedFunc =
214        llvm::Function::Create(FT,
215                               llvm::GlobalValue::ExternalLinkage,
216                               F->getName() + ".expand", M);
217
218    // Create and name the actual arguments to this expanded function.
219    llvm::SmallVector<llvm::Argument*, 8> ArgVec;
220    for (llvm::Function::arg_iterator B = ExpandedFunc->arg_begin(),
221                                      E = ExpandedFunc->arg_end();
222         B != E;
223         ++B) {
224      ArgVec.push_back(B);
225    }
226
227    if (ArgVec.size() != 5) {
228      ALOGE("Incorrect number of arguments to function: %zu",
229            ArgVec.size());
230      return false;
231    }
232    llvm::Value *Arg_p = ArgVec[0];
233    llvm::Value *Arg_x1 = ArgVec[1];
234    llvm::Value *Arg_x2 = ArgVec[2];
235    llvm::Value *Arg_instep = ArgVec[3];
236    llvm::Value *Arg_outstep = ArgVec[4];
237
238    Arg_p->setName("p");
239    Arg_x1->setName("x1");
240    Arg_x2->setName("x2");
241    Arg_instep->setName("arg_instep");
242    Arg_outstep->setName("arg_outstep");
243
244    llvm::Value *InStep = NULL;
245    llvm::Value *OutStep = NULL;
246
247    // Construct the actual function body.
248    llvm::BasicBlock *Begin =
249        llvm::BasicBlock::Create(*C, "Begin", ExpandedFunc);
250    llvm::IRBuilder<> Builder(Begin);
251
252    // uint32_t X = x1;
253    llvm::AllocaInst *AX = Builder.CreateAlloca(Int32Ty, 0, "AX");
254    Builder.CreateStore(Arg_x1, AX);
255
256    // Collect and construct the arguments for the kernel().
257    // Note that we load any loop-invariant arguments before entering the Loop.
258    llvm::Function::arg_iterator Args = F->arg_begin();
259
260    llvm::Type *InTy = NULL;
261    llvm::AllocaInst *AIn = NULL;
262    if (hasIn(Signature)) {
263      InTy = Args->getType();
264      AIn = Builder.CreateAlloca(InTy, 0, "AIn");
265      InStep = getStepValue(&TD, InTy, Arg_instep);
266      InStep->setName("instep");
267      Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
268          Builder.CreateStructGEP(Arg_p, 0)), InTy), AIn);
269      Args++;
270    }
271
272    llvm::Type *OutTy = NULL;
273    llvm::AllocaInst *AOut = NULL;
274    if (hasOut(Signature)) {
275      OutTy = Args->getType();
276      AOut = Builder.CreateAlloca(OutTy, 0, "AOut");
277      OutStep = getStepValue(&TD, OutTy, Arg_outstep);
278      OutStep->setName("outstep");
279      Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
280          Builder.CreateStructGEP(Arg_p, 1)), OutTy), AOut);
281      Args++;
282    }
283
284    llvm::Value *UsrData = NULL;
285    if (hasUsrData(Signature)) {
286      llvm::Type *UsrDataTy = Args->getType();
287      UsrData = Builder.CreatePointerCast(Builder.CreateLoad(
288          Builder.CreateStructGEP(Arg_p, 2)), UsrDataTy);
289      UsrData->setName("UsrData");
290      Args++;
291    }
292
293    if (hasX(Signature)) {
294      Args++;
295    }
296
297    llvm::Value *Y = NULL;
298    if (hasY(Signature)) {
299      Y = Builder.CreateLoad(Builder.CreateStructGEP(Arg_p, 5), "Y");
300      Args++;
301    }
302
303    bccAssert(Args == F->arg_end());
304
305    llvm::BasicBlock *Loop = llvm::BasicBlock::Create(*C, "Loop", ExpandedFunc);
306    llvm::BasicBlock *Exit = llvm::BasicBlock::Create(*C, "Exit", ExpandedFunc);
307
308    // if (x1 < x2) goto Loop; else goto Exit;
309    llvm::Value *Cond = Builder.CreateICmpSLT(Arg_x1, Arg_x2);
310    Builder.CreateCondBr(Cond, Loop, Exit);
311
312    // Loop:
313    Builder.SetInsertPoint(Loop);
314
315    // Populate the actual call to kernel().
316    llvm::SmallVector<llvm::Value*, 8> RootArgs;
317
318    llvm::Value *In = NULL;
319    llvm::Value *Out = NULL;
320
321    if (AIn) {
322      In = Builder.CreateLoad(AIn, "In");
323      RootArgs.push_back(In);
324    }
325
326    if (AOut) {
327      Out = Builder.CreateLoad(AOut, "Out");
328      RootArgs.push_back(Out);
329    }
330
331    if (UsrData) {
332      RootArgs.push_back(UsrData);
333    }
334
335    // We always have to load X, since it is used to iterate through the loop.
336    llvm::Value *X = Builder.CreateLoad(AX, "X");
337    if (hasX(Signature)) {
338      RootArgs.push_back(X);
339    }
340
341    if (Y) {
342      RootArgs.push_back(Y);
343    }
344
345    Builder.CreateCall(F, RootArgs);
346
347    if (In) {
348      // In += instep
349      llvm::Value *NewIn = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
350          Builder.CreatePtrToInt(In, Int32Ty), InStep), InTy);
351      Builder.CreateStore(NewIn, AIn);
352    }
353
354    if (Out) {
355      // Out += outstep
356      llvm::Value *NewOut = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
357          Builder.CreatePtrToInt(Out, Int32Ty), OutStep), OutTy);
358      Builder.CreateStore(NewOut, AOut);
359    }
360
361    // X++;
362    llvm::Value *XPlusOne =
363        Builder.CreateNUWAdd(X, llvm::ConstantInt::get(Int32Ty, 1));
364    Builder.CreateStore(XPlusOne, AX);
365
366    // If (X < x2) goto Loop; else goto Exit;
367    Cond = Builder.CreateICmpSLT(XPlusOne, Arg_x2);
368    Builder.CreateCondBr(Cond, Loop, Exit);
369
370    // Exit:
371    Builder.SetInsertPoint(Exit);
372    Builder.CreateRetVoid();
373
374    return true;
375  }
376
377  virtual bool runOnModule(llvm::Module &M) {
378    bool Changed = false;
379    this->M = &M;
380    C = &M.getContext();
381
382    for (RSInfo::ExportForeachFuncListTy::const_iterator
383             func_iter = mFuncs.begin(), func_end = mFuncs.end();
384         func_iter != func_end; func_iter++) {
385      const char *name = func_iter->first;
386      uint32_t signature = func_iter->second;
387      llvm::Function *kernel = M.getFunction(name);
388      if (kernel && kernel->getReturnType()->isVoidTy()) {
389        Changed |= ExpandFunction(kernel, signature);
390      }
391    }
392
393    return Changed;
394  }
395
396  virtual const char *getPassName() const {
397    return "ForEach-able Function Expansion";
398  }
399
400}; // end RSForEachExpandPass
401
402} // end anonymous namespace
403
404char RSForEachExpandPass::ID = 0;
405
406namespace bcc {
407
408llvm::ModulePass *
409createRSForEachExpandPass(const RSInfo::ExportForeachFuncListTy &pForeachFuncs,
410                          bool pEnableStepOpt){
411  return new RSForEachExpandPass(pForeachFuncs, pEnableStepOpt);
412}
413
414} // end namespace bcc
415