1//===-- AMDGPUPromoteAlloca.cpp - Promote Allocas -------------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass eliminates allocas by either converting them into vectors or
11// by migrating them to local address space.
12//
13//===----------------------------------------------------------------------===//
14
15#include "AMDGPU.h"
16#include "AMDGPUSubtarget.h"
17#include "llvm/Analysis/ValueTracking.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/IR/InstVisitor.h"
20#include "llvm/Support/Debug.h"
21
22#define DEBUG_TYPE "amdgpu-promote-alloca"
23
24using namespace llvm;
25
26namespace {
27
28class AMDGPUPromoteAlloca : public FunctionPass,
29                       public InstVisitor<AMDGPUPromoteAlloca> {
30
31  static char ID;
32  Module *Mod;
33  const AMDGPUSubtarget &ST;
34  int LocalMemAvailable;
35
36public:
37  AMDGPUPromoteAlloca(const AMDGPUSubtarget &st) : FunctionPass(ID), ST(st),
38                                                   LocalMemAvailable(0) { }
39  virtual bool doInitialization(Module &M);
40  virtual bool runOnFunction(Function &F);
41  virtual const char *getPassName() const {
42    return "AMDGPU Promote Alloca";
43  }
44  void visitAlloca(AllocaInst &I);
45};
46
47} // End anonymous namespace
48
49char AMDGPUPromoteAlloca::ID = 0;
50
51bool AMDGPUPromoteAlloca::doInitialization(Module &M) {
52  Mod = &M;
53  return false;
54}
55
56bool AMDGPUPromoteAlloca::runOnFunction(Function &F) {
57
58  const FunctionType *FTy = F.getFunctionType();
59
60  LocalMemAvailable = ST.getLocalMemorySize();
61
62
63  // If the function has any arguments in the local address space, then it's
64  // possible these arguments require the entire local memory space, so
65  // we cannot use local memory in the pass.
66  for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i) {
67    const Type *ParamTy = FTy->getParamType(i);
68    if (ParamTy->isPointerTy() &&
69        ParamTy->getPointerAddressSpace() == AMDGPUAS::LOCAL_ADDRESS) {
70      LocalMemAvailable = 0;
71      DEBUG(dbgs() << "Function has local memory argument.  Promoting to "
72                      "local memory disabled.\n");
73      break;
74    }
75  }
76
77  if (LocalMemAvailable > 0) {
78    // Check how much local memory is being used by global objects
79    for (Module::global_iterator I = Mod->global_begin(),
80                                 E = Mod->global_end(); I != E; ++I) {
81      GlobalVariable *GV = I;
82      PointerType *GVTy = GV->getType();
83      if (GVTy->getAddressSpace() != AMDGPUAS::LOCAL_ADDRESS)
84        continue;
85      for (Value::use_iterator U = GV->use_begin(),
86                               UE = GV->use_end(); U != UE; ++U) {
87        Instruction *Use = dyn_cast<Instruction>(*U);
88        if (!Use)
89          continue;
90        if (Use->getParent()->getParent() == &F)
91          LocalMemAvailable -=
92              Mod->getDataLayout()->getTypeAllocSize(GVTy->getElementType());
93      }
94    }
95  }
96
97  LocalMemAvailable = std::max(0, LocalMemAvailable);
98  DEBUG(dbgs() << LocalMemAvailable << "bytes free in local memory.\n");
99
100  visit(F);
101
102  return false;
103}
104
105static VectorType *arrayTypeToVecType(const Type *ArrayTy) {
106  return VectorType::get(ArrayTy->getArrayElementType(),
107                         ArrayTy->getArrayNumElements());
108}
109
110static Value* calculateVectorIndex(Value *Ptr,
111                                  std::map<GetElementPtrInst*, Value*> GEPIdx) {
112  if (isa<AllocaInst>(Ptr))
113    return Constant::getNullValue(Type::getInt32Ty(Ptr->getContext()));
114
115  GetElementPtrInst *GEP = cast<GetElementPtrInst>(Ptr);
116
117  return GEPIdx[GEP];
118}
119
120static Value* GEPToVectorIndex(GetElementPtrInst *GEP) {
121  // FIXME we only support simple cases
122  if (GEP->getNumOperands() != 3)
123    return NULL;
124
125  ConstantInt *I0 = dyn_cast<ConstantInt>(GEP->getOperand(1));
126  if (!I0 || !I0->isZero())
127    return NULL;
128
129  return GEP->getOperand(2);
130}
131
132// Not an instruction handled below to turn into a vector.
133//
134// TODO: Check isTriviallyVectorizable for calls and handle other
135// instructions.
136static bool canVectorizeInst(Instruction *Inst) {
137  switch (Inst->getOpcode()) {
138  case Instruction::Load:
139  case Instruction::Store:
140  case Instruction::BitCast:
141  case Instruction::AddrSpaceCast:
142    return true;
143  default:
144    return false;
145  }
146}
147
148static bool tryPromoteAllocaToVector(AllocaInst *Alloca) {
149  Type *AllocaTy = Alloca->getAllocatedType();
150
151  DEBUG(dbgs() << "Alloca Candidate for vectorization \n");
152
153  // FIXME: There is no reason why we can't support larger arrays, we
154  // are just being conservative for now.
155  if (!AllocaTy->isArrayTy() ||
156      AllocaTy->getArrayElementType()->isVectorTy() ||
157      AllocaTy->getArrayNumElements() > 4) {
158
159    DEBUG(dbgs() << "  Cannot convert type to vector");
160    return false;
161  }
162
163  std::map<GetElementPtrInst*, Value*> GEPVectorIdx;
164  std::vector<Value*> WorkList;
165  for (User *AllocaUser : Alloca->users()) {
166    GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(AllocaUser);
167    if (!GEP) {
168      if (!canVectorizeInst(cast<Instruction>(AllocaUser)))
169        return false;
170
171      WorkList.push_back(AllocaUser);
172      continue;
173    }
174
175    Value *Index = GEPToVectorIndex(GEP);
176
177    // If we can't compute a vector index from this GEP, then we can't
178    // promote this alloca to vector.
179    if (!Index) {
180      DEBUG(dbgs() << "  Cannot compute vector index for GEP " << *GEP << '\n');
181      return false;
182    }
183
184    GEPVectorIdx[GEP] = Index;
185    for (User *GEPUser : AllocaUser->users()) {
186      if (!canVectorizeInst(cast<Instruction>(GEPUser)))
187        return false;
188
189      WorkList.push_back(GEPUser);
190    }
191  }
192
193  VectorType *VectorTy = arrayTypeToVecType(AllocaTy);
194
195  DEBUG(dbgs() << "  Converting alloca to vector "
196        << *AllocaTy << " -> " << *VectorTy << '\n');
197
198  for (std::vector<Value*>::iterator I = WorkList.begin(),
199                                     E = WorkList.end(); I != E; ++I) {
200    Instruction *Inst = cast<Instruction>(*I);
201    IRBuilder<> Builder(Inst);
202    switch (Inst->getOpcode()) {
203    case Instruction::Load: {
204      Value *Ptr = Inst->getOperand(0);
205      Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
206      Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
207      Value *VecValue = Builder.CreateLoad(BitCast);
208      Value *ExtractElement = Builder.CreateExtractElement(VecValue, Index);
209      Inst->replaceAllUsesWith(ExtractElement);
210      Inst->eraseFromParent();
211      break;
212    }
213    case Instruction::Store: {
214      Value *Ptr = Inst->getOperand(1);
215      Value *Index = calculateVectorIndex(Ptr, GEPVectorIdx);
216      Value *BitCast = Builder.CreateBitCast(Alloca, VectorTy->getPointerTo(0));
217      Value *VecValue = Builder.CreateLoad(BitCast);
218      Value *NewVecValue = Builder.CreateInsertElement(VecValue,
219                                                       Inst->getOperand(0),
220                                                       Index);
221      Builder.CreateStore(NewVecValue, BitCast);
222      Inst->eraseFromParent();
223      break;
224    }
225    case Instruction::BitCast:
226    case Instruction::AddrSpaceCast:
227      break;
228
229    default:
230      Inst->dump();
231      llvm_unreachable("Inconsistency in instructions promotable to vector");
232    }
233  }
234  return true;
235}
236
237static void collectUsesWithPtrTypes(Value *Val, std::vector<Value*> &WorkList) {
238  for (User *User : Val->users()) {
239    if(std::find(WorkList.begin(), WorkList.end(), User) != WorkList.end())
240      continue;
241    if (isa<CallInst>(User)) {
242      WorkList.push_back(User);
243      continue;
244    }
245    if (!User->getType()->isPointerTy())
246      continue;
247    WorkList.push_back(User);
248    collectUsesWithPtrTypes(User, WorkList);
249  }
250}
251
252void AMDGPUPromoteAlloca::visitAlloca(AllocaInst &I) {
253  IRBuilder<> Builder(&I);
254
255  // First try to replace the alloca with a vector
256  Type *AllocaTy = I.getAllocatedType();
257
258  DEBUG(dbgs() << "Trying to promote " << I << '\n');
259
260  if (tryPromoteAllocaToVector(&I))
261    return;
262
263  DEBUG(dbgs() << " alloca is not a candidate for vectorization.\n");
264
265  // FIXME: This is the maximum work group size.  We should try to get
266  // value from the reqd_work_group_size function attribute if it is
267  // available.
268  unsigned WorkGroupSize = 256;
269  int AllocaSize = WorkGroupSize *
270      Mod->getDataLayout()->getTypeAllocSize(AllocaTy);
271
272  if (AllocaSize > LocalMemAvailable) {
273    DEBUG(dbgs() << " Not enough local memory to promote alloca.\n");
274    return;
275  }
276
277  DEBUG(dbgs() << "Promoting alloca to local memory\n");
278  LocalMemAvailable -= AllocaSize;
279
280  GlobalVariable *GV = new GlobalVariable(
281      *Mod, ArrayType::get(I.getAllocatedType(), 256), false,
282      GlobalValue::ExternalLinkage, 0, I.getName(), 0,
283      GlobalVariable::NotThreadLocal, AMDGPUAS::LOCAL_ADDRESS);
284
285  FunctionType *FTy = FunctionType::get(
286      Type::getInt32Ty(Mod->getContext()), false);
287  AttributeSet AttrSet;
288  AttrSet.addAttribute(Mod->getContext(), 0, Attribute::ReadNone);
289
290  Value *ReadLocalSizeY = Mod->getOrInsertFunction(
291      "llvm.r600.read.local.size.y", FTy, AttrSet);
292  Value *ReadLocalSizeZ = Mod->getOrInsertFunction(
293      "llvm.r600.read.local.size.z", FTy, AttrSet);
294  Value *ReadTIDIGX = Mod->getOrInsertFunction(
295      "llvm.r600.read.tidig.x", FTy, AttrSet);
296  Value *ReadTIDIGY = Mod->getOrInsertFunction(
297      "llvm.r600.read.tidig.y", FTy, AttrSet);
298  Value *ReadTIDIGZ = Mod->getOrInsertFunction(
299      "llvm.r600.read.tidig.z", FTy, AttrSet);
300
301
302  Value *TCntY = Builder.CreateCall(ReadLocalSizeY);
303  Value *TCntZ = Builder.CreateCall(ReadLocalSizeZ);
304  Value *TIdX  = Builder.CreateCall(ReadTIDIGX);
305  Value *TIdY  = Builder.CreateCall(ReadTIDIGY);
306  Value *TIdZ  = Builder.CreateCall(ReadTIDIGZ);
307
308  Value *Tmp0 = Builder.CreateMul(TCntY, TCntZ);
309  Tmp0 = Builder.CreateMul(Tmp0, TIdX);
310  Value *Tmp1 = Builder.CreateMul(TIdY, TCntZ);
311  Value *TID = Builder.CreateAdd(Tmp0, Tmp1);
312  TID = Builder.CreateAdd(TID, TIdZ);
313
314  std::vector<Value*> Indices;
315  Indices.push_back(Constant::getNullValue(Type::getInt32Ty(Mod->getContext())));
316  Indices.push_back(TID);
317
318  Value *Offset = Builder.CreateGEP(GV, Indices);
319  I.mutateType(Offset->getType());
320  I.replaceAllUsesWith(Offset);
321  I.eraseFromParent();
322
323  std::vector<Value*> WorkList;
324
325  collectUsesWithPtrTypes(Offset, WorkList);
326
327  for (std::vector<Value*>::iterator i = WorkList.begin(),
328                                     e = WorkList.end(); i != e; ++i) {
329    Value *V = *i;
330    CallInst *Call = dyn_cast<CallInst>(V);
331    if (!Call) {
332      Type *EltTy = V->getType()->getPointerElementType();
333      PointerType *NewTy = PointerType::get(EltTy, AMDGPUAS::LOCAL_ADDRESS);
334      V->mutateType(NewTy);
335      continue;
336    }
337
338    IntrinsicInst *Intr = dyn_cast<IntrinsicInst>(Call);
339    if (!Intr) {
340      std::vector<Type*> ArgTypes;
341      for (unsigned ArgIdx = 0, ArgEnd = Call->getNumArgOperands();
342                                ArgIdx != ArgEnd; ++ArgIdx) {
343        ArgTypes.push_back(Call->getArgOperand(ArgIdx)->getType());
344      }
345      Function *F = Call->getCalledFunction();
346      FunctionType *NewType = FunctionType::get(Call->getType(), ArgTypes,
347                                                F->isVarArg());
348      Constant *C = Mod->getOrInsertFunction(StringRef(F->getName().str() + ".local"), NewType,
349                                             F->getAttributes());
350      Function *NewF = cast<Function>(C);
351      Call->setCalledFunction(NewF);
352      continue;
353    }
354
355    Builder.SetInsertPoint(Intr);
356    switch (Intr->getIntrinsicID()) {
357    case Intrinsic::lifetime_start:
358    case Intrinsic::lifetime_end:
359      // These intrinsics are for address space 0 only
360      Intr->eraseFromParent();
361      continue;
362    case Intrinsic::memcpy: {
363      MemCpyInst *MemCpy = cast<MemCpyInst>(Intr);
364      Builder.CreateMemCpy(MemCpy->getRawDest(), MemCpy->getRawSource(),
365                           MemCpy->getLength(), MemCpy->getAlignment(),
366                           MemCpy->isVolatile());
367      Intr->eraseFromParent();
368      continue;
369    }
370    case Intrinsic::memset: {
371      MemSetInst *MemSet = cast<MemSetInst>(Intr);
372      Builder.CreateMemSet(MemSet->getRawDest(), MemSet->getValue(),
373                           MemSet->getLength(), MemSet->getAlignment(),
374                           MemSet->isVolatile());
375      Intr->eraseFromParent();
376      continue;
377    }
378    default:
379      Intr->dump();
380      llvm_unreachable("Don't know how to promote alloca intrinsic use.");
381    }
382  }
383}
384
385FunctionPass *llvm::createAMDGPUPromoteAlloca(const AMDGPUSubtarget &ST) {
386  return new AMDGPUPromoteAlloca(ST);
387}
388