1//===-- SITypeRewriter.cpp - Remove unwanted types ------------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10/// \file 11/// This pass removes performs the following type substitution on all 12/// non-compute shaders: 13/// 14/// v16i8 => i128 15/// - v16i8 is used for constant memory resource descriptors. This type is 16/// legal for some compute APIs, and we don't want to declare it as legal 17/// in the backend, because we want the legalizer to expand all v16i8 18/// operations. 19/// v1* => * 20/// - Having v1* types complicates the legalizer and we can easily replace 21/// - them with the element type. 22//===----------------------------------------------------------------------===// 23 24#include "AMDGPU.h" 25#include "llvm/IR/IRBuilder.h" 26#include "llvm/IR/InstVisitor.h" 27 28using namespace llvm; 29 30namespace { 31 32class SITypeRewriter : public FunctionPass, 33 public InstVisitor<SITypeRewriter> { 34 35 static char ID; 36 Module *Mod; 37 Type *v16i8; 38 Type *v4i32; 39 40public: 41 SITypeRewriter() : FunctionPass(ID) { } 42 bool doInitialization(Module &M) override; 43 bool runOnFunction(Function &F) override; 44 const char *getPassName() const override { 45 return "SI Type Rewriter"; 46 } 47 void visitLoadInst(LoadInst &I); 48 void visitCallInst(CallInst &I); 49 void visitBitCast(BitCastInst &I); 50}; 51 52} // End anonymous namespace 53 54char SITypeRewriter::ID = 0; 55 56bool SITypeRewriter::doInitialization(Module &M) { 57 Mod = &M; 58 v16i8 = VectorType::get(Type::getInt8Ty(M.getContext()), 16); 59 v4i32 = VectorType::get(Type::getInt32Ty(M.getContext()), 4); 60 return false; 61} 62 63bool SITypeRewriter::runOnFunction(Function &F) { 64 Attribute A = F.getFnAttribute("ShaderType"); 65 66 unsigned ShaderType = ShaderType::COMPUTE; 67 if (A.isStringAttribute()) { 68 StringRef Str = A.getValueAsString(); 69 Str.getAsInteger(0, ShaderType); 70 } 71 if (ShaderType == ShaderType::COMPUTE) 72 return false; 73 74 visit(F); 75 visit(F); 76 77 return false; 78} 79 80void SITypeRewriter::visitLoadInst(LoadInst &I) { 81 Value *Ptr = I.getPointerOperand(); 82 Type *PtrTy = Ptr->getType(); 83 Type *ElemTy = PtrTy->getPointerElementType(); 84 IRBuilder<> Builder(&I); 85 if (ElemTy == v16i8) { 86 Value *BitCast = Builder.CreateBitCast(Ptr, 87 PointerType::get(v4i32,PtrTy->getPointerAddressSpace())); 88 LoadInst *Load = Builder.CreateLoad(BitCast); 89 SmallVector<std::pair<unsigned, MDNode *>, 8> MD; 90 I.getAllMetadataOtherThanDebugLoc(MD); 91 for (unsigned i = 0, e = MD.size(); i != e; ++i) { 92 Load->setMetadata(MD[i].first, MD[i].second); 93 } 94 Value *BitCastLoad = Builder.CreateBitCast(Load, I.getType()); 95 I.replaceAllUsesWith(BitCastLoad); 96 I.eraseFromParent(); 97 } 98} 99 100void SITypeRewriter::visitCallInst(CallInst &I) { 101 IRBuilder<> Builder(&I); 102 103 SmallVector <Value*, 8> Args; 104 SmallVector <Type*, 8> Types; 105 bool NeedToReplace = false; 106 Function *F = I.getCalledFunction(); 107 std::string Name = F->getName(); 108 for (unsigned i = 0, e = I.getNumArgOperands(); i != e; ++i) { 109 Value *Arg = I.getArgOperand(i); 110 if (Arg->getType() == v16i8) { 111 Args.push_back(Builder.CreateBitCast(Arg, v4i32)); 112 Types.push_back(v4i32); 113 NeedToReplace = true; 114 Name = Name + ".v4i32"; 115 } else if (Arg->getType()->isVectorTy() && 116 Arg->getType()->getVectorNumElements() == 1 && 117 Arg->getType()->getVectorElementType() == 118 Type::getInt32Ty(I.getContext())){ 119 Type *ElementTy = Arg->getType()->getVectorElementType(); 120 std::string TypeName = "i32"; 121 InsertElementInst *Def = cast<InsertElementInst>(Arg); 122 Args.push_back(Def->getOperand(1)); 123 Types.push_back(ElementTy); 124 std::string VecTypeName = "v1" + TypeName; 125 Name = Name.replace(Name.find(VecTypeName), VecTypeName.length(), TypeName); 126 NeedToReplace = true; 127 } else { 128 Args.push_back(Arg); 129 Types.push_back(Arg->getType()); 130 } 131 } 132 133 if (!NeedToReplace) { 134 return; 135 } 136 Function *NewF = Mod->getFunction(Name); 137 if (!NewF) { 138 NewF = Function::Create(FunctionType::get(F->getReturnType(), Types, false), GlobalValue::ExternalLinkage, Name, Mod); 139 NewF->setAttributes(F->getAttributes()); 140 } 141 I.replaceAllUsesWith(Builder.CreateCall(NewF, Args)); 142 I.eraseFromParent(); 143} 144 145void SITypeRewriter::visitBitCast(BitCastInst &I) { 146 IRBuilder<> Builder(&I); 147 if (I.getDestTy() != v4i32) { 148 return; 149 } 150 151 if (BitCastInst *Op = dyn_cast<BitCastInst>(I.getOperand(0))) { 152 if (Op->getSrcTy() == v4i32) { 153 I.replaceAllUsesWith(Op->getOperand(0)); 154 I.eraseFromParent(); 155 } 156 } 157} 158 159FunctionPass *llvm::createSITypeRewriter() { 160 return new SITypeRewriter(); 161} 162