SITypeRewriter.cpp revision dce4a407a24b04eebc6a376f8e62b41aaa7b071f
1//===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10/// \file
11/// This pass removes performs the following type substitution on all
12/// non-compute shaders:
13///
14/// v16i8 => i128
15///   - v16i8 is used for constant memory resource descriptors.  This type is
16///      legal for some compute APIs, and we don't want to declare it as legal
17///      in the backend, because we want the legalizer to expand all v16i8
18///      operations.
19/// v1* => *
20///   - Having v1* types complicates the legalizer and we can easily replace
21///   - them with the element type.
22//===----------------------------------------------------------------------===//
23
24#include "AMDGPU.h"
25#include "llvm/IR/IRBuilder.h"
26#include "llvm/IR/InstVisitor.h"
27
28using namespace llvm;
29
30namespace {
31
32class SITypeRewriter : public FunctionPass,
33                       public InstVisitor<SITypeRewriter> {
34
35  static char ID;
36  Module *Mod;
37  Type *v16i8;
38  Type *v4i32;
39
40public:
41  SITypeRewriter() : FunctionPass(ID) { }
42  bool doInitialization(Module &M) override;
43  bool runOnFunction(Function &F) override;
44  const char *getPassName() const override {
45    return "SI Type Rewriter";
46  }
47  void visitLoadInst(LoadInst &I);
48  void visitCallInst(CallInst &I);
49  void visitBitCast(BitCastInst &I);
50};
51
52} // End anonymous namespace
53
54char SITypeRewriter::ID = 0;
55
56bool SITypeRewriter::doInitialization(Module &M) {
57  Mod = &M;
58  v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16);
59  v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4);
60  return false;
61}
62
63bool SITypeRewriter::runOnFunction(Function &F) {
64  AttributeSet Set = F.getAttributes();
65  Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType");
66
67  unsigned ShaderType = ShaderType::COMPUTE;
68  if (A.isStringAttribute()) {
69    StringRef Str = A.getValueAsString();
70    Str.getAsInteger(0, ShaderType);
71  }
72  if (ShaderType == ShaderType::COMPUTE)
73    return false;
74
75  visit(F);
76  visit(F);
77
78  return false;
79}
80
81void SITypeRewriter::visitLoadInst(LoadInst &I) {
82  Value *Ptr = I.getPointerOperand();
83  Type *PtrTy = Ptr->getType();
84  Type *ElemTy = PtrTy->getPointerElementType();
85  IRBuilder<> Builder(&I);
86  if (ElemTy == v16i8)  {
87    Value *BitCast = Builder.CreateBitCast(Ptr,
88        PointerType::get(v4i32,PtrTy->getPointerAddressSpace()));
89    LoadInst *Load = Builder.CreateLoad(BitCast);
90    SmallVector <std::pair<unsigned, MDNode*>, 8> MD;
91    I.getAllMetadataOtherThanDebugLoc(MD);
92    for (unsigned i = 0, e = MD.size(); i != e; ++i) {
93      Load->setMetadata(MD[i].first, MD[i].second);
94    }
95    Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType());
96    I.replaceAllUsesWith(BitCastLoad);
97    I.eraseFromParent();
98  }
99}
100
101void SITypeRewriter::visitCallInst(CallInst &I) {
102  IRBuilder<> Builder(&I);
103
104  SmallVector <Value*, 8> Args;
105  SmallVector <Type*, 8> Types;
106  bool NeedToReplace = false;
107  Function *F = I.getCalledFunction();
108  std::string Name = F->getName().str();
109  for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) {
110    Value *Arg = I.getArgOperand(i);
111    if (Arg->getType() == v16i8) {
112      Args.push_back(Builder.CreateBitCast(Arg, v4i32));
113      Types.push_back(v4i32);
114      NeedToReplace = true;
115      Name = Name + ".v4i32";
116    } else if (Arg->getType()->isVectorTy() &&
117               Arg->getType()->getVectorNumElements() == 1 &&
118               Arg->getType()->getVectorElementType() ==
119                                              Type::getInt32Ty(I.getContext())){
120      Type *ElementTy = Arg->getType()->getVectorElementType();
121      std::string TypeName = "i32";
122      InsertElementInst *Def = dyn_cast<InsertElementInst>(Arg);
123      assert(Def);
124      Args.push_back(Def->getOperand(1));
125      Types.push_back(ElementTy);
126      std::string VecTypeName = "v1" + TypeName;
127      Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName);
128      NeedToReplace = true;
129    } else {
130      Args.push_back(Arg);
131      Types.push_back(Arg->getType());
132    }
133  }
134
135  if (!NeedToReplace) {
136    return;
137  }
138  Function *NewF = Mod->getFunction(Name);
139  if (!NewF) {
140    NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod);
141    NewF->setAttributes(F->getAttributes());
142  }
143  I.replaceAllUsesWith(Builder.CreateCall(NewF, Args));
144  I.eraseFromParent();
145}
146
147void SITypeRewriter::visitBitCast(BitCastInst &I) {
148  IRBuilder<> Builder(&I);
149  if (I.getDestTy() != v4i32) {
150    return;
151  }
152
153  if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) {
154    if (Op->getSrcTy() == v4i32) {
155      I.replaceAllUsesWith(Op->getOperand(0));
156      I.eraseFromParent();
157    }
158  }
159}
160
161FunctionPass *llvm::createSITypeRewriter() {
162  return new SITypeRewriter();
163}
164