NVPTXISelLowering.cpp revision 36b56886974eae4f9c5ebc96befd3e7bfe5de338
1//
2//                     The LLVM Compiler Infrastructure
3//
4// This file is distributed under the University of Illinois Open Source
5// License. See LICENSE.TXT for details.
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
15#include "NVPTX.h"
16#include "NVPTXTargetMachine.h"
17#include "NVPTXTargetObjectFile.h"
18#include "NVPTXUtilities.h"
19#include "llvm/CodeGen/Analysis.h"
20#include "llvm/CodeGen/MachineFrameInfo.h"
21#include "llvm/CodeGen/MachineFunction.h"
22#include "llvm/CodeGen/MachineInstrBuilder.h"
23#include "llvm/CodeGen/MachineRegisterInfo.h"
24#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25#include "llvm/IR/CallSite.h"
26#include "llvm/IR/DerivedTypes.h"
27#include "llvm/IR/Function.h"
28#include "llvm/IR/GlobalValue.h"
29#include "llvm/IR/IntrinsicInst.h"
30#include "llvm/IR/Intrinsics.h"
31#include "llvm/IR/Module.h"
32#include "llvm/MC/MCSectionELF.h"
33#include "llvm/Support/CommandLine.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/ErrorHandling.h"
36#include "llvm/Support/raw_ostream.h"
37#include <sstream>
38
39#undef DEBUG_TYPE
40#define DEBUG_TYPE "nvptx-lower"
41
42using namespace llvm;
43
44static unsigned int uniqueCallSite = 0;
45
46static cl::opt<bool> sched4reg(
47    "nvptx-sched4reg",
48    cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50static bool IsPTXVectorType(MVT VT) {
51  switch (VT.SimpleTy) {
52  default:
53    return false;
54  case MVT::v2i1:
55  case MVT::v4i1:
56  case MVT::v2i8:
57  case MVT::v4i8:
58  case MVT::v2i16:
59  case MVT::v4i16:
60  case MVT::v2i32:
61  case MVT::v4i32:
62  case MVT::v2i64:
63  case MVT::v2f32:
64  case MVT::v4f32:
65  case MVT::v2f64:
66    return true;
67  }
68}
69
70/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
71/// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
72/// into their primitive components.
73/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
74/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
75/// LowerCall, and LowerReturn.
76static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
77                               SmallVectorImpl<EVT> &ValueVTs,
78                               SmallVectorImpl<uint64_t> *Offsets = 0,
79                               uint64_t StartingOffset = 0) {
80  SmallVector<EVT, 16> TempVTs;
81  SmallVector<uint64_t, 16> TempOffsets;
82
83  ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset);
84  for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
85    EVT VT = TempVTs[i];
86    uint64_t Off = TempOffsets[i];
87    if (VT.isVector())
88      for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
89        ValueVTs.push_back(VT.getVectorElementType());
90        if (Offsets)
91          Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
92      }
93    else {
94      ValueVTs.push_back(VT);
95      if (Offsets)
96        Offsets->push_back(Off);
97    }
98  }
99}
100
101// NVPTXTargetLowering Constructor.
102NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
103    : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
104      nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
105
106  // always lower memset, memcpy, and memmove intrinsics to load/store
107  // instructions, rather
108  // then generating calls to memset, mempcy or memmove.
109  MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
110  MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
111  MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
112
113  setBooleanContents(ZeroOrNegativeOneBooleanContent);
114
115  // Jump is Expensive. Don't create extra control flow for 'and', 'or'
116  // condition branches.
117  setJumpIsExpensive(true);
118
119  // By default, use the Source scheduling
120  if (sched4reg)
121    setSchedulingPreference(Sched::RegPressure);
122  else
123    setSchedulingPreference(Sched::Source);
124
125  addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
126  addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
127  addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
128  addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
129  addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
130  addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
131
132  // Operations not directly supported by NVPTX.
133  setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
134  setOperationAction(ISD::BR_CC, MVT::f32, Expand);
135  setOperationAction(ISD::BR_CC, MVT::f64, Expand);
136  setOperationAction(ISD::BR_CC, MVT::i1, Expand);
137  setOperationAction(ISD::BR_CC, MVT::i8, Expand);
138  setOperationAction(ISD::BR_CC, MVT::i16, Expand);
139  setOperationAction(ISD::BR_CC, MVT::i32, Expand);
140  setOperationAction(ISD::BR_CC, MVT::i64, Expand);
141  // Some SIGN_EXTEND_INREG can be done using cvt instruction.
142  // For others we will expand to a SHL/SRA pair.
143  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
144  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal);
145  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal);
146  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal);
147  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
148
149  if (nvptxSubtarget.hasROT64()) {
150    setOperationAction(ISD::ROTL, MVT::i64, Legal);
151    setOperationAction(ISD::ROTR, MVT::i64, Legal);
152  } else {
153    setOperationAction(ISD::ROTL, MVT::i64, Expand);
154    setOperationAction(ISD::ROTR, MVT::i64, Expand);
155  }
156  if (nvptxSubtarget.hasROT32()) {
157    setOperationAction(ISD::ROTL, MVT::i32, Legal);
158    setOperationAction(ISD::ROTR, MVT::i32, Legal);
159  } else {
160    setOperationAction(ISD::ROTL, MVT::i32, Expand);
161    setOperationAction(ISD::ROTR, MVT::i32, Expand);
162  }
163
164  setOperationAction(ISD::ROTL, MVT::i16, Expand);
165  setOperationAction(ISD::ROTR, MVT::i16, Expand);
166  setOperationAction(ISD::ROTL, MVT::i8, Expand);
167  setOperationAction(ISD::ROTR, MVT::i8, Expand);
168  setOperationAction(ISD::BSWAP, MVT::i16, Expand);
169  setOperationAction(ISD::BSWAP, MVT::i32, Expand);
170  setOperationAction(ISD::BSWAP, MVT::i64, Expand);
171
172  // Indirect branch is not supported.
173  // This also disables Jump Table creation.
174  setOperationAction(ISD::BR_JT, MVT::Other, Expand);
175  setOperationAction(ISD::BRIND, MVT::Other, Expand);
176
177  setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
178  setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
179
180  // We want to legalize constant related memmove and memcopy
181  // intrinsics.
182  setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
183
184  // Turn FP extload into load/fextend
185  setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
186  // Turn FP truncstore into trunc + store.
187  setTruncStoreAction(MVT::f64, MVT::f32, Expand);
188
189  // PTX does not support load / store predicate registers
190  setOperationAction(ISD::LOAD, MVT::i1, Custom);
191  setOperationAction(ISD::STORE, MVT::i1, Custom);
192
193  setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
194  setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
195  setTruncStoreAction(MVT::i64, MVT::i1, Expand);
196  setTruncStoreAction(MVT::i32, MVT::i1, Expand);
197  setTruncStoreAction(MVT::i16, MVT::i1, Expand);
198  setTruncStoreAction(MVT::i8, MVT::i1, Expand);
199
200  // This is legal in NVPTX
201  setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
202  setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
203
204  // TRAP can be lowered to PTX trap
205  setOperationAction(ISD::TRAP, MVT::Other, Legal);
206
207  setOperationAction(ISD::ADDC, MVT::i64, Expand);
208  setOperationAction(ISD::ADDE, MVT::i64, Expand);
209
210  // Register custom handling for vector loads/stores
211  for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
212       ++i) {
213    MVT VT = (MVT::SimpleValueType) i;
214    if (IsPTXVectorType(VT)) {
215      setOperationAction(ISD::LOAD, VT, Custom);
216      setOperationAction(ISD::STORE, VT, Custom);
217      setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
218    }
219  }
220
221  // Custom handling for i8 intrinsics
222  setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
223
224  setOperationAction(ISD::CTLZ, MVT::i16, Legal);
225  setOperationAction(ISD::CTLZ, MVT::i32, Legal);
226  setOperationAction(ISD::CTLZ, MVT::i64, Legal);
227  setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i16, Legal);
228  setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Legal);
229  setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i64, Legal);
230  setOperationAction(ISD::CTTZ, MVT::i16, Expand);
231  setOperationAction(ISD::CTTZ, MVT::i32, Expand);
232  setOperationAction(ISD::CTTZ, MVT::i64, Expand);
233  setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i16, Expand);
234  setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Expand);
235  setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Expand);
236  setOperationAction(ISD::CTPOP, MVT::i16, Legal);
237  setOperationAction(ISD::CTPOP, MVT::i32, Legal);
238  setOperationAction(ISD::CTPOP, MVT::i64, Legal);
239
240  // Now deduce the information based on the above mentioned
241  // actions
242  computeRegisterProperties();
243}
244
245const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
246  switch (Opcode) {
247  default:
248    return 0;
249  case NVPTXISD::CALL:
250    return "NVPTXISD::CALL";
251  case NVPTXISD::RET_FLAG:
252    return "NVPTXISD::RET_FLAG";
253  case NVPTXISD::Wrapper:
254    return "NVPTXISD::Wrapper";
255  case NVPTXISD::DeclareParam:
256    return "NVPTXISD::DeclareParam";
257  case NVPTXISD::DeclareScalarParam:
258    return "NVPTXISD::DeclareScalarParam";
259  case NVPTXISD::DeclareRet:
260    return "NVPTXISD::DeclareRet";
261  case NVPTXISD::DeclareRetParam:
262    return "NVPTXISD::DeclareRetParam";
263  case NVPTXISD::PrintCall:
264    return "NVPTXISD::PrintCall";
265  case NVPTXISD::LoadParam:
266    return "NVPTXISD::LoadParam";
267  case NVPTXISD::LoadParamV2:
268    return "NVPTXISD::LoadParamV2";
269  case NVPTXISD::LoadParamV4:
270    return "NVPTXISD::LoadParamV4";
271  case NVPTXISD::StoreParam:
272    return "NVPTXISD::StoreParam";
273  case NVPTXISD::StoreParamV2:
274    return "NVPTXISD::StoreParamV2";
275  case NVPTXISD::StoreParamV4:
276    return "NVPTXISD::StoreParamV4";
277  case NVPTXISD::StoreParamS32:
278    return "NVPTXISD::StoreParamS32";
279  case NVPTXISD::StoreParamU32:
280    return "NVPTXISD::StoreParamU32";
281  case NVPTXISD::CallArgBegin:
282    return "NVPTXISD::CallArgBegin";
283  case NVPTXISD::CallArg:
284    return "NVPTXISD::CallArg";
285  case NVPTXISD::LastCallArg:
286    return "NVPTXISD::LastCallArg";
287  case NVPTXISD::CallArgEnd:
288    return "NVPTXISD::CallArgEnd";
289  case NVPTXISD::CallVoid:
290    return "NVPTXISD::CallVoid";
291  case NVPTXISD::CallVal:
292    return "NVPTXISD::CallVal";
293  case NVPTXISD::CallSymbol:
294    return "NVPTXISD::CallSymbol";
295  case NVPTXISD::Prototype:
296    return "NVPTXISD::Prototype";
297  case NVPTXISD::MoveParam:
298    return "NVPTXISD::MoveParam";
299  case NVPTXISD::StoreRetval:
300    return "NVPTXISD::StoreRetval";
301  case NVPTXISD::StoreRetvalV2:
302    return "NVPTXISD::StoreRetvalV2";
303  case NVPTXISD::StoreRetvalV4:
304    return "NVPTXISD::StoreRetvalV4";
305  case NVPTXISD::PseudoUseParam:
306    return "NVPTXISD::PseudoUseParam";
307  case NVPTXISD::RETURN:
308    return "NVPTXISD::RETURN";
309  case NVPTXISD::CallSeqBegin:
310    return "NVPTXISD::CallSeqBegin";
311  case NVPTXISD::CallSeqEnd:
312    return "NVPTXISD::CallSeqEnd";
313  case NVPTXISD::CallPrototype:
314    return "NVPTXISD::CallPrototype";
315  case NVPTXISD::LoadV2:
316    return "NVPTXISD::LoadV2";
317  case NVPTXISD::LoadV4:
318    return "NVPTXISD::LoadV4";
319  case NVPTXISD::LDGV2:
320    return "NVPTXISD::LDGV2";
321  case NVPTXISD::LDGV4:
322    return "NVPTXISD::LDGV4";
323  case NVPTXISD::LDUV2:
324    return "NVPTXISD::LDUV2";
325  case NVPTXISD::LDUV4:
326    return "NVPTXISD::LDUV4";
327  case NVPTXISD::StoreV2:
328    return "NVPTXISD::StoreV2";
329  case NVPTXISD::StoreV4:
330    return "NVPTXISD::StoreV4";
331  }
332}
333
334bool NVPTXTargetLowering::shouldSplitVectorType(EVT VT) const {
335  return VT.getScalarType() == MVT::i1;
336}
337
338SDValue
339NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
340  SDLoc dl(Op);
341  const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
342  Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
343  return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
344}
345
346std::string
347NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
348                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
349                                  unsigned retAlignment,
350                                  const ImmutableCallSite *CS) const {
351
352  bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
353  assert(isABI && "Non-ABI compilation is not supported");
354  if (!isABI)
355    return "";
356
357  std::stringstream O;
358  O << "prototype_" << uniqueCallSite << " : .callprototype ";
359
360  if (retTy->getTypeID() == Type::VoidTyID) {
361    O << "()";
362  } else {
363    O << "(";
364    if (retTy->isFloatingPointTy() || retTy->isIntegerTy()) {
365      unsigned size = 0;
366      if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
367        size = ITy->getBitWidth();
368        if (size < 32)
369          size = 32;
370      } else {
371        assert(retTy->isFloatingPointTy() &&
372               "Floating point type expected here");
373        size = retTy->getPrimitiveSizeInBits();
374      }
375
376      O << ".param .b" << size << " _";
377    } else if (isa<PointerType>(retTy)) {
378      O << ".param .b" << getPointerTy().getSizeInBits() << " _";
379    } else {
380      if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
381        SmallVector<EVT, 16> vtparts;
382        ComputeValueVTs(*this, retTy, vtparts);
383        unsigned totalsz = 0;
384        for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
385          unsigned elems = 1;
386          EVT elemtype = vtparts[i];
387          if (vtparts[i].isVector()) {
388            elems = vtparts[i].getVectorNumElements();
389            elemtype = vtparts[i].getVectorElementType();
390          }
391          // TODO: no need to loop
392          for (unsigned j = 0, je = elems; j != je; ++j) {
393            unsigned sz = elemtype.getSizeInBits();
394            if (elemtype.isInteger() && (sz < 8))
395              sz = 8;
396            totalsz += sz / 8;
397          }
398        }
399        O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
400      } else {
401        assert(false && "Unknown return type");
402      }
403    }
404    O << ") ";
405  }
406  O << "_ (";
407
408  bool first = true;
409  MVT thePointerTy = getPointerTy();
410
411  unsigned OIdx = 0;
412  for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
413    Type *Ty = Args[i].Ty;
414    if (!first) {
415      O << ", ";
416    }
417    first = false;
418
419    if (Outs[OIdx].Flags.isByVal() == false) {
420      if (Ty->isAggregateType() || Ty->isVectorTy()) {
421        unsigned align = 0;
422        const CallInst *CallI = cast<CallInst>(CS->getInstruction());
423        const DataLayout *TD = getDataLayout();
424        // +1 because index 0 is reserved for return type alignment
425        if (!llvm::getAlign(*CallI, i + 1, align))
426          align = TD->getABITypeAlignment(Ty);
427        unsigned sz = TD->getTypeAllocSize(Ty);
428        O << ".param .align " << align << " .b8 ";
429        O << "_";
430        O << "[" << sz << "]";
431        // update the index for Outs
432        SmallVector<EVT, 16> vtparts;
433        ComputeValueVTs(*this, Ty, vtparts);
434        if (unsigned len = vtparts.size())
435          OIdx += len - 1;
436        continue;
437      }
438       // i8 types in IR will be i16 types in SDAG
439      assert((getValueType(Ty) == Outs[OIdx].VT ||
440             (getValueType(Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
441             "type mismatch between callee prototype and arguments");
442      // scalar type
443      unsigned sz = 0;
444      if (isa<IntegerType>(Ty)) {
445        sz = cast<IntegerType>(Ty)->getBitWidth();
446        if (sz < 32)
447          sz = 32;
448      } else if (isa<PointerType>(Ty))
449        sz = thePointerTy.getSizeInBits();
450      else
451        sz = Ty->getPrimitiveSizeInBits();
452      O << ".param .b" << sz << " ";
453      O << "_";
454      continue;
455    }
456    const PointerType *PTy = dyn_cast<PointerType>(Ty);
457    assert(PTy && "Param with byval attribute should be a pointer type");
458    Type *ETy = PTy->getElementType();
459
460    unsigned align = Outs[OIdx].Flags.getByValAlign();
461    unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
462    O << ".param .align " << align << " .b8 ";
463    O << "_";
464    O << "[" << sz << "]";
465  }
466  O << ");";
467  return O.str();
468}
469
470unsigned
471NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
472                                          const ImmutableCallSite *CS,
473                                          Type *Ty,
474                                          unsigned Idx) const {
475  const DataLayout *TD = getDataLayout();
476  unsigned Align = 0;
477  const Value *DirectCallee = CS->getCalledFunction();
478
479  if (!DirectCallee) {
480    // We don't have a direct function symbol, but that may be because of
481    // constant cast instructions in the call.
482    const Instruction *CalleeI = CS->getInstruction();
483    assert(CalleeI && "Call target is not a function or derived value?");
484
485    // With bitcast'd call targets, the instruction will be the call
486    if (isa<CallInst>(CalleeI)) {
487      // Check if we have call alignment metadata
488      if (llvm::getAlign(*cast<CallInst>(CalleeI), Idx, Align))
489        return Align;
490
491      const Value *CalleeV = cast<CallInst>(CalleeI)->getCalledValue();
492      // Ignore any bitcast instructions
493      while(isa<ConstantExpr>(CalleeV)) {
494        const ConstantExpr *CE = cast<ConstantExpr>(CalleeV);
495        if (!CE->isCast())
496          break;
497        // Look through the bitcast
498        CalleeV = cast<ConstantExpr>(CalleeV)->getOperand(0);
499      }
500
501      // We have now looked past all of the bitcasts.  Do we finally have a
502      // Function?
503      if (isa<Function>(CalleeV))
504        DirectCallee = CalleeV;
505    }
506  }
507
508  // Check for function alignment information if we found that the
509  // ultimate target is a Function
510  if (DirectCallee)
511    if (llvm::getAlign(*cast<Function>(DirectCallee), Idx, Align))
512      return Align;
513
514  // Call is indirect or alignment information is not available, fall back to
515  // the ABI type alignment
516  return TD->getABITypeAlignment(Ty);
517}
518
519SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
520                                       SmallVectorImpl<SDValue> &InVals) const {
521  SelectionDAG &DAG = CLI.DAG;
522  SDLoc dl = CLI.DL;
523  SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
524  SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
525  SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
526  SDValue Chain = CLI.Chain;
527  SDValue Callee = CLI.Callee;
528  bool &isTailCall = CLI.IsTailCall;
529  ArgListTy &Args = CLI.Args;
530  Type *retTy = CLI.RetTy;
531  ImmutableCallSite *CS = CLI.CS;
532
533  bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
534  assert(isABI && "Non-ABI compilation is not supported");
535  if (!isABI)
536    return Chain;
537  const DataLayout *TD = getDataLayout();
538  MachineFunction &MF = DAG.getMachineFunction();
539  const Function *F = MF.getFunction();
540
541  SDValue tempChain = Chain;
542  Chain =
543      DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
544                           dl);
545  SDValue InFlag = Chain.getValue(1);
546
547  unsigned paramCount = 0;
548  // Args.size() and Outs.size() need not match.
549  // Outs.size() will be larger
550  //   * if there is an aggregate argument with multiple fields (each field
551  //     showing up separately in Outs)
552  //   * if there is a vector argument with more than typical vector-length
553  //     elements (generally if more than 4) where each vector element is
554  //     individually present in Outs.
555  // So a different index should be used for indexing into Outs/OutVals.
556  // See similar issue in LowerFormalArguments.
557  unsigned OIdx = 0;
558  // Declare the .params or .reg need to pass values
559  // to the function
560  for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
561    EVT VT = Outs[OIdx].VT;
562    Type *Ty = Args[i].Ty;
563
564    if (Outs[OIdx].Flags.isByVal() == false) {
565      if (Ty->isAggregateType()) {
566        // aggregate
567        SmallVector<EVT, 16> vtparts;
568        ComputeValueVTs(*this, Ty, vtparts);
569
570        unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
571        // declare .param .align <align> .b8 .param<n>[<size>];
572        unsigned sz = TD->getTypeAllocSize(Ty);
573        SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
574        SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
575                                      DAG.getConstant(paramCount, MVT::i32),
576                                      DAG.getConstant(sz, MVT::i32), InFlag };
577        Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
578                            DeclareParamOps, 5);
579        InFlag = Chain.getValue(1);
580        unsigned curOffset = 0;
581        for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
582          unsigned elems = 1;
583          EVT elemtype = vtparts[j];
584          if (vtparts[j].isVector()) {
585            elems = vtparts[j].getVectorNumElements();
586            elemtype = vtparts[j].getVectorElementType();
587          }
588          for (unsigned k = 0, ke = elems; k != ke; ++k) {
589            unsigned sz = elemtype.getSizeInBits();
590            if (elemtype.isInteger() && (sz < 8))
591              sz = 8;
592            SDValue StVal = OutVals[OIdx];
593            if (elemtype.getSizeInBits() < 16) {
594              StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
595            }
596            SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
597            SDValue CopyParamOps[] = { Chain,
598                                       DAG.getConstant(paramCount, MVT::i32),
599                                       DAG.getConstant(curOffset, MVT::i32),
600                                       StVal, InFlag };
601            Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
602                                            CopyParamVTs, &CopyParamOps[0], 5,
603                                            elemtype, MachinePointerInfo());
604            InFlag = Chain.getValue(1);
605            curOffset += sz / 8;
606            ++OIdx;
607          }
608        }
609        if (vtparts.size() > 0)
610          --OIdx;
611        ++paramCount;
612        continue;
613      }
614      if (Ty->isVectorTy()) {
615        EVT ObjectVT = getValueType(Ty);
616        unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
617        // declare .param .align <align> .b8 .param<n>[<size>];
618        unsigned sz = TD->getTypeAllocSize(Ty);
619        SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
620        SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
621                                      DAG.getConstant(paramCount, MVT::i32),
622                                      DAG.getConstant(sz, MVT::i32), InFlag };
623        Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
624                            DeclareParamOps, 5);
625        InFlag = Chain.getValue(1);
626        unsigned NumElts = ObjectVT.getVectorNumElements();
627        EVT EltVT = ObjectVT.getVectorElementType();
628        EVT MemVT = EltVT;
629        bool NeedExtend = false;
630        if (EltVT.getSizeInBits() < 16) {
631          NeedExtend = true;
632          EltVT = MVT::i16;
633        }
634
635        // V1 store
636        if (NumElts == 1) {
637          SDValue Elt = OutVals[OIdx++];
638          if (NeedExtend)
639            Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
640
641          SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
642          SDValue CopyParamOps[] = { Chain,
643                                     DAG.getConstant(paramCount, MVT::i32),
644                                     DAG.getConstant(0, MVT::i32), Elt,
645                                     InFlag };
646          Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
647                                          CopyParamVTs, &CopyParamOps[0], 5,
648                                          MemVT, MachinePointerInfo());
649          InFlag = Chain.getValue(1);
650        } else if (NumElts == 2) {
651          SDValue Elt0 = OutVals[OIdx++];
652          SDValue Elt1 = OutVals[OIdx++];
653          if (NeedExtend) {
654            Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
655            Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
656          }
657
658          SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
659          SDValue CopyParamOps[] = { Chain,
660                                     DAG.getConstant(paramCount, MVT::i32),
661                                     DAG.getConstant(0, MVT::i32), Elt0, Elt1,
662                                     InFlag };
663          Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
664                                          CopyParamVTs, &CopyParamOps[0], 6,
665                                          MemVT, MachinePointerInfo());
666          InFlag = Chain.getValue(1);
667        } else {
668          unsigned curOffset = 0;
669          // V4 stores
670          // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
671          // the
672          // vector will be expanded to a power of 2 elements, so we know we can
673          // always round up to the next multiple of 4 when creating the vector
674          // stores.
675          // e.g.  4 elem => 1 st.v4
676          //       6 elem => 2 st.v4
677          //       8 elem => 2 st.v4
678          //      11 elem => 3 st.v4
679          unsigned VecSize = 4;
680          if (EltVT.getSizeInBits() == 64)
681            VecSize = 2;
682
683          // This is potentially only part of a vector, so assume all elements
684          // are packed together.
685          unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
686
687          for (unsigned i = 0; i < NumElts; i += VecSize) {
688            // Get values
689            SDValue StoreVal;
690            SmallVector<SDValue, 8> Ops;
691            Ops.push_back(Chain);
692            Ops.push_back(DAG.getConstant(paramCount, MVT::i32));
693            Ops.push_back(DAG.getConstant(curOffset, MVT::i32));
694
695            unsigned Opc = NVPTXISD::StoreParamV2;
696
697            StoreVal = OutVals[OIdx++];
698            if (NeedExtend)
699              StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
700            Ops.push_back(StoreVal);
701
702            if (i + 1 < NumElts) {
703              StoreVal = OutVals[OIdx++];
704              if (NeedExtend)
705                StoreVal =
706                    DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
707            } else {
708              StoreVal = DAG.getUNDEF(EltVT);
709            }
710            Ops.push_back(StoreVal);
711
712            if (VecSize == 4) {
713              Opc = NVPTXISD::StoreParamV4;
714              if (i + 2 < NumElts) {
715                StoreVal = OutVals[OIdx++];
716                if (NeedExtend)
717                  StoreVal =
718                      DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
719              } else {
720                StoreVal = DAG.getUNDEF(EltVT);
721              }
722              Ops.push_back(StoreVal);
723
724              if (i + 3 < NumElts) {
725                StoreVal = OutVals[OIdx++];
726                if (NeedExtend)
727                  StoreVal =
728                      DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
729              } else {
730                StoreVal = DAG.getUNDEF(EltVT);
731              }
732              Ops.push_back(StoreVal);
733            }
734
735            Ops.push_back(InFlag);
736
737            SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
738            Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0],
739                                            Ops.size(), MemVT,
740                                            MachinePointerInfo());
741            InFlag = Chain.getValue(1);
742            curOffset += PerStoreOffset;
743          }
744        }
745        ++paramCount;
746        --OIdx;
747        continue;
748      }
749      // Plain scalar
750      // for ABI,    declare .param .b<size> .param<n>;
751      unsigned sz = VT.getSizeInBits();
752      bool needExtend = false;
753      if (VT.isInteger()) {
754        if (sz < 16)
755          needExtend = true;
756        if (sz < 32)
757          sz = 32;
758      }
759      SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
760      SDValue DeclareParamOps[] = { Chain,
761                                    DAG.getConstant(paramCount, MVT::i32),
762                                    DAG.getConstant(sz, MVT::i32),
763                                    DAG.getConstant(0, MVT::i32), InFlag };
764      Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
765                          DeclareParamOps, 5);
766      InFlag = Chain.getValue(1);
767      SDValue OutV = OutVals[OIdx];
768      if (needExtend) {
769        // zext/sext i1 to i16
770        unsigned opc = ISD::ZERO_EXTEND;
771        if (Outs[OIdx].Flags.isSExt())
772          opc = ISD::SIGN_EXTEND;
773        OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
774      }
775      SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
776      SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
777                                 DAG.getConstant(0, MVT::i32), OutV, InFlag };
778
779      unsigned opcode = NVPTXISD::StoreParam;
780      if (Outs[OIdx].Flags.isZExt())
781        opcode = NVPTXISD::StoreParamU32;
782      else if (Outs[OIdx].Flags.isSExt())
783        opcode = NVPTXISD::StoreParamS32;
784      Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5,
785                                      VT, MachinePointerInfo());
786
787      InFlag = Chain.getValue(1);
788      ++paramCount;
789      continue;
790    }
791    // struct or vector
792    SmallVector<EVT, 16> vtparts;
793    const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
794    assert(PTy && "Type of a byval parameter should be pointer");
795    ComputeValueVTs(*this, PTy->getElementType(), vtparts);
796
797    // declare .param .align <align> .b8 .param<n>[<size>];
798    unsigned sz = Outs[OIdx].Flags.getByValSize();
799    SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
800    // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
801    // so we don't need to worry about natural alignment or not.
802    // See TargetLowering::LowerCallTo().
803    SDValue DeclareParamOps[] = {
804      Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32),
805      DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
806      InFlag
807    };
808    Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
809                        DeclareParamOps, 5);
810    InFlag = Chain.getValue(1);
811    unsigned curOffset = 0;
812    for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
813      unsigned elems = 1;
814      EVT elemtype = vtparts[j];
815      if (vtparts[j].isVector()) {
816        elems = vtparts[j].getVectorNumElements();
817        elemtype = vtparts[j].getVectorElementType();
818      }
819      for (unsigned k = 0, ke = elems; k != ke; ++k) {
820        unsigned sz = elemtype.getSizeInBits();
821        if (elemtype.isInteger() && (sz < 8))
822          sz = 8;
823        SDValue srcAddr =
824            DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
825                        DAG.getConstant(curOffset, getPointerTy()));
826        SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
827                                     MachinePointerInfo(), false, false, false,
828                                     0);
829        if (elemtype.getSizeInBits() < 16) {
830          theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
831        }
832        SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
833        SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
834                                   DAG.getConstant(curOffset, MVT::i32), theVal,
835                                   InFlag };
836        Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
837                                        CopyParamOps, 5, elemtype,
838                                        MachinePointerInfo());
839
840        InFlag = Chain.getValue(1);
841        curOffset += sz / 8;
842      }
843    }
844    ++paramCount;
845  }
846
847  GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
848  unsigned retAlignment = 0;
849
850  // Handle Result
851  if (Ins.size() > 0) {
852    SmallVector<EVT, 16> resvtparts;
853    ComputeValueVTs(*this, retTy, resvtparts);
854
855    // Declare
856    //  .param .align 16 .b8 retval0[<size-in-bytes>], or
857    //  .param .b<size-in-bits> retval0
858    unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
859    if (retTy->isSingleValueType()) {
860      // Scalar needs to be at least 32bit wide
861      if (resultsz < 32)
862        resultsz = 32;
863      SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
864      SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
865                                  DAG.getConstant(resultsz, MVT::i32),
866                                  DAG.getConstant(0, MVT::i32), InFlag };
867      Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
868                          DeclareRetOps, 5);
869      InFlag = Chain.getValue(1);
870    } else {
871      retAlignment = getArgumentAlignment(Callee, CS, retTy, 0);
872      SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
873      SDValue DeclareRetOps[] = { Chain,
874                                  DAG.getConstant(retAlignment, MVT::i32),
875                                  DAG.getConstant(resultsz / 8, MVT::i32),
876                                  DAG.getConstant(0, MVT::i32), InFlag };
877      Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
878                          DeclareRetOps, 5);
879      InFlag = Chain.getValue(1);
880    }
881  }
882
883  if (!Func) {
884    // This is indirect function call case : PTX requires a prototype of the
885    // form
886    // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
887    // to be emitted, and the label has to used as the last arg of call
888    // instruction.
889    // The prototype is embedded in a string and put as the operand for a
890    // CallPrototype SDNode which will print out to the value of the string.
891    SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
892    std::string Proto = getPrototype(retTy, Args, Outs, retAlignment, CS);
893    const char *ProtoStr =
894      nvTM->getManagedStrPool()->getManagedString(Proto.c_str())->c_str();
895    SDValue ProtoOps[] = {
896      Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InFlag,
897    };
898    Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, &ProtoOps[0], 3);
899    InFlag = Chain.getValue(1);
900  }
901  // Op to just print "call"
902  SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
903  SDValue PrintCallOps[] = {
904    Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag
905  };
906  Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
907                      dl, PrintCallVTs, PrintCallOps, 3);
908  InFlag = Chain.getValue(1);
909
910  // Ops to print out the function name
911  SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
912  SDValue CallVoidOps[] = { Chain, Callee, InFlag };
913  Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
914  InFlag = Chain.getValue(1);
915
916  // Ops to print out the param list
917  SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
918  SDValue CallArgBeginOps[] = { Chain, InFlag };
919  Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
920                      CallArgBeginOps, 2);
921  InFlag = Chain.getValue(1);
922
923  for (unsigned i = 0, e = paramCount; i != e; ++i) {
924    unsigned opcode;
925    if (i == (e - 1))
926      opcode = NVPTXISD::LastCallArg;
927    else
928      opcode = NVPTXISD::CallArg;
929    SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
930    SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
931                             DAG.getConstant(i, MVT::i32), InFlag };
932    Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
933    InFlag = Chain.getValue(1);
934  }
935  SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
936  SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
937                              InFlag };
938  Chain =
939      DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
940  InFlag = Chain.getValue(1);
941
942  if (!Func) {
943    SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
944    SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
945                               InFlag };
946    Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
947    InFlag = Chain.getValue(1);
948  }
949
950  // Generate loads from param memory/moves from registers for result
951  if (Ins.size() > 0) {
952    unsigned resoffset = 0;
953    if (retTy && retTy->isVectorTy()) {
954      EVT ObjectVT = getValueType(retTy);
955      unsigned NumElts = ObjectVT.getVectorNumElements();
956      EVT EltVT = ObjectVT.getVectorElementType();
957      assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(),
958                                                        ObjectVT) == NumElts &&
959             "Vector was not scalarized");
960      unsigned sz = EltVT.getSizeInBits();
961      bool needTruncate = sz < 16 ? true : false;
962
963      if (NumElts == 1) {
964        // Just a simple load
965        std::vector<EVT> LoadRetVTs;
966        if (needTruncate) {
967          // If loading i1 result, generate
968          //   load i16
969          //   trunc i16 to i1
970          LoadRetVTs.push_back(MVT::i16);
971        } else
972          LoadRetVTs.push_back(EltVT);
973        LoadRetVTs.push_back(MVT::Other);
974        LoadRetVTs.push_back(MVT::Glue);
975        std::vector<SDValue> LoadRetOps;
976        LoadRetOps.push_back(Chain);
977        LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
978        LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
979        LoadRetOps.push_back(InFlag);
980        SDValue retval = DAG.getMemIntrinsicNode(
981            NVPTXISD::LoadParam, dl,
982            DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
983            LoadRetOps.size(), EltVT, MachinePointerInfo());
984        Chain = retval.getValue(1);
985        InFlag = retval.getValue(2);
986        SDValue Ret0 = retval;
987        if (needTruncate)
988          Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
989        InVals.push_back(Ret0);
990      } else if (NumElts == 2) {
991        // LoadV2
992        std::vector<EVT> LoadRetVTs;
993        if (needTruncate) {
994          // If loading i1 result, generate
995          //   load i16
996          //   trunc i16 to i1
997          LoadRetVTs.push_back(MVT::i16);
998          LoadRetVTs.push_back(MVT::i16);
999        } else {
1000          LoadRetVTs.push_back(EltVT);
1001          LoadRetVTs.push_back(EltVT);
1002        }
1003        LoadRetVTs.push_back(MVT::Other);
1004        LoadRetVTs.push_back(MVT::Glue);
1005        std::vector<SDValue> LoadRetOps;
1006        LoadRetOps.push_back(Chain);
1007        LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1008        LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1009        LoadRetOps.push_back(InFlag);
1010        SDValue retval = DAG.getMemIntrinsicNode(
1011            NVPTXISD::LoadParamV2, dl,
1012            DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1013            LoadRetOps.size(), EltVT, MachinePointerInfo());
1014        Chain = retval.getValue(2);
1015        InFlag = retval.getValue(3);
1016        SDValue Ret0 = retval.getValue(0);
1017        SDValue Ret1 = retval.getValue(1);
1018        if (needTruncate) {
1019          Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
1020          InVals.push_back(Ret0);
1021          Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
1022          InVals.push_back(Ret1);
1023        } else {
1024          InVals.push_back(Ret0);
1025          InVals.push_back(Ret1);
1026        }
1027      } else {
1028        // Split into N LoadV4
1029        unsigned Ofst = 0;
1030        unsigned VecSize = 4;
1031        unsigned Opc = NVPTXISD::LoadParamV4;
1032        if (EltVT.getSizeInBits() == 64) {
1033          VecSize = 2;
1034          Opc = NVPTXISD::LoadParamV2;
1035        }
1036        EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1037        for (unsigned i = 0; i < NumElts; i += VecSize) {
1038          SmallVector<EVT, 8> LoadRetVTs;
1039          if (needTruncate) {
1040            // If loading i1 result, generate
1041            //   load i16
1042            //   trunc i16 to i1
1043            for (unsigned j = 0; j < VecSize; ++j)
1044              LoadRetVTs.push_back(MVT::i16);
1045          } else {
1046            for (unsigned j = 0; j < VecSize; ++j)
1047              LoadRetVTs.push_back(EltVT);
1048          }
1049          LoadRetVTs.push_back(MVT::Other);
1050          LoadRetVTs.push_back(MVT::Glue);
1051          SmallVector<SDValue, 4> LoadRetOps;
1052          LoadRetOps.push_back(Chain);
1053          LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1054          LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32));
1055          LoadRetOps.push_back(InFlag);
1056          SDValue retval = DAG.getMemIntrinsicNode(
1057              Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()),
1058              &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo());
1059          if (VecSize == 2) {
1060            Chain = retval.getValue(2);
1061            InFlag = retval.getValue(3);
1062          } else {
1063            Chain = retval.getValue(4);
1064            InFlag = retval.getValue(5);
1065          }
1066
1067          for (unsigned j = 0; j < VecSize; ++j) {
1068            if (i + j >= NumElts)
1069              break;
1070            SDValue Elt = retval.getValue(j);
1071            if (needTruncate)
1072              Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
1073            InVals.push_back(Elt);
1074          }
1075          Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1076        }
1077      }
1078    } else {
1079      SmallVector<EVT, 16> VTs;
1080      ComputePTXValueVTs(*this, retTy, VTs);
1081      assert(VTs.size() == Ins.size() && "Bad value decomposition");
1082      for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
1083        unsigned sz = VTs[i].getSizeInBits();
1084        bool needTruncate = sz < 8 ? true : false;
1085        if (VTs[i].isInteger() && (sz < 8))
1086          sz = 8;
1087
1088        SmallVector<EVT, 4> LoadRetVTs;
1089        EVT TheLoadType = VTs[i];
1090        if (retTy->isIntegerTy() &&
1091            TD->getTypeAllocSizeInBits(retTy) < 32) {
1092          // This is for integer types only, and specifically not for
1093          // aggregates.
1094          LoadRetVTs.push_back(MVT::i32);
1095          TheLoadType = MVT::i32;
1096        } else if (sz < 16) {
1097          // If loading i1/i8 result, generate
1098          //   load i8 (-> i16)
1099          //   trunc i16 to i1/i8
1100          LoadRetVTs.push_back(MVT::i16);
1101        } else
1102          LoadRetVTs.push_back(Ins[i].VT);
1103        LoadRetVTs.push_back(MVT::Other);
1104        LoadRetVTs.push_back(MVT::Glue);
1105
1106        SmallVector<SDValue, 4> LoadRetOps;
1107        LoadRetOps.push_back(Chain);
1108        LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1109        LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
1110        LoadRetOps.push_back(InFlag);
1111        SDValue retval = DAG.getMemIntrinsicNode(
1112            NVPTXISD::LoadParam, dl,
1113            DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1114            LoadRetOps.size(), TheLoadType, MachinePointerInfo());
1115        Chain = retval.getValue(1);
1116        InFlag = retval.getValue(2);
1117        SDValue Ret0 = retval.getValue(0);
1118        if (needTruncate)
1119          Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
1120        InVals.push_back(Ret0);
1121        resoffset += sz / 8;
1122      }
1123    }
1124  }
1125
1126  Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
1127                             DAG.getIntPtrConstant(uniqueCallSite + 1, true),
1128                             InFlag, dl);
1129  uniqueCallSite++;
1130
1131  // set isTailCall to false for now, until we figure out how to express
1132  // tail call optimization in PTX
1133  isTailCall = false;
1134  return Chain;
1135}
1136
1137// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1138// (see LegalizeDAG.cpp). This is slow and uses local memory.
1139// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1140SDValue
1141NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1142  SDNode *Node = Op.getNode();
1143  SDLoc dl(Node);
1144  SmallVector<SDValue, 8> Ops;
1145  unsigned NumOperands = Node->getNumOperands();
1146  for (unsigned i = 0; i < NumOperands; ++i) {
1147    SDValue SubOp = Node->getOperand(i);
1148    EVT VVT = SubOp.getNode()->getValueType(0);
1149    EVT EltVT = VVT.getVectorElementType();
1150    unsigned NumSubElem = VVT.getVectorNumElements();
1151    for (unsigned j = 0; j < NumSubElem; ++j) {
1152      Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1153                                DAG.getIntPtrConstant(j)));
1154    }
1155  }
1156  return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
1157                     Ops.size());
1158}
1159
1160SDValue
1161NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1162  switch (Op.getOpcode()) {
1163  case ISD::RETURNADDR:
1164    return SDValue();
1165  case ISD::FRAMEADDR:
1166    return SDValue();
1167  case ISD::GlobalAddress:
1168    return LowerGlobalAddress(Op, DAG);
1169  case ISD::INTRINSIC_W_CHAIN:
1170    return Op;
1171  case ISD::BUILD_VECTOR:
1172  case ISD::EXTRACT_SUBVECTOR:
1173    return Op;
1174  case ISD::CONCAT_VECTORS:
1175    return LowerCONCAT_VECTORS(Op, DAG);
1176  case ISD::STORE:
1177    return LowerSTORE(Op, DAG);
1178  case ISD::LOAD:
1179    return LowerLOAD(Op, DAG);
1180  default:
1181    llvm_unreachable("Custom lowering not defined for operation");
1182  }
1183}
1184
1185SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
1186  if (Op.getValueType() == MVT::i1)
1187    return LowerLOADi1(Op, DAG);
1188  else
1189    return SDValue();
1190}
1191
1192// v = ld i1* addr
1193//   =>
1194// v1 = ld i8* addr (-> i16)
1195// v = trunc i16 to i1
1196SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
1197  SDNode *Node = Op.getNode();
1198  LoadSDNode *LD = cast<LoadSDNode>(Node);
1199  SDLoc dl(Node);
1200  assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
1201  assert(Node->getValueType(0) == MVT::i1 &&
1202         "Custom lowering for i1 load only");
1203  SDValue newLD =
1204      DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
1205                  LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
1206                  LD->isInvariant(), LD->getAlignment());
1207  SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
1208  // The legalizer (the caller) is expecting two values from the legalized
1209  // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
1210  // in LegalizeDAG.cpp which also uses MergeValues.
1211  SDValue Ops[] = { result, LD->getChain() };
1212  return DAG.getMergeValues(Ops, 2, dl);
1213}
1214
1215SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
1216  EVT ValVT = Op.getOperand(1).getValueType();
1217  if (ValVT == MVT::i1)
1218    return LowerSTOREi1(Op, DAG);
1219  else if (ValVT.isVector())
1220    return LowerSTOREVector(Op, DAG);
1221  else
1222    return SDValue();
1223}
1224
1225SDValue
1226NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
1227  SDNode *N = Op.getNode();
1228  SDValue Val = N->getOperand(1);
1229  SDLoc DL(N);
1230  EVT ValVT = Val.getValueType();
1231
1232  if (ValVT.isVector()) {
1233    // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1234    // legal.  We can (and should) split that into 2 stores of <2 x double> here
1235    // but I'm leaving that as a TODO for now.
1236    if (!ValVT.isSimple())
1237      return SDValue();
1238    switch (ValVT.getSimpleVT().SimpleTy) {
1239    default:
1240      return SDValue();
1241    case MVT::v2i8:
1242    case MVT::v2i16:
1243    case MVT::v2i32:
1244    case MVT::v2i64:
1245    case MVT::v2f32:
1246    case MVT::v2f64:
1247    case MVT::v4i8:
1248    case MVT::v4i16:
1249    case MVT::v4i32:
1250    case MVT::v4f32:
1251      // This is a "native" vector type
1252      break;
1253    }
1254
1255    unsigned Opcode = 0;
1256    EVT EltVT = ValVT.getVectorElementType();
1257    unsigned NumElts = ValVT.getVectorNumElements();
1258
1259    // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
1260    // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1261    // stored type to i16 and propagate the "real" type as the memory type.
1262    bool NeedExt = false;
1263    if (EltVT.getSizeInBits() < 16)
1264      NeedExt = true;
1265
1266    switch (NumElts) {
1267    default:
1268      return SDValue();
1269    case 2:
1270      Opcode = NVPTXISD::StoreV2;
1271      break;
1272    case 4: {
1273      Opcode = NVPTXISD::StoreV4;
1274      break;
1275    }
1276    }
1277
1278    SmallVector<SDValue, 8> Ops;
1279
1280    // First is the chain
1281    Ops.push_back(N->getOperand(0));
1282
1283    // Then the split values
1284    for (unsigned i = 0; i < NumElts; ++i) {
1285      SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
1286                                   DAG.getIntPtrConstant(i));
1287      if (NeedExt)
1288        ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
1289      Ops.push_back(ExtVal);
1290    }
1291
1292    // Then any remaining arguments
1293    for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
1294      Ops.push_back(N->getOperand(i));
1295    }
1296
1297    MemSDNode *MemSD = cast<MemSDNode>(N);
1298
1299    SDValue NewSt = DAG.getMemIntrinsicNode(
1300        Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
1301        MemSD->getMemoryVT(), MemSD->getMemOperand());
1302
1303    //return DCI.CombineTo(N, NewSt, true);
1304    return NewSt;
1305  }
1306
1307  return SDValue();
1308}
1309
1310// st i1 v, addr
1311//    =>
1312// v1 = zxt v to i16
1313// st.u8 i16, addr
1314SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1315  SDNode *Node = Op.getNode();
1316  SDLoc dl(Node);
1317  StoreSDNode *ST = cast<StoreSDNode>(Node);
1318  SDValue Tmp1 = ST->getChain();
1319  SDValue Tmp2 = ST->getBasePtr();
1320  SDValue Tmp3 = ST->getValue();
1321  assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1322  unsigned Alignment = ST->getAlignment();
1323  bool isVolatile = ST->isVolatile();
1324  bool isNonTemporal = ST->isNonTemporal();
1325  Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
1326  SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2,
1327                                     ST->getPointerInfo(), MVT::i8, isNonTemporal,
1328                                     isVolatile, Alignment);
1329  return Result;
1330}
1331
1332SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1333                                        int idx, EVT v) const {
1334  std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1335  std::stringstream suffix;
1336  suffix << idx;
1337  *name += suffix.str();
1338  return DAG.getTargetExternalSymbol(name->c_str(), v);
1339}
1340
1341SDValue
1342NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1343  std::string ParamSym;
1344  raw_string_ostream ParamStr(ParamSym);
1345
1346  ParamStr << DAG.getMachineFunction().getName() << "_param_" << idx;
1347  ParamStr.flush();
1348
1349  std::string *SavedStr =
1350    nvTM->getManagedStrPool()->getManagedString(ParamSym.c_str());
1351  return DAG.getTargetExternalSymbol(SavedStr->c_str(), v);
1352}
1353
1354SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1355  return getExtSymb(DAG, ".HLPPARAM", idx);
1356}
1357
1358// Check to see if the kernel argument is image*_t or sampler_t
1359
1360bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1361  static const char *const specialTypes[] = { "struct._image2d_t",
1362                                              "struct._image3d_t",
1363                                              "struct._sampler_t" };
1364
1365  const Type *Ty = arg->getType();
1366  const PointerType *PTy = dyn_cast<PointerType>(Ty);
1367
1368  if (!PTy)
1369    return false;
1370
1371  if (!context)
1372    return false;
1373
1374  const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1375  const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1376
1377  for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1378    if (TypeName == specialTypes[i])
1379      return true;
1380
1381  return false;
1382}
1383
1384SDValue NVPTXTargetLowering::LowerFormalArguments(
1385    SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1386    const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1387    SmallVectorImpl<SDValue> &InVals) const {
1388  MachineFunction &MF = DAG.getMachineFunction();
1389  const DataLayout *TD = getDataLayout();
1390
1391  const Function *F = MF.getFunction();
1392  const AttributeSet &PAL = F->getAttributes();
1393  const TargetLowering *TLI = nvTM->getTargetLowering();
1394
1395  SDValue Root = DAG.getRoot();
1396  std::vector<SDValue> OutChains;
1397
1398  bool isKernel = llvm::isKernelFunction(*F);
1399  bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1400  assert(isABI && "Non-ABI compilation is not supported");
1401  if (!isABI)
1402    return Chain;
1403
1404  std::vector<Type *> argTypes;
1405  std::vector<const Argument *> theArgs;
1406  for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1407       I != E; ++I) {
1408    theArgs.push_back(I);
1409    argTypes.push_back(I->getType());
1410  }
1411  // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1412  // Ins.size() will be larger
1413  //   * if there is an aggregate argument with multiple fields (each field
1414  //     showing up separately in Ins)
1415  //   * if there is a vector argument with more than typical vector-length
1416  //     elements (generally if more than 4) where each vector element is
1417  //     individually present in Ins.
1418  // So a different index should be used for indexing into Ins.
1419  // See similar issue in LowerCall.
1420  unsigned InsIdx = 0;
1421
1422  int idx = 0;
1423  for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1424    Type *Ty = argTypes[i];
1425
1426    // If the kernel argument is image*_t or sampler_t, convert it to
1427    // a i32 constant holding the parameter position. This can later
1428    // matched in the AsmPrinter to output the correct mangled name.
1429    if (isImageOrSamplerVal(
1430            theArgs[i],
1431            (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1432                                     : 0))) {
1433      assert(isKernel && "Only kernels can have image/sampler params");
1434      InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1435      continue;
1436    }
1437
1438    if (theArgs[i]->use_empty()) {
1439      // argument is dead
1440      if (Ty->isAggregateType()) {
1441        SmallVector<EVT, 16> vtparts;
1442
1443        ComputePTXValueVTs(*this, Ty, vtparts);
1444        assert(vtparts.size() > 0 && "empty aggregate type not expected");
1445        for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1446             ++parti) {
1447          EVT partVT = vtparts[parti];
1448          InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1449          ++InsIdx;
1450        }
1451        if (vtparts.size() > 0)
1452          --InsIdx;
1453        continue;
1454      }
1455      if (Ty->isVectorTy()) {
1456        EVT ObjectVT = getValueType(Ty);
1457        unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1458        for (unsigned parti = 0; parti < NumRegs; ++parti) {
1459          InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1460          ++InsIdx;
1461        }
1462        if (NumRegs > 0)
1463          --InsIdx;
1464        continue;
1465      }
1466      InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1467      continue;
1468    }
1469
1470    // In the following cases, assign a node order of "idx+1"
1471    // to newly created nodes. The SDNodes for params have to
1472    // appear in the same order as their order of appearance
1473    // in the original function. "idx+1" holds that order.
1474    if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1475      if (Ty->isAggregateType()) {
1476        SmallVector<EVT, 16> vtparts;
1477        SmallVector<uint64_t, 16> offsets;
1478
1479        // NOTE: Here, we lose the ability to issue vector loads for vectors
1480        // that are a part of a struct.  This should be investigated in the
1481        // future.
1482        ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0);
1483        assert(vtparts.size() > 0 && "empty aggregate type not expected");
1484        bool aggregateIsPacked = false;
1485        if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1486          aggregateIsPacked = STy->isPacked();
1487
1488        SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1489        for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1490             ++parti) {
1491          EVT partVT = vtparts[parti];
1492          Value *srcValue = Constant::getNullValue(
1493              PointerType::get(partVT.getTypeForEVT(F->getContext()),
1494                               llvm::ADDRESS_SPACE_PARAM));
1495          SDValue srcAddr =
1496              DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1497                          DAG.getConstant(offsets[parti], getPointerTy()));
1498          unsigned partAlign =
1499              aggregateIsPacked ? 1
1500                                : TD->getABITypeAlignment(
1501                                      partVT.getTypeForEVT(F->getContext()));
1502          SDValue p;
1503          if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits()) {
1504            ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ?
1505                                     ISD::SEXTLOAD : ISD::ZEXTLOAD;
1506            p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr,
1507                               MachinePointerInfo(srcValue), partVT, false,
1508                               false, partAlign);
1509          } else {
1510            p = DAG.getLoad(partVT, dl, Root, srcAddr,
1511                            MachinePointerInfo(srcValue), false, false, false,
1512                            partAlign);
1513          }
1514          if (p.getNode())
1515            p.getNode()->setIROrder(idx + 1);
1516          InVals.push_back(p);
1517          ++InsIdx;
1518        }
1519        if (vtparts.size() > 0)
1520          --InsIdx;
1521        continue;
1522      }
1523      if (Ty->isVectorTy()) {
1524        EVT ObjectVT = getValueType(Ty);
1525        SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1526        unsigned NumElts = ObjectVT.getVectorNumElements();
1527        assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1528               "Vector was not scalarized");
1529        unsigned Ofst = 0;
1530        EVT EltVT = ObjectVT.getVectorElementType();
1531
1532        // V1 load
1533        // f32 = load ...
1534        if (NumElts == 1) {
1535          // We only have one element, so just directly load it
1536          Value *SrcValue = Constant::getNullValue(PointerType::get(
1537              EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1538          SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1539                                        DAG.getConstant(Ofst, getPointerTy()));
1540          SDValue P = DAG.getLoad(
1541              EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1542              false, true,
1543              TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1544          if (P.getNode())
1545            P.getNode()->setIROrder(idx + 1);
1546
1547          if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1548            P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P);
1549          InVals.push_back(P);
1550          Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1551          ++InsIdx;
1552        } else if (NumElts == 2) {
1553          // V2 load
1554          // f32,f32 = load ...
1555          EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1556          Value *SrcValue = Constant::getNullValue(PointerType::get(
1557              VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1558          SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1559                                        DAG.getConstant(Ofst, getPointerTy()));
1560          SDValue P = DAG.getLoad(
1561              VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1562              false, true,
1563              TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1564          if (P.getNode())
1565            P.getNode()->setIROrder(idx + 1);
1566
1567          SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1568                                     DAG.getIntPtrConstant(0));
1569          SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1570                                     DAG.getIntPtrConstant(1));
1571
1572          if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
1573            Elt0 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt0);
1574            Elt1 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt1);
1575          }
1576
1577          InVals.push_back(Elt0);
1578          InVals.push_back(Elt1);
1579          Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1580          InsIdx += 2;
1581        } else {
1582          // V4 loads
1583          // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1584          // the
1585          // vector will be expanded to a power of 2 elements, so we know we can
1586          // always round up to the next multiple of 4 when creating the vector
1587          // loads.
1588          // e.g.  4 elem => 1 ld.v4
1589          //       6 elem => 2 ld.v4
1590          //       8 elem => 2 ld.v4
1591          //      11 elem => 3 ld.v4
1592          unsigned VecSize = 4;
1593          if (EltVT.getSizeInBits() == 64) {
1594            VecSize = 2;
1595          }
1596          EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1597          for (unsigned i = 0; i < NumElts; i += VecSize) {
1598            Value *SrcValue = Constant::getNullValue(
1599                PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1600                                 llvm::ADDRESS_SPACE_PARAM));
1601            SDValue SrcAddr =
1602                DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1603                            DAG.getConstant(Ofst, getPointerTy()));
1604            SDValue P = DAG.getLoad(
1605                VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1606                false, true,
1607                TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1608            if (P.getNode())
1609              P.getNode()->setIROrder(idx + 1);
1610
1611            for (unsigned j = 0; j < VecSize; ++j) {
1612              if (i + j >= NumElts)
1613                break;
1614              SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1615                                        DAG.getIntPtrConstant(j));
1616              if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1617                Elt = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt);
1618              InVals.push_back(Elt);
1619            }
1620            Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1621          }
1622          InsIdx += NumElts;
1623        }
1624
1625        if (NumElts > 0)
1626          --InsIdx;
1627        continue;
1628      }
1629      // A plain scalar.
1630      EVT ObjectVT = getValueType(Ty);
1631      // If ABI, load from the param symbol
1632      SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1633      Value *srcValue = Constant::getNullValue(PointerType::get(
1634          ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1635      SDValue p;
1636       if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits()) {
1637        ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ?
1638                                       ISD::SEXTLOAD : ISD::ZEXTLOAD;
1639        p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, Arg,
1640                           MachinePointerInfo(srcValue), ObjectVT, false, false,
1641        TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1642      } else {
1643        p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
1644                        MachinePointerInfo(srcValue), false, false, false,
1645        TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1646      }
1647      if (p.getNode())
1648        p.getNode()->setIROrder(idx + 1);
1649      InVals.push_back(p);
1650      continue;
1651    }
1652
1653    // Param has ByVal attribute
1654    // Return MoveParam(param symbol).
1655    // Ideally, the param symbol can be returned directly,
1656    // but when SDNode builder decides to use it in a CopyToReg(),
1657    // machine instruction fails because TargetExternalSymbol
1658    // (not lowered) is target dependent, and CopyToReg assumes
1659    // the source is lowered.
1660    EVT ObjectVT = getValueType(Ty);
1661    assert(ObjectVT == Ins[InsIdx].VT &&
1662           "Ins type did not match function type");
1663    SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1664    SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1665    if (p.getNode())
1666      p.getNode()->setIROrder(idx + 1);
1667    if (isKernel)
1668      InVals.push_back(p);
1669    else {
1670      SDValue p2 = DAG.getNode(
1671          ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1672          DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1673      InVals.push_back(p2);
1674    }
1675  }
1676
1677  // Clang will check explicit VarArg and issue error if any. However, Clang
1678  // will let code with
1679  // implicit var arg like f() pass. See bug 617733.
1680  // We treat this case as if the arg list is empty.
1681  // if (F.isVarArg()) {
1682  // assert(0 && "VarArg not supported yet!");
1683  //}
1684
1685  if (!OutChains.empty())
1686    DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1687                            OutChains.size()));
1688
1689  return Chain;
1690}
1691
1692
1693SDValue
1694NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1695                                 bool isVarArg,
1696                                 const SmallVectorImpl<ISD::OutputArg> &Outs,
1697                                 const SmallVectorImpl<SDValue> &OutVals,
1698                                 SDLoc dl, SelectionDAG &DAG) const {
1699  MachineFunction &MF = DAG.getMachineFunction();
1700  const Function *F = MF.getFunction();
1701  Type *RetTy = F->getReturnType();
1702  const DataLayout *TD = getDataLayout();
1703
1704  bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1705  assert(isABI && "Non-ABI compilation is not supported");
1706  if (!isABI)
1707    return Chain;
1708
1709  if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) {
1710    // If we have a vector type, the OutVals array will be the scalarized
1711    // components and we have combine them into 1 or more vector stores.
1712    unsigned NumElts = VTy->getNumElements();
1713    assert(NumElts == Outs.size() && "Bad scalarization of return value");
1714
1715    // const_cast can be removed in later LLVM versions
1716    EVT EltVT = getValueType(RetTy).getVectorElementType();
1717    bool NeedExtend = false;
1718    if (EltVT.getSizeInBits() < 16)
1719      NeedExtend = true;
1720
1721    // V1 store
1722    if (NumElts == 1) {
1723      SDValue StoreVal = OutVals[0];
1724      // We only have one element, so just directly store it
1725      if (NeedExtend)
1726        StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
1727      SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal };
1728      Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1729                                      DAG.getVTList(MVT::Other), &Ops[0], 3,
1730                                      EltVT, MachinePointerInfo());
1731
1732    } else if (NumElts == 2) {
1733      // V2 store
1734      SDValue StoreVal0 = OutVals[0];
1735      SDValue StoreVal1 = OutVals[1];
1736
1737      if (NeedExtend) {
1738        StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
1739        StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
1740      }
1741
1742      SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0,
1743                        StoreVal1 };
1744      Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
1745                                      DAG.getVTList(MVT::Other), &Ops[0], 4,
1746                                      EltVT, MachinePointerInfo());
1747    } else {
1748      // V4 stores
1749      // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1750      // vector will be expanded to a power of 2 elements, so we know we can
1751      // always round up to the next multiple of 4 when creating the vector
1752      // stores.
1753      // e.g.  4 elem => 1 st.v4
1754      //       6 elem => 2 st.v4
1755      //       8 elem => 2 st.v4
1756      //      11 elem => 3 st.v4
1757
1758      unsigned VecSize = 4;
1759      if (OutVals[0].getValueType().getSizeInBits() == 64)
1760        VecSize = 2;
1761
1762      unsigned Offset = 0;
1763
1764      EVT VecVT =
1765          EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1766      unsigned PerStoreOffset =
1767          TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1768
1769      for (unsigned i = 0; i < NumElts; i += VecSize) {
1770        // Get values
1771        SDValue StoreVal;
1772        SmallVector<SDValue, 8> Ops;
1773        Ops.push_back(Chain);
1774        Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1775        unsigned Opc = NVPTXISD::StoreRetvalV2;
1776        EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
1777
1778        StoreVal = OutVals[i];
1779        if (NeedExtend)
1780          StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1781        Ops.push_back(StoreVal);
1782
1783        if (i + 1 < NumElts) {
1784          StoreVal = OutVals[i + 1];
1785          if (NeedExtend)
1786            StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1787        } else {
1788          StoreVal = DAG.getUNDEF(ExtendedVT);
1789        }
1790        Ops.push_back(StoreVal);
1791
1792        if (VecSize == 4) {
1793          Opc = NVPTXISD::StoreRetvalV4;
1794          if (i + 2 < NumElts) {
1795            StoreVal = OutVals[i + 2];
1796            if (NeedExtend)
1797              StoreVal =
1798                  DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1799          } else {
1800            StoreVal = DAG.getUNDEF(ExtendedVT);
1801          }
1802          Ops.push_back(StoreVal);
1803
1804          if (i + 3 < NumElts) {
1805            StoreVal = OutVals[i + 3];
1806            if (NeedExtend)
1807              StoreVal =
1808                  DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1809          } else {
1810            StoreVal = DAG.getUNDEF(ExtendedVT);
1811          }
1812          Ops.push_back(StoreVal);
1813        }
1814
1815        // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1816        Chain =
1817            DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0],
1818                                    Ops.size(), EltVT, MachinePointerInfo());
1819        Offset += PerStoreOffset;
1820      }
1821    }
1822  } else {
1823    SmallVector<EVT, 16> ValVTs;
1824    // const_cast is necessary since we are still using an LLVM version from
1825    // before the type system re-write.
1826    ComputePTXValueVTs(*this, RetTy, ValVTs);
1827    assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
1828
1829    unsigned SizeSoFar = 0;
1830    for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1831      SDValue theVal = OutVals[i];
1832      EVT TheValType = theVal.getValueType();
1833      unsigned numElems = 1;
1834      if (TheValType.isVector())
1835        numElems = TheValType.getVectorNumElements();
1836      for (unsigned j = 0, je = numElems; j != je; ++j) {
1837        SDValue TmpVal = theVal;
1838        if (TheValType.isVector())
1839          TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1840                               TheValType.getVectorElementType(), TmpVal,
1841                               DAG.getIntPtrConstant(j));
1842        EVT TheStoreType = ValVTs[i];
1843        if (RetTy->isIntegerTy() &&
1844            TD->getTypeAllocSizeInBits(RetTy) < 32) {
1845          // The following zero-extension is for integer types only, and
1846          // specifically not for aggregates.
1847          TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal);
1848          TheStoreType = MVT::i32;
1849        }
1850        else if (TmpVal.getValueType().getSizeInBits() < 16)
1851          TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal);
1852
1853        SDValue Ops[] = { Chain, DAG.getConstant(SizeSoFar, MVT::i32), TmpVal };
1854        Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1855                                        DAG.getVTList(MVT::Other), &Ops[0],
1856                                        3, TheStoreType,
1857                                        MachinePointerInfo());
1858        if(TheValType.isVector())
1859          SizeSoFar +=
1860            TheStoreType.getVectorElementType().getStoreSizeInBits() / 8;
1861        else
1862          SizeSoFar += TheStoreType.getStoreSizeInBits()/8;
1863      }
1864    }
1865  }
1866
1867  return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1868}
1869
1870
1871void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1872    SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1873    SelectionDAG &DAG) const {
1874  if (Constraint.length() > 1)
1875    return;
1876  else
1877    TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1878}
1879
1880// NVPTX suuport vector of legal types of any length in Intrinsics because the
1881// NVPTX specific type legalizer
1882// will legalize them to the PTX supported length.
1883bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1884  if (isTypeLegal(VT))
1885    return true;
1886  if (VT.isVector()) {
1887    MVT eVT = VT.getVectorElementType();
1888    if (isTypeLegal(eVT))
1889      return true;
1890  }
1891  return false;
1892}
1893
1894// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1895// TgtMemIntrinsic
1896// because we need the information that is only available in the "Value" type
1897// of destination
1898// pointer. In particular, the address space information.
1899bool NVPTXTargetLowering::getTgtMemIntrinsic(
1900    IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1901  switch (Intrinsic) {
1902  default:
1903    return false;
1904
1905  case Intrinsic::nvvm_atomic_load_add_f32:
1906    Info.opc = ISD::INTRINSIC_W_CHAIN;
1907    Info.memVT = MVT::f32;
1908    Info.ptrVal = I.getArgOperand(0);
1909    Info.offset = 0;
1910    Info.vol = 0;
1911    Info.readMem = true;
1912    Info.writeMem = true;
1913    Info.align = 0;
1914    return true;
1915
1916  case Intrinsic::nvvm_atomic_load_inc_32:
1917  case Intrinsic::nvvm_atomic_load_dec_32:
1918    Info.opc = ISD::INTRINSIC_W_CHAIN;
1919    Info.memVT = MVT::i32;
1920    Info.ptrVal = I.getArgOperand(0);
1921    Info.offset = 0;
1922    Info.vol = 0;
1923    Info.readMem = true;
1924    Info.writeMem = true;
1925    Info.align = 0;
1926    return true;
1927
1928  case Intrinsic::nvvm_ldu_global_i:
1929  case Intrinsic::nvvm_ldu_global_f:
1930  case Intrinsic::nvvm_ldu_global_p:
1931
1932    Info.opc = ISD::INTRINSIC_W_CHAIN;
1933    if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1934      Info.memVT = getValueType(I.getType());
1935    else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1936      Info.memVT = getValueType(I.getType());
1937    else
1938      Info.memVT = MVT::f32;
1939    Info.ptrVal = I.getArgOperand(0);
1940    Info.offset = 0;
1941    Info.vol = 0;
1942    Info.readMem = true;
1943    Info.writeMem = false;
1944    Info.align = 0;
1945    return true;
1946
1947  }
1948  return false;
1949}
1950
1951/// isLegalAddressingMode - Return true if the addressing mode represented
1952/// by AM is legal for this target, for a load/store of the specified type.
1953/// Used to guide target specific optimizations, like loop strength reduction
1954/// (LoopStrengthReduce.cpp) and memory optimization for address mode
1955/// (CodeGenPrepare.cpp)
1956bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1957                                                Type *Ty) const {
1958
1959  // AddrMode - This represents an addressing mode of:
1960  //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1961  //
1962  // The legal address modes are
1963  // - [avar]
1964  // - [areg]
1965  // - [areg+immoff]
1966  // - [immAddr]
1967
1968  if (AM.BaseGV) {
1969    if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1970      return false;
1971    return true;
1972  }
1973
1974  switch (AM.Scale) {
1975  case 0: // "r", "r+i" or "i" is allowed
1976    break;
1977  case 1:
1978    if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
1979      return false;
1980    // Otherwise we have r+i.
1981    break;
1982  default:
1983    // No scale > 1 is allowed
1984    return false;
1985  }
1986  return true;
1987}
1988
1989//===----------------------------------------------------------------------===//
1990//                         NVPTX Inline Assembly Support
1991//===----------------------------------------------------------------------===//
1992
1993/// getConstraintType - Given a constraint letter, return the type of
1994/// constraint it is for this target.
1995NVPTXTargetLowering::ConstraintType
1996NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1997  if (Constraint.size() == 1) {
1998    switch (Constraint[0]) {
1999    default:
2000      break;
2001    case 'r':
2002    case 'h':
2003    case 'c':
2004    case 'l':
2005    case 'f':
2006    case 'd':
2007    case '0':
2008    case 'N':
2009      return C_RegisterClass;
2010    }
2011  }
2012  return TargetLowering::getConstraintType(Constraint);
2013}
2014
2015std::pair<unsigned, const TargetRegisterClass *>
2016NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
2017                                                  MVT VT) const {
2018  if (Constraint.size() == 1) {
2019    switch (Constraint[0]) {
2020    case 'c':
2021      return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2022    case 'h':
2023      return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2024    case 'r':
2025      return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
2026    case 'l':
2027    case 'N':
2028      return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
2029    case 'f':
2030      return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
2031    case 'd':
2032      return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
2033    }
2034  }
2035  return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
2036}
2037
2038/// getFunctionAlignment - Return the Log2 alignment of this function.
2039unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
2040  return 4;
2041}
2042
2043/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
2044static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
2045                              SmallVectorImpl<SDValue> &Results) {
2046  EVT ResVT = N->getValueType(0);
2047  SDLoc DL(N);
2048
2049  assert(ResVT.isVector() && "Vector load must have vector type");
2050
2051  // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2052  // legal.  We can (and should) split that into 2 loads of <2 x double> here
2053  // but I'm leaving that as a TODO for now.
2054  assert(ResVT.isSimple() && "Can only handle simple types");
2055  switch (ResVT.getSimpleVT().SimpleTy) {
2056  default:
2057    return;
2058  case MVT::v2i8:
2059  case MVT::v2i16:
2060  case MVT::v2i32:
2061  case MVT::v2i64:
2062  case MVT::v2f32:
2063  case MVT::v2f64:
2064  case MVT::v4i8:
2065  case MVT::v4i16:
2066  case MVT::v4i32:
2067  case MVT::v4f32:
2068    // This is a "native" vector type
2069    break;
2070  }
2071
2072  EVT EltVT = ResVT.getVectorElementType();
2073  unsigned NumElts = ResVT.getVectorNumElements();
2074
2075  // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
2076  // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2077  // loaded type to i16 and propagate the "real" type as the memory type.
2078  bool NeedTrunc = false;
2079  if (EltVT.getSizeInBits() < 16) {
2080    EltVT = MVT::i16;
2081    NeedTrunc = true;
2082  }
2083
2084  unsigned Opcode = 0;
2085  SDVTList LdResVTs;
2086
2087  switch (NumElts) {
2088  default:
2089    return;
2090  case 2:
2091    Opcode = NVPTXISD::LoadV2;
2092    LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2093    break;
2094  case 4: {
2095    Opcode = NVPTXISD::LoadV4;
2096    EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2097    LdResVTs = DAG.getVTList(ListVTs, 5);
2098    break;
2099  }
2100  }
2101
2102  SmallVector<SDValue, 8> OtherOps;
2103
2104  // Copy regular operands
2105  for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2106    OtherOps.push_back(N->getOperand(i));
2107
2108  LoadSDNode *LD = cast<LoadSDNode>(N);
2109
2110  // The select routine does not have access to the LoadSDNode instance, so
2111  // pass along the extension information
2112  OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
2113
2114  SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
2115                                          OtherOps.size(), LD->getMemoryVT(),
2116                                          LD->getMemOperand());
2117
2118  SmallVector<SDValue, 4> ScalarRes;
2119
2120  for (unsigned i = 0; i < NumElts; ++i) {
2121    SDValue Res = NewLD.getValue(i);
2122    if (NeedTrunc)
2123      Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2124    ScalarRes.push_back(Res);
2125  }
2126
2127  SDValue LoadChain = NewLD.getValue(NumElts);
2128
2129  SDValue BuildVec =
2130      DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2131
2132  Results.push_back(BuildVec);
2133  Results.push_back(LoadChain);
2134}
2135
2136static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
2137                                     SmallVectorImpl<SDValue> &Results) {
2138  SDValue Chain = N->getOperand(0);
2139  SDValue Intrin = N->getOperand(1);
2140  SDLoc DL(N);
2141
2142  // Get the intrinsic ID
2143  unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2144  switch (IntrinNo) {
2145  default:
2146    return;
2147  case Intrinsic::nvvm_ldg_global_i:
2148  case Intrinsic::nvvm_ldg_global_f:
2149  case Intrinsic::nvvm_ldg_global_p:
2150  case Intrinsic::nvvm_ldu_global_i:
2151  case Intrinsic::nvvm_ldu_global_f:
2152  case Intrinsic::nvvm_ldu_global_p: {
2153    EVT ResVT = N->getValueType(0);
2154
2155    if (ResVT.isVector()) {
2156      // Vector LDG/LDU
2157
2158      unsigned NumElts = ResVT.getVectorNumElements();
2159      EVT EltVT = ResVT.getVectorElementType();
2160
2161      // Since LDU/LDG are target nodes, we cannot rely on DAG type
2162      // legalization.
2163      // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2164      // loaded type to i16 and propagate the "real" type as the memory type.
2165      bool NeedTrunc = false;
2166      if (EltVT.getSizeInBits() < 16) {
2167        EltVT = MVT::i16;
2168        NeedTrunc = true;
2169      }
2170
2171      unsigned Opcode = 0;
2172      SDVTList LdResVTs;
2173
2174      switch (NumElts) {
2175      default:
2176        return;
2177      case 2:
2178        switch (IntrinNo) {
2179        default:
2180          return;
2181        case Intrinsic::nvvm_ldg_global_i:
2182        case Intrinsic::nvvm_ldg_global_f:
2183        case Intrinsic::nvvm_ldg_global_p:
2184          Opcode = NVPTXISD::LDGV2;
2185          break;
2186        case Intrinsic::nvvm_ldu_global_i:
2187        case Intrinsic::nvvm_ldu_global_f:
2188        case Intrinsic::nvvm_ldu_global_p:
2189          Opcode = NVPTXISD::LDUV2;
2190          break;
2191        }
2192        LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2193        break;
2194      case 4: {
2195        switch (IntrinNo) {
2196        default:
2197          return;
2198        case Intrinsic::nvvm_ldg_global_i:
2199        case Intrinsic::nvvm_ldg_global_f:
2200        case Intrinsic::nvvm_ldg_global_p:
2201          Opcode = NVPTXISD::LDGV4;
2202          break;
2203        case Intrinsic::nvvm_ldu_global_i:
2204        case Intrinsic::nvvm_ldu_global_f:
2205        case Intrinsic::nvvm_ldu_global_p:
2206          Opcode = NVPTXISD::LDUV4;
2207          break;
2208        }
2209        EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2210        LdResVTs = DAG.getVTList(ListVTs, 5);
2211        break;
2212      }
2213      }
2214
2215      SmallVector<SDValue, 8> OtherOps;
2216
2217      // Copy regular operands
2218
2219      OtherOps.push_back(Chain); // Chain
2220                                 // Skip operand 1 (intrinsic ID)
2221      // Others
2222      for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
2223        OtherOps.push_back(N->getOperand(i));
2224
2225      MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2226
2227      SDValue NewLD = DAG.getMemIntrinsicNode(
2228          Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
2229          MemSD->getMemoryVT(), MemSD->getMemOperand());
2230
2231      SmallVector<SDValue, 4> ScalarRes;
2232
2233      for (unsigned i = 0; i < NumElts; ++i) {
2234        SDValue Res = NewLD.getValue(i);
2235        if (NeedTrunc)
2236          Res =
2237              DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2238        ScalarRes.push_back(Res);
2239      }
2240
2241      SDValue LoadChain = NewLD.getValue(NumElts);
2242
2243      SDValue BuildVec =
2244          DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2245
2246      Results.push_back(BuildVec);
2247      Results.push_back(LoadChain);
2248    } else {
2249      // i8 LDG/LDU
2250      assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
2251             "Custom handling of non-i8 ldu/ldg?");
2252
2253      // Just copy all operands as-is
2254      SmallVector<SDValue, 4> Ops;
2255      for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2256        Ops.push_back(N->getOperand(i));
2257
2258      // Force output to i16
2259      SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
2260
2261      MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2262
2263      // We make sure the memory type is i8, which will be used during isel
2264      // to select the proper instruction.
2265      SDValue NewLD =
2266          DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
2267                                  Ops.size(), MVT::i8, MemSD->getMemOperand());
2268
2269      Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
2270                                    NewLD.getValue(0)));
2271      Results.push_back(NewLD.getValue(1));
2272    }
2273  }
2274  }
2275}
2276
2277void NVPTXTargetLowering::ReplaceNodeResults(
2278    SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
2279  switch (N->getOpcode()) {
2280  default:
2281    report_fatal_error("Unhandled custom legalization");
2282  case ISD::LOAD:
2283    ReplaceLoadVector(N, DAG, Results);
2284    return;
2285  case ISD::INTRINSIC_W_CHAIN:
2286    ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
2287    return;
2288  }
2289}
2290
2291// Pin NVPTXSection's and NVPTXTargetObjectFile's vtables to this file.
2292void NVPTXSection::anchor() {}
2293
2294NVPTXTargetObjectFile::~NVPTXTargetObjectFile() {
2295  delete TextSection;
2296  delete DataSection;
2297  delete BSSSection;
2298  delete ReadOnlySection;
2299
2300  delete StaticCtorSection;
2301  delete StaticDtorSection;
2302  delete LSDASection;
2303  delete EHFrameSection;
2304  delete DwarfAbbrevSection;
2305  delete DwarfInfoSection;
2306  delete DwarfLineSection;
2307  delete DwarfFrameSection;
2308  delete DwarfPubTypesSection;
2309  delete DwarfDebugInlineSection;
2310  delete DwarfStrSection;
2311  delete DwarfLocSection;
2312  delete DwarfARangesSection;
2313  delete DwarfRangesSection;
2314  delete DwarfMacroInfoSection;
2315}
2316