NVPTXISelLowering.cpp revision 4d748eb0e4b55262619305c96a89c55c30bffe6c
1// 2// The LLVM Compiler Infrastructure 3// 4// This file is distributed under the University of Illinois Open Source 5// License. See LICENSE.TXT for details. 6// 7//===----------------------------------------------------------------------===// 8// 9// This file defines the interfaces that NVPTX uses to lower LLVM code into a 10// selection DAG. 11// 12//===----------------------------------------------------------------------===// 13 14#include "NVPTXISelLowering.h" 15#include "NVPTX.h" 16#include "NVPTXTargetMachine.h" 17#include "NVPTXTargetObjectFile.h" 18#include "NVPTXUtilities.h" 19#include "llvm/CodeGen/Analysis.h" 20#include "llvm/CodeGen/MachineFrameInfo.h" 21#include "llvm/CodeGen/MachineFunction.h" 22#include "llvm/CodeGen/MachineInstrBuilder.h" 23#include "llvm/CodeGen/MachineRegisterInfo.h" 24#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" 25#include "llvm/IR/DerivedTypes.h" 26#include "llvm/IR/Function.h" 27#include "llvm/IR/GlobalValue.h" 28#include "llvm/IR/IntrinsicInst.h" 29#include "llvm/IR/Intrinsics.h" 30#include "llvm/IR/Module.h" 31#include "llvm/MC/MCSectionELF.h" 32#include "llvm/Support/CallSite.h" 33#include "llvm/Support/CommandLine.h" 34#include "llvm/Support/Debug.h" 35#include "llvm/Support/ErrorHandling.h" 36#include "llvm/Support/raw_ostream.h" 37#include <sstream> 38 39#undef DEBUG_TYPE 40#define DEBUG_TYPE "nvptx-lower" 41 42using namespace llvm; 43 44static unsigned int uniqueCallSite = 0; 45 46static cl::opt<bool> sched4reg( 47 "nvptx-sched4reg", 48 cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false)); 49 50static bool IsPTXVectorType(MVT VT) { 51 switch (VT.SimpleTy) { 52 default: 53 return false; 54 case MVT::v2i1: 55 case MVT::v4i1: 56 case MVT::v2i8: 57 case MVT::v4i8: 58 case MVT::v2i16: 59 case MVT::v4i16: 60 case MVT::v2i32: 61 case MVT::v4i32: 62 case MVT::v2i64: 63 case MVT::v2f32: 64 case MVT::v4f32: 65 case MVT::v2f64: 66 return true; 67 } 68} 69 70/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive 71/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors 72/// into their primitive components. 73/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the 74/// same number of types as the Ins/Outs arrays in LowerFormalArguments, 75/// LowerCall, and LowerReturn. 76static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty, 77 SmallVectorImpl<EVT> &ValueVTs, 78 SmallVectorImpl<uint64_t> *Offsets = 0, 79 uint64_t StartingOffset = 0) { 80 SmallVector<EVT, 16> TempVTs; 81 SmallVector<uint64_t, 16> TempOffsets; 82 83 ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset); 84 for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) { 85 EVT VT = TempVTs[i]; 86 uint64_t Off = TempOffsets[i]; 87 if (VT.isVector()) 88 for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) { 89 ValueVTs.push_back(VT.getVectorElementType()); 90 if (Offsets) 91 Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize()); 92 } 93 else { 94 ValueVTs.push_back(VT); 95 if (Offsets) 96 Offsets->push_back(Off); 97 } 98 } 99} 100 101// NVPTXTargetLowering Constructor. 102NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM) 103 : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM), 104 nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) { 105 106 // always lower memset, memcpy, and memmove intrinsics to load/store 107 // instructions, rather 108 // then generating calls to memset, mempcy or memmove. 109 MaxStoresPerMemset = (unsigned) 0xFFFFFFFF; 110 MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF; 111 MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF; 112 113 setBooleanContents(ZeroOrNegativeOneBooleanContent); 114 115 // Jump is Expensive. Don't create extra control flow for 'and', 'or' 116 // condition branches. 117 setJumpIsExpensive(true); 118 119 // By default, use the Source scheduling 120 if (sched4reg) 121 setSchedulingPreference(Sched::RegPressure); 122 else 123 setSchedulingPreference(Sched::Source); 124 125 addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass); 126 addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass); 127 addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass); 128 addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass); 129 addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass); 130 addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass); 131 132 // Operations not directly supported by NVPTX. 133 setOperationAction(ISD::SELECT_CC, MVT::Other, Expand); 134 setOperationAction(ISD::BR_CC, MVT::f32, Expand); 135 setOperationAction(ISD::BR_CC, MVT::f64, Expand); 136 setOperationAction(ISD::BR_CC, MVT::i1, Expand); 137 setOperationAction(ISD::BR_CC, MVT::i8, Expand); 138 setOperationAction(ISD::BR_CC, MVT::i16, Expand); 139 setOperationAction(ISD::BR_CC, MVT::i32, Expand); 140 setOperationAction(ISD::BR_CC, MVT::i64, Expand); 141 // Some SIGN_EXTEND_INREG can be done using cvt instruction. 142 // For others we will expand to a SHL/SRA pair. 143 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal); 144 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal); 145 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal); 146 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal); 147 setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand); 148 149 if (nvptxSubtarget.hasROT64()) { 150 setOperationAction(ISD::ROTL, MVT::i64, Legal); 151 setOperationAction(ISD::ROTR, MVT::i64, Legal); 152 } else { 153 setOperationAction(ISD::ROTL, MVT::i64, Expand); 154 setOperationAction(ISD::ROTR, MVT::i64, Expand); 155 } 156 if (nvptxSubtarget.hasROT32()) { 157 setOperationAction(ISD::ROTL, MVT::i32, Legal); 158 setOperationAction(ISD::ROTR, MVT::i32, Legal); 159 } else { 160 setOperationAction(ISD::ROTL, MVT::i32, Expand); 161 setOperationAction(ISD::ROTR, MVT::i32, Expand); 162 } 163 164 setOperationAction(ISD::ROTL, MVT::i16, Expand); 165 setOperationAction(ISD::ROTR, MVT::i16, Expand); 166 setOperationAction(ISD::ROTL, MVT::i8, Expand); 167 setOperationAction(ISD::ROTR, MVT::i8, Expand); 168 setOperationAction(ISD::BSWAP, MVT::i16, Expand); 169 setOperationAction(ISD::BSWAP, MVT::i32, Expand); 170 setOperationAction(ISD::BSWAP, MVT::i64, Expand); 171 172 // Indirect branch is not supported. 173 // This also disables Jump Table creation. 174 setOperationAction(ISD::BR_JT, MVT::Other, Expand); 175 setOperationAction(ISD::BRIND, MVT::Other, Expand); 176 177 setOperationAction(ISD::GlobalAddress, MVT::i32, Custom); 178 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom); 179 180 // We want to legalize constant related memmove and memcopy 181 // intrinsics. 182 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom); 183 184 // Turn FP extload into load/fextend 185 setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand); 186 // Turn FP truncstore into trunc + store. 187 setTruncStoreAction(MVT::f64, MVT::f32, Expand); 188 189 // PTX does not support load / store predicate registers 190 setOperationAction(ISD::LOAD, MVT::i1, Custom); 191 setOperationAction(ISD::STORE, MVT::i1, Custom); 192 193 setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote); 194 setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote); 195 setTruncStoreAction(MVT::i64, MVT::i1, Expand); 196 setTruncStoreAction(MVT::i32, MVT::i1, Expand); 197 setTruncStoreAction(MVT::i16, MVT::i1, Expand); 198 setTruncStoreAction(MVT::i8, MVT::i1, Expand); 199 200 // This is legal in NVPTX 201 setOperationAction(ISD::ConstantFP, MVT::f64, Legal); 202 setOperationAction(ISD::ConstantFP, MVT::f32, Legal); 203 204 // TRAP can be lowered to PTX trap 205 setOperationAction(ISD::TRAP, MVT::Other, Legal); 206 207 setOperationAction(ISD::ADDC, MVT::i64, Expand); 208 setOperationAction(ISD::ADDE, MVT::i64, Expand); 209 210 // Register custom handling for vector loads/stores 211 for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE; 212 ++i) { 213 MVT VT = (MVT::SimpleValueType) i; 214 if (IsPTXVectorType(VT)) { 215 setOperationAction(ISD::LOAD, VT, Custom); 216 setOperationAction(ISD::STORE, VT, Custom); 217 setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom); 218 } 219 } 220 221 // Custom handling for i8 intrinsics 222 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom); 223 224 setOperationAction(ISD::CTLZ, MVT::i16, Legal); 225 setOperationAction(ISD::CTLZ, MVT::i32, Legal); 226 setOperationAction(ISD::CTLZ, MVT::i64, Legal); 227 setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i16, Legal); 228 setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Legal); 229 setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i64, Legal); 230 setOperationAction(ISD::CTTZ, MVT::i16, Expand); 231 setOperationAction(ISD::CTTZ, MVT::i32, Expand); 232 setOperationAction(ISD::CTTZ, MVT::i64, Expand); 233 setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i16, Expand); 234 setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Expand); 235 setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Expand); 236 setOperationAction(ISD::CTPOP, MVT::i16, Legal); 237 setOperationAction(ISD::CTPOP, MVT::i32, Legal); 238 setOperationAction(ISD::CTPOP, MVT::i64, Legal); 239 240 // Now deduce the information based on the above mentioned 241 // actions 242 computeRegisterProperties(); 243} 244 245const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const { 246 switch (Opcode) { 247 default: 248 return 0; 249 case NVPTXISD::CALL: 250 return "NVPTXISD::CALL"; 251 case NVPTXISD::RET_FLAG: 252 return "NVPTXISD::RET_FLAG"; 253 case NVPTXISD::Wrapper: 254 return "NVPTXISD::Wrapper"; 255 case NVPTXISD::DeclareParam: 256 return "NVPTXISD::DeclareParam"; 257 case NVPTXISD::DeclareScalarParam: 258 return "NVPTXISD::DeclareScalarParam"; 259 case NVPTXISD::DeclareRet: 260 return "NVPTXISD::DeclareRet"; 261 case NVPTXISD::DeclareRetParam: 262 return "NVPTXISD::DeclareRetParam"; 263 case NVPTXISD::PrintCall: 264 return "NVPTXISD::PrintCall"; 265 case NVPTXISD::LoadParam: 266 return "NVPTXISD::LoadParam"; 267 case NVPTXISD::LoadParamV2: 268 return "NVPTXISD::LoadParamV2"; 269 case NVPTXISD::LoadParamV4: 270 return "NVPTXISD::LoadParamV4"; 271 case NVPTXISD::StoreParam: 272 return "NVPTXISD::StoreParam"; 273 case NVPTXISD::StoreParamV2: 274 return "NVPTXISD::StoreParamV2"; 275 case NVPTXISD::StoreParamV4: 276 return "NVPTXISD::StoreParamV4"; 277 case NVPTXISD::StoreParamS32: 278 return "NVPTXISD::StoreParamS32"; 279 case NVPTXISD::StoreParamU32: 280 return "NVPTXISD::StoreParamU32"; 281 case NVPTXISD::CallArgBegin: 282 return "NVPTXISD::CallArgBegin"; 283 case NVPTXISD::CallArg: 284 return "NVPTXISD::CallArg"; 285 case NVPTXISD::LastCallArg: 286 return "NVPTXISD::LastCallArg"; 287 case NVPTXISD::CallArgEnd: 288 return "NVPTXISD::CallArgEnd"; 289 case NVPTXISD::CallVoid: 290 return "NVPTXISD::CallVoid"; 291 case NVPTXISD::CallVal: 292 return "NVPTXISD::CallVal"; 293 case NVPTXISD::CallSymbol: 294 return "NVPTXISD::CallSymbol"; 295 case NVPTXISD::Prototype: 296 return "NVPTXISD::Prototype"; 297 case NVPTXISD::MoveParam: 298 return "NVPTXISD::MoveParam"; 299 case NVPTXISD::StoreRetval: 300 return "NVPTXISD::StoreRetval"; 301 case NVPTXISD::StoreRetvalV2: 302 return "NVPTXISD::StoreRetvalV2"; 303 case NVPTXISD::StoreRetvalV4: 304 return "NVPTXISD::StoreRetvalV4"; 305 case NVPTXISD::PseudoUseParam: 306 return "NVPTXISD::PseudoUseParam"; 307 case NVPTXISD::RETURN: 308 return "NVPTXISD::RETURN"; 309 case NVPTXISD::CallSeqBegin: 310 return "NVPTXISD::CallSeqBegin"; 311 case NVPTXISD::CallSeqEnd: 312 return "NVPTXISD::CallSeqEnd"; 313 case NVPTXISD::CallPrototype: 314 return "NVPTXISD::CallPrototype"; 315 case NVPTXISD::LoadV2: 316 return "NVPTXISD::LoadV2"; 317 case NVPTXISD::LoadV4: 318 return "NVPTXISD::LoadV4"; 319 case NVPTXISD::LDGV2: 320 return "NVPTXISD::LDGV2"; 321 case NVPTXISD::LDGV4: 322 return "NVPTXISD::LDGV4"; 323 case NVPTXISD::LDUV2: 324 return "NVPTXISD::LDUV2"; 325 case NVPTXISD::LDUV4: 326 return "NVPTXISD::LDUV4"; 327 case NVPTXISD::StoreV2: 328 return "NVPTXISD::StoreV2"; 329 case NVPTXISD::StoreV4: 330 return "NVPTXISD::StoreV4"; 331 } 332} 333 334bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const { 335 return VT == MVT::i1; 336} 337 338SDValue 339NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const { 340 SDLoc dl(Op); 341 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal(); 342 Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy()); 343 return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op); 344} 345 346std::string 347NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args, 348 const SmallVectorImpl<ISD::OutputArg> &Outs, 349 unsigned retAlignment, 350 const ImmutableCallSite *CS) const { 351 352 bool isABI = (nvptxSubtarget.getSmVersion() >= 20); 353 assert(isABI && "Non-ABI compilation is not supported"); 354 if (!isABI) 355 return ""; 356 357 std::stringstream O; 358 O << "prototype_" << uniqueCallSite << " : .callprototype "; 359 360 if (retTy->getTypeID() == Type::VoidTyID) { 361 O << "()"; 362 } else { 363 O << "("; 364 if (retTy->isPrimitiveType() || retTy->isIntegerTy()) { 365 unsigned size = 0; 366 if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) { 367 size = ITy->getBitWidth(); 368 if (size < 32) 369 size = 32; 370 } else { 371 assert(retTy->isFloatingPointTy() && 372 "Floating point type expected here"); 373 size = retTy->getPrimitiveSizeInBits(); 374 } 375 376 O << ".param .b" << size << " _"; 377 } else if (isa<PointerType>(retTy)) { 378 O << ".param .b" << getPointerTy().getSizeInBits() << " _"; 379 } else { 380 if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) { 381 SmallVector<EVT, 16> vtparts; 382 ComputeValueVTs(*this, retTy, vtparts); 383 unsigned totalsz = 0; 384 for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { 385 unsigned elems = 1; 386 EVT elemtype = vtparts[i]; 387 if (vtparts[i].isVector()) { 388 elems = vtparts[i].getVectorNumElements(); 389 elemtype = vtparts[i].getVectorElementType(); 390 } 391 // TODO: no need to loop 392 for (unsigned j = 0, je = elems; j != je; ++j) { 393 unsigned sz = elemtype.getSizeInBits(); 394 if (elemtype.isInteger() && (sz < 8)) 395 sz = 8; 396 totalsz += sz / 8; 397 } 398 } 399 O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]"; 400 } else { 401 assert(false && "Unknown return type"); 402 } 403 } 404 O << ") "; 405 } 406 O << "_ ("; 407 408 bool first = true; 409 MVT thePointerTy = getPointerTy(); 410 411 unsigned OIdx = 0; 412 for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) { 413 Type *Ty = Args[i].Ty; 414 if (!first) { 415 O << ", "; 416 } 417 first = false; 418 419 if (Outs[OIdx].Flags.isByVal() == false) { 420 if (Ty->isAggregateType() || Ty->isVectorTy()) { 421 unsigned align = 0; 422 const CallInst *CallI = cast<CallInst>(CS->getInstruction()); 423 const DataLayout *TD = getDataLayout(); 424 // +1 because index 0 is reserved for return type alignment 425 if (!llvm::getAlign(*CallI, i + 1, align)) 426 align = TD->getABITypeAlignment(Ty); 427 unsigned sz = TD->getTypeAllocSize(Ty); 428 O << ".param .align " << align << " .b8 "; 429 O << "_"; 430 O << "[" << sz << "]"; 431 // update the index for Outs 432 SmallVector<EVT, 16> vtparts; 433 ComputeValueVTs(*this, Ty, vtparts); 434 if (unsigned len = vtparts.size()) 435 OIdx += len - 1; 436 continue; 437 } 438 // i8 types in IR will be i16 types in SDAG 439 assert((getValueType(Ty) == Outs[OIdx].VT || 440 (getValueType(Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) && 441 "type mismatch between callee prototype and arguments"); 442 // scalar type 443 unsigned sz = 0; 444 if (isa<IntegerType>(Ty)) { 445 sz = cast<IntegerType>(Ty)->getBitWidth(); 446 if (sz < 32) 447 sz = 32; 448 } else if (isa<PointerType>(Ty)) 449 sz = thePointerTy.getSizeInBits(); 450 else 451 sz = Ty->getPrimitiveSizeInBits(); 452 O << ".param .b" << sz << " "; 453 O << "_"; 454 continue; 455 } 456 const PointerType *PTy = dyn_cast<PointerType>(Ty); 457 assert(PTy && "Param with byval attribute should be a pointer type"); 458 Type *ETy = PTy->getElementType(); 459 460 unsigned align = Outs[OIdx].Flags.getByValAlign(); 461 unsigned sz = getDataLayout()->getTypeAllocSize(ETy); 462 O << ".param .align " << align << " .b8 "; 463 O << "_"; 464 O << "[" << sz << "]"; 465 } 466 O << ");"; 467 return O.str(); 468} 469 470unsigned 471NVPTXTargetLowering::getArgumentAlignment(SDValue Callee, 472 const ImmutableCallSite *CS, 473 Type *Ty, 474 unsigned Idx) const { 475 const DataLayout *TD = getDataLayout(); 476 unsigned Align = 0; 477 const Value *DirectCallee = CS->getCalledFunction(); 478 479 if (!DirectCallee) { 480 // We don't have a direct function symbol, but that may be because of 481 // constant cast instructions in the call. 482 const Instruction *CalleeI = CS->getInstruction(); 483 assert(CalleeI && "Call target is not a function or derived value?"); 484 485 // With bitcast'd call targets, the instruction will be the call 486 if (isa<CallInst>(CalleeI)) { 487 // Check if we have call alignment metadata 488 if (llvm::getAlign(*cast<CallInst>(CalleeI), Idx, Align)) 489 return Align; 490 491 const Value *CalleeV = cast<CallInst>(CalleeI)->getCalledValue(); 492 // Ignore any bitcast instructions 493 while(isa<ConstantExpr>(CalleeV)) { 494 const ConstantExpr *CE = cast<ConstantExpr>(CalleeV); 495 if (!CE->isCast()) 496 break; 497 // Look through the bitcast 498 CalleeV = cast<ConstantExpr>(CalleeV)->getOperand(0); 499 } 500 501 // We have now looked past all of the bitcasts. Do we finally have a 502 // Function? 503 if (isa<Function>(CalleeV)) 504 DirectCallee = CalleeV; 505 } 506 } 507 508 // Check for function alignment information if we found that the 509 // ultimate target is a Function 510 if (DirectCallee) 511 if (llvm::getAlign(*cast<Function>(DirectCallee), Idx, Align)) 512 return Align; 513 514 // Call is indirect or alignment information is not available, fall back to 515 // the ABI type alignment 516 return TD->getABITypeAlignment(Ty); 517} 518 519SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI, 520 SmallVectorImpl<SDValue> &InVals) const { 521 SelectionDAG &DAG = CLI.DAG; 522 SDLoc dl = CLI.DL; 523 SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs; 524 SmallVectorImpl<SDValue> &OutVals = CLI.OutVals; 525 SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins; 526 SDValue Chain = CLI.Chain; 527 SDValue Callee = CLI.Callee; 528 bool &isTailCall = CLI.IsTailCall; 529 ArgListTy &Args = CLI.Args; 530 Type *retTy = CLI.RetTy; 531 ImmutableCallSite *CS = CLI.CS; 532 533 bool isABI = (nvptxSubtarget.getSmVersion() >= 20); 534 assert(isABI && "Non-ABI compilation is not supported"); 535 if (!isABI) 536 return Chain; 537 const DataLayout *TD = getDataLayout(); 538 MachineFunction &MF = DAG.getMachineFunction(); 539 const Function *F = MF.getFunction(); 540 541 SDValue tempChain = Chain; 542 Chain = 543 DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true), 544 dl); 545 SDValue InFlag = Chain.getValue(1); 546 547 unsigned paramCount = 0; 548 // Args.size() and Outs.size() need not match. 549 // Outs.size() will be larger 550 // * if there is an aggregate argument with multiple fields (each field 551 // showing up separately in Outs) 552 // * if there is a vector argument with more than typical vector-length 553 // elements (generally if more than 4) where each vector element is 554 // individually present in Outs. 555 // So a different index should be used for indexing into Outs/OutVals. 556 // See similar issue in LowerFormalArguments. 557 unsigned OIdx = 0; 558 // Declare the .params or .reg need to pass values 559 // to the function 560 for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) { 561 EVT VT = Outs[OIdx].VT; 562 Type *Ty = Args[i].Ty; 563 564 if (Outs[OIdx].Flags.isByVal() == false) { 565 if (Ty->isAggregateType()) { 566 // aggregate 567 SmallVector<EVT, 16> vtparts; 568 ComputeValueVTs(*this, Ty, vtparts); 569 570 unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1); 571 // declare .param .align <align> .b8 .param<n>[<size>]; 572 unsigned sz = TD->getTypeAllocSize(Ty); 573 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 574 SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32), 575 DAG.getConstant(paramCount, MVT::i32), 576 DAG.getConstant(sz, MVT::i32), InFlag }; 577 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, 578 DeclareParamOps, 5); 579 InFlag = Chain.getValue(1); 580 unsigned curOffset = 0; 581 for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { 582 unsigned elems = 1; 583 EVT elemtype = vtparts[j]; 584 if (vtparts[j].isVector()) { 585 elems = vtparts[j].getVectorNumElements(); 586 elemtype = vtparts[j].getVectorElementType(); 587 } 588 for (unsigned k = 0, ke = elems; k != ke; ++k) { 589 unsigned sz = elemtype.getSizeInBits(); 590 if (elemtype.isInteger() && (sz < 8)) 591 sz = 8; 592 SDValue StVal = OutVals[OIdx]; 593 if (elemtype.getSizeInBits() < 16) { 594 StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal); 595 } 596 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 597 SDValue CopyParamOps[] = { Chain, 598 DAG.getConstant(paramCount, MVT::i32), 599 DAG.getConstant(curOffset, MVT::i32), 600 StVal, InFlag }; 601 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, 602 CopyParamVTs, &CopyParamOps[0], 5, 603 elemtype, MachinePointerInfo()); 604 InFlag = Chain.getValue(1); 605 curOffset += sz / 8; 606 ++OIdx; 607 } 608 } 609 if (vtparts.size() > 0) 610 --OIdx; 611 ++paramCount; 612 continue; 613 } 614 if (Ty->isVectorTy()) { 615 EVT ObjectVT = getValueType(Ty); 616 unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1); 617 // declare .param .align <align> .b8 .param<n>[<size>]; 618 unsigned sz = TD->getTypeAllocSize(Ty); 619 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 620 SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32), 621 DAG.getConstant(paramCount, MVT::i32), 622 DAG.getConstant(sz, MVT::i32), InFlag }; 623 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, 624 DeclareParamOps, 5); 625 InFlag = Chain.getValue(1); 626 unsigned NumElts = ObjectVT.getVectorNumElements(); 627 EVT EltVT = ObjectVT.getVectorElementType(); 628 EVT MemVT = EltVT; 629 bool NeedExtend = false; 630 if (EltVT.getSizeInBits() < 16) { 631 NeedExtend = true; 632 EltVT = MVT::i16; 633 } 634 635 // V1 store 636 if (NumElts == 1) { 637 SDValue Elt = OutVals[OIdx++]; 638 if (NeedExtend) 639 Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt); 640 641 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 642 SDValue CopyParamOps[] = { Chain, 643 DAG.getConstant(paramCount, MVT::i32), 644 DAG.getConstant(0, MVT::i32), Elt, 645 InFlag }; 646 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, 647 CopyParamVTs, &CopyParamOps[0], 5, 648 MemVT, MachinePointerInfo()); 649 InFlag = Chain.getValue(1); 650 } else if (NumElts == 2) { 651 SDValue Elt0 = OutVals[OIdx++]; 652 SDValue Elt1 = OutVals[OIdx++]; 653 if (NeedExtend) { 654 Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0); 655 Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1); 656 } 657 658 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 659 SDValue CopyParamOps[] = { Chain, 660 DAG.getConstant(paramCount, MVT::i32), 661 DAG.getConstant(0, MVT::i32), Elt0, Elt1, 662 InFlag }; 663 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl, 664 CopyParamVTs, &CopyParamOps[0], 6, 665 MemVT, MachinePointerInfo()); 666 InFlag = Chain.getValue(1); 667 } else { 668 unsigned curOffset = 0; 669 // V4 stores 670 // We have at least 4 elements (<3 x Ty> expands to 4 elements) and 671 // the 672 // vector will be expanded to a power of 2 elements, so we know we can 673 // always round up to the next multiple of 4 when creating the vector 674 // stores. 675 // e.g. 4 elem => 1 st.v4 676 // 6 elem => 2 st.v4 677 // 8 elem => 2 st.v4 678 // 11 elem => 3 st.v4 679 unsigned VecSize = 4; 680 if (EltVT.getSizeInBits() == 64) 681 VecSize = 2; 682 683 // This is potentially only part of a vector, so assume all elements 684 // are packed together. 685 unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize; 686 687 for (unsigned i = 0; i < NumElts; i += VecSize) { 688 // Get values 689 SDValue StoreVal; 690 SmallVector<SDValue, 8> Ops; 691 Ops.push_back(Chain); 692 Ops.push_back(DAG.getConstant(paramCount, MVT::i32)); 693 Ops.push_back(DAG.getConstant(curOffset, MVT::i32)); 694 695 unsigned Opc = NVPTXISD::StoreParamV2; 696 697 StoreVal = OutVals[OIdx++]; 698 if (NeedExtend) 699 StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); 700 Ops.push_back(StoreVal); 701 702 if (i + 1 < NumElts) { 703 StoreVal = OutVals[OIdx++]; 704 if (NeedExtend) 705 StoreVal = 706 DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); 707 } else { 708 StoreVal = DAG.getUNDEF(EltVT); 709 } 710 Ops.push_back(StoreVal); 711 712 if (VecSize == 4) { 713 Opc = NVPTXISD::StoreParamV4; 714 if (i + 2 < NumElts) { 715 StoreVal = OutVals[OIdx++]; 716 if (NeedExtend) 717 StoreVal = 718 DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); 719 } else { 720 StoreVal = DAG.getUNDEF(EltVT); 721 } 722 Ops.push_back(StoreVal); 723 724 if (i + 3 < NumElts) { 725 StoreVal = OutVals[OIdx++]; 726 if (NeedExtend) 727 StoreVal = 728 DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); 729 } else { 730 StoreVal = DAG.getUNDEF(EltVT); 731 } 732 Ops.push_back(StoreVal); 733 } 734 735 Ops.push_back(InFlag); 736 737 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 738 Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0], 739 Ops.size(), MemVT, 740 MachinePointerInfo()); 741 InFlag = Chain.getValue(1); 742 curOffset += PerStoreOffset; 743 } 744 } 745 ++paramCount; 746 --OIdx; 747 continue; 748 } 749 // Plain scalar 750 // for ABI, declare .param .b<size> .param<n>; 751 unsigned sz = VT.getSizeInBits(); 752 bool needExtend = false; 753 if (VT.isInteger()) { 754 if (sz < 16) 755 needExtend = true; 756 if (sz < 32) 757 sz = 32; 758 } 759 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 760 SDValue DeclareParamOps[] = { Chain, 761 DAG.getConstant(paramCount, MVT::i32), 762 DAG.getConstant(sz, MVT::i32), 763 DAG.getConstant(0, MVT::i32), InFlag }; 764 Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs, 765 DeclareParamOps, 5); 766 InFlag = Chain.getValue(1); 767 SDValue OutV = OutVals[OIdx]; 768 if (needExtend) { 769 // zext/sext i1 to i16 770 unsigned opc = ISD::ZERO_EXTEND; 771 if (Outs[OIdx].Flags.isSExt()) 772 opc = ISD::SIGN_EXTEND; 773 OutV = DAG.getNode(opc, dl, MVT::i16, OutV); 774 } 775 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 776 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32), 777 DAG.getConstant(0, MVT::i32), OutV, InFlag }; 778 779 unsigned opcode = NVPTXISD::StoreParam; 780 if (Outs[OIdx].Flags.isZExt()) 781 opcode = NVPTXISD::StoreParamU32; 782 else if (Outs[OIdx].Flags.isSExt()) 783 opcode = NVPTXISD::StoreParamS32; 784 Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5, 785 VT, MachinePointerInfo()); 786 787 InFlag = Chain.getValue(1); 788 ++paramCount; 789 continue; 790 } 791 // struct or vector 792 SmallVector<EVT, 16> vtparts; 793 const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty); 794 assert(PTy && "Type of a byval parameter should be pointer"); 795 ComputeValueVTs(*this, PTy->getElementType(), vtparts); 796 797 // declare .param .align <align> .b8 .param<n>[<size>]; 798 unsigned sz = Outs[OIdx].Flags.getByValSize(); 799 SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 800 // The ByValAlign in the Outs[OIdx].Flags is alway set at this point, 801 // so we don't need to worry about natural alignment or not. 802 // See TargetLowering::LowerCallTo(). 803 SDValue DeclareParamOps[] = { 804 Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32), 805 DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32), 806 InFlag 807 }; 808 Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, 809 DeclareParamOps, 5); 810 InFlag = Chain.getValue(1); 811 unsigned curOffset = 0; 812 for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { 813 unsigned elems = 1; 814 EVT elemtype = vtparts[j]; 815 if (vtparts[j].isVector()) { 816 elems = vtparts[j].getVectorNumElements(); 817 elemtype = vtparts[j].getVectorElementType(); 818 } 819 for (unsigned k = 0, ke = elems; k != ke; ++k) { 820 unsigned sz = elemtype.getSizeInBits(); 821 if (elemtype.isInteger() && (sz < 8)) 822 sz = 8; 823 SDValue srcAddr = 824 DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx], 825 DAG.getConstant(curOffset, getPointerTy())); 826 SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr, 827 MachinePointerInfo(), false, false, false, 828 0); 829 if (elemtype.getSizeInBits() < 16) { 830 theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal); 831 } 832 SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); 833 SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32), 834 DAG.getConstant(curOffset, MVT::i32), theVal, 835 InFlag }; 836 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs, 837 CopyParamOps, 5, elemtype, 838 MachinePointerInfo()); 839 840 InFlag = Chain.getValue(1); 841 curOffset += sz / 8; 842 } 843 } 844 ++paramCount; 845 } 846 847 GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode()); 848 unsigned retAlignment = 0; 849 850 // Handle Result 851 if (Ins.size() > 0) { 852 SmallVector<EVT, 16> resvtparts; 853 ComputeValueVTs(*this, retTy, resvtparts); 854 855 // Declare 856 // .param .align 16 .b8 retval0[<size-in-bytes>], or 857 // .param .b<size-in-bits> retval0 858 unsigned resultsz = TD->getTypeAllocSizeInBits(retTy); 859 if (retTy->isPrimitiveType() || retTy->isIntegerTy() || 860 retTy->isPointerTy()) { 861 // Scalar needs to be at least 32bit wide 862 if (resultsz < 32) 863 resultsz = 32; 864 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); 865 SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32), 866 DAG.getConstant(resultsz, MVT::i32), 867 DAG.getConstant(0, MVT::i32), InFlag }; 868 Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs, 869 DeclareRetOps, 5); 870 InFlag = Chain.getValue(1); 871 } else { 872 retAlignment = getArgumentAlignment(Callee, CS, retTy, 0); 873 SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue); 874 SDValue DeclareRetOps[] = { Chain, 875 DAG.getConstant(retAlignment, MVT::i32), 876 DAG.getConstant(resultsz / 8, MVT::i32), 877 DAG.getConstant(0, MVT::i32), InFlag }; 878 Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs, 879 DeclareRetOps, 5); 880 InFlag = Chain.getValue(1); 881 } 882 } 883 884 if (!Func) { 885 // This is indirect function call case : PTX requires a prototype of the 886 // form 887 // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _); 888 // to be emitted, and the label has to used as the last arg of call 889 // instruction. 890 // The prototype is embedded in a string and put as the operand for a 891 // CallPrototype SDNode which will print out to the value of the string. 892 SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue); 893 std::string Proto = getPrototype(retTy, Args, Outs, retAlignment, CS); 894 const char *ProtoStr = 895 nvTM->getManagedStrPool()->getManagedString(Proto.c_str())->c_str(); 896 SDValue ProtoOps[] = { 897 Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InFlag, 898 }; 899 Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, &ProtoOps[0], 3); 900 InFlag = Chain.getValue(1); 901 } 902 // Op to just print "call" 903 SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue); 904 SDValue PrintCallOps[] = { 905 Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag 906 }; 907 Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall), 908 dl, PrintCallVTs, PrintCallOps, 3); 909 InFlag = Chain.getValue(1); 910 911 // Ops to print out the function name 912 SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue); 913 SDValue CallVoidOps[] = { Chain, Callee, InFlag }; 914 Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3); 915 InFlag = Chain.getValue(1); 916 917 // Ops to print out the param list 918 SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue); 919 SDValue CallArgBeginOps[] = { Chain, InFlag }; 920 Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs, 921 CallArgBeginOps, 2); 922 InFlag = Chain.getValue(1); 923 924 for (unsigned i = 0, e = paramCount; i != e; ++i) { 925 unsigned opcode; 926 if (i == (e - 1)) 927 opcode = NVPTXISD::LastCallArg; 928 else 929 opcode = NVPTXISD::CallArg; 930 SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue); 931 SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32), 932 DAG.getConstant(i, MVT::i32), InFlag }; 933 Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4); 934 InFlag = Chain.getValue(1); 935 } 936 SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue); 937 SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32), 938 InFlag }; 939 Chain = 940 DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3); 941 InFlag = Chain.getValue(1); 942 943 if (!Func) { 944 SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue); 945 SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32), 946 InFlag }; 947 Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3); 948 InFlag = Chain.getValue(1); 949 } 950 951 // Generate loads from param memory/moves from registers for result 952 if (Ins.size() > 0) { 953 unsigned resoffset = 0; 954 if (retTy && retTy->isVectorTy()) { 955 EVT ObjectVT = getValueType(retTy); 956 unsigned NumElts = ObjectVT.getVectorNumElements(); 957 EVT EltVT = ObjectVT.getVectorElementType(); 958 assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(), 959 ObjectVT) == NumElts && 960 "Vector was not scalarized"); 961 unsigned sz = EltVT.getSizeInBits(); 962 bool needTruncate = sz < 16 ? true : false; 963 964 if (NumElts == 1) { 965 // Just a simple load 966 std::vector<EVT> LoadRetVTs; 967 if (needTruncate) { 968 // If loading i1 result, generate 969 // load i16 970 // trunc i16 to i1 971 LoadRetVTs.push_back(MVT::i16); 972 } else 973 LoadRetVTs.push_back(EltVT); 974 LoadRetVTs.push_back(MVT::Other); 975 LoadRetVTs.push_back(MVT::Glue); 976 std::vector<SDValue> LoadRetOps; 977 LoadRetOps.push_back(Chain); 978 LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); 979 LoadRetOps.push_back(DAG.getConstant(0, MVT::i32)); 980 LoadRetOps.push_back(InFlag); 981 SDValue retval = DAG.getMemIntrinsicNode( 982 NVPTXISD::LoadParam, dl, 983 DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0], 984 LoadRetOps.size(), EltVT, MachinePointerInfo()); 985 Chain = retval.getValue(1); 986 InFlag = retval.getValue(2); 987 SDValue Ret0 = retval; 988 if (needTruncate) 989 Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0); 990 InVals.push_back(Ret0); 991 } else if (NumElts == 2) { 992 // LoadV2 993 std::vector<EVT> LoadRetVTs; 994 if (needTruncate) { 995 // If loading i1 result, generate 996 // load i16 997 // trunc i16 to i1 998 LoadRetVTs.push_back(MVT::i16); 999 LoadRetVTs.push_back(MVT::i16); 1000 } else { 1001 LoadRetVTs.push_back(EltVT); 1002 LoadRetVTs.push_back(EltVT); 1003 } 1004 LoadRetVTs.push_back(MVT::Other); 1005 LoadRetVTs.push_back(MVT::Glue); 1006 std::vector<SDValue> LoadRetOps; 1007 LoadRetOps.push_back(Chain); 1008 LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); 1009 LoadRetOps.push_back(DAG.getConstant(0, MVT::i32)); 1010 LoadRetOps.push_back(InFlag); 1011 SDValue retval = DAG.getMemIntrinsicNode( 1012 NVPTXISD::LoadParamV2, dl, 1013 DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0], 1014 LoadRetOps.size(), EltVT, MachinePointerInfo()); 1015 Chain = retval.getValue(2); 1016 InFlag = retval.getValue(3); 1017 SDValue Ret0 = retval.getValue(0); 1018 SDValue Ret1 = retval.getValue(1); 1019 if (needTruncate) { 1020 Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0); 1021 InVals.push_back(Ret0); 1022 Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1); 1023 InVals.push_back(Ret1); 1024 } else { 1025 InVals.push_back(Ret0); 1026 InVals.push_back(Ret1); 1027 } 1028 } else { 1029 // Split into N LoadV4 1030 unsigned Ofst = 0; 1031 unsigned VecSize = 4; 1032 unsigned Opc = NVPTXISD::LoadParamV4; 1033 if (EltVT.getSizeInBits() == 64) { 1034 VecSize = 2; 1035 Opc = NVPTXISD::LoadParamV2; 1036 } 1037 EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize); 1038 for (unsigned i = 0; i < NumElts; i += VecSize) { 1039 SmallVector<EVT, 8> LoadRetVTs; 1040 if (needTruncate) { 1041 // If loading i1 result, generate 1042 // load i16 1043 // trunc i16 to i1 1044 for (unsigned j = 0; j < VecSize; ++j) 1045 LoadRetVTs.push_back(MVT::i16); 1046 } else { 1047 for (unsigned j = 0; j < VecSize; ++j) 1048 LoadRetVTs.push_back(EltVT); 1049 } 1050 LoadRetVTs.push_back(MVT::Other); 1051 LoadRetVTs.push_back(MVT::Glue); 1052 SmallVector<SDValue, 4> LoadRetOps; 1053 LoadRetOps.push_back(Chain); 1054 LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); 1055 LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32)); 1056 LoadRetOps.push_back(InFlag); 1057 SDValue retval = DAG.getMemIntrinsicNode( 1058 Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), 1059 &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo()); 1060 if (VecSize == 2) { 1061 Chain = retval.getValue(2); 1062 InFlag = retval.getValue(3); 1063 } else { 1064 Chain = retval.getValue(4); 1065 InFlag = retval.getValue(5); 1066 } 1067 1068 for (unsigned j = 0; j < VecSize; ++j) { 1069 if (i + j >= NumElts) 1070 break; 1071 SDValue Elt = retval.getValue(j); 1072 if (needTruncate) 1073 Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt); 1074 InVals.push_back(Elt); 1075 } 1076 Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); 1077 } 1078 } 1079 } else { 1080 SmallVector<EVT, 16> VTs; 1081 ComputePTXValueVTs(*this, retTy, VTs); 1082 assert(VTs.size() == Ins.size() && "Bad value decomposition"); 1083 for (unsigned i = 0, e = Ins.size(); i != e; ++i) { 1084 unsigned sz = VTs[i].getSizeInBits(); 1085 bool needTruncate = sz < 8 ? true : false; 1086 if (VTs[i].isInteger() && (sz < 8)) 1087 sz = 8; 1088 1089 SmallVector<EVT, 4> LoadRetVTs; 1090 EVT TheLoadType = VTs[i]; 1091 if (retTy->isIntegerTy() && 1092 TD->getTypeAllocSizeInBits(retTy) < 32) { 1093 // This is for integer types only, and specifically not for 1094 // aggregates. 1095 LoadRetVTs.push_back(MVT::i32); 1096 TheLoadType = MVT::i32; 1097 } else if (sz < 16) { 1098 // If loading i1/i8 result, generate 1099 // load i8 (-> i16) 1100 // trunc i16 to i1/i8 1101 LoadRetVTs.push_back(MVT::i16); 1102 } else 1103 LoadRetVTs.push_back(Ins[i].VT); 1104 LoadRetVTs.push_back(MVT::Other); 1105 LoadRetVTs.push_back(MVT::Glue); 1106 1107 SmallVector<SDValue, 4> LoadRetOps; 1108 LoadRetOps.push_back(Chain); 1109 LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); 1110 LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32)); 1111 LoadRetOps.push_back(InFlag); 1112 SDValue retval = DAG.getMemIntrinsicNode( 1113 NVPTXISD::LoadParam, dl, 1114 DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0], 1115 LoadRetOps.size(), TheLoadType, MachinePointerInfo()); 1116 Chain = retval.getValue(1); 1117 InFlag = retval.getValue(2); 1118 SDValue Ret0 = retval.getValue(0); 1119 if (needTruncate) 1120 Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0); 1121 InVals.push_back(Ret0); 1122 resoffset += sz / 8; 1123 } 1124 } 1125 } 1126 1127 Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true), 1128 DAG.getIntPtrConstant(uniqueCallSite + 1, true), 1129 InFlag, dl); 1130 uniqueCallSite++; 1131 1132 // set isTailCall to false for now, until we figure out how to express 1133 // tail call optimization in PTX 1134 isTailCall = false; 1135 return Chain; 1136} 1137 1138// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack() 1139// (see LegalizeDAG.cpp). This is slow and uses local memory. 1140// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5 1141SDValue 1142NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const { 1143 SDNode *Node = Op.getNode(); 1144 SDLoc dl(Node); 1145 SmallVector<SDValue, 8> Ops; 1146 unsigned NumOperands = Node->getNumOperands(); 1147 for (unsigned i = 0; i < NumOperands; ++i) { 1148 SDValue SubOp = Node->getOperand(i); 1149 EVT VVT = SubOp.getNode()->getValueType(0); 1150 EVT EltVT = VVT.getVectorElementType(); 1151 unsigned NumSubElem = VVT.getVectorNumElements(); 1152 for (unsigned j = 0; j < NumSubElem; ++j) { 1153 Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp, 1154 DAG.getIntPtrConstant(j))); 1155 } 1156 } 1157 return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0], 1158 Ops.size()); 1159} 1160 1161SDValue 1162NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { 1163 switch (Op.getOpcode()) { 1164 case ISD::RETURNADDR: 1165 return SDValue(); 1166 case ISD::FRAMEADDR: 1167 return SDValue(); 1168 case ISD::GlobalAddress: 1169 return LowerGlobalAddress(Op, DAG); 1170 case ISD::INTRINSIC_W_CHAIN: 1171 return Op; 1172 case ISD::BUILD_VECTOR: 1173 case ISD::EXTRACT_SUBVECTOR: 1174 return Op; 1175 case ISD::CONCAT_VECTORS: 1176 return LowerCONCAT_VECTORS(Op, DAG); 1177 case ISD::STORE: 1178 return LowerSTORE(Op, DAG); 1179 case ISD::LOAD: 1180 return LowerLOAD(Op, DAG); 1181 default: 1182 llvm_unreachable("Custom lowering not defined for operation"); 1183 } 1184} 1185 1186SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { 1187 if (Op.getValueType() == MVT::i1) 1188 return LowerLOADi1(Op, DAG); 1189 else 1190 return SDValue(); 1191} 1192 1193// v = ld i1* addr 1194// => 1195// v1 = ld i8* addr (-> i16) 1196// v = trunc i16 to i1 1197SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const { 1198 SDNode *Node = Op.getNode(); 1199 LoadSDNode *LD = cast<LoadSDNode>(Node); 1200 SDLoc dl(Node); 1201 assert(LD->getExtensionType() == ISD::NON_EXTLOAD); 1202 assert(Node->getValueType(0) == MVT::i1 && 1203 "Custom lowering for i1 load only"); 1204 SDValue newLD = 1205 DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(), 1206 LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(), 1207 LD->isInvariant(), LD->getAlignment()); 1208 SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD); 1209 // The legalizer (the caller) is expecting two values from the legalized 1210 // load, so we build a MergeValues node for it. See ExpandUnalignedLoad() 1211 // in LegalizeDAG.cpp which also uses MergeValues. 1212 SDValue Ops[] = { result, LD->getChain() }; 1213 return DAG.getMergeValues(Ops, 2, dl); 1214} 1215 1216SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const { 1217 EVT ValVT = Op.getOperand(1).getValueType(); 1218 if (ValVT == MVT::i1) 1219 return LowerSTOREi1(Op, DAG); 1220 else if (ValVT.isVector()) 1221 return LowerSTOREVector(Op, DAG); 1222 else 1223 return SDValue(); 1224} 1225 1226SDValue 1227NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const { 1228 SDNode *N = Op.getNode(); 1229 SDValue Val = N->getOperand(1); 1230 SDLoc DL(N); 1231 EVT ValVT = Val.getValueType(); 1232 1233 if (ValVT.isVector()) { 1234 // We only handle "native" vector sizes for now, e.g. <4 x double> is not 1235 // legal. We can (and should) split that into 2 stores of <2 x double> here 1236 // but I'm leaving that as a TODO for now. 1237 if (!ValVT.isSimple()) 1238 return SDValue(); 1239 switch (ValVT.getSimpleVT().SimpleTy) { 1240 default: 1241 return SDValue(); 1242 case MVT::v2i8: 1243 case MVT::v2i16: 1244 case MVT::v2i32: 1245 case MVT::v2i64: 1246 case MVT::v2f32: 1247 case MVT::v2f64: 1248 case MVT::v4i8: 1249 case MVT::v4i16: 1250 case MVT::v4i32: 1251 case MVT::v4f32: 1252 // This is a "native" vector type 1253 break; 1254 } 1255 1256 unsigned Opcode = 0; 1257 EVT EltVT = ValVT.getVectorElementType(); 1258 unsigned NumElts = ValVT.getVectorNumElements(); 1259 1260 // Since StoreV2 is a target node, we cannot rely on DAG type legalization. 1261 // Therefore, we must ensure the type is legal. For i1 and i8, we set the 1262 // stored type to i16 and propogate the "real" type as the memory type. 1263 bool NeedExt = false; 1264 if (EltVT.getSizeInBits() < 16) 1265 NeedExt = true; 1266 1267 switch (NumElts) { 1268 default: 1269 return SDValue(); 1270 case 2: 1271 Opcode = NVPTXISD::StoreV2; 1272 break; 1273 case 4: { 1274 Opcode = NVPTXISD::StoreV4; 1275 break; 1276 } 1277 } 1278 1279 SmallVector<SDValue, 8> Ops; 1280 1281 // First is the chain 1282 Ops.push_back(N->getOperand(0)); 1283 1284 // Then the split values 1285 for (unsigned i = 0; i < NumElts; ++i) { 1286 SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val, 1287 DAG.getIntPtrConstant(i)); 1288 if (NeedExt) 1289 ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal); 1290 Ops.push_back(ExtVal); 1291 } 1292 1293 // Then any remaining arguments 1294 for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) { 1295 Ops.push_back(N->getOperand(i)); 1296 } 1297 1298 MemSDNode *MemSD = cast<MemSDNode>(N); 1299 1300 SDValue NewSt = DAG.getMemIntrinsicNode( 1301 Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(), 1302 MemSD->getMemoryVT(), MemSD->getMemOperand()); 1303 1304 //return DCI.CombineTo(N, NewSt, true); 1305 return NewSt; 1306 } 1307 1308 return SDValue(); 1309} 1310 1311// st i1 v, addr 1312// => 1313// v1 = zxt v to i16 1314// st.u8 i16, addr 1315SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const { 1316 SDNode *Node = Op.getNode(); 1317 SDLoc dl(Node); 1318 StoreSDNode *ST = cast<StoreSDNode>(Node); 1319 SDValue Tmp1 = ST->getChain(); 1320 SDValue Tmp2 = ST->getBasePtr(); 1321 SDValue Tmp3 = ST->getValue(); 1322 assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only"); 1323 unsigned Alignment = ST->getAlignment(); 1324 bool isVolatile = ST->isVolatile(); 1325 bool isNonTemporal = ST->isNonTemporal(); 1326 Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3); 1327 SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2, 1328 ST->getPointerInfo(), MVT::i8, isNonTemporal, 1329 isVolatile, Alignment); 1330 return Result; 1331} 1332 1333SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname, 1334 int idx, EVT v) const { 1335 std::string *name = nvTM->getManagedStrPool()->getManagedString(inname); 1336 std::stringstream suffix; 1337 suffix << idx; 1338 *name += suffix.str(); 1339 return DAG.getTargetExternalSymbol(name->c_str(), v); 1340} 1341 1342SDValue 1343NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const { 1344 std::string ParamSym; 1345 raw_string_ostream ParamStr(ParamSym); 1346 1347 ParamStr << DAG.getMachineFunction().getName() << "_param_" << idx; 1348 ParamStr.flush(); 1349 1350 std::string *SavedStr = 1351 nvTM->getManagedStrPool()->getManagedString(ParamSym.c_str()); 1352 return DAG.getTargetExternalSymbol(SavedStr->c_str(), v); 1353} 1354 1355SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) { 1356 return getExtSymb(DAG, ".HLPPARAM", idx); 1357} 1358 1359// Check to see if the kernel argument is image*_t or sampler_t 1360 1361bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) { 1362 static const char *const specialTypes[] = { "struct._image2d_t", 1363 "struct._image3d_t", 1364 "struct._sampler_t" }; 1365 1366 const Type *Ty = arg->getType(); 1367 const PointerType *PTy = dyn_cast<PointerType>(Ty); 1368 1369 if (!PTy) 1370 return false; 1371 1372 if (!context) 1373 return false; 1374 1375 const StructType *STy = dyn_cast<StructType>(PTy->getElementType()); 1376 const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : ""; 1377 1378 for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i) 1379 if (TypeName == specialTypes[i]) 1380 return true; 1381 1382 return false; 1383} 1384 1385SDValue NVPTXTargetLowering::LowerFormalArguments( 1386 SDValue Chain, CallingConv::ID CallConv, bool isVarArg, 1387 const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG, 1388 SmallVectorImpl<SDValue> &InVals) const { 1389 MachineFunction &MF = DAG.getMachineFunction(); 1390 const DataLayout *TD = getDataLayout(); 1391 1392 const Function *F = MF.getFunction(); 1393 const AttributeSet &PAL = F->getAttributes(); 1394 const TargetLowering *TLI = nvTM->getTargetLowering(); 1395 1396 SDValue Root = DAG.getRoot(); 1397 std::vector<SDValue> OutChains; 1398 1399 bool isKernel = llvm::isKernelFunction(*F); 1400 bool isABI = (nvptxSubtarget.getSmVersion() >= 20); 1401 assert(isABI && "Non-ABI compilation is not supported"); 1402 if (!isABI) 1403 return Chain; 1404 1405 std::vector<Type *> argTypes; 1406 std::vector<const Argument *> theArgs; 1407 for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end(); 1408 I != E; ++I) { 1409 theArgs.push_back(I); 1410 argTypes.push_back(I->getType()); 1411 } 1412 // argTypes.size() (or theArgs.size()) and Ins.size() need not match. 1413 // Ins.size() will be larger 1414 // * if there is an aggregate argument with multiple fields (each field 1415 // showing up separately in Ins) 1416 // * if there is a vector argument with more than typical vector-length 1417 // elements (generally if more than 4) where each vector element is 1418 // individually present in Ins. 1419 // So a different index should be used for indexing into Ins. 1420 // See similar issue in LowerCall. 1421 unsigned InsIdx = 0; 1422 1423 int idx = 0; 1424 for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) { 1425 Type *Ty = argTypes[i]; 1426 1427 // If the kernel argument is image*_t or sampler_t, convert it to 1428 // a i32 constant holding the parameter position. This can later 1429 // matched in the AsmPrinter to output the correct mangled name. 1430 if (isImageOrSamplerVal( 1431 theArgs[i], 1432 (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent() 1433 : 0))) { 1434 assert(isKernel && "Only kernels can have image/sampler params"); 1435 InVals.push_back(DAG.getConstant(i + 1, MVT::i32)); 1436 continue; 1437 } 1438 1439 if (theArgs[i]->use_empty()) { 1440 // argument is dead 1441 if (Ty->isAggregateType()) { 1442 SmallVector<EVT, 16> vtparts; 1443 1444 ComputePTXValueVTs(*this, Ty, vtparts); 1445 assert(vtparts.size() > 0 && "empty aggregate type not expected"); 1446 for (unsigned parti = 0, parte = vtparts.size(); parti != parte; 1447 ++parti) { 1448 EVT partVT = vtparts[parti]; 1449 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT)); 1450 ++InsIdx; 1451 } 1452 if (vtparts.size() > 0) 1453 --InsIdx; 1454 continue; 1455 } 1456 if (Ty->isVectorTy()) { 1457 EVT ObjectVT = getValueType(Ty); 1458 unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT); 1459 for (unsigned parti = 0; parti < NumRegs; ++parti) { 1460 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); 1461 ++InsIdx; 1462 } 1463 if (NumRegs > 0) 1464 --InsIdx; 1465 continue; 1466 } 1467 InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT)); 1468 continue; 1469 } 1470 1471 // In the following cases, assign a node order of "idx+1" 1472 // to newly created nodes. The SDNodes for params have to 1473 // appear in the same order as their order of appearance 1474 // in the original function. "idx+1" holds that order. 1475 if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) { 1476 if (Ty->isAggregateType()) { 1477 SmallVector<EVT, 16> vtparts; 1478 SmallVector<uint64_t, 16> offsets; 1479 1480 // NOTE: Here, we lose the ability to issue vector loads for vectors 1481 // that are a part of a struct. This should be investigated in the 1482 // future. 1483 ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0); 1484 assert(vtparts.size() > 0 && "empty aggregate type not expected"); 1485 bool aggregateIsPacked = false; 1486 if (StructType *STy = llvm::dyn_cast<StructType>(Ty)) 1487 aggregateIsPacked = STy->isPacked(); 1488 1489 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy()); 1490 for (unsigned parti = 0, parte = vtparts.size(); parti != parte; 1491 ++parti) { 1492 EVT partVT = vtparts[parti]; 1493 Value *srcValue = Constant::getNullValue( 1494 PointerType::get(partVT.getTypeForEVT(F->getContext()), 1495 llvm::ADDRESS_SPACE_PARAM)); 1496 SDValue srcAddr = 1497 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg, 1498 DAG.getConstant(offsets[parti], getPointerTy())); 1499 unsigned partAlign = 1500 aggregateIsPacked ? 1 1501 : TD->getABITypeAlignment( 1502 partVT.getTypeForEVT(F->getContext())); 1503 SDValue p; 1504 if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits()) { 1505 ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? 1506 ISD::SEXTLOAD : ISD::ZEXTLOAD; 1507 p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr, 1508 MachinePointerInfo(srcValue), partVT, false, 1509 false, partAlign); 1510 } else { 1511 p = DAG.getLoad(partVT, dl, Root, srcAddr, 1512 MachinePointerInfo(srcValue), false, false, false, 1513 partAlign); 1514 } 1515 if (p.getNode()) 1516 p.getNode()->setIROrder(idx + 1); 1517 InVals.push_back(p); 1518 ++InsIdx; 1519 } 1520 if (vtparts.size() > 0) 1521 --InsIdx; 1522 continue; 1523 } 1524 if (Ty->isVectorTy()) { 1525 EVT ObjectVT = getValueType(Ty); 1526 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy()); 1527 unsigned NumElts = ObjectVT.getVectorNumElements(); 1528 assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts && 1529 "Vector was not scalarized"); 1530 unsigned Ofst = 0; 1531 EVT EltVT = ObjectVT.getVectorElementType(); 1532 1533 // V1 load 1534 // f32 = load ... 1535 if (NumElts == 1) { 1536 // We only have one element, so just directly load it 1537 Value *SrcValue = Constant::getNullValue(PointerType::get( 1538 EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM)); 1539 SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg, 1540 DAG.getConstant(Ofst, getPointerTy())); 1541 SDValue P = DAG.getLoad( 1542 EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false, 1543 false, true, 1544 TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext()))); 1545 if (P.getNode()) 1546 P.getNode()->setIROrder(idx + 1); 1547 1548 if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) 1549 P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P); 1550 InVals.push_back(P); 1551 Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext())); 1552 ++InsIdx; 1553 } else if (NumElts == 2) { 1554 // V2 load 1555 // f32,f32 = load ... 1556 EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2); 1557 Value *SrcValue = Constant::getNullValue(PointerType::get( 1558 VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM)); 1559 SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg, 1560 DAG.getConstant(Ofst, getPointerTy())); 1561 SDValue P = DAG.getLoad( 1562 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false, 1563 false, true, 1564 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext()))); 1565 if (P.getNode()) 1566 P.getNode()->setIROrder(idx + 1); 1567 1568 SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P, 1569 DAG.getIntPtrConstant(0)); 1570 SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P, 1571 DAG.getIntPtrConstant(1)); 1572 1573 if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) { 1574 Elt0 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt0); 1575 Elt1 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt1); 1576 } 1577 1578 InVals.push_back(Elt0); 1579 InVals.push_back(Elt1); 1580 Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); 1581 InsIdx += 2; 1582 } else { 1583 // V4 loads 1584 // We have at least 4 elements (<3 x Ty> expands to 4 elements) and 1585 // the 1586 // vector will be expanded to a power of 2 elements, so we know we can 1587 // always round up to the next multiple of 4 when creating the vector 1588 // loads. 1589 // e.g. 4 elem => 1 ld.v4 1590 // 6 elem => 2 ld.v4 1591 // 8 elem => 2 ld.v4 1592 // 11 elem => 3 ld.v4 1593 unsigned VecSize = 4; 1594 if (EltVT.getSizeInBits() == 64) { 1595 VecSize = 2; 1596 } 1597 EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize); 1598 for (unsigned i = 0; i < NumElts; i += VecSize) { 1599 Value *SrcValue = Constant::getNullValue( 1600 PointerType::get(VecVT.getTypeForEVT(F->getContext()), 1601 llvm::ADDRESS_SPACE_PARAM)); 1602 SDValue SrcAddr = 1603 DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg, 1604 DAG.getConstant(Ofst, getPointerTy())); 1605 SDValue P = DAG.getLoad( 1606 VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false, 1607 false, true, 1608 TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext()))); 1609 if (P.getNode()) 1610 P.getNode()->setIROrder(idx + 1); 1611 1612 for (unsigned j = 0; j < VecSize; ++j) { 1613 if (i + j >= NumElts) 1614 break; 1615 SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P, 1616 DAG.getIntPtrConstant(j)); 1617 if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) 1618 Elt = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt); 1619 InVals.push_back(Elt); 1620 } 1621 Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); 1622 } 1623 InsIdx += NumElts; 1624 } 1625 1626 if (NumElts > 0) 1627 --InsIdx; 1628 continue; 1629 } 1630 // A plain scalar. 1631 EVT ObjectVT = getValueType(Ty); 1632 // If ABI, load from the param symbol 1633 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy()); 1634 Value *srcValue = Constant::getNullValue(PointerType::get( 1635 ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM)); 1636 SDValue p; 1637 if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits()) { 1638 ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ? 1639 ISD::SEXTLOAD : ISD::ZEXTLOAD; 1640 p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, Arg, 1641 MachinePointerInfo(srcValue), ObjectVT, false, false, 1642 TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext()))); 1643 } else { 1644 p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg, 1645 MachinePointerInfo(srcValue), false, false, false, 1646 TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext()))); 1647 } 1648 if (p.getNode()) 1649 p.getNode()->setIROrder(idx + 1); 1650 InVals.push_back(p); 1651 continue; 1652 } 1653 1654 // Param has ByVal attribute 1655 // Return MoveParam(param symbol). 1656 // Ideally, the param symbol can be returned directly, 1657 // but when SDNode builder decides to use it in a CopyToReg(), 1658 // machine instruction fails because TargetExternalSymbol 1659 // (not lowered) is target dependent, and CopyToReg assumes 1660 // the source is lowered. 1661 EVT ObjectVT = getValueType(Ty); 1662 assert(ObjectVT == Ins[InsIdx].VT && 1663 "Ins type did not match function type"); 1664 SDValue Arg = getParamSymbol(DAG, idx, getPointerTy()); 1665 SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg); 1666 if (p.getNode()) 1667 p.getNode()->setIROrder(idx + 1); 1668 if (isKernel) 1669 InVals.push_back(p); 1670 else { 1671 SDValue p2 = DAG.getNode( 1672 ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT, 1673 DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p); 1674 InVals.push_back(p2); 1675 } 1676 } 1677 1678 // Clang will check explicit VarArg and issue error if any. However, Clang 1679 // will let code with 1680 // implicit var arg like f() pass. See bug 617733. 1681 // We treat this case as if the arg list is empty. 1682 // if (F.isVarArg()) { 1683 // assert(0 && "VarArg not supported yet!"); 1684 //} 1685 1686 if (!OutChains.empty()) 1687 DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0], 1688 OutChains.size())); 1689 1690 return Chain; 1691} 1692 1693 1694SDValue 1695NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, 1696 bool isVarArg, 1697 const SmallVectorImpl<ISD::OutputArg> &Outs, 1698 const SmallVectorImpl<SDValue> &OutVals, 1699 SDLoc dl, SelectionDAG &DAG) const { 1700 MachineFunction &MF = DAG.getMachineFunction(); 1701 const Function *F = MF.getFunction(); 1702 Type *RetTy = F->getReturnType(); 1703 const DataLayout *TD = getDataLayout(); 1704 1705 bool isABI = (nvptxSubtarget.getSmVersion() >= 20); 1706 assert(isABI && "Non-ABI compilation is not supported"); 1707 if (!isABI) 1708 return Chain; 1709 1710 if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) { 1711 // If we have a vector type, the OutVals array will be the scalarized 1712 // components and we have combine them into 1 or more vector stores. 1713 unsigned NumElts = VTy->getNumElements(); 1714 assert(NumElts == Outs.size() && "Bad scalarization of return value"); 1715 1716 // const_cast can be removed in later LLVM versions 1717 EVT EltVT = getValueType(RetTy).getVectorElementType(); 1718 bool NeedExtend = false; 1719 if (EltVT.getSizeInBits() < 16) 1720 NeedExtend = true; 1721 1722 // V1 store 1723 if (NumElts == 1) { 1724 SDValue StoreVal = OutVals[0]; 1725 // We only have one element, so just directly store it 1726 if (NeedExtend) 1727 StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal); 1728 SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal }; 1729 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl, 1730 DAG.getVTList(MVT::Other), &Ops[0], 3, 1731 EltVT, MachinePointerInfo()); 1732 1733 } else if (NumElts == 2) { 1734 // V2 store 1735 SDValue StoreVal0 = OutVals[0]; 1736 SDValue StoreVal1 = OutVals[1]; 1737 1738 if (NeedExtend) { 1739 StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0); 1740 StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1); 1741 } 1742 1743 SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0, 1744 StoreVal1 }; 1745 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl, 1746 DAG.getVTList(MVT::Other), &Ops[0], 4, 1747 EltVT, MachinePointerInfo()); 1748 } else { 1749 // V4 stores 1750 // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the 1751 // vector will be expanded to a power of 2 elements, so we know we can 1752 // always round up to the next multiple of 4 when creating the vector 1753 // stores. 1754 // e.g. 4 elem => 1 st.v4 1755 // 6 elem => 2 st.v4 1756 // 8 elem => 2 st.v4 1757 // 11 elem => 3 st.v4 1758 1759 unsigned VecSize = 4; 1760 if (OutVals[0].getValueType().getSizeInBits() == 64) 1761 VecSize = 2; 1762 1763 unsigned Offset = 0; 1764 1765 EVT VecVT = 1766 EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize); 1767 unsigned PerStoreOffset = 1768 TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext())); 1769 1770 for (unsigned i = 0; i < NumElts; i += VecSize) { 1771 // Get values 1772 SDValue StoreVal; 1773 SmallVector<SDValue, 8> Ops; 1774 Ops.push_back(Chain); 1775 Ops.push_back(DAG.getConstant(Offset, MVT::i32)); 1776 unsigned Opc = NVPTXISD::StoreRetvalV2; 1777 EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType(); 1778 1779 StoreVal = OutVals[i]; 1780 if (NeedExtend) 1781 StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); 1782 Ops.push_back(StoreVal); 1783 1784 if (i + 1 < NumElts) { 1785 StoreVal = OutVals[i + 1]; 1786 if (NeedExtend) 1787 StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); 1788 } else { 1789 StoreVal = DAG.getUNDEF(ExtendedVT); 1790 } 1791 Ops.push_back(StoreVal); 1792 1793 if (VecSize == 4) { 1794 Opc = NVPTXISD::StoreRetvalV4; 1795 if (i + 2 < NumElts) { 1796 StoreVal = OutVals[i + 2]; 1797 if (NeedExtend) 1798 StoreVal = 1799 DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); 1800 } else { 1801 StoreVal = DAG.getUNDEF(ExtendedVT); 1802 } 1803 Ops.push_back(StoreVal); 1804 1805 if (i + 3 < NumElts) { 1806 StoreVal = OutVals[i + 3]; 1807 if (NeedExtend) 1808 StoreVal = 1809 DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal); 1810 } else { 1811 StoreVal = DAG.getUNDEF(ExtendedVT); 1812 } 1813 Ops.push_back(StoreVal); 1814 } 1815 1816 // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size()); 1817 Chain = 1818 DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0], 1819 Ops.size(), EltVT, MachinePointerInfo()); 1820 Offset += PerStoreOffset; 1821 } 1822 } 1823 } else { 1824 SmallVector<EVT, 16> ValVTs; 1825 // const_cast is necessary since we are still using an LLVM version from 1826 // before the type system re-write. 1827 ComputePTXValueVTs(*this, RetTy, ValVTs); 1828 assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition"); 1829 1830 unsigned SizeSoFar = 0; 1831 for (unsigned i = 0, e = Outs.size(); i != e; ++i) { 1832 SDValue theVal = OutVals[i]; 1833 EVT TheValType = theVal.getValueType(); 1834 unsigned numElems = 1; 1835 if (TheValType.isVector()) 1836 numElems = TheValType.getVectorNumElements(); 1837 for (unsigned j = 0, je = numElems; j != je; ++j) { 1838 SDValue TmpVal = theVal; 1839 if (TheValType.isVector()) 1840 TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, 1841 TheValType.getVectorElementType(), TmpVal, 1842 DAG.getIntPtrConstant(j)); 1843 EVT TheStoreType = ValVTs[i]; 1844 if (RetTy->isIntegerTy() && 1845 TD->getTypeAllocSizeInBits(RetTy) < 32) { 1846 // The following zero-extension is for integer types only, and 1847 // specifically not for aggregates. 1848 TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal); 1849 TheStoreType = MVT::i32; 1850 } 1851 else if (TmpVal.getValueType().getSizeInBits() < 16) 1852 TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal); 1853 1854 SDValue Ops[] = { Chain, DAG.getConstant(SizeSoFar, MVT::i32), TmpVal }; 1855 Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl, 1856 DAG.getVTList(MVT::Other), &Ops[0], 1857 3, TheStoreType, 1858 MachinePointerInfo()); 1859 if(TheValType.isVector()) 1860 SizeSoFar += 1861 TheStoreType.getVectorElementType().getStoreSizeInBits() / 8; 1862 else 1863 SizeSoFar += TheStoreType.getStoreSizeInBits()/8; 1864 } 1865 } 1866 } 1867 1868 return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain); 1869} 1870 1871 1872void NVPTXTargetLowering::LowerAsmOperandForConstraint( 1873 SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops, 1874 SelectionDAG &DAG) const { 1875 if (Constraint.length() > 1) 1876 return; 1877 else 1878 TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG); 1879} 1880 1881// NVPTX suuport vector of legal types of any length in Intrinsics because the 1882// NVPTX specific type legalizer 1883// will legalize them to the PTX supported length. 1884bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const { 1885 if (isTypeLegal(VT)) 1886 return true; 1887 if (VT.isVector()) { 1888 MVT eVT = VT.getVectorElementType(); 1889 if (isTypeLegal(eVT)) 1890 return true; 1891 } 1892 return false; 1893} 1894 1895// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as 1896// TgtMemIntrinsic 1897// because we need the information that is only available in the "Value" type 1898// of destination 1899// pointer. In particular, the address space information. 1900bool NVPTXTargetLowering::getTgtMemIntrinsic( 1901 IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const { 1902 switch (Intrinsic) { 1903 default: 1904 return false; 1905 1906 case Intrinsic::nvvm_atomic_load_add_f32: 1907 Info.opc = ISD::INTRINSIC_W_CHAIN; 1908 Info.memVT = MVT::f32; 1909 Info.ptrVal = I.getArgOperand(0); 1910 Info.offset = 0; 1911 Info.vol = 0; 1912 Info.readMem = true; 1913 Info.writeMem = true; 1914 Info.align = 0; 1915 return true; 1916 1917 case Intrinsic::nvvm_atomic_load_inc_32: 1918 case Intrinsic::nvvm_atomic_load_dec_32: 1919 Info.opc = ISD::INTRINSIC_W_CHAIN; 1920 Info.memVT = MVT::i32; 1921 Info.ptrVal = I.getArgOperand(0); 1922 Info.offset = 0; 1923 Info.vol = 0; 1924 Info.readMem = true; 1925 Info.writeMem = true; 1926 Info.align = 0; 1927 return true; 1928 1929 case Intrinsic::nvvm_ldu_global_i: 1930 case Intrinsic::nvvm_ldu_global_f: 1931 case Intrinsic::nvvm_ldu_global_p: 1932 1933 Info.opc = ISD::INTRINSIC_W_CHAIN; 1934 if (Intrinsic == Intrinsic::nvvm_ldu_global_i) 1935 Info.memVT = getValueType(I.getType()); 1936 else if (Intrinsic == Intrinsic::nvvm_ldu_global_p) 1937 Info.memVT = getValueType(I.getType()); 1938 else 1939 Info.memVT = MVT::f32; 1940 Info.ptrVal = I.getArgOperand(0); 1941 Info.offset = 0; 1942 Info.vol = 0; 1943 Info.readMem = true; 1944 Info.writeMem = false; 1945 Info.align = 0; 1946 return true; 1947 1948 } 1949 return false; 1950} 1951 1952/// isLegalAddressingMode - Return true if the addressing mode represented 1953/// by AM is legal for this target, for a load/store of the specified type. 1954/// Used to guide target specific optimizations, like loop strength reduction 1955/// (LoopStrengthReduce.cpp) and memory optimization for address mode 1956/// (CodeGenPrepare.cpp) 1957bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM, 1958 Type *Ty) const { 1959 1960 // AddrMode - This represents an addressing mode of: 1961 // BaseGV + BaseOffs + BaseReg + Scale*ScaleReg 1962 // 1963 // The legal address modes are 1964 // - [avar] 1965 // - [areg] 1966 // - [areg+immoff] 1967 // - [immAddr] 1968 1969 if (AM.BaseGV) { 1970 if (AM.BaseOffs || AM.HasBaseReg || AM.Scale) 1971 return false; 1972 return true; 1973 } 1974 1975 switch (AM.Scale) { 1976 case 0: // "r", "r+i" or "i" is allowed 1977 break; 1978 case 1: 1979 if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed. 1980 return false; 1981 // Otherwise we have r+i. 1982 break; 1983 default: 1984 // No scale > 1 is allowed 1985 return false; 1986 } 1987 return true; 1988} 1989 1990//===----------------------------------------------------------------------===// 1991// NVPTX Inline Assembly Support 1992//===----------------------------------------------------------------------===// 1993 1994/// getConstraintType - Given a constraint letter, return the type of 1995/// constraint it is for this target. 1996NVPTXTargetLowering::ConstraintType 1997NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const { 1998 if (Constraint.size() == 1) { 1999 switch (Constraint[0]) { 2000 default: 2001 break; 2002 case 'r': 2003 case 'h': 2004 case 'c': 2005 case 'l': 2006 case 'f': 2007 case 'd': 2008 case '0': 2009 case 'N': 2010 return C_RegisterClass; 2011 } 2012 } 2013 return TargetLowering::getConstraintType(Constraint); 2014} 2015 2016std::pair<unsigned, const TargetRegisterClass *> 2017NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint, 2018 MVT VT) const { 2019 if (Constraint.size() == 1) { 2020 switch (Constraint[0]) { 2021 case 'c': 2022 return std::make_pair(0U, &NVPTX::Int16RegsRegClass); 2023 case 'h': 2024 return std::make_pair(0U, &NVPTX::Int16RegsRegClass); 2025 case 'r': 2026 return std::make_pair(0U, &NVPTX::Int32RegsRegClass); 2027 case 'l': 2028 case 'N': 2029 return std::make_pair(0U, &NVPTX::Int64RegsRegClass); 2030 case 'f': 2031 return std::make_pair(0U, &NVPTX::Float32RegsRegClass); 2032 case 'd': 2033 return std::make_pair(0U, &NVPTX::Float64RegsRegClass); 2034 } 2035 } 2036 return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT); 2037} 2038 2039/// getFunctionAlignment - Return the Log2 alignment of this function. 2040unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const { 2041 return 4; 2042} 2043 2044/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads. 2045static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG, 2046 SmallVectorImpl<SDValue> &Results) { 2047 EVT ResVT = N->getValueType(0); 2048 SDLoc DL(N); 2049 2050 assert(ResVT.isVector() && "Vector load must have vector type"); 2051 2052 // We only handle "native" vector sizes for now, e.g. <4 x double> is not 2053 // legal. We can (and should) split that into 2 loads of <2 x double> here 2054 // but I'm leaving that as a TODO for now. 2055 assert(ResVT.isSimple() && "Can only handle simple types"); 2056 switch (ResVT.getSimpleVT().SimpleTy) { 2057 default: 2058 return; 2059 case MVT::v2i8: 2060 case MVT::v2i16: 2061 case MVT::v2i32: 2062 case MVT::v2i64: 2063 case MVT::v2f32: 2064 case MVT::v2f64: 2065 case MVT::v4i8: 2066 case MVT::v4i16: 2067 case MVT::v4i32: 2068 case MVT::v4f32: 2069 // This is a "native" vector type 2070 break; 2071 } 2072 2073 EVT EltVT = ResVT.getVectorElementType(); 2074 unsigned NumElts = ResVT.getVectorNumElements(); 2075 2076 // Since LoadV2 is a target node, we cannot rely on DAG type legalization. 2077 // Therefore, we must ensure the type is legal. For i1 and i8, we set the 2078 // loaded type to i16 and propogate the "real" type as the memory type. 2079 bool NeedTrunc = false; 2080 if (EltVT.getSizeInBits() < 16) { 2081 EltVT = MVT::i16; 2082 NeedTrunc = true; 2083 } 2084 2085 unsigned Opcode = 0; 2086 SDVTList LdResVTs; 2087 2088 switch (NumElts) { 2089 default: 2090 return; 2091 case 2: 2092 Opcode = NVPTXISD::LoadV2; 2093 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other); 2094 break; 2095 case 4: { 2096 Opcode = NVPTXISD::LoadV4; 2097 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other }; 2098 LdResVTs = DAG.getVTList(ListVTs, 5); 2099 break; 2100 } 2101 } 2102 2103 SmallVector<SDValue, 8> OtherOps; 2104 2105 // Copy regular operands 2106 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) 2107 OtherOps.push_back(N->getOperand(i)); 2108 2109 LoadSDNode *LD = cast<LoadSDNode>(N); 2110 2111 // The select routine does not have access to the LoadSDNode instance, so 2112 // pass along the extension information 2113 OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType())); 2114 2115 SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0], 2116 OtherOps.size(), LD->getMemoryVT(), 2117 LD->getMemOperand()); 2118 2119 SmallVector<SDValue, 4> ScalarRes; 2120 2121 for (unsigned i = 0; i < NumElts; ++i) { 2122 SDValue Res = NewLD.getValue(i); 2123 if (NeedTrunc) 2124 Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res); 2125 ScalarRes.push_back(Res); 2126 } 2127 2128 SDValue LoadChain = NewLD.getValue(NumElts); 2129 2130 SDValue BuildVec = 2131 DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts); 2132 2133 Results.push_back(BuildVec); 2134 Results.push_back(LoadChain); 2135} 2136 2137static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG, 2138 SmallVectorImpl<SDValue> &Results) { 2139 SDValue Chain = N->getOperand(0); 2140 SDValue Intrin = N->getOperand(1); 2141 SDLoc DL(N); 2142 2143 // Get the intrinsic ID 2144 unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue(); 2145 switch (IntrinNo) { 2146 default: 2147 return; 2148 case Intrinsic::nvvm_ldg_global_i: 2149 case Intrinsic::nvvm_ldg_global_f: 2150 case Intrinsic::nvvm_ldg_global_p: 2151 case Intrinsic::nvvm_ldu_global_i: 2152 case Intrinsic::nvvm_ldu_global_f: 2153 case Intrinsic::nvvm_ldu_global_p: { 2154 EVT ResVT = N->getValueType(0); 2155 2156 if (ResVT.isVector()) { 2157 // Vector LDG/LDU 2158 2159 unsigned NumElts = ResVT.getVectorNumElements(); 2160 EVT EltVT = ResVT.getVectorElementType(); 2161 2162 // Since LDU/LDG are target nodes, we cannot rely on DAG type 2163 // legalization. 2164 // Therefore, we must ensure the type is legal. For i1 and i8, we set the 2165 // loaded type to i16 and propogate the "real" type as the memory type. 2166 bool NeedTrunc = false; 2167 if (EltVT.getSizeInBits() < 16) { 2168 EltVT = MVT::i16; 2169 NeedTrunc = true; 2170 } 2171 2172 unsigned Opcode = 0; 2173 SDVTList LdResVTs; 2174 2175 switch (NumElts) { 2176 default: 2177 return; 2178 case 2: 2179 switch (IntrinNo) { 2180 default: 2181 return; 2182 case Intrinsic::nvvm_ldg_global_i: 2183 case Intrinsic::nvvm_ldg_global_f: 2184 case Intrinsic::nvvm_ldg_global_p: 2185 Opcode = NVPTXISD::LDGV2; 2186 break; 2187 case Intrinsic::nvvm_ldu_global_i: 2188 case Intrinsic::nvvm_ldu_global_f: 2189 case Intrinsic::nvvm_ldu_global_p: 2190 Opcode = NVPTXISD::LDUV2; 2191 break; 2192 } 2193 LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other); 2194 break; 2195 case 4: { 2196 switch (IntrinNo) { 2197 default: 2198 return; 2199 case Intrinsic::nvvm_ldg_global_i: 2200 case Intrinsic::nvvm_ldg_global_f: 2201 case Intrinsic::nvvm_ldg_global_p: 2202 Opcode = NVPTXISD::LDGV4; 2203 break; 2204 case Intrinsic::nvvm_ldu_global_i: 2205 case Intrinsic::nvvm_ldu_global_f: 2206 case Intrinsic::nvvm_ldu_global_p: 2207 Opcode = NVPTXISD::LDUV4; 2208 break; 2209 } 2210 EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other }; 2211 LdResVTs = DAG.getVTList(ListVTs, 5); 2212 break; 2213 } 2214 } 2215 2216 SmallVector<SDValue, 8> OtherOps; 2217 2218 // Copy regular operands 2219 2220 OtherOps.push_back(Chain); // Chain 2221 // Skip operand 1 (intrinsic ID) 2222 // Others 2223 for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) 2224 OtherOps.push_back(N->getOperand(i)); 2225 2226 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N); 2227 2228 SDValue NewLD = DAG.getMemIntrinsicNode( 2229 Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(), 2230 MemSD->getMemoryVT(), MemSD->getMemOperand()); 2231 2232 SmallVector<SDValue, 4> ScalarRes; 2233 2234 for (unsigned i = 0; i < NumElts; ++i) { 2235 SDValue Res = NewLD.getValue(i); 2236 if (NeedTrunc) 2237 Res = 2238 DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res); 2239 ScalarRes.push_back(Res); 2240 } 2241 2242 SDValue LoadChain = NewLD.getValue(NumElts); 2243 2244 SDValue BuildVec = 2245 DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts); 2246 2247 Results.push_back(BuildVec); 2248 Results.push_back(LoadChain); 2249 } else { 2250 // i8 LDG/LDU 2251 assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 && 2252 "Custom handling of non-i8 ldu/ldg?"); 2253 2254 // Just copy all operands as-is 2255 SmallVector<SDValue, 4> Ops; 2256 for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) 2257 Ops.push_back(N->getOperand(i)); 2258 2259 // Force output to i16 2260 SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other); 2261 2262 MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N); 2263 2264 // We make sure the memory type is i8, which will be used during isel 2265 // to select the proper instruction. 2266 SDValue NewLD = 2267 DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0], 2268 Ops.size(), MVT::i8, MemSD->getMemOperand()); 2269 2270 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, 2271 NewLD.getValue(0))); 2272 Results.push_back(NewLD.getValue(1)); 2273 } 2274 } 2275 } 2276} 2277 2278void NVPTXTargetLowering::ReplaceNodeResults( 2279 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const { 2280 switch (N->getOpcode()) { 2281 default: 2282 report_fatal_error("Unhandled custom legalization"); 2283 case ISD::LOAD: 2284 ReplaceLoadVector(N, DAG, Results); 2285 return; 2286 case ISD::INTRINSIC_W_CHAIN: 2287 ReplaceINTRINSIC_W_CHAIN(N, DAG, Results); 2288 return; 2289 } 2290} 2291