1/*
2 * Copyright 2014, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Assert.h"
18#include "Log.h"
19#include "RSTransforms.h"
20#include "RSUtils.h"
21#include "rsDefines.h"
22
23#include "bcc/Config.h"
24#include "bcinfo/MetadataExtractor.h"
25
26#include <cstdlib>
27
28#include <llvm/IR/DataLayout.h>
29#include <llvm/IR/DerivedTypes.h>
30#include <llvm/IR/Function.h>
31#include <llvm/IR/Instructions.h>
32#include <llvm/IR/IRBuilder.h>
33#include <llvm/IR/MDBuilder.h>
34#include <llvm/IR/Module.h>
35#include <llvm/IR/Type.h>
36#include <llvm/Pass.h>
37#include <llvm/Support/raw_ostream.h>
38#include <llvm/Transforms/Utils/BasicBlockUtils.h>
39
40using namespace bcc;
41
42namespace {
43
44class RSInvokeHelperPass : public llvm::FunctionPass {
45private:
46  static char ID;
47
48  llvm::StructType* rsAllocationType;
49  llvm::StructType* rsElementType;
50  llvm::StructType* rsSamplerType;
51  llvm::StructType* rsScriptType;
52  llvm::StructType* rsTypeType;
53
54  llvm::Constant* rsAllocationSetObj;
55  llvm::Constant* rsElementSetObj;
56  llvm::Constant* rsSamplerSetObj;
57  llvm::Constant* rsScriptSetObj;
58  llvm::Constant* rsTypeSetObj;
59
60
61public:
62  RSInvokeHelperPass()
63    : FunctionPass(ID) {
64
65    }
66
67  virtual void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
68    // This pass does not use any other analysis passes, but it does
69    // modify the existing functions in the module (thus altering the CFG).
70  }
71
72  virtual bool doInitialization(llvm::Module &M) override {
73    llvm::FunctionType * SetObjType = nullptr;
74    llvm::SmallVector<llvm::Type*, 4> rsBaseObj;
75    rsBaseObj.append(4, llvm::Type::getInt64PtrTy(M.getContext()));
76
77    rsAllocationType = llvm::StructType::create(rsBaseObj, kAllocationTypeName);
78    rsElementType = llvm::StructType::create(rsBaseObj, kElementTypeName);
79    rsSamplerType = llvm::StructType::create(rsBaseObj, kSamplerTypeName);
80    rsScriptType = llvm::StructType::create(rsBaseObj, kScriptTypeName);
81    rsTypeType = llvm::StructType::create(rsBaseObj, kTypeTypeName);
82
83    llvm::SmallVector<llvm::Value*, 1> SetObjParams;
84    llvm::SmallVector<llvm::Type*, 2> SetObjTypeParams;
85
86    // get rsSetObject(rs_allocation*, rs_allocation*)
87    // according to AArch64 calling convention, these are both pointers because of the size of the struct
88    SetObjTypeParams.push_back(rsAllocationType->getPointerTo());
89    SetObjTypeParams.push_back(rsAllocationType->getPointerTo());
90    SetObjType = llvm::FunctionType::get(llvm::Type::getVoidTy(M.getContext()), SetObjTypeParams, false);
91    rsAllocationSetObj = M.getOrInsertFunction("_Z11rsSetObjectP13rs_allocationS_", SetObjType);
92    SetObjTypeParams.clear();
93
94    SetObjTypeParams.push_back(rsElementType->getPointerTo());
95    SetObjTypeParams.push_back(rsElementType->getPointerTo());
96    SetObjType = llvm::FunctionType::get(llvm::Type::getVoidTy(M.getContext()), SetObjTypeParams, false);
97    rsElementSetObj = M.getOrInsertFunction("_Z11rsSetObjectP10rs_elementS_", SetObjType);
98    SetObjTypeParams.clear();
99
100    SetObjTypeParams.push_back(rsSamplerType->getPointerTo());
101    SetObjTypeParams.push_back(rsSamplerType->getPointerTo());
102    SetObjType = llvm::FunctionType::get(llvm::Type::getVoidTy(M.getContext()), SetObjTypeParams, false);
103    rsSamplerSetObj = M.getOrInsertFunction("_Z11rsSetObjectP10rs_samplerS_", SetObjType);
104    SetObjTypeParams.clear();
105
106    SetObjTypeParams.push_back(rsScriptType->getPointerTo());
107    SetObjTypeParams.push_back(rsScriptType->getPointerTo());
108    SetObjType = llvm::FunctionType::get(llvm::Type::getVoidTy(M.getContext()), SetObjTypeParams, false);
109    rsScriptSetObj = M.getOrInsertFunction("_Z11rsSetObjectP9rs_scriptS_", SetObjType);
110    SetObjTypeParams.clear();
111
112    SetObjTypeParams.push_back(rsTypeType->getPointerTo());
113    SetObjTypeParams.push_back(rsTypeType->getPointerTo());
114    SetObjType = llvm::FunctionType::get(llvm::Type::getVoidTy(M.getContext()), SetObjTypeParams, false);
115    rsTypeSetObj = M.getOrInsertFunction("_Z11rsSetObjectP7rs_typeS_", SetObjType);
116    SetObjTypeParams.clear();
117
118    return true;
119  }
120
121  bool insertSetObjectHelper(llvm::CallInst *Call, llvm::Value *V, enum RsDataType DT) {
122    llvm::Constant *SetObj = nullptr;
123    llvm::StructType *RSStructType = nullptr;
124    switch (DT) {
125    case RS_TYPE_ALLOCATION:
126      SetObj = rsAllocationSetObj;
127      RSStructType = rsAllocationType;
128      break;
129    case RS_TYPE_ELEMENT:
130      SetObj = rsElementSetObj;
131      RSStructType = rsElementType;
132      break;
133    case RS_TYPE_SAMPLER:
134      SetObj = rsSamplerSetObj;
135      RSStructType = rsSamplerType;
136      break;
137    case RS_TYPE_SCRIPT:
138      SetObj = rsScriptSetObj;
139      RSStructType = rsScriptType;
140      break;
141    case RS_TYPE_TYPE:
142      SetObj = rsTypeSetObj;
143      RSStructType = rsTypeType;
144      break;
145    default:
146      return false; // this is for graphics types and matrices; do nothing
147    }
148
149
150    llvm::CastInst* CastedValue = llvm::CastInst::CreatePointerCast(V, RSStructType->getPointerTo(), "", Call);
151
152    llvm::SmallVector<llvm::Value*, 2> SetObjParams;
153    SetObjParams.push_back(CastedValue);
154    SetObjParams.push_back(CastedValue);
155
156    llvm::CallInst::Create(SetObj, SetObjParams, "", Call);
157
158
159    return true;
160  }
161
162
163  // this only modifies .helper functions that take certain RS base object types
164  virtual bool runOnFunction(llvm::Function &F) override {
165    if (!F.getName().startswith(".helper"))
166      return false;
167
168    bool changed = false;
169    const llvm::Function::ArgumentListType &argList(F.getArgumentList());
170    bool containsBaseObj = false;
171
172    // .helper methods should have one arg only, an anonymous struct
173    // that struct may contain BaseObjs
174    for (auto arg = argList.begin(); arg != argList.end(); arg++) {
175      llvm::Type *argType = arg->getType();
176      if (!argType->isPointerTy() || !argType->getPointerElementType()->isStructTy())
177        continue;
178
179      llvm::StructType *argStructType = llvm::dyn_cast<llvm::StructType>(argType->getPointerElementType());
180
181      for (unsigned int i = 0; i < argStructType->getNumElements(); i++) {
182        llvm::Type *currentType = argStructType->getElementType(i);
183        if (currentType->isStructTy() && currentType->getStructName().startswith("struct.rs_")) {
184          containsBaseObj = true;
185        }
186      }
187      break;
188    }
189
190
191    if (containsBaseObj) {
192      // modify the thing that should not be
193      auto &BBList(F.getBasicBlockList());
194      for (auto &BB : BBList) {
195        auto &InstList(BB.getInstList());
196        for (auto &Inst : InstList) {
197          // don't care about anything except call instructions that we didn't already add
198          if (llvm::CallInst *call = llvm::dyn_cast<llvm::CallInst>(&Inst)) {
199            for (unsigned int i = 0; i < call->getNumArgOperands(); i++) {
200              llvm::Value *V = call->getArgOperand(i);
201              llvm::Type *T = V->getType();
202              enum RsDataType DT = RS_TYPE_NONE;
203              if (T->isPointerTy() && T->getPointerElementType()->isStructTy()) {
204                DT = getRsDataTypeForType(T->getPointerElementType());
205              }
206              if (DT != RS_TYPE_NONE) {
207                // generate the new call instruction and insert it
208                changed |= insertSetObjectHelper(call, V, DT);
209              }
210            }
211          }
212        }
213      }
214    }
215
216    return changed;
217  }
218
219  virtual const char *getPassName() const override {
220    return ".helper method expansion for large RS objects";
221  }
222}; // end RSInvokeHelperPass class
223} // end anonymous namespace
224
225char RSInvokeHelperPass::ID = 0;
226
227namespace bcc {
228
229llvm::FunctionPass *
230createRSInvokeHelperPass(){
231  return new RSInvokeHelperPass();
232}
233
234}
235