CostModel.cpp revision dce4a407a24b04eebc6a376f8e62b41aaa7b071f
1//===- CostModel.cpp ------ Cost Model Analysis ---------------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// This file defines the cost model analysis. It provides a very basic cost 11// estimation for LLVM-IR. This analysis uses the services of the codegen 12// to approximate the cost of any IR instruction when lowered to machine 13// instructions. The cost results are unit-less and the cost number represents 14// the throughput of the machine assuming that all loads hit the cache, all 15// branches are predicted, etc. The cost numbers can be added in order to 16// compare two or more transformation alternatives. 17// 18//===----------------------------------------------------------------------===// 19 20#include "llvm/ADT/STLExtras.h" 21#include "llvm/Analysis/Passes.h" 22#include "llvm/Analysis/TargetTransformInfo.h" 23#include "llvm/IR/Function.h" 24#include "llvm/IR/Instructions.h" 25#include "llvm/IR/IntrinsicInst.h" 26#include "llvm/IR/Value.h" 27#include "llvm/Pass.h" 28#include "llvm/Support/CommandLine.h" 29#include "llvm/Support/Debug.h" 30#include "llvm/Support/raw_ostream.h" 31using namespace llvm; 32 33#define CM_NAME "cost-model" 34#define DEBUG_TYPE CM_NAME 35 36static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false), 37 cl::Hidden, 38 cl::desc("Recognize reduction patterns.")); 39 40namespace { 41 class CostModelAnalysis : public FunctionPass { 42 43 public: 44 static char ID; // Class identification, replacement for typeinfo 45 CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) { 46 initializeCostModelAnalysisPass( 47 *PassRegistry::getPassRegistry()); 48 } 49 50 /// Returns the expected cost of the instruction. 51 /// Returns -1 if the cost is unknown. 52 /// Note, this method does not cache the cost calculation and it 53 /// can be expensive in some cases. 54 unsigned getInstructionCost(const Instruction *I) const; 55 56 private: 57 void getAnalysisUsage(AnalysisUsage &AU) const override; 58 bool runOnFunction(Function &F) override; 59 void print(raw_ostream &OS, const Module*) const override; 60 61 /// The function that we analyze. 62 Function *F; 63 /// Target information. 64 const TargetTransformInfo *TTI; 65 }; 66} // End of anonymous namespace 67 68// Register this pass. 69char CostModelAnalysis::ID = 0; 70static const char cm_name[] = "Cost Model Analysis"; 71INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true) 72INITIALIZE_PASS_END (CostModelAnalysis, CM_NAME, cm_name, false, true) 73 74FunctionPass *llvm::createCostModelAnalysisPass() { 75 return new CostModelAnalysis(); 76} 77 78void 79CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 80 AU.setPreservesAll(); 81} 82 83bool 84CostModelAnalysis::runOnFunction(Function &F) { 85 this->F = &F; 86 TTI = getAnalysisIfAvailable<TargetTransformInfo>(); 87 88 return false; 89} 90 91static bool isReverseVectorMask(SmallVectorImpl<int> &Mask) { 92 for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i) 93 if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i)) 94 return false; 95 return true; 96} 97 98static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) { 99 TargetTransformInfo::OperandValueKind OpInfo = 100 TargetTransformInfo::OK_AnyValue; 101 102 // Check for a splat of a constant or for a non uniform vector of constants. 103 if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) { 104 OpInfo = TargetTransformInfo::OK_NonUniformConstantValue; 105 if (cast<Constant>(V)->getSplatValue() != nullptr) 106 OpInfo = TargetTransformInfo::OK_UniformConstantValue; 107 } 108 109 return OpInfo; 110} 111 112static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, 113 unsigned Level) { 114 // We don't need a shuffle if we just want to have element 0 in position 0 of 115 // the vector. 116 if (!SI && Level == 0 && IsLeft) 117 return true; 118 else if (!SI) 119 return false; 120 121 SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1); 122 123 // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether 124 // we look at the left or right side. 125 for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2) 126 Mask[i] = val; 127 128 SmallVector<int, 16> ActualMask = SI->getShuffleMask(); 129 if (Mask != ActualMask) 130 return false; 131 132 return true; 133} 134 135static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp, 136 unsigned Level, unsigned NumLevels) { 137 // Match one level of pairwise operations. 138 // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, 139 // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> 140 // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, 141 // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> 142 // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 143 if (BinOp == nullptr) 144 return false; 145 146 assert(BinOp->getType()->isVectorTy() && "Expecting a vector type"); 147 148 unsigned Opcode = BinOp->getOpcode(); 149 Value *L = BinOp->getOperand(0); 150 Value *R = BinOp->getOperand(1); 151 152 ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L); 153 if (!LS && Level) 154 return false; 155 ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R); 156 if (!RS && Level) 157 return false; 158 159 // On level 0 we can omit one shufflevector instruction. 160 if (!Level && !RS && !LS) 161 return false; 162 163 // Shuffle inputs must match. 164 Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr; 165 Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr; 166 Value *NextLevelOp = nullptr; 167 if (NextLevelOpR && NextLevelOpL) { 168 // If we have two shuffles their operands must match. 169 if (NextLevelOpL != NextLevelOpR) 170 return false; 171 172 NextLevelOp = NextLevelOpL; 173 } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) { 174 // On the first level we can omit the shufflevector <0, undef,...>. So the 175 // input to the other shufflevector <1, undef> must match with one of the 176 // inputs to the current binary operation. 177 // Example: 178 // %NextLevelOpL = shufflevector %R, <1, undef ...> 179 // %BinOp = fadd %NextLevelOpL, %R 180 if (NextLevelOpL && NextLevelOpL != R) 181 return false; 182 else if (NextLevelOpR && NextLevelOpR != L) 183 return false; 184 185 NextLevelOp = NextLevelOpL ? R : L; 186 } else 187 return false; 188 189 // Check that the next levels binary operation exists and matches with the 190 // current one. 191 BinaryOperator *NextLevelBinOp = nullptr; 192 if (Level + 1 != NumLevels) { 193 if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp))) 194 return false; 195 else if (NextLevelBinOp->getOpcode() != Opcode) 196 return false; 197 } 198 199 // Shuffle mask for pairwise operation must match. 200 if (matchPairwiseShuffleMask(LS, true, Level)) { 201 if (!matchPairwiseShuffleMask(RS, false, Level)) 202 return false; 203 } else if (matchPairwiseShuffleMask(RS, true, Level)) { 204 if (!matchPairwiseShuffleMask(LS, false, Level)) 205 return false; 206 } else 207 return false; 208 209 if (++Level == NumLevels) 210 return true; 211 212 // Match next level. 213 return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels); 214} 215 216static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, 217 unsigned &Opcode, Type *&Ty) { 218 if (!EnableReduxCost) 219 return false; 220 221 // Need to extract the first element. 222 ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); 223 unsigned Idx = ~0u; 224 if (CI) 225 Idx = CI->getZExtValue(); 226 if (Idx != 0) 227 return false; 228 229 BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0)); 230 if (!RdxStart) 231 return false; 232 233 Type *VecTy = ReduxRoot->getOperand(0)->getType(); 234 unsigned NumVecElems = VecTy->getVectorNumElements(); 235 if (!isPowerOf2_32(NumVecElems)) 236 return false; 237 238 // We look for a sequence of shuffle,shuffle,add triples like the following 239 // that builds a pairwise reduction tree. 240 // 241 // (X0, X1, X2, X3) 242 // (X0 + X1, X2 + X3, undef, undef) 243 // ((X0 + X1) + (X2 + X3), undef, undef, undef) 244 // 245 // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, 246 // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> 247 // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, 248 // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> 249 // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 250 // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, 251 // <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef> 252 // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, 253 // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> 254 // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1 255 // %r = extractelement <4 x float> %bin.rdx8, i32 0 256 if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems))) 257 return false; 258 259 Opcode = RdxStart->getOpcode(); 260 Ty = VecTy; 261 262 return true; 263} 264 265static std::pair<Value *, ShuffleVectorInst *> 266getShuffleAndOtherOprd(BinaryOperator *B) { 267 268 Value *L = B->getOperand(0); 269 Value *R = B->getOperand(1); 270 ShuffleVectorInst *S = nullptr; 271 272 if ((S = dyn_cast<ShuffleVectorInst>(L))) 273 return std::make_pair(R, S); 274 275 S = dyn_cast<ShuffleVectorInst>(R); 276 return std::make_pair(L, S); 277} 278 279static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, 280 unsigned &Opcode, Type *&Ty) { 281 if (!EnableReduxCost) 282 return false; 283 284 // Need to extract the first element. 285 ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); 286 unsigned Idx = ~0u; 287 if (CI) 288 Idx = CI->getZExtValue(); 289 if (Idx != 0) 290 return false; 291 292 BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0)); 293 if (!RdxStart) 294 return false; 295 unsigned RdxOpcode = RdxStart->getOpcode(); 296 297 Type *VecTy = ReduxRoot->getOperand(0)->getType(); 298 unsigned NumVecElems = VecTy->getVectorNumElements(); 299 if (!isPowerOf2_32(NumVecElems)) 300 return false; 301 302 // We look for a sequence of shuffles and adds like the following matching one 303 // fadd, shuffle vector pair at a time. 304 // 305 // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef, 306 // <4 x i32> <i32 2, i32 3, i32 undef, i32 undef> 307 // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf 308 // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef, 309 // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> 310 // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7 311 // %r = extractelement <4 x float> %bin.rdx8, i32 0 312 313 unsigned MaskStart = 1; 314 Value *RdxOp = RdxStart; 315 SmallVector<int, 32> ShuffleMask(NumVecElems, 0); 316 unsigned NumVecElemsRemain = NumVecElems; 317 while (NumVecElemsRemain - 1) { 318 // Check for the right reduction operation. 319 BinaryOperator *BinOp; 320 if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp))) 321 return false; 322 if (BinOp->getOpcode() != RdxOpcode) 323 return false; 324 325 Value *NextRdxOp; 326 ShuffleVectorInst *Shuffle; 327 std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp); 328 329 // Check the current reduction operation and the shuffle use the same value. 330 if (Shuffle == nullptr) 331 return false; 332 if (Shuffle->getOperand(0) != NextRdxOp) 333 return false; 334 335 // Check that shuffle masks matches. 336 for (unsigned j = 0; j != MaskStart; ++j) 337 ShuffleMask[j] = MaskStart + j; 338 // Fill the rest of the mask with -1 for undef. 339 std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1); 340 341 SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); 342 if (ShuffleMask != Mask) 343 return false; 344 345 RdxOp = NextRdxOp; 346 NumVecElemsRemain /= 2; 347 MaskStart *= 2; 348 } 349 350 Opcode = RdxOpcode; 351 Ty = VecTy; 352 return true; 353} 354 355unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { 356 if (!TTI) 357 return -1; 358 359 switch (I->getOpcode()) { 360 case Instruction::GetElementPtr:{ 361 Type *ValTy = I->getOperand(0)->getType()->getPointerElementType(); 362 return TTI->getAddressComputationCost(ValTy); 363 } 364 365 case Instruction::Ret: 366 case Instruction::PHI: 367 case Instruction::Br: { 368 return TTI->getCFInstrCost(I->getOpcode()); 369 } 370 case Instruction::Add: 371 case Instruction::FAdd: 372 case Instruction::Sub: 373 case Instruction::FSub: 374 case Instruction::Mul: 375 case Instruction::FMul: 376 case Instruction::UDiv: 377 case Instruction::SDiv: 378 case Instruction::FDiv: 379 case Instruction::URem: 380 case Instruction::SRem: 381 case Instruction::FRem: 382 case Instruction::Shl: 383 case Instruction::LShr: 384 case Instruction::AShr: 385 case Instruction::And: 386 case Instruction::Or: 387 case Instruction::Xor: { 388 TargetTransformInfo::OperandValueKind Op1VK = 389 getOperandInfo(I->getOperand(0)); 390 TargetTransformInfo::OperandValueKind Op2VK = 391 getOperandInfo(I->getOperand(1)); 392 return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK, 393 Op2VK); 394 } 395 case Instruction::Select: { 396 const SelectInst *SI = cast<SelectInst>(I); 397 Type *CondTy = SI->getCondition()->getType(); 398 return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy); 399 } 400 case Instruction::ICmp: 401 case Instruction::FCmp: { 402 Type *ValTy = I->getOperand(0)->getType(); 403 return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy); 404 } 405 case Instruction::Store: { 406 const StoreInst *SI = cast<StoreInst>(I); 407 Type *ValTy = SI->getValueOperand()->getType(); 408 return TTI->getMemoryOpCost(I->getOpcode(), ValTy, 409 SI->getAlignment(), 410 SI->getPointerAddressSpace()); 411 } 412 case Instruction::Load: { 413 const LoadInst *LI = cast<LoadInst>(I); 414 return TTI->getMemoryOpCost(I->getOpcode(), I->getType(), 415 LI->getAlignment(), 416 LI->getPointerAddressSpace()); 417 } 418 case Instruction::ZExt: 419 case Instruction::SExt: 420 case Instruction::FPToUI: 421 case Instruction::FPToSI: 422 case Instruction::FPExt: 423 case Instruction::PtrToInt: 424 case Instruction::IntToPtr: 425 case Instruction::SIToFP: 426 case Instruction::UIToFP: 427 case Instruction::Trunc: 428 case Instruction::FPTrunc: 429 case Instruction::BitCast: 430 case Instruction::AddrSpaceCast: { 431 Type *SrcTy = I->getOperand(0)->getType(); 432 return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy); 433 } 434 case Instruction::ExtractElement: { 435 const ExtractElementInst * EEI = cast<ExtractElementInst>(I); 436 ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1)); 437 unsigned Idx = -1; 438 if (CI) 439 Idx = CI->getZExtValue(); 440 441 // Try to match a reduction sequence (series of shufflevector and vector 442 // adds followed by a extractelement). 443 unsigned ReduxOpCode; 444 Type *ReduxType; 445 446 if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) 447 return TTI->getReductionCost(ReduxOpCode, ReduxType, false); 448 else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) 449 return TTI->getReductionCost(ReduxOpCode, ReduxType, true); 450 451 return TTI->getVectorInstrCost(I->getOpcode(), 452 EEI->getOperand(0)->getType(), Idx); 453 } 454 case Instruction::InsertElement: { 455 const InsertElementInst * IE = cast<InsertElementInst>(I); 456 ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2)); 457 unsigned Idx = -1; 458 if (CI) 459 Idx = CI->getZExtValue(); 460 return TTI->getVectorInstrCost(I->getOpcode(), 461 IE->getType(), Idx); 462 } 463 case Instruction::ShuffleVector: { 464 const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); 465 Type *VecTypOp0 = Shuffle->getOperand(0)->getType(); 466 unsigned NumVecElems = VecTypOp0->getVectorNumElements(); 467 SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); 468 469 if (NumVecElems == Mask.size() && isReverseVectorMask(Mask)) 470 return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0, 0, 471 nullptr); 472 return -1; 473 } 474 case Instruction::Call: 475 if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { 476 SmallVector<Type*, 4> Tys; 477 for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J) 478 Tys.push_back(II->getArgOperand(J)->getType()); 479 480 return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(), 481 Tys); 482 } 483 return -1; 484 default: 485 // We don't have any information on this instruction. 486 return -1; 487 } 488} 489 490void CostModelAnalysis::print(raw_ostream &OS, const Module*) const { 491 if (!F) 492 return; 493 494 for (Function::iterator B = F->begin(), BE = F->end(); B != BE; ++B) { 495 for (BasicBlock::iterator it = B->begin(), e = B->end(); it != e; ++it) { 496 Instruction *Inst = it; 497 unsigned Cost = getInstructionCost(Inst); 498 if (Cost != (unsigned)-1) 499 OS << "Cost Model: Found an estimated cost of " << Cost; 500 else 501 OS << "Cost Model: Unknown cost"; 502 503 OS << " for instruction: "<< *Inst << "\n"; 504 } 505 } 506} 507