RSForEachExpand.cpp revision e4a73f68e1b338881adf682c458e0b4b92ecd91e
1/*
2 * Copyright 2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "bcc/Assert.h"
18#include "bcc/Renderscript/RSTransforms.h"
19
20#include <cstdlib>
21
22#include <llvm/IR/DerivedTypes.h>
23#include <llvm/IR/Function.h>
24#include <llvm/IR/Instructions.h>
25#include <llvm/IR/IRBuilder.h>
26#include <llvm/IR/Module.h>
27#include <llvm/Pass.h>
28#include <llvm/Support/raw_ostream.h>
29#include <llvm/IR/DataLayout.h>
30#include <llvm/IR/Type.h>
31#include <llvm/Transforms/Utils/BasicBlockUtils.h>
32
33#include "bcc/Config/Config.h"
34#include "bcc/Renderscript/RSInfo.h"
35#include "bcc/Support/Log.h"
36
37using namespace bcc;
38
39namespace {
40
41/* RSForEachExpandPass - This pass operates on functions that are able to be
42 * called via rsForEach() or "foreach_<NAME>". We create an inner loop for the
43 * ForEach-able function to be invoked over the appropriate data cells of the
44 * input/output allocations (adjusting other relevant parameters as we go). We
45 * support doing this for any ForEach-able compute kernels. The new function
46 * name is the original function name followed by ".expand". Note that we
47 * still generate code for the original function.
48 */
49class RSForEachExpandPass : public llvm::ModulePass {
50private:
51  static char ID;
52
53  llvm::Module *M;
54  llvm::LLVMContext *C;
55
56  const RSInfo::ExportForeachFuncListTy &mFuncs;
57
58  // Turns on optimization of allocation stride values.
59  bool mEnableStepOpt;
60
61  uint32_t getRootSignature(llvm::Function *F) {
62    const llvm::NamedMDNode *ExportForEachMetadata =
63        M->getNamedMetadata("#rs_export_foreach");
64
65    if (!ExportForEachMetadata) {
66      llvm::SmallVector<llvm::Type*, 8> RootArgTys;
67      for (llvm::Function::arg_iterator B = F->arg_begin(),
68                                        E = F->arg_end();
69           B != E;
70           ++B) {
71        RootArgTys.push_back(B->getType());
72      }
73
74      // For pre-ICS bitcode, we may not have signature information. In that
75      // case, we use the size of the RootArgTys to select the number of
76      // arguments.
77      return (1 << RootArgTys.size()) - 1;
78    }
79
80    if (ExportForEachMetadata->getNumOperands() == 0) {
81      return 0;
82    }
83
84    bccAssert(ExportForEachMetadata->getNumOperands() > 0);
85
86    // We only handle the case for legacy root() functions here, so this is
87    // hard-coded to look at only the first such function.
88    llvm::MDNode *SigNode = ExportForEachMetadata->getOperand(0);
89    if (SigNode != NULL && SigNode->getNumOperands() == 1) {
90      llvm::Value *SigVal = SigNode->getOperand(0);
91      if (SigVal->getValueID() == llvm::Value::MDStringVal) {
92        llvm::StringRef SigString =
93            static_cast<llvm::MDString*>(SigVal)->getString();
94        uint32_t Signature = 0;
95        if (SigString.getAsInteger(10, Signature)) {
96          ALOGE("Non-integer signature value '%s'", SigString.str().c_str());
97          return 0;
98        }
99        return Signature;
100      }
101    }
102
103    return 0;
104  }
105
106  // Get the actual value we should use to step through an allocation.
107  // DL - Target Data size/layout information.
108  // T - Type of allocation (should be a pointer).
109  // OrigStep - Original step increment (root.expand() input from driver).
110  llvm::Value *getStepValue(llvm::DataLayout *DL, llvm::Type *T,
111                            llvm::Value *OrigStep) {
112    bccAssert(DL);
113    bccAssert(T);
114    bccAssert(OrigStep);
115    llvm::PointerType *PT = llvm::dyn_cast<llvm::PointerType>(T);
116    llvm::Type *VoidPtrTy = llvm::Type::getInt8PtrTy(*C);
117    if (mEnableStepOpt && T != VoidPtrTy && PT) {
118      llvm::Type *ET = PT->getElementType();
119      uint64_t ETSize = DL->getTypeAllocSize(ET);
120      llvm::Type *Int32Ty = llvm::Type::getInt32Ty(*C);
121      return llvm::ConstantInt::get(Int32Ty, ETSize);
122    } else {
123      return OrigStep;
124    }
125  }
126
127  static bool hasIn(uint32_t Signature) {
128    return Signature & 0x01;
129  }
130
131  static bool hasOut(uint32_t Signature) {
132    return Signature & 0x02;
133  }
134
135  static bool hasUsrData(uint32_t Signature) {
136    return Signature & 0x04;
137  }
138
139  static bool hasX(uint32_t Signature) {
140    return Signature & 0x08;
141  }
142
143  static bool hasY(uint32_t Signature) {
144    return Signature & 0x10;
145  }
146
147  static bool isKernel(uint32_t Signature) {
148    return Signature & 0x20;
149  }
150
151  /// @brief Returns the type of the ForEach stub parameter structure.
152  ///
153  /// Renderscript uses a single structure in which all parameters are passed
154  /// to keep the signature of the expanded function independent of the
155  /// parameters passed to it.
156  llvm::Type *getForeachStubTy() {
157    llvm::Type *VoidPtrTy = llvm::Type::getInt8PtrTy(*C);
158    llvm::Type *Int32Ty = llvm::Type::getInt32Ty(*C);
159    llvm::Type *SizeTy = Int32Ty;
160    /* Defined in frameworks/base/libs/rs/rs_hal.h:
161     *
162     * struct RsForEachStubParamStruct {
163     *   const void *in;
164     *   void *out;
165     *   const void *usr;
166     *   size_t usr_len;
167     *   uint32_t x;
168     *   uint32_t y;
169     *   uint32_t z;
170     *   uint32_t lod;
171     *   enum RsAllocationCubemapFace face;
172     *   uint32_t ar[16];
173     * };
174     */
175    llvm::SmallVector<llvm::Type*, 9> StructTys;
176    StructTys.push_back(VoidPtrTy);  // const void *in
177    StructTys.push_back(VoidPtrTy);  // void *out
178    StructTys.push_back(VoidPtrTy);  // const void *usr
179    StructTys.push_back(SizeTy);     // size_t usr_len
180    StructTys.push_back(Int32Ty);    // uint32_t x
181    StructTys.push_back(Int32Ty);    // uint32_t y
182    StructTys.push_back(Int32Ty);    // uint32_t z
183    StructTys.push_back(Int32Ty);    // uint32_t lod
184    StructTys.push_back(Int32Ty);    // enum RsAllocationCubemapFace
185    StructTys.push_back(llvm::ArrayType::get(Int32Ty, 16));  // uint32_t ar[16]
186
187    return llvm::StructType::create(StructTys, "RsForEachStubParamStruct");
188  }
189
190  /// @brief Create skeleton of the expanded function.
191  ///
192  /// This creates a function with the following signature:
193  ///
194  ///   void (const RsForEachStubParamStruct *p, uint32_t x1, uint32_t x2,
195  ///         uint32_t instep, uint32_t outstep)
196  ///
197  llvm::Function *createEmptyExpandedFunction(llvm::StringRef OldName) {
198    llvm::Type *ForEachStubPtrTy = getForeachStubTy()->getPointerTo();
199    llvm::Type *Int32Ty = llvm::Type::getInt32Ty(*C);
200
201    llvm::SmallVector<llvm::Type*, 8> ParamTys;
202    ParamTys.push_back(ForEachStubPtrTy);  // const RsForEachStubParamStruct *p
203    ParamTys.push_back(Int32Ty);           // uint32_t x1
204    ParamTys.push_back(Int32Ty);           // uint32_t x2
205    ParamTys.push_back(Int32Ty);           // uint32_t instep
206    ParamTys.push_back(Int32Ty);           // uint32_t outstep
207
208    llvm::FunctionType *FT =
209        llvm::FunctionType::get(llvm::Type::getVoidTy(*C), ParamTys, false);
210    llvm::Function *F =
211        llvm::Function::Create(FT, llvm::GlobalValue::ExternalLinkage,
212                               OldName + ".expand", M);
213
214    llvm::Function::arg_iterator AI = F->arg_begin();
215
216    AI->setName("p");
217    AI++;
218    AI->setName("x1");
219    AI++;
220    AI->setName("x2");
221    AI++;
222    AI->setName("arg_instep");
223    AI++;
224    AI->setName("arg_outstep");
225    AI++;
226
227    assert(AI == F->arg_end());
228
229    llvm::BasicBlock *Begin = llvm::BasicBlock::Create(*C, "Begin", F);
230    llvm::IRBuilder<> Builder(Begin);
231    Builder.CreateRetVoid();
232
233    return F;
234  }
235
236  /// @brief Create an empty loop
237  ///
238  /// Create a loop of the form:
239  ///
240  /// for (i = LowerBound; i < UpperBound; i++)
241  ///   ;
242  ///
243  /// After the loop has been created, the builder is set such that
244  /// instructions can be added to the loop body.
245  ///
246  /// @param Builder The builder to use to build this loop. The current
247  ///                position of the builder is the position the loop
248  ///                will be inserted.
249  /// @param LowerBound The first value of the loop iterator
250  /// @param UpperBound The maximal value of the loop iterator
251  /// @param LoopIV A reference that will be set to the loop iterator.
252  /// @return The BasicBlock that will be executed after the loop.
253  llvm::BasicBlock *createLoop(llvm::IRBuilder<> &Builder,
254                               llvm::Value *LowerBound,
255                               llvm::Value *UpperBound,
256                               llvm::PHINode **LoopIV) {
257    assert(LowerBound->getType() == UpperBound->getType());
258
259    llvm::BasicBlock *CondBB, *AfterBB, *HeaderBB;
260    llvm::Value *Cond, *IVNext;
261    llvm::PHINode *IV;
262
263    CondBB = Builder.GetInsertBlock();
264    AfterBB = llvm::SplitBlock(CondBB, Builder.GetInsertPoint(), this);
265    HeaderBB = llvm::BasicBlock::Create(*C, "Loop", CondBB->getParent());
266
267    // if (LowerBound < Upperbound)
268    //   goto LoopHeader
269    // else
270    //   goto AfterBB
271    CondBB->getTerminator()->eraseFromParent();
272    Builder.SetInsertPoint(CondBB);
273    Cond = Builder.CreateICmpSLT(LowerBound, UpperBound);
274    Builder.CreateCondBr(Cond, HeaderBB, AfterBB);
275
276    // iv = PHI [CondBB -> LowerBound], [LoopHeader -> NextIV ]
277    // iv.next = iv + 1
278    // if (iv.next < Upperbound)
279    //   goto LoopHeader
280    // else
281    //   goto AfterBB
282    Builder.SetInsertPoint(HeaderBB);
283    IV = Builder.CreatePHI(LowerBound->getType(), 2, "X");
284    IV->addIncoming(LowerBound, CondBB);
285    IVNext = Builder.CreateNUWAdd(IV, Builder.getInt32(1));
286    IV->addIncoming(IVNext, HeaderBB);
287    Cond = Builder.CreateICmpSLT(IVNext, UpperBound);
288    Builder.CreateCondBr(Cond, HeaderBB, AfterBB);
289    AfterBB->setName("Exit");
290    Builder.SetInsertPoint(HeaderBB->getFirstNonPHI());
291    *LoopIV = IV;
292    return AfterBB;
293  }
294
295public:
296  RSForEachExpandPass(const RSInfo::ExportForeachFuncListTy &pForeachFuncs,
297                      bool pEnableStepOpt)
298      : ModulePass(ID), M(NULL), C(NULL), mFuncs(pForeachFuncs),
299        mEnableStepOpt(pEnableStepOpt) {
300  }
301
302  /* Performs the actual optimization on a selected function. On success, the
303   * Module will contain a new function of the name "<NAME>.expand" that
304   * invokes <NAME>() in a loop with the appropriate parameters.
305   */
306  bool ExpandFunction(llvm::Function *F, uint32_t Signature) {
307    ALOGV("Expanding ForEach-able Function %s", F->getName().str().c_str());
308
309    if (!Signature) {
310      Signature = getRootSignature(F);
311      if (!Signature) {
312        // We couldn't determine how to expand this function based on its
313        // function signature.
314        return false;
315      }
316    }
317
318    llvm::DataLayout DL(M);
319
320    llvm::Type *Int32Ty = llvm::Type::getInt32Ty(*C);
321    llvm::Function *ExpandedFunc = createEmptyExpandedFunction(F->getName());
322
323    // Create and name the actual arguments to this expanded function.
324    llvm::SmallVector<llvm::Argument*, 8> ArgVec;
325    for (llvm::Function::arg_iterator B = ExpandedFunc->arg_begin(),
326                                      E = ExpandedFunc->arg_end();
327         B != E;
328         ++B) {
329      ArgVec.push_back(B);
330    }
331
332    if (ArgVec.size() != 5) {
333      ALOGE("Incorrect number of arguments to function: %zu",
334            ArgVec.size());
335      return false;
336    }
337    llvm::Value *Arg_p = ArgVec[0];
338    llvm::Value *Arg_x1 = ArgVec[1];
339    llvm::Value *Arg_x2 = ArgVec[2];
340    llvm::Value *Arg_instep = ArgVec[3];
341    llvm::Value *Arg_outstep = ArgVec[4];
342
343    llvm::Value *InStep = NULL;
344    llvm::Value *OutStep = NULL;
345
346    // Construct the actual function body.
347    llvm::IRBuilder<> Builder(ExpandedFunc->getEntryBlock().begin());
348
349    // Collect and construct the arguments for the kernel().
350    // Note that we load any loop-invariant arguments before entering the Loop.
351    llvm::Function::arg_iterator Args = F->arg_begin();
352
353    llvm::Type *InTy = NULL;
354    llvm::AllocaInst *AIn = NULL;
355    if (hasIn(Signature)) {
356      InTy = Args->getType();
357      AIn = Builder.CreateAlloca(InTy, 0, "AIn");
358      InStep = getStepValue(&DL, InTy, Arg_instep);
359      InStep->setName("instep");
360      Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
361          Builder.CreateStructGEP(Arg_p, 0)), InTy), AIn);
362      Args++;
363    }
364
365    llvm::Type *OutTy = NULL;
366    llvm::AllocaInst *AOut = NULL;
367    if (hasOut(Signature)) {
368      OutTy = Args->getType();
369      AOut = Builder.CreateAlloca(OutTy, 0, "AOut");
370      OutStep = getStepValue(&DL, OutTy, Arg_outstep);
371      OutStep->setName("outstep");
372      Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
373          Builder.CreateStructGEP(Arg_p, 1)), OutTy), AOut);
374      Args++;
375    }
376
377    llvm::Value *UsrData = NULL;
378    if (hasUsrData(Signature)) {
379      llvm::Type *UsrDataTy = Args->getType();
380      UsrData = Builder.CreatePointerCast(Builder.CreateLoad(
381          Builder.CreateStructGEP(Arg_p, 2)), UsrDataTy);
382      UsrData->setName("UsrData");
383      Args++;
384    }
385
386    if (hasX(Signature)) {
387      Args++;
388    }
389
390    llvm::Value *Y = NULL;
391    if (hasY(Signature)) {
392      Y = Builder.CreateLoad(Builder.CreateStructGEP(Arg_p, 5), "Y");
393      Args++;
394    }
395
396    bccAssert(Args == F->arg_end());
397
398    llvm::PHINode *IV;
399    createLoop(Builder, Arg_x1, Arg_x2, &IV);
400
401    // Populate the actual call to kernel().
402    llvm::SmallVector<llvm::Value*, 8> RootArgs;
403
404    llvm::Value *InPtr = NULL;
405    llvm::Value *OutPtr = NULL;
406
407    if (AIn) {
408      InPtr = Builder.CreateLoad(AIn, "InPtr");
409      RootArgs.push_back(InPtr);
410    }
411
412    if (AOut) {
413      OutPtr = Builder.CreateLoad(AOut, "OutPtr");
414      RootArgs.push_back(OutPtr);
415    }
416
417    if (UsrData) {
418      RootArgs.push_back(UsrData);
419    }
420
421    llvm::Value *X = IV;
422    if (hasX(Signature)) {
423      RootArgs.push_back(X);
424    }
425
426    if (Y) {
427      RootArgs.push_back(Y);
428    }
429
430    Builder.CreateCall(F, RootArgs);
431
432    if (InPtr) {
433      // InPtr += instep
434      llvm::Value *NewIn = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
435          Builder.CreatePtrToInt(InPtr, Int32Ty), InStep), InTy);
436      Builder.CreateStore(NewIn, AIn);
437    }
438
439    if (OutPtr) {
440      // OutPtr += outstep
441      llvm::Value *NewOut = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
442          Builder.CreatePtrToInt(OutPtr, Int32Ty), OutStep), OutTy);
443      Builder.CreateStore(NewOut, AOut);
444    }
445
446    return true;
447  }
448
449  /* Expand a pass-by-value kernel.
450   */
451  bool ExpandKernel(llvm::Function *F, uint32_t Signature) {
452    bccAssert(isKernel(Signature));
453    ALOGV("Expanding kernel Function %s", F->getName().str().c_str());
454
455    // TODO: Refactor this to share functionality with ExpandFunction.
456    llvm::DataLayout DL(M);
457
458    llvm::Type *Int32Ty = llvm::Type::getInt32Ty(*C);
459    llvm::Function *ExpandedFunc = createEmptyExpandedFunction(F->getName());
460
461    // Create and name the actual arguments to this expanded function.
462    llvm::SmallVector<llvm::Argument*, 8> ArgVec;
463    for (llvm::Function::arg_iterator B = ExpandedFunc->arg_begin(),
464                                      E = ExpandedFunc->arg_end();
465         B != E;
466         ++B) {
467      ArgVec.push_back(B);
468    }
469
470    if (ArgVec.size() != 5) {
471      ALOGE("Incorrect number of arguments to function: %zu",
472            ArgVec.size());
473      return false;
474    }
475    llvm::Value *Arg_p = ArgVec[0];
476    llvm::Value *Arg_x1 = ArgVec[1];
477    llvm::Value *Arg_x2 = ArgVec[2];
478    llvm::Value *Arg_instep = ArgVec[3];
479    llvm::Value *Arg_outstep = ArgVec[4];
480
481    llvm::Value *InStep = NULL;
482    llvm::Value *OutStep = NULL;
483
484    // Construct the actual function body.
485    llvm::IRBuilder<> Builder(ExpandedFunc->getEntryBlock().begin());
486
487    // Collect and construct the arguments for the kernel().
488    // Note that we load any loop-invariant arguments before entering the Loop.
489    llvm::Function::arg_iterator Args = F->arg_begin();
490
491    llvm::Type *OutTy = NULL;
492    llvm::AllocaInst *AOut = NULL;
493    bool PassOutByReference = false;
494    if (hasOut(Signature)) {
495      llvm::Type *OutBaseTy = F->getReturnType();
496      if (OutBaseTy->isVoidTy()) {
497        PassOutByReference = true;
498        OutTy = Args->getType();
499        Args++;
500      } else {
501        OutTy = OutBaseTy->getPointerTo();
502        // We don't increment Args, since we are using the actual return type.
503      }
504      AOut = Builder.CreateAlloca(OutTy, 0, "AOut");
505      OutStep = getStepValue(&DL, OutTy, Arg_outstep);
506      OutStep->setName("outstep");
507      Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
508          Builder.CreateStructGEP(Arg_p, 1)), OutTy), AOut);
509    }
510
511    llvm::Type *InBaseTy = NULL;
512    llvm::Type *InTy = NULL;
513    llvm::AllocaInst *AIn = NULL;
514    if (hasIn(Signature)) {
515      InBaseTy = Args->getType();
516      InTy =InBaseTy->getPointerTo();
517      AIn = Builder.CreateAlloca(InTy, 0, "AIn");
518      InStep = getStepValue(&DL, InTy, Arg_instep);
519      InStep->setName("instep");
520      Builder.CreateStore(Builder.CreatePointerCast(Builder.CreateLoad(
521          Builder.CreateStructGEP(Arg_p, 0)), InTy), AIn);
522      Args++;
523    }
524
525    // No usrData parameter on kernels.
526    bccAssert(!hasUsrData(Signature));
527
528    if (hasX(Signature)) {
529      Args++;
530    }
531
532    llvm::Value *Y = NULL;
533    if (hasY(Signature)) {
534      Y = Builder.CreateLoad(Builder.CreateStructGEP(Arg_p, 5), "Y");
535      Args++;
536    }
537
538    bccAssert(Args == F->arg_end());
539
540    llvm::PHINode *IV;
541    createLoop(Builder, Arg_x1, Arg_x2, &IV);
542
543    // Populate the actual call to kernel().
544    llvm::SmallVector<llvm::Value*, 8> RootArgs;
545
546    llvm::Value *InPtr = NULL;
547    llvm::Value *In = NULL;
548    llvm::Value *OutPtr = NULL;
549
550    if (PassOutByReference) {
551      OutPtr = Builder.CreateLoad(AOut, "OutPtr");
552      RootArgs.push_back(OutPtr);
553    }
554
555    if (AIn) {
556      InPtr = Builder.CreateLoad(AIn, "InPtr");
557      In = Builder.CreateLoad(InPtr, "In");
558      RootArgs.push_back(In);
559    }
560
561    llvm::Value *X = IV;
562    if (hasX(Signature)) {
563      RootArgs.push_back(X);
564    }
565
566    if (Y) {
567      RootArgs.push_back(Y);
568    }
569
570    llvm::Value *RetVal = Builder.CreateCall(F, RootArgs);
571
572    if (AOut && !PassOutByReference) {
573      OutPtr = Builder.CreateLoad(AOut, "OutPtr");
574      Builder.CreateStore(RetVal, OutPtr);
575    }
576
577    if (InPtr) {
578      // InPtr += instep
579      llvm::Value *NewIn = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
580          Builder.CreatePtrToInt(InPtr, Int32Ty), InStep), InTy);
581      Builder.CreateStore(NewIn, AIn);
582    }
583
584    if (OutPtr) {
585      // OutPtr += outstep
586      llvm::Value *NewOut = Builder.CreateIntToPtr(Builder.CreateNUWAdd(
587          Builder.CreatePtrToInt(OutPtr, Int32Ty), OutStep), OutTy);
588      Builder.CreateStore(NewOut, AOut);
589    }
590
591    return true;
592  }
593
594  virtual bool runOnModule(llvm::Module &M) {
595    bool Changed = false;
596    this->M = &M;
597    C = &M.getContext();
598
599    for (RSInfo::ExportForeachFuncListTy::const_iterator
600             func_iter = mFuncs.begin(), func_end = mFuncs.end();
601         func_iter != func_end; func_iter++) {
602      const char *name = func_iter->first;
603      uint32_t signature = func_iter->second;
604      llvm::Function *kernel = M.getFunction(name);
605      if (kernel && isKernel(signature)) {
606        Changed |= ExpandKernel(kernel, signature);
607      }
608      else if (kernel && kernel->getReturnType()->isVoidTy()) {
609        Changed |= ExpandFunction(kernel, signature);
610      }
611    }
612
613    return Changed;
614  }
615
616  virtual const char *getPassName() const {
617    return "ForEach-able Function Expansion";
618  }
619
620}; // end RSForEachExpandPass
621
622} // end anonymous namespace
623
624char RSForEachExpandPass::ID = 0;
625
626namespace bcc {
627
628llvm::ModulePass *
629createRSForEachExpandPass(const RSInfo::ExportForeachFuncListTy &pForeachFuncs,
630                          bool pEnableStepOpt){
631  return new RSForEachExpandPass(pForeachFuncs, pEnableStepOpt);
632}
633
634} // end namespace bcc
635