SITypeRewriter.cpp revision dce4a407a24b04eebc6a376f8e62b41aaa7b071f
1//===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10/// \file 11/// This pass removes performs the following type substitution on all 12/// non-compute shaders: 13/// 14/// v16i8 => i128 15/// - v16i8 is used for constant memory resource descriptors. This type is 16/// legal for some compute APIs, and we don't want to declare it as legal 17/// in the backend, because we want the legalizer to expand all v16i8 18/// operations. 19/// v1* => * 20/// - Having v1* types complicates the legalizer and we can easily replace 21/// - them with the element type. 22//===----------------------------------------------------------------------===// 23 24#include "AMDGPU.h" 25#include "llvm/IR/IRBuilder.h" 26#include "llvm/IR/InstVisitor.h" 27 28using namespace llvm; 29 30namespace { 31 32class SITypeRewriter : public FunctionPass, 33 public InstVisitor<SITypeRewriter> { 34 35 static char ID; 36 Module *Mod; 37 Type *v16i8; 38 Type *v4i32; 39 40public: 41 SITypeRewriter() : FunctionPass(ID) { } 42 bool doInitialization(Module &M) override; 43 bool runOnFunction(Function &F) override; 44 const char *getPassName() const override { 45 return "SI Type Rewriter"; 46 } 47 void visitLoadInst(LoadInst &I); 48 void visitCallInst(CallInst &I); 49 void visitBitCast(BitCastInst &I); 50}; 51 52} // End anonymous namespace 53 54char SITypeRewriter::ID = 0; 55 56bool SITypeRewriter::doInitialization(Module &M) { 57 Mod = &M; 58 v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16); 59 v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4); 60 return false; 61} 62 63bool SITypeRewriter::runOnFunction(Function &F) { 64 AttributeSet Set = F.getAttributes(); 65 Attribute A = Set.getAttribute(AttributeSet::FunctionIndex, "ShaderType"); 66 67 unsigned ShaderType = ShaderType::COMPUTE; 68 if (A.isStringAttribute()) { 69 StringRef Str = A.getValueAsString(); 70 Str.getAsInteger(0, ShaderType); 71 } 72 if (ShaderType == ShaderType::COMPUTE) 73 return false; 74 75 visit(F); 76 visit(F); 77 78 return false; 79} 80 81void SITypeRewriter::visitLoadInst(LoadInst &I) { 82 Value *Ptr = I.getPointerOperand(); 83 Type *PtrTy = Ptr->getType(); 84 Type *ElemTy = PtrTy->getPointerElementType(); 85 IRBuilder<> Builder(&I); 86 if (ElemTy == v16i8) { 87 Value *BitCast = Builder.CreateBitCast(Ptr, 88 PointerType::get(v4i32,PtrTy->getPointerAddressSpace())); 89 LoadInst *Load = Builder.CreateLoad(BitCast); 90 SmallVector <std::pair<unsigned, MDNode*>, 8> MD; 91 I.getAllMetadataOtherThanDebugLoc(MD); 92 for (unsigned i = 0, e = MD.size(); i != e; ++i) { 93 Load->setMetadata(MD[i].first, MD[i].second); 94 } 95 Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType()); 96 I.replaceAllUsesWith(BitCastLoad); 97 I.eraseFromParent(); 98 } 99} 100 101void SITypeRewriter::visitCallInst(CallInst &I) { 102 IRBuilder<> Builder(&I); 103 104 SmallVector <Value*, 8> Args; 105 SmallVector <Type*, 8> Types; 106 bool NeedToReplace = false; 107 Function *F = I.getCalledFunction(); 108 std::string Name = F->getName().str(); 109 for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) { 110 Value *Arg = I.getArgOperand(i); 111 if (Arg->getType() == v16i8) { 112 Args.push_back(Builder.CreateBitCast(Arg, v4i32)); 113 Types.push_back(v4i32); 114 NeedToReplace = true; 115 Name = Name + ".v4i32"; 116 } else if (Arg->getType()->isVectorTy() && 117 Arg->getType()->getVectorNumElements() == 1 && 118 Arg->getType()->getVectorElementType() == 119 Type::getInt32Ty(I.getContext())){ 120 Type *ElementTy = Arg->getType()->getVectorElementType(); 121 std::string TypeName = "i32"; 122 InsertElementInst *Def = dyn_cast<InsertElementInst>(Arg); 123 assert(Def); 124 Args.push_back(Def->getOperand(1)); 125 Types.push_back(ElementTy); 126 std::string VecTypeName = "v1" + TypeName; 127 Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName); 128 NeedToReplace = true; 129 } else { 130 Args.push_back(Arg); 131 Types.push_back(Arg->getType()); 132 } 133 } 134 135 if (!NeedToReplace) { 136 return; 137 } 138 Function *NewF = Mod->getFunction(Name); 139 if (!NewF) { 140 NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod); 141 NewF->setAttributes(F->getAttributes()); 142 } 143 I.replaceAllUsesWith(Builder.CreateCall(NewF, Args)); 144 I.eraseFromParent(); 145} 146 147void SITypeRewriter::visitBitCast(BitCastInst &I) { 148 IRBuilder<> Builder(&I); 149 if (I.getDestTy() != v4i32) { 150 return; 151 } 152 153 if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) { 154 if (Op->getSrcTy() == v4i32) { 155 I.replaceAllUsesWith(Op->getOperand(0)); 156 I.eraseFromParent(); 157 } 158 } 159} 160 161FunctionPass *llvm::createSITypeRewriter() { 162 return new SITypeRewriter(); 163} 164