1/*
2 * Copyright 2016, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "RSAllocationUtils.h"
18
19#include "llvm/ADT/StringRef.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/GlobalVariable.h"
22#include "llvm/IR/Instructions.h"
23#include "llvm/IR/Module.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/raw_ostream.h"
26
27#include "cxxabi.h"
28
29#include <sstream>
30#include <unordered_map>
31
32#define DEBUG_TYPE "rs2spirv-rs-allocation-utils"
33
34using namespace llvm;
35
36namespace rs2spirv {
37
38bool isRSAllocation(const GlobalVariable &GV) {
39  auto *PT = cast<PointerType>(GV.getType());
40  DEBUG(PT->dump());
41
42  auto *VT = PT->getElementType();
43  DEBUG(VT->dump());
44  std::string TypeName;
45  raw_string_ostream RSO(TypeName);
46  VT->print(RSO);
47  RSO.str(); // Force flush.
48  DEBUG(dbgs() << "TypeName: " << TypeName << '\n');
49
50  return TypeName.find("struct.rs_allocation") != std::string::npos;
51}
52
53bool getRSAllocationInfo(Module &M, SmallVectorImpl<RSAllocationInfo> &Allocs) {
54  DEBUG(dbgs() << "getRSAllocationInfo\n");
55  for (auto &GV : M.globals()) {
56    if (GV.isDeclaration() || !isRSAllocation(GV))
57      continue;
58
59    Allocs.push_back({'%' + GV.getName().str(), None, &GV, -1});
60  }
61
62  return true;
63}
64
65// Collect Allocation access calls into the Calls
66// Also update Allocs with assigned ID.
67// After calling this function, Allocs would contain the mapping from
68// GV name to the corresponding ID.
69bool getRSAllocAccesses(SmallVectorImpl<RSAllocationInfo> &Allocs,
70                        SmallVectorImpl<RSAllocationCallInfo> &Calls) {
71  DEBUG(dbgs() << "getRSGEATCalls\n");
72  DEBUG(dbgs() << "\n\n~~~~~~~~~~~~~~~~~~~~~\n\n");
73
74  std::unordered_map<const Value *, const GlobalVariable *> Mapping;
75  int id_assigned = 0;
76
77  for (auto &A : Allocs) {
78    auto *GV = A.GlobalVar;
79    std::vector<User *> WorkList(GV->user_begin(), GV->user_end());
80    size_t Idx = 0;
81
82    while (Idx < WorkList.size()) {
83      auto *U = WorkList[Idx];
84      DEBUG(dbgs() << "Visiting ");
85      DEBUG(U->dump());
86      ++Idx;
87      auto It = Mapping.find(U);
88      if (It != Mapping.end()) {
89        if (It->second == GV) {
90          continue;
91        } else {
92          errs() << "Duplicate global mapping discovered!\n";
93          errs() << "\nGlobal: ";
94          GV->print(errs());
95          errs() << "\nExisting mapping: ";
96          It->second->print(errs());
97          errs() << "\nUser: ";
98          U->print(errs());
99          errs() << '\n';
100
101          return false;
102        }
103      }
104
105      Mapping[U] = GV;
106      DEBUG(dbgs() << "New mapping: ");
107      DEBUG(U->print(dbgs()));
108      DEBUG(dbgs() << " -> " << GV->getName() << '\n');
109
110      if (auto *FCall = dyn_cast<CallInst>(U)) {
111        if (auto *F = FCall->getCalledFunction()) {
112          const auto FName = F->getName();
113          DEBUG(dbgs() << "Discovered function call to : " << FName << '\n');
114          // Treat memcpy as moves for the purpose of this analysis
115          if (FName.startswith("llvm.memcpy")) {
116            assert(FCall->getNumArgOperands() > 0);
117            Value *CopyDest = FCall->getArgOperand(0);
118            // We are interested in the users of the dest operand of
119            // memcpy here
120            Value *LocalCopy = CopyDest->stripPointerCasts();
121            User *NewU = dyn_cast<User>(LocalCopy);
122            assert(NewU);
123            WorkList.push_back(NewU);
124            continue;
125          }
126
127          char *demangled = __cxxabiv1::__cxa_demangle(
128              FName.str().c_str(), nullptr, nullptr, nullptr);
129          if (!demangled)
130            continue;
131          const StringRef DemangledNameRef(demangled);
132          DEBUG(dbgs() << "Demangled name: " << DemangledNameRef << '\n');
133
134          const StringRef GEAPrefix = "rsGetElementAt_";
135          const StringRef SEAPrefix = "rsSetElementAt_";
136          const StringRef DIMXPrefix = "rsAllocationGetDimX";
137          assert(GEAPrefix.size() == SEAPrefix.size());
138
139          const bool IsGEA = DemangledNameRef.startswith(GEAPrefix);
140          const bool IsSEA = DemangledNameRef.startswith(SEAPrefix);
141          const bool IsDIMX = DemangledNameRef.startswith(DIMXPrefix);
142
143          assert(IsGEA || IsSEA || IsDIMX);
144          if (!A.hasID()) {
145            A.assignID(id_assigned++);
146          }
147
148          if (IsGEA || IsSEA) {
149            DEBUG(dbgs() << "Found rsAlloc function!\n");
150
151            const auto Kind =
152                IsGEA ? RSAllocAccessKind::GEA : RSAllocAccessKind::SEA;
153
154            const auto RSElementTy =
155                DemangledNameRef.drop_front(GEAPrefix.size());
156
157            Calls.push_back({A, FCall, Kind, RSElementTy.str()});
158            continue;
159          } else if (DemangledNameRef.startswith(GEAPrefix.drop_back()) ||
160                     DemangledNameRef.startswith(SEAPrefix.drop_back())) {
161            errs() << "Untyped accesses to global rs_allocations are not "
162                      "supported.\n";
163            return false;
164          } else if (IsDIMX) {
165            DEBUG(dbgs() << "Found rsAllocationGetDimX function!\n");
166            const auto Kind = RSAllocAccessKind::DIMX;
167            Calls.push_back({A, FCall, Kind, ""});
168          }
169        }
170      }
171
172      // TODO: Consider using set-like container to reduce computational
173      // complexity.
174      for (auto *NewU : U->users())
175        if (std::find(WorkList.begin(), WorkList.end(), NewU) == WorkList.end())
176          WorkList.push_back(NewU);
177    }
178  }
179
180  std::unordered_map<const GlobalVariable *, std::string> GVAccessTypes;
181
182  for (auto &Access : Calls) {
183    auto AccessElemTyIt = GVAccessTypes.find(Access.RSAlloc.GlobalVar);
184    if (AccessElemTyIt != GVAccessTypes.end() &&
185        AccessElemTyIt->second != Access.RSElementTy) {
186      errs() << "Could not infere element type for: ";
187      Access.RSAlloc.GlobalVar->print(errs());
188      errs() << '\n';
189      return false;
190    } else if (AccessElemTyIt == GVAccessTypes.end()) {
191      GVAccessTypes.emplace(Access.RSAlloc.GlobalVar, Access.RSElementTy);
192      Access.RSAlloc.RSElementType = Access.RSElementTy;
193    }
194  }
195
196  DEBUG(dbgs() << "\n\n~~~~~~~~~~~~~~~~~~~~~\n\n");
197  return true;
198}
199
200bool solidifyRSAllocAccess(Module &M, RSAllocationCallInfo CallInfo) {
201  DEBUG(dbgs() << "solidifyRSAllocAccess " << CallInfo.RSAlloc.VarName << '\n');
202  auto *FCall = CallInfo.FCall;
203  auto *Fun = FCall->getCalledFunction();
204  assert(Fun);
205
206  StringRef FName;
207  if (CallInfo.Kind == RSAllocAccessKind::DIMX)
208    FName = "rsAllocationGetDimX";
209  else
210    FName = Fun->getName();
211
212  std::ostringstream OSS;
213  OSS << "__rsov_" << FName.str();
214  // Make up uint32_t F(uint32_t)
215  Type *UInt32Ty = IntegerType::get(M.getContext(), 32);
216  auto *NewFT = FunctionType::get(UInt32Ty, ArrayRef<Type *>(UInt32Ty), false);
217
218  auto *NewF = Function::Create(NewFT, // Fun->getFunctionType(),
219                                Function::ExternalLinkage, OSS.str(), &M);
220  FCall->setCalledFunction(NewF);
221  FCall->setArgOperand(0, ConstantInt::get(UInt32Ty, 0, false));
222  NewF->setAttributes(Fun->getAttributes());
223
224  DEBUG(M.dump());
225
226  return true;
227}
228
229} // namespace rs2spirv
230