NVPTXISelLowering.cpp revision 354362524a72b3fa43a6c09380b7ae3b2380cbba
1//
2//                     The LLVM Compiler Infrastructure
3//
4// This file is distributed under the University of Illinois Open Source
5// License. See LICENSE.TXT for details.
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the interfaces that NVPTX uses to lower LLVM code into a
10// selection DAG.
11//
12//===----------------------------------------------------------------------===//
13
14#include "NVPTXISelLowering.h"
15#include "NVPTX.h"
16#include "NVPTXTargetMachine.h"
17#include "NVPTXTargetObjectFile.h"
18#include "NVPTXUtilities.h"
19#include "llvm/CodeGen/Analysis.h"
20#include "llvm/CodeGen/MachineFrameInfo.h"
21#include "llvm/CodeGen/MachineFunction.h"
22#include "llvm/CodeGen/MachineInstrBuilder.h"
23#include "llvm/CodeGen/MachineRegisterInfo.h"
24#include "llvm/CodeGen/TargetLoweringObjectFileImpl.h"
25#include "llvm/IR/DerivedTypes.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/GlobalValue.h"
28#include "llvm/IR/IntrinsicInst.h"
29#include "llvm/IR/Intrinsics.h"
30#include "llvm/IR/Module.h"
31#include "llvm/MC/MCSectionELF.h"
32#include "llvm/Support/CallSite.h"
33#include "llvm/Support/CommandLine.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/ErrorHandling.h"
36#include "llvm/Support/raw_ostream.h"
37#include <sstream>
38
39#undef DEBUG_TYPE
40#define DEBUG_TYPE "nvptx-lower"
41
42using namespace llvm;
43
44static unsigned int uniqueCallSite = 0;
45
46static cl::opt<bool> sched4reg(
47    "nvptx-sched4reg",
48    cl::desc("NVPTX Specific: schedule for register pressue"), cl::init(false));
49
50static bool IsPTXVectorType(MVT VT) {
51  switch (VT.SimpleTy) {
52  default:
53    return false;
54  case MVT::v2i1:
55  case MVT::v4i1:
56  case MVT::v2i8:
57  case MVT::v4i8:
58  case MVT::v2i16:
59  case MVT::v4i16:
60  case MVT::v2i32:
61  case MVT::v4i32:
62  case MVT::v2i64:
63  case MVT::v2f32:
64  case MVT::v4f32:
65  case MVT::v2f64:
66    return true;
67  }
68}
69
70/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
71/// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
72/// into their primitive components.
73/// NOTE: This is a band-aid for code that expects ComputeValueVTs to return the
74/// same number of types as the Ins/Outs arrays in LowerFormalArguments,
75/// LowerCall, and LowerReturn.
76static void ComputePTXValueVTs(const TargetLowering &TLI, Type *Ty,
77                               SmallVectorImpl<EVT> &ValueVTs,
78                               SmallVectorImpl<uint64_t> *Offsets = 0,
79                               uint64_t StartingOffset = 0) {
80  SmallVector<EVT, 16> TempVTs;
81  SmallVector<uint64_t, 16> TempOffsets;
82
83  ComputeValueVTs(TLI, Ty, TempVTs, &TempOffsets, StartingOffset);
84  for (unsigned i = 0, e = TempVTs.size(); i != e; ++i) {
85    EVT VT = TempVTs[i];
86    uint64_t Off = TempOffsets[i];
87    if (VT.isVector())
88      for (unsigned j = 0, je = VT.getVectorNumElements(); j != je; ++j) {
89        ValueVTs.push_back(VT.getVectorElementType());
90        if (Offsets)
91          Offsets->push_back(Off+j*VT.getVectorElementType().getStoreSize());
92      }
93    else {
94      ValueVTs.push_back(VT);
95      if (Offsets)
96        Offsets->push_back(Off);
97    }
98  }
99}
100
101// NVPTXTargetLowering Constructor.
102NVPTXTargetLowering::NVPTXTargetLowering(NVPTXTargetMachine &TM)
103    : TargetLowering(TM, new NVPTXTargetObjectFile()), nvTM(&TM),
104      nvptxSubtarget(TM.getSubtarget<NVPTXSubtarget>()) {
105
106  // always lower memset, memcpy, and memmove intrinsics to load/store
107  // instructions, rather
108  // then generating calls to memset, mempcy or memmove.
109  MaxStoresPerMemset = (unsigned) 0xFFFFFFFF;
110  MaxStoresPerMemcpy = (unsigned) 0xFFFFFFFF;
111  MaxStoresPerMemmove = (unsigned) 0xFFFFFFFF;
112
113  setBooleanContents(ZeroOrNegativeOneBooleanContent);
114
115  // Jump is Expensive. Don't create extra control flow for 'and', 'or'
116  // condition branches.
117  setJumpIsExpensive(true);
118
119  // By default, use the Source scheduling
120  if (sched4reg)
121    setSchedulingPreference(Sched::RegPressure);
122  else
123    setSchedulingPreference(Sched::Source);
124
125  addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
126  addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
127  addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
128  addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
129  addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
130  addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
131
132  // Operations not directly supported by NVPTX.
133  setOperationAction(ISD::SELECT_CC, MVT::Other, Expand);
134  setOperationAction(ISD::BR_CC, MVT::f32, Expand);
135  setOperationAction(ISD::BR_CC, MVT::f64, Expand);
136  setOperationAction(ISD::BR_CC, MVT::i1, Expand);
137  setOperationAction(ISD::BR_CC, MVT::i8, Expand);
138  setOperationAction(ISD::BR_CC, MVT::i16, Expand);
139  setOperationAction(ISD::BR_CC, MVT::i32, Expand);
140  setOperationAction(ISD::BR_CC, MVT::i64, Expand);
141  // Some SIGN_EXTEND_INREG can be done using cvt instruction.
142  // For others we will expand to a SHL/SRA pair.
143  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i64, Legal);
144  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i32, Legal);
145  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i16, Legal);
146  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i8 , Legal);
147  setOperationAction(ISD::SIGN_EXTEND_INREG, MVT::i1, Expand);
148
149  if (nvptxSubtarget.hasROT64()) {
150    setOperationAction(ISD::ROTL, MVT::i64, Legal);
151    setOperationAction(ISD::ROTR, MVT::i64, Legal);
152  } else {
153    setOperationAction(ISD::ROTL, MVT::i64, Expand);
154    setOperationAction(ISD::ROTR, MVT::i64, Expand);
155  }
156  if (nvptxSubtarget.hasROT32()) {
157    setOperationAction(ISD::ROTL, MVT::i32, Legal);
158    setOperationAction(ISD::ROTR, MVT::i32, Legal);
159  } else {
160    setOperationAction(ISD::ROTL, MVT::i32, Expand);
161    setOperationAction(ISD::ROTR, MVT::i32, Expand);
162  }
163
164  setOperationAction(ISD::ROTL, MVT::i16, Expand);
165  setOperationAction(ISD::ROTR, MVT::i16, Expand);
166  setOperationAction(ISD::ROTL, MVT::i8, Expand);
167  setOperationAction(ISD::ROTR, MVT::i8, Expand);
168  setOperationAction(ISD::BSWAP, MVT::i16, Expand);
169  setOperationAction(ISD::BSWAP, MVT::i32, Expand);
170  setOperationAction(ISD::BSWAP, MVT::i64, Expand);
171
172  // Indirect branch is not supported.
173  // This also disables Jump Table creation.
174  setOperationAction(ISD::BR_JT, MVT::Other, Expand);
175  setOperationAction(ISD::BRIND, MVT::Other, Expand);
176
177  setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
178  setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
179
180  // We want to legalize constant related memmove and memcopy
181  // intrinsics.
182  setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
183
184  // Turn FP extload into load/fextend
185  setLoadExtAction(ISD::EXTLOAD, MVT::f32, Expand);
186  // Turn FP truncstore into trunc + store.
187  setTruncStoreAction(MVT::f64, MVT::f32, Expand);
188
189  // PTX does not support load / store predicate registers
190  setOperationAction(ISD::LOAD, MVT::i1, Custom);
191  setOperationAction(ISD::STORE, MVT::i1, Custom);
192
193  setLoadExtAction(ISD::SEXTLOAD, MVT::i1, Promote);
194  setLoadExtAction(ISD::ZEXTLOAD, MVT::i1, Promote);
195  setTruncStoreAction(MVT::i64, MVT::i1, Expand);
196  setTruncStoreAction(MVT::i32, MVT::i1, Expand);
197  setTruncStoreAction(MVT::i16, MVT::i1, Expand);
198  setTruncStoreAction(MVT::i8, MVT::i1, Expand);
199
200  // This is legal in NVPTX
201  setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
202  setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
203
204  // TRAP can be lowered to PTX trap
205  setOperationAction(ISD::TRAP, MVT::Other, Legal);
206
207  setOperationAction(ISD::ADDC, MVT::i64, Expand);
208  setOperationAction(ISD::ADDE, MVT::i64, Expand);
209
210  // Register custom handling for vector loads/stores
211  for (int i = MVT::FIRST_VECTOR_VALUETYPE; i <= MVT::LAST_VECTOR_VALUETYPE;
212       ++i) {
213    MVT VT = (MVT::SimpleValueType) i;
214    if (IsPTXVectorType(VT)) {
215      setOperationAction(ISD::LOAD, VT, Custom);
216      setOperationAction(ISD::STORE, VT, Custom);
217      setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
218    }
219  }
220
221  // Custom handling for i8 intrinsics
222  setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
223
224  setOperationAction(ISD::CTLZ, MVT::i16, Legal);
225  setOperationAction(ISD::CTLZ, MVT::i32, Legal);
226  setOperationAction(ISD::CTLZ, MVT::i64, Legal);
227  setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i16, Legal);
228  setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i32, Legal);
229  setOperationAction(ISD::CTLZ_ZERO_UNDEF, MVT::i64, Legal);
230  setOperationAction(ISD::CTTZ, MVT::i16, Expand);
231  setOperationAction(ISD::CTTZ, MVT::i32, Expand);
232  setOperationAction(ISD::CTTZ, MVT::i64, Expand);
233  setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i16, Expand);
234  setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i32, Expand);
235  setOperationAction(ISD::CTTZ_ZERO_UNDEF, MVT::i64, Expand);
236  setOperationAction(ISD::CTPOP, MVT::i16, Legal);
237  setOperationAction(ISD::CTPOP, MVT::i32, Legal);
238  setOperationAction(ISD::CTPOP, MVT::i64, Legal);
239
240  // Now deduce the information based on the above mentioned
241  // actions
242  computeRegisterProperties();
243}
244
245const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
246  switch (Opcode) {
247  default:
248    return 0;
249  case NVPTXISD::CALL:
250    return "NVPTXISD::CALL";
251  case NVPTXISD::RET_FLAG:
252    return "NVPTXISD::RET_FLAG";
253  case NVPTXISD::Wrapper:
254    return "NVPTXISD::Wrapper";
255  case NVPTXISD::DeclareParam:
256    return "NVPTXISD::DeclareParam";
257  case NVPTXISD::DeclareScalarParam:
258    return "NVPTXISD::DeclareScalarParam";
259  case NVPTXISD::DeclareRet:
260    return "NVPTXISD::DeclareRet";
261  case NVPTXISD::DeclareRetParam:
262    return "NVPTXISD::DeclareRetParam";
263  case NVPTXISD::PrintCall:
264    return "NVPTXISD::PrintCall";
265  case NVPTXISD::LoadParam:
266    return "NVPTXISD::LoadParam";
267  case NVPTXISD::LoadParamV2:
268    return "NVPTXISD::LoadParamV2";
269  case NVPTXISD::LoadParamV4:
270    return "NVPTXISD::LoadParamV4";
271  case NVPTXISD::StoreParam:
272    return "NVPTXISD::StoreParam";
273  case NVPTXISD::StoreParamV2:
274    return "NVPTXISD::StoreParamV2";
275  case NVPTXISD::StoreParamV4:
276    return "NVPTXISD::StoreParamV4";
277  case NVPTXISD::StoreParamS32:
278    return "NVPTXISD::StoreParamS32";
279  case NVPTXISD::StoreParamU32:
280    return "NVPTXISD::StoreParamU32";
281  case NVPTXISD::CallArgBegin:
282    return "NVPTXISD::CallArgBegin";
283  case NVPTXISD::CallArg:
284    return "NVPTXISD::CallArg";
285  case NVPTXISD::LastCallArg:
286    return "NVPTXISD::LastCallArg";
287  case NVPTXISD::CallArgEnd:
288    return "NVPTXISD::CallArgEnd";
289  case NVPTXISD::CallVoid:
290    return "NVPTXISD::CallVoid";
291  case NVPTXISD::CallVal:
292    return "NVPTXISD::CallVal";
293  case NVPTXISD::CallSymbol:
294    return "NVPTXISD::CallSymbol";
295  case NVPTXISD::Prototype:
296    return "NVPTXISD::Prototype";
297  case NVPTXISD::MoveParam:
298    return "NVPTXISD::MoveParam";
299  case NVPTXISD::StoreRetval:
300    return "NVPTXISD::StoreRetval";
301  case NVPTXISD::StoreRetvalV2:
302    return "NVPTXISD::StoreRetvalV2";
303  case NVPTXISD::StoreRetvalV4:
304    return "NVPTXISD::StoreRetvalV4";
305  case NVPTXISD::PseudoUseParam:
306    return "NVPTXISD::PseudoUseParam";
307  case NVPTXISD::RETURN:
308    return "NVPTXISD::RETURN";
309  case NVPTXISD::CallSeqBegin:
310    return "NVPTXISD::CallSeqBegin";
311  case NVPTXISD::CallSeqEnd:
312    return "NVPTXISD::CallSeqEnd";
313  case NVPTXISD::CallPrototype:
314    return "NVPTXISD::CallPrototype";
315  case NVPTXISD::LoadV2:
316    return "NVPTXISD::LoadV2";
317  case NVPTXISD::LoadV4:
318    return "NVPTXISD::LoadV4";
319  case NVPTXISD::LDGV2:
320    return "NVPTXISD::LDGV2";
321  case NVPTXISD::LDGV4:
322    return "NVPTXISD::LDGV4";
323  case NVPTXISD::LDUV2:
324    return "NVPTXISD::LDUV2";
325  case NVPTXISD::LDUV4:
326    return "NVPTXISD::LDUV4";
327  case NVPTXISD::StoreV2:
328    return "NVPTXISD::StoreV2";
329  case NVPTXISD::StoreV4:
330    return "NVPTXISD::StoreV4";
331  }
332}
333
334bool NVPTXTargetLowering::shouldSplitVectorElementType(EVT VT) const {
335  return VT == MVT::i1;
336}
337
338SDValue
339NVPTXTargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
340  SDLoc dl(Op);
341  const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
342  Op = DAG.getTargetGlobalAddress(GV, dl, getPointerTy());
343  return DAG.getNode(NVPTXISD::Wrapper, dl, getPointerTy(), Op);
344}
345
346std::string
347NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
348                                  const SmallVectorImpl<ISD::OutputArg> &Outs,
349                                  unsigned retAlignment,
350                                  const ImmutableCallSite *CS) const {
351
352  bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
353  assert(isABI && "Non-ABI compilation is not supported");
354  if (!isABI)
355    return "";
356
357  std::stringstream O;
358  O << "prototype_" << uniqueCallSite << " : .callprototype ";
359
360  if (retTy->getTypeID() == Type::VoidTyID) {
361    O << "()";
362  } else {
363    O << "(";
364    if (retTy->isPrimitiveType() || retTy->isIntegerTy()) {
365      unsigned size = 0;
366      if (const IntegerType *ITy = dyn_cast<IntegerType>(retTy)) {
367        size = ITy->getBitWidth();
368        if (size < 32)
369          size = 32;
370      } else {
371        assert(retTy->isFloatingPointTy() &&
372               "Floating point type expected here");
373        size = retTy->getPrimitiveSizeInBits();
374      }
375
376      O << ".param .b" << size << " _";
377    } else if (isa<PointerType>(retTy)) {
378      O << ".param .b" << getPointerTy().getSizeInBits() << " _";
379    } else {
380      if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) {
381        SmallVector<EVT, 16> vtparts;
382        ComputeValueVTs(*this, retTy, vtparts);
383        unsigned totalsz = 0;
384        for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
385          unsigned elems = 1;
386          EVT elemtype = vtparts[i];
387          if (vtparts[i].isVector()) {
388            elems = vtparts[i].getVectorNumElements();
389            elemtype = vtparts[i].getVectorElementType();
390          }
391          // TODO: no need to loop
392          for (unsigned j = 0, je = elems; j != je; ++j) {
393            unsigned sz = elemtype.getSizeInBits();
394            if (elemtype.isInteger() && (sz < 8))
395              sz = 8;
396            totalsz += sz / 8;
397          }
398        }
399        O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
400      } else {
401        assert(false && "Unknown return type");
402      }
403    }
404    O << ") ";
405  }
406  O << "_ (";
407
408  bool first = true;
409  MVT thePointerTy = getPointerTy();
410
411  unsigned OIdx = 0;
412  for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
413    Type *Ty = Args[i].Ty;
414    if (!first) {
415      O << ", ";
416    }
417    first = false;
418
419    if (Outs[OIdx].Flags.isByVal() == false) {
420      if (Ty->isAggregateType() || Ty->isVectorTy()) {
421        unsigned align = 0;
422        const CallInst *CallI = cast<CallInst>(CS->getInstruction());
423        const DataLayout *TD = getDataLayout();
424        // +1 because index 0 is reserved for return type alignment
425        if (!llvm::getAlign(*CallI, i + 1, align))
426          align = TD->getABITypeAlignment(Ty);
427        unsigned sz = TD->getTypeAllocSize(Ty);
428        O << ".param .align " << align << " .b8 ";
429        O << "_";
430        O << "[" << sz << "]";
431        // update the index for Outs
432        SmallVector<EVT, 16> vtparts;
433        ComputeValueVTs(*this, Ty, vtparts);
434        if (unsigned len = vtparts.size())
435          OIdx += len - 1;
436        continue;
437      }
438       // i8 types in IR will be i16 types in SDAG
439      assert((getValueType(Ty) == Outs[OIdx].VT ||
440             (getValueType(Ty) == MVT::i8 && Outs[OIdx].VT == MVT::i16)) &&
441             "type mismatch between callee prototype and arguments");
442      // scalar type
443      unsigned sz = 0;
444      if (isa<IntegerType>(Ty)) {
445        sz = cast<IntegerType>(Ty)->getBitWidth();
446        if (sz < 32)
447          sz = 32;
448      } else if (isa<PointerType>(Ty))
449        sz = thePointerTy.getSizeInBits();
450      else
451        sz = Ty->getPrimitiveSizeInBits();
452      O << ".param .b" << sz << " ";
453      O << "_";
454      continue;
455    }
456    const PointerType *PTy = dyn_cast<PointerType>(Ty);
457    assert(PTy && "Param with byval attribute should be a pointer type");
458    Type *ETy = PTy->getElementType();
459
460    unsigned align = Outs[OIdx].Flags.getByValAlign();
461    unsigned sz = getDataLayout()->getTypeAllocSize(ETy);
462    O << ".param .align " << align << " .b8 ";
463    O << "_";
464    O << "[" << sz << "]";
465  }
466  O << ");";
467  return O.str();
468}
469
470unsigned
471NVPTXTargetLowering::getArgumentAlignment(SDValue Callee,
472                                          const ImmutableCallSite *CS,
473                                          Type *Ty,
474                                          unsigned Idx) const {
475  const DataLayout *TD = getDataLayout();
476  unsigned Align = 0;
477  const Value *DirectCallee = CS->getCalledFunction();
478
479  if (!DirectCallee) {
480    // We don't have a direct function symbol, but that may be because of
481    // constant cast instructions in the call.
482    const Instruction *CalleeI = CS->getInstruction();
483    assert(CalleeI && "Call target is not a function or derived value?");
484
485    // With bitcast'd call targets, the instruction will be the call
486    if (isa<CallInst>(CalleeI)) {
487      // Check if we have call alignment metadata
488      if (llvm::getAlign(*cast<CallInst>(CalleeI), Idx, Align))
489        return Align;
490
491      const Value *CalleeV = cast<CallInst>(CalleeI)->getCalledValue();
492      // Ignore any bitcast instructions
493      while(isa<ConstantExpr>(CalleeV)) {
494        const ConstantExpr *CE = cast<ConstantExpr>(CalleeV);
495        if (!CE->isCast())
496          break;
497        // Look through the bitcast
498        CalleeV = cast<ConstantExpr>(CalleeV)->getOperand(0);
499      }
500
501      // We have now looked past all of the bitcasts.  Do we finally have a
502      // Function?
503      if (isa<Function>(CalleeV))
504        DirectCallee = CalleeV;
505    }
506  }
507
508  // Check for function alignment information if we found that the
509  // ultimate target is a Function
510  if (DirectCallee)
511    if (llvm::getAlign(*cast<Function>(DirectCallee), Idx, Align))
512      return Align;
513
514  // Call is indirect or alignment information is not available, fall back to
515  // the ABI type alignment
516  return TD->getABITypeAlignment(Ty);
517}
518
519SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
520                                       SmallVectorImpl<SDValue> &InVals) const {
521  SelectionDAG &DAG = CLI.DAG;
522  SDLoc dl = CLI.DL;
523  SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
524  SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
525  SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
526  SDValue Chain = CLI.Chain;
527  SDValue Callee = CLI.Callee;
528  bool &isTailCall = CLI.IsTailCall;
529  ArgListTy &Args = CLI.Args;
530  Type *retTy = CLI.RetTy;
531  ImmutableCallSite *CS = CLI.CS;
532
533  bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
534  assert(isABI && "Non-ABI compilation is not supported");
535  if (!isABI)
536    return Chain;
537  const DataLayout *TD = getDataLayout();
538  MachineFunction &MF = DAG.getMachineFunction();
539  const Function *F = MF.getFunction();
540
541  SDValue tempChain = Chain;
542  Chain =
543      DAG.getCALLSEQ_START(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
544                           dl);
545  SDValue InFlag = Chain.getValue(1);
546
547  unsigned paramCount = 0;
548  // Args.size() and Outs.size() need not match.
549  // Outs.size() will be larger
550  //   * if there is an aggregate argument with multiple fields (each field
551  //     showing up separately in Outs)
552  //   * if there is a vector argument with more than typical vector-length
553  //     elements (generally if more than 4) where each vector element is
554  //     individually present in Outs.
555  // So a different index should be used for indexing into Outs/OutVals.
556  // See similar issue in LowerFormalArguments.
557  unsigned OIdx = 0;
558  // Declare the .params or .reg need to pass values
559  // to the function
560  for (unsigned i = 0, e = Args.size(); i != e; ++i, ++OIdx) {
561    EVT VT = Outs[OIdx].VT;
562    Type *Ty = Args[i].Ty;
563
564    if (Outs[OIdx].Flags.isByVal() == false) {
565      if (Ty->isAggregateType()) {
566        // aggregate
567        SmallVector<EVT, 16> vtparts;
568        ComputeValueVTs(*this, Ty, vtparts);
569
570        unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
571        // declare .param .align <align> .b8 .param<n>[<size>];
572        unsigned sz = TD->getTypeAllocSize(Ty);
573        SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
574        SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
575                                      DAG.getConstant(paramCount, MVT::i32),
576                                      DAG.getConstant(sz, MVT::i32), InFlag };
577        Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
578                            DeclareParamOps, 5);
579        InFlag = Chain.getValue(1);
580        unsigned curOffset = 0;
581        for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
582          unsigned elems = 1;
583          EVT elemtype = vtparts[j];
584          if (vtparts[j].isVector()) {
585            elems = vtparts[j].getVectorNumElements();
586            elemtype = vtparts[j].getVectorElementType();
587          }
588          for (unsigned k = 0, ke = elems; k != ke; ++k) {
589            unsigned sz = elemtype.getSizeInBits();
590            if (elemtype.isInteger() && (sz < 8))
591              sz = 8;
592            SDValue StVal = OutVals[OIdx];
593            if (elemtype.getSizeInBits() < 16) {
594              StVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, StVal);
595            }
596            SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
597            SDValue CopyParamOps[] = { Chain,
598                                       DAG.getConstant(paramCount, MVT::i32),
599                                       DAG.getConstant(curOffset, MVT::i32),
600                                       StVal, InFlag };
601            Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
602                                            CopyParamVTs, &CopyParamOps[0], 5,
603                                            elemtype, MachinePointerInfo());
604            InFlag = Chain.getValue(1);
605            curOffset += sz / 8;
606            ++OIdx;
607          }
608        }
609        if (vtparts.size() > 0)
610          --OIdx;
611        ++paramCount;
612        continue;
613      }
614      if (Ty->isVectorTy()) {
615        EVT ObjectVT = getValueType(Ty);
616        unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
617        // declare .param .align <align> .b8 .param<n>[<size>];
618        unsigned sz = TD->getTypeAllocSize(Ty);
619        SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
620        SDValue DeclareParamOps[] = { Chain, DAG.getConstant(align, MVT::i32),
621                                      DAG.getConstant(paramCount, MVT::i32),
622                                      DAG.getConstant(sz, MVT::i32), InFlag };
623        Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
624                            DeclareParamOps, 5);
625        InFlag = Chain.getValue(1);
626        unsigned NumElts = ObjectVT.getVectorNumElements();
627        EVT EltVT = ObjectVT.getVectorElementType();
628        EVT MemVT = EltVT;
629        bool NeedExtend = false;
630        if (EltVT.getSizeInBits() < 16) {
631          NeedExtend = true;
632          EltVT = MVT::i16;
633        }
634
635        // V1 store
636        if (NumElts == 1) {
637          SDValue Elt = OutVals[OIdx++];
638          if (NeedExtend)
639            Elt = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt);
640
641          SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
642          SDValue CopyParamOps[] = { Chain,
643                                     DAG.getConstant(paramCount, MVT::i32),
644                                     DAG.getConstant(0, MVT::i32), Elt,
645                                     InFlag };
646          Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
647                                          CopyParamVTs, &CopyParamOps[0], 5,
648                                          MemVT, MachinePointerInfo());
649          InFlag = Chain.getValue(1);
650        } else if (NumElts == 2) {
651          SDValue Elt0 = OutVals[OIdx++];
652          SDValue Elt1 = OutVals[OIdx++];
653          if (NeedExtend) {
654            Elt0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt0);
655            Elt1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Elt1);
656          }
657
658          SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
659          SDValue CopyParamOps[] = { Chain,
660                                     DAG.getConstant(paramCount, MVT::i32),
661                                     DAG.getConstant(0, MVT::i32), Elt0, Elt1,
662                                     InFlag };
663          Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParamV2, dl,
664                                          CopyParamVTs, &CopyParamOps[0], 6,
665                                          MemVT, MachinePointerInfo());
666          InFlag = Chain.getValue(1);
667        } else {
668          unsigned curOffset = 0;
669          // V4 stores
670          // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
671          // the
672          // vector will be expanded to a power of 2 elements, so we know we can
673          // always round up to the next multiple of 4 when creating the vector
674          // stores.
675          // e.g.  4 elem => 1 st.v4
676          //       6 elem => 2 st.v4
677          //       8 elem => 2 st.v4
678          //      11 elem => 3 st.v4
679          unsigned VecSize = 4;
680          if (EltVT.getSizeInBits() == 64)
681            VecSize = 2;
682
683          // This is potentially only part of a vector, so assume all elements
684          // are packed together.
685          unsigned PerStoreOffset = MemVT.getStoreSizeInBits() / 8 * VecSize;
686
687          for (unsigned i = 0; i < NumElts; i += VecSize) {
688            // Get values
689            SDValue StoreVal;
690            SmallVector<SDValue, 8> Ops;
691            Ops.push_back(Chain);
692            Ops.push_back(DAG.getConstant(paramCount, MVT::i32));
693            Ops.push_back(DAG.getConstant(curOffset, MVT::i32));
694
695            unsigned Opc = NVPTXISD::StoreParamV2;
696
697            StoreVal = OutVals[OIdx++];
698            if (NeedExtend)
699              StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
700            Ops.push_back(StoreVal);
701
702            if (i + 1 < NumElts) {
703              StoreVal = OutVals[OIdx++];
704              if (NeedExtend)
705                StoreVal =
706                    DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
707            } else {
708              StoreVal = DAG.getUNDEF(EltVT);
709            }
710            Ops.push_back(StoreVal);
711
712            if (VecSize == 4) {
713              Opc = NVPTXISD::StoreParamV4;
714              if (i + 2 < NumElts) {
715                StoreVal = OutVals[OIdx++];
716                if (NeedExtend)
717                  StoreVal =
718                      DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
719              } else {
720                StoreVal = DAG.getUNDEF(EltVT);
721              }
722              Ops.push_back(StoreVal);
723
724              if (i + 3 < NumElts) {
725                StoreVal = OutVals[OIdx++];
726                if (NeedExtend)
727                  StoreVal =
728                      DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
729              } else {
730                StoreVal = DAG.getUNDEF(EltVT);
731              }
732              Ops.push_back(StoreVal);
733            }
734
735            Ops.push_back(InFlag);
736
737            SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
738            Chain = DAG.getMemIntrinsicNode(Opc, dl, CopyParamVTs, &Ops[0],
739                                            Ops.size(), MemVT,
740                                            MachinePointerInfo());
741            InFlag = Chain.getValue(1);
742            curOffset += PerStoreOffset;
743          }
744        }
745        ++paramCount;
746        --OIdx;
747        continue;
748      }
749      // Plain scalar
750      // for ABI,    declare .param .b<size> .param<n>;
751      unsigned sz = VT.getSizeInBits();
752      bool needExtend = false;
753      if (VT.isInteger()) {
754        if (sz < 16)
755          needExtend = true;
756        if (sz < 32)
757          sz = 32;
758      }
759      SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
760      SDValue DeclareParamOps[] = { Chain,
761                                    DAG.getConstant(paramCount, MVT::i32),
762                                    DAG.getConstant(sz, MVT::i32),
763                                    DAG.getConstant(0, MVT::i32), InFlag };
764      Chain = DAG.getNode(NVPTXISD::DeclareScalarParam, dl, DeclareParamVTs,
765                          DeclareParamOps, 5);
766      InFlag = Chain.getValue(1);
767      SDValue OutV = OutVals[OIdx];
768      if (needExtend) {
769        // zext/sext i1 to i16
770        unsigned opc = ISD::ZERO_EXTEND;
771        if (Outs[OIdx].Flags.isSExt())
772          opc = ISD::SIGN_EXTEND;
773        OutV = DAG.getNode(opc, dl, MVT::i16, OutV);
774      }
775      SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
776      SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
777                                 DAG.getConstant(0, MVT::i32), OutV, InFlag };
778
779      unsigned opcode = NVPTXISD::StoreParam;
780      if (Outs[OIdx].Flags.isZExt())
781        opcode = NVPTXISD::StoreParamU32;
782      else if (Outs[OIdx].Flags.isSExt())
783        opcode = NVPTXISD::StoreParamS32;
784      Chain = DAG.getMemIntrinsicNode(opcode, dl, CopyParamVTs, CopyParamOps, 5,
785                                      VT, MachinePointerInfo());
786
787      InFlag = Chain.getValue(1);
788      ++paramCount;
789      continue;
790    }
791    // struct or vector
792    SmallVector<EVT, 16> vtparts;
793    const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
794    assert(PTy && "Type of a byval parameter should be pointer");
795    ComputeValueVTs(*this, PTy->getElementType(), vtparts);
796
797    // declare .param .align <align> .b8 .param<n>[<size>];
798    unsigned sz = Outs[OIdx].Flags.getByValSize();
799    SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
800    // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
801    // so we don't need to worry about natural alignment or not.
802    // See TargetLowering::LowerCallTo().
803    SDValue DeclareParamOps[] = {
804      Chain, DAG.getConstant(Outs[OIdx].Flags.getByValAlign(), MVT::i32),
805      DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(sz, MVT::i32),
806      InFlag
807    };
808    Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
809                        DeclareParamOps, 5);
810    InFlag = Chain.getValue(1);
811    unsigned curOffset = 0;
812    for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
813      unsigned elems = 1;
814      EVT elemtype = vtparts[j];
815      if (vtparts[j].isVector()) {
816        elems = vtparts[j].getVectorNumElements();
817        elemtype = vtparts[j].getVectorElementType();
818      }
819      for (unsigned k = 0, ke = elems; k != ke; ++k) {
820        unsigned sz = elemtype.getSizeInBits();
821        if (elemtype.isInteger() && (sz < 8))
822          sz = 8;
823        SDValue srcAddr =
824            DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
825                        DAG.getConstant(curOffset, getPointerTy()));
826        SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
827                                     MachinePointerInfo(), false, false, false,
828                                     0);
829        if (elemtype.getSizeInBits() < 16) {
830          theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
831        }
832        SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
833        SDValue CopyParamOps[] = { Chain, DAG.getConstant(paramCount, MVT::i32),
834                                   DAG.getConstant(curOffset, MVT::i32), theVal,
835                                   InFlag };
836        Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, CopyParamVTs,
837                                        CopyParamOps, 5, elemtype,
838                                        MachinePointerInfo());
839
840        InFlag = Chain.getValue(1);
841        curOffset += sz / 8;
842      }
843    }
844    ++paramCount;
845  }
846
847  GlobalAddressSDNode *Func = dyn_cast<GlobalAddressSDNode>(Callee.getNode());
848  unsigned retAlignment = 0;
849
850  // Handle Result
851  if (Ins.size() > 0) {
852    SmallVector<EVT, 16> resvtparts;
853    ComputeValueVTs(*this, retTy, resvtparts);
854
855    // Declare
856    //  .param .align 16 .b8 retval0[<size-in-bytes>], or
857    //  .param .b<size-in-bits> retval0
858    unsigned resultsz = TD->getTypeAllocSizeInBits(retTy);
859    if (retTy->isPrimitiveType() || retTy->isIntegerTy() ||
860        retTy->isPointerTy()) {
861      // Scalar needs to be at least 32bit wide
862      if (resultsz < 32)
863        resultsz = 32;
864      SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
865      SDValue DeclareRetOps[] = { Chain, DAG.getConstant(1, MVT::i32),
866                                  DAG.getConstant(resultsz, MVT::i32),
867                                  DAG.getConstant(0, MVT::i32), InFlag };
868      Chain = DAG.getNode(NVPTXISD::DeclareRet, dl, DeclareRetVTs,
869                          DeclareRetOps, 5);
870      InFlag = Chain.getValue(1);
871    } else {
872      retAlignment = getArgumentAlignment(Callee, CS, retTy, 0);
873      SDVTList DeclareRetVTs = DAG.getVTList(MVT::Other, MVT::Glue);
874      SDValue DeclareRetOps[] = { Chain,
875                                  DAG.getConstant(retAlignment, MVT::i32),
876                                  DAG.getConstant(resultsz / 8, MVT::i32),
877                                  DAG.getConstant(0, MVT::i32), InFlag };
878      Chain = DAG.getNode(NVPTXISD::DeclareRetParam, dl, DeclareRetVTs,
879                          DeclareRetOps, 5);
880      InFlag = Chain.getValue(1);
881    }
882  }
883
884  if (!Func) {
885    // This is indirect function call case : PTX requires a prototype of the
886    // form
887    // proto_0 : .callprototype(.param .b32 _) _ (.param .b32 _);
888    // to be emitted, and the label has to used as the last arg of call
889    // instruction.
890    // The prototype is embedded in a string and put as the operand for a
891    // CallPrototype SDNode which will print out to the value of the string.
892    SDVTList ProtoVTs = DAG.getVTList(MVT::Other, MVT::Glue);
893    std::string Proto = getPrototype(retTy, Args, Outs, retAlignment, CS);
894    const char *ProtoStr =
895      nvTM->getManagedStrPool()->getManagedString(Proto.c_str())->c_str();
896    SDValue ProtoOps[] = {
897      Chain, DAG.getTargetExternalSymbol(ProtoStr, MVT::i32), InFlag,
898    };
899    Chain = DAG.getNode(NVPTXISD::CallPrototype, dl, ProtoVTs, &ProtoOps[0], 3);
900    InFlag = Chain.getValue(1);
901  }
902  // Op to just print "call"
903  SDVTList PrintCallVTs = DAG.getVTList(MVT::Other, MVT::Glue);
904  SDValue PrintCallOps[] = {
905    Chain, DAG.getConstant((Ins.size() == 0) ? 0 : 1, MVT::i32), InFlag
906  };
907  Chain = DAG.getNode(Func ? (NVPTXISD::PrintCallUni) : (NVPTXISD::PrintCall),
908                      dl, PrintCallVTs, PrintCallOps, 3);
909  InFlag = Chain.getValue(1);
910
911  // Ops to print out the function name
912  SDVTList CallVoidVTs = DAG.getVTList(MVT::Other, MVT::Glue);
913  SDValue CallVoidOps[] = { Chain, Callee, InFlag };
914  Chain = DAG.getNode(NVPTXISD::CallVoid, dl, CallVoidVTs, CallVoidOps, 3);
915  InFlag = Chain.getValue(1);
916
917  // Ops to print out the param list
918  SDVTList CallArgBeginVTs = DAG.getVTList(MVT::Other, MVT::Glue);
919  SDValue CallArgBeginOps[] = { Chain, InFlag };
920  Chain = DAG.getNode(NVPTXISD::CallArgBegin, dl, CallArgBeginVTs,
921                      CallArgBeginOps, 2);
922  InFlag = Chain.getValue(1);
923
924  for (unsigned i = 0, e = paramCount; i != e; ++i) {
925    unsigned opcode;
926    if (i == (e - 1))
927      opcode = NVPTXISD::LastCallArg;
928    else
929      opcode = NVPTXISD::CallArg;
930    SDVTList CallArgVTs = DAG.getVTList(MVT::Other, MVT::Glue);
931    SDValue CallArgOps[] = { Chain, DAG.getConstant(1, MVT::i32),
932                             DAG.getConstant(i, MVT::i32), InFlag };
933    Chain = DAG.getNode(opcode, dl, CallArgVTs, CallArgOps, 4);
934    InFlag = Chain.getValue(1);
935  }
936  SDVTList CallArgEndVTs = DAG.getVTList(MVT::Other, MVT::Glue);
937  SDValue CallArgEndOps[] = { Chain, DAG.getConstant(Func ? 1 : 0, MVT::i32),
938                              InFlag };
939  Chain =
940      DAG.getNode(NVPTXISD::CallArgEnd, dl, CallArgEndVTs, CallArgEndOps, 3);
941  InFlag = Chain.getValue(1);
942
943  if (!Func) {
944    SDVTList PrototypeVTs = DAG.getVTList(MVT::Other, MVT::Glue);
945    SDValue PrototypeOps[] = { Chain, DAG.getConstant(uniqueCallSite, MVT::i32),
946                               InFlag };
947    Chain = DAG.getNode(NVPTXISD::Prototype, dl, PrototypeVTs, PrototypeOps, 3);
948    InFlag = Chain.getValue(1);
949  }
950
951  // Generate loads from param memory/moves from registers for result
952  if (Ins.size() > 0) {
953    unsigned resoffset = 0;
954    if (retTy && retTy->isVectorTy()) {
955      EVT ObjectVT = getValueType(retTy);
956      unsigned NumElts = ObjectVT.getVectorNumElements();
957      EVT EltVT = ObjectVT.getVectorElementType();
958      assert(nvTM->getTargetLowering()->getNumRegisters(F->getContext(),
959                                                        ObjectVT) == NumElts &&
960             "Vector was not scalarized");
961      unsigned sz = EltVT.getSizeInBits();
962      bool needTruncate = sz < 16 ? true : false;
963
964      if (NumElts == 1) {
965        // Just a simple load
966        std::vector<EVT> LoadRetVTs;
967        if (needTruncate) {
968          // If loading i1 result, generate
969          //   load i16
970          //   trunc i16 to i1
971          LoadRetVTs.push_back(MVT::i16);
972        } else
973          LoadRetVTs.push_back(EltVT);
974        LoadRetVTs.push_back(MVT::Other);
975        LoadRetVTs.push_back(MVT::Glue);
976        std::vector<SDValue> LoadRetOps;
977        LoadRetOps.push_back(Chain);
978        LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
979        LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
980        LoadRetOps.push_back(InFlag);
981        SDValue retval = DAG.getMemIntrinsicNode(
982            NVPTXISD::LoadParam, dl,
983            DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
984            LoadRetOps.size(), EltVT, MachinePointerInfo());
985        Chain = retval.getValue(1);
986        InFlag = retval.getValue(2);
987        SDValue Ret0 = retval;
988        if (needTruncate)
989          Ret0 = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Ret0);
990        InVals.push_back(Ret0);
991      } else if (NumElts == 2) {
992        // LoadV2
993        std::vector<EVT> LoadRetVTs;
994        if (needTruncate) {
995          // If loading i1 result, generate
996          //   load i16
997          //   trunc i16 to i1
998          LoadRetVTs.push_back(MVT::i16);
999          LoadRetVTs.push_back(MVT::i16);
1000        } else {
1001          LoadRetVTs.push_back(EltVT);
1002          LoadRetVTs.push_back(EltVT);
1003        }
1004        LoadRetVTs.push_back(MVT::Other);
1005        LoadRetVTs.push_back(MVT::Glue);
1006        std::vector<SDValue> LoadRetOps;
1007        LoadRetOps.push_back(Chain);
1008        LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1009        LoadRetOps.push_back(DAG.getConstant(0, MVT::i32));
1010        LoadRetOps.push_back(InFlag);
1011        SDValue retval = DAG.getMemIntrinsicNode(
1012            NVPTXISD::LoadParamV2, dl,
1013            DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1014            LoadRetOps.size(), EltVT, MachinePointerInfo());
1015        Chain = retval.getValue(2);
1016        InFlag = retval.getValue(3);
1017        SDValue Ret0 = retval.getValue(0);
1018        SDValue Ret1 = retval.getValue(1);
1019        if (needTruncate) {
1020          Ret0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret0);
1021          InVals.push_back(Ret0);
1022          Ret1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Ret1);
1023          InVals.push_back(Ret1);
1024        } else {
1025          InVals.push_back(Ret0);
1026          InVals.push_back(Ret1);
1027        }
1028      } else {
1029        // Split into N LoadV4
1030        unsigned Ofst = 0;
1031        unsigned VecSize = 4;
1032        unsigned Opc = NVPTXISD::LoadParamV4;
1033        if (EltVT.getSizeInBits() == 64) {
1034          VecSize = 2;
1035          Opc = NVPTXISD::LoadParamV2;
1036        }
1037        EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1038        for (unsigned i = 0; i < NumElts; i += VecSize) {
1039          SmallVector<EVT, 8> LoadRetVTs;
1040          if (needTruncate) {
1041            // If loading i1 result, generate
1042            //   load i16
1043            //   trunc i16 to i1
1044            for (unsigned j = 0; j < VecSize; ++j)
1045              LoadRetVTs.push_back(MVT::i16);
1046          } else {
1047            for (unsigned j = 0; j < VecSize; ++j)
1048              LoadRetVTs.push_back(EltVT);
1049          }
1050          LoadRetVTs.push_back(MVT::Other);
1051          LoadRetVTs.push_back(MVT::Glue);
1052          SmallVector<SDValue, 4> LoadRetOps;
1053          LoadRetOps.push_back(Chain);
1054          LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1055          LoadRetOps.push_back(DAG.getConstant(Ofst, MVT::i32));
1056          LoadRetOps.push_back(InFlag);
1057          SDValue retval = DAG.getMemIntrinsicNode(
1058              Opc, dl, DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()),
1059              &LoadRetOps[0], LoadRetOps.size(), EltVT, MachinePointerInfo());
1060          if (VecSize == 2) {
1061            Chain = retval.getValue(2);
1062            InFlag = retval.getValue(3);
1063          } else {
1064            Chain = retval.getValue(4);
1065            InFlag = retval.getValue(5);
1066          }
1067
1068          for (unsigned j = 0; j < VecSize; ++j) {
1069            if (i + j >= NumElts)
1070              break;
1071            SDValue Elt = retval.getValue(j);
1072            if (needTruncate)
1073              Elt = DAG.getNode(ISD::TRUNCATE, dl, EltVT, Elt);
1074            InVals.push_back(Elt);
1075          }
1076          Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1077        }
1078      }
1079    } else {
1080      SmallVector<EVT, 16> VTs;
1081      ComputePTXValueVTs(*this, retTy, VTs);
1082      assert(VTs.size() == Ins.size() && "Bad value decomposition");
1083      for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
1084        unsigned sz = VTs[i].getSizeInBits();
1085        bool needTruncate = sz < 8 ? true : false;
1086        if (VTs[i].isInteger() && (sz < 8))
1087          sz = 8;
1088
1089        SmallVector<EVT, 4> LoadRetVTs;
1090        EVT TheLoadType = VTs[i];
1091        if (retTy->isIntegerTy() &&
1092            TD->getTypeAllocSizeInBits(retTy) < 32) {
1093          // This is for integer types only, and specifically not for
1094          // aggregates.
1095          LoadRetVTs.push_back(MVT::i32);
1096          TheLoadType = MVT::i32;
1097        } else if (sz < 16) {
1098          // If loading i1/i8 result, generate
1099          //   load i8 (-> i16)
1100          //   trunc i16 to i1/i8
1101          LoadRetVTs.push_back(MVT::i16);
1102        } else
1103          LoadRetVTs.push_back(Ins[i].VT);
1104        LoadRetVTs.push_back(MVT::Other);
1105        LoadRetVTs.push_back(MVT::Glue);
1106
1107        SmallVector<SDValue, 4> LoadRetOps;
1108        LoadRetOps.push_back(Chain);
1109        LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
1110        LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32));
1111        LoadRetOps.push_back(InFlag);
1112        SDValue retval = DAG.getMemIntrinsicNode(
1113            NVPTXISD::LoadParam, dl,
1114            DAG.getVTList(&LoadRetVTs[0], LoadRetVTs.size()), &LoadRetOps[0],
1115            LoadRetOps.size(), TheLoadType, MachinePointerInfo());
1116        Chain = retval.getValue(1);
1117        InFlag = retval.getValue(2);
1118        SDValue Ret0 = retval.getValue(0);
1119        if (needTruncate)
1120          Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
1121        InVals.push_back(Ret0);
1122        resoffset += sz / 8;
1123      }
1124    }
1125  }
1126
1127  Chain = DAG.getCALLSEQ_END(Chain, DAG.getIntPtrConstant(uniqueCallSite, true),
1128                             DAG.getIntPtrConstant(uniqueCallSite + 1, true),
1129                             InFlag, dl);
1130  uniqueCallSite++;
1131
1132  // set isTailCall to false for now, until we figure out how to express
1133  // tail call optimization in PTX
1134  isTailCall = false;
1135  return Chain;
1136}
1137
1138// By default CONCAT_VECTORS is lowered by ExpandVectorBuildThroughStack()
1139// (see LegalizeDAG.cpp). This is slow and uses local memory.
1140// We use extract/insert/build vector just as what LegalizeOp() does in llvm 2.5
1141SDValue
1142NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
1143  SDNode *Node = Op.getNode();
1144  SDLoc dl(Node);
1145  SmallVector<SDValue, 8> Ops;
1146  unsigned NumOperands = Node->getNumOperands();
1147  for (unsigned i = 0; i < NumOperands; ++i) {
1148    SDValue SubOp = Node->getOperand(i);
1149    EVT VVT = SubOp.getNode()->getValueType(0);
1150    EVT EltVT = VVT.getVectorElementType();
1151    unsigned NumSubElem = VVT.getVectorNumElements();
1152    for (unsigned j = 0; j < NumSubElem; ++j) {
1153      Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, SubOp,
1154                                DAG.getIntPtrConstant(j)));
1155    }
1156  }
1157  return DAG.getNode(ISD::BUILD_VECTOR, dl, Node->getValueType(0), &Ops[0],
1158                     Ops.size());
1159}
1160
1161SDValue
1162NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
1163  switch (Op.getOpcode()) {
1164  case ISD::RETURNADDR:
1165    return SDValue();
1166  case ISD::FRAMEADDR:
1167    return SDValue();
1168  case ISD::GlobalAddress:
1169    return LowerGlobalAddress(Op, DAG);
1170  case ISD::INTRINSIC_W_CHAIN:
1171    return Op;
1172  case ISD::BUILD_VECTOR:
1173  case ISD::EXTRACT_SUBVECTOR:
1174    return Op;
1175  case ISD::CONCAT_VECTORS:
1176    return LowerCONCAT_VECTORS(Op, DAG);
1177  case ISD::STORE:
1178    return LowerSTORE(Op, DAG);
1179  case ISD::LOAD:
1180    return LowerLOAD(Op, DAG);
1181  default:
1182    llvm_unreachable("Custom lowering not defined for operation");
1183  }
1184}
1185
1186SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
1187  if (Op.getValueType() == MVT::i1)
1188    return LowerLOADi1(Op, DAG);
1189  else
1190    return SDValue();
1191}
1192
1193// v = ld i1* addr
1194//   =>
1195// v1 = ld i8* addr (-> i16)
1196// v = trunc i16 to i1
1197SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
1198  SDNode *Node = Op.getNode();
1199  LoadSDNode *LD = cast<LoadSDNode>(Node);
1200  SDLoc dl(Node);
1201  assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
1202  assert(Node->getValueType(0) == MVT::i1 &&
1203         "Custom lowering for i1 load only");
1204  SDValue newLD =
1205      DAG.getLoad(MVT::i16, dl, LD->getChain(), LD->getBasePtr(),
1206                  LD->getPointerInfo(), LD->isVolatile(), LD->isNonTemporal(),
1207                  LD->isInvariant(), LD->getAlignment());
1208  SDValue result = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, newLD);
1209  // The legalizer (the caller) is expecting two values from the legalized
1210  // load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
1211  // in LegalizeDAG.cpp which also uses MergeValues.
1212  SDValue Ops[] = { result, LD->getChain() };
1213  return DAG.getMergeValues(Ops, 2, dl);
1214}
1215
1216SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
1217  EVT ValVT = Op.getOperand(1).getValueType();
1218  if (ValVT == MVT::i1)
1219    return LowerSTOREi1(Op, DAG);
1220  else if (ValVT.isVector())
1221    return LowerSTOREVector(Op, DAG);
1222  else
1223    return SDValue();
1224}
1225
1226SDValue
1227NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
1228  SDNode *N = Op.getNode();
1229  SDValue Val = N->getOperand(1);
1230  SDLoc DL(N);
1231  EVT ValVT = Val.getValueType();
1232
1233  if (ValVT.isVector()) {
1234    // We only handle "native" vector sizes for now, e.g. <4 x double> is not
1235    // legal.  We can (and should) split that into 2 stores of <2 x double> here
1236    // but I'm leaving that as a TODO for now.
1237    if (!ValVT.isSimple())
1238      return SDValue();
1239    switch (ValVT.getSimpleVT().SimpleTy) {
1240    default:
1241      return SDValue();
1242    case MVT::v2i8:
1243    case MVT::v2i16:
1244    case MVT::v2i32:
1245    case MVT::v2i64:
1246    case MVT::v2f32:
1247    case MVT::v2f64:
1248    case MVT::v4i8:
1249    case MVT::v4i16:
1250    case MVT::v4i32:
1251    case MVT::v4f32:
1252      // This is a "native" vector type
1253      break;
1254    }
1255
1256    unsigned Opcode = 0;
1257    EVT EltVT = ValVT.getVectorElementType();
1258    unsigned NumElts = ValVT.getVectorNumElements();
1259
1260    // Since StoreV2 is a target node, we cannot rely on DAG type legalization.
1261    // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
1262    // stored type to i16 and propogate the "real" type as the memory type.
1263    bool NeedExt = false;
1264    if (EltVT.getSizeInBits() < 16)
1265      NeedExt = true;
1266
1267    switch (NumElts) {
1268    default:
1269      return SDValue();
1270    case 2:
1271      Opcode = NVPTXISD::StoreV2;
1272      break;
1273    case 4: {
1274      Opcode = NVPTXISD::StoreV4;
1275      break;
1276    }
1277    }
1278
1279    SmallVector<SDValue, 8> Ops;
1280
1281    // First is the chain
1282    Ops.push_back(N->getOperand(0));
1283
1284    // Then the split values
1285    for (unsigned i = 0; i < NumElts; ++i) {
1286      SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
1287                                   DAG.getIntPtrConstant(i));
1288      if (NeedExt)
1289        ExtVal = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i16, ExtVal);
1290      Ops.push_back(ExtVal);
1291    }
1292
1293    // Then any remaining arguments
1294    for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i) {
1295      Ops.push_back(N->getOperand(i));
1296    }
1297
1298    MemSDNode *MemSD = cast<MemSDNode>(N);
1299
1300    SDValue NewSt = DAG.getMemIntrinsicNode(
1301        Opcode, DL, DAG.getVTList(MVT::Other), &Ops[0], Ops.size(),
1302        MemSD->getMemoryVT(), MemSD->getMemOperand());
1303
1304    //return DCI.CombineTo(N, NewSt, true);
1305    return NewSt;
1306  }
1307
1308  return SDValue();
1309}
1310
1311// st i1 v, addr
1312//    =>
1313// v1 = zxt v to i16
1314// st.u8 i16, addr
1315SDValue NVPTXTargetLowering::LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const {
1316  SDNode *Node = Op.getNode();
1317  SDLoc dl(Node);
1318  StoreSDNode *ST = cast<StoreSDNode>(Node);
1319  SDValue Tmp1 = ST->getChain();
1320  SDValue Tmp2 = ST->getBasePtr();
1321  SDValue Tmp3 = ST->getValue();
1322  assert(Tmp3.getValueType() == MVT::i1 && "Custom lowering for i1 store only");
1323  unsigned Alignment = ST->getAlignment();
1324  bool isVolatile = ST->isVolatile();
1325  bool isNonTemporal = ST->isNonTemporal();
1326  Tmp3 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, Tmp3);
1327  SDValue Result = DAG.getTruncStore(Tmp1, dl, Tmp3, Tmp2,
1328                                     ST->getPointerInfo(), MVT::i8, isNonTemporal,
1329                                     isVolatile, Alignment);
1330  return Result;
1331}
1332
1333SDValue NVPTXTargetLowering::getExtSymb(SelectionDAG &DAG, const char *inname,
1334                                        int idx, EVT v) const {
1335  std::string *name = nvTM->getManagedStrPool()->getManagedString(inname);
1336  std::stringstream suffix;
1337  suffix << idx;
1338  *name += suffix.str();
1339  return DAG.getTargetExternalSymbol(name->c_str(), v);
1340}
1341
1342SDValue
1343NVPTXTargetLowering::getParamSymbol(SelectionDAG &DAG, int idx, EVT v) const {
1344  std::string ParamSym;
1345  raw_string_ostream ParamStr(ParamSym);
1346
1347  ParamStr << DAG.getMachineFunction().getName() << "_param_" << idx;
1348  ParamStr.flush();
1349
1350  std::string *SavedStr =
1351    nvTM->getManagedStrPool()->getManagedString(ParamSym.c_str());
1352  return DAG.getTargetExternalSymbol(SavedStr->c_str(), v);
1353}
1354
1355SDValue NVPTXTargetLowering::getParamHelpSymbol(SelectionDAG &DAG, int idx) {
1356  return getExtSymb(DAG, ".HLPPARAM", idx);
1357}
1358
1359// Check to see if the kernel argument is image*_t or sampler_t
1360
1361bool llvm::isImageOrSamplerVal(const Value *arg, const Module *context) {
1362  static const char *const specialTypes[] = { "struct._image2d_t",
1363                                              "struct._image3d_t",
1364                                              "struct._sampler_t" };
1365
1366  const Type *Ty = arg->getType();
1367  const PointerType *PTy = dyn_cast<PointerType>(Ty);
1368
1369  if (!PTy)
1370    return false;
1371
1372  if (!context)
1373    return false;
1374
1375  const StructType *STy = dyn_cast<StructType>(PTy->getElementType());
1376  const std::string TypeName = STy && !STy->isLiteral() ? STy->getName() : "";
1377
1378  for (int i = 0, e = array_lengthof(specialTypes); i != e; ++i)
1379    if (TypeName == specialTypes[i])
1380      return true;
1381
1382  return false;
1383}
1384
1385SDValue NVPTXTargetLowering::LowerFormalArguments(
1386    SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
1387    const SmallVectorImpl<ISD::InputArg> &Ins, SDLoc dl, SelectionDAG &DAG,
1388    SmallVectorImpl<SDValue> &InVals) const {
1389  MachineFunction &MF = DAG.getMachineFunction();
1390  const DataLayout *TD = getDataLayout();
1391
1392  const Function *F = MF.getFunction();
1393  const AttributeSet &PAL = F->getAttributes();
1394  const TargetLowering *TLI = nvTM->getTargetLowering();
1395
1396  SDValue Root = DAG.getRoot();
1397  std::vector<SDValue> OutChains;
1398
1399  bool isKernel = llvm::isKernelFunction(*F);
1400  bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1401  assert(isABI && "Non-ABI compilation is not supported");
1402  if (!isABI)
1403    return Chain;
1404
1405  std::vector<Type *> argTypes;
1406  std::vector<const Argument *> theArgs;
1407  for (Function::const_arg_iterator I = F->arg_begin(), E = F->arg_end();
1408       I != E; ++I) {
1409    theArgs.push_back(I);
1410    argTypes.push_back(I->getType());
1411  }
1412  // argTypes.size() (or theArgs.size()) and Ins.size() need not match.
1413  // Ins.size() will be larger
1414  //   * if there is an aggregate argument with multiple fields (each field
1415  //     showing up separately in Ins)
1416  //   * if there is a vector argument with more than typical vector-length
1417  //     elements (generally if more than 4) where each vector element is
1418  //     individually present in Ins.
1419  // So a different index should be used for indexing into Ins.
1420  // See similar issue in LowerCall.
1421  unsigned InsIdx = 0;
1422
1423  int idx = 0;
1424  for (unsigned i = 0, e = theArgs.size(); i != e; ++i, ++idx, ++InsIdx) {
1425    Type *Ty = argTypes[i];
1426
1427    // If the kernel argument is image*_t or sampler_t, convert it to
1428    // a i32 constant holding the parameter position. This can later
1429    // matched in the AsmPrinter to output the correct mangled name.
1430    if (isImageOrSamplerVal(
1431            theArgs[i],
1432            (theArgs[i]->getParent() ? theArgs[i]->getParent()->getParent()
1433                                     : 0))) {
1434      assert(isKernel && "Only kernels can have image/sampler params");
1435      InVals.push_back(DAG.getConstant(i + 1, MVT::i32));
1436      continue;
1437    }
1438
1439    if (theArgs[i]->use_empty()) {
1440      // argument is dead
1441      if (Ty->isAggregateType()) {
1442        SmallVector<EVT, 16> vtparts;
1443
1444        ComputePTXValueVTs(*this, Ty, vtparts);
1445        assert(vtparts.size() > 0 && "empty aggregate type not expected");
1446        for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1447             ++parti) {
1448          EVT partVT = vtparts[parti];
1449          InVals.push_back(DAG.getNode(ISD::UNDEF, dl, partVT));
1450          ++InsIdx;
1451        }
1452        if (vtparts.size() > 0)
1453          --InsIdx;
1454        continue;
1455      }
1456      if (Ty->isVectorTy()) {
1457        EVT ObjectVT = getValueType(Ty);
1458        unsigned NumRegs = TLI->getNumRegisters(F->getContext(), ObjectVT);
1459        for (unsigned parti = 0; parti < NumRegs; ++parti) {
1460          InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1461          ++InsIdx;
1462        }
1463        if (NumRegs > 0)
1464          --InsIdx;
1465        continue;
1466      }
1467      InVals.push_back(DAG.getNode(ISD::UNDEF, dl, Ins[InsIdx].VT));
1468      continue;
1469    }
1470
1471    // In the following cases, assign a node order of "idx+1"
1472    // to newly created nodes. The SDNodes for params have to
1473    // appear in the same order as their order of appearance
1474    // in the original function. "idx+1" holds that order.
1475    if (PAL.hasAttribute(i + 1, Attribute::ByVal) == false) {
1476      if (Ty->isAggregateType()) {
1477        SmallVector<EVT, 16> vtparts;
1478        SmallVector<uint64_t, 16> offsets;
1479
1480        // NOTE: Here, we lose the ability to issue vector loads for vectors
1481        // that are a part of a struct.  This should be investigated in the
1482        // future.
1483        ComputePTXValueVTs(*this, Ty, vtparts, &offsets, 0);
1484        assert(vtparts.size() > 0 && "empty aggregate type not expected");
1485        bool aggregateIsPacked = false;
1486        if (StructType *STy = llvm::dyn_cast<StructType>(Ty))
1487          aggregateIsPacked = STy->isPacked();
1488
1489        SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1490        for (unsigned parti = 0, parte = vtparts.size(); parti != parte;
1491             ++parti) {
1492          EVT partVT = vtparts[parti];
1493          Value *srcValue = Constant::getNullValue(
1494              PointerType::get(partVT.getTypeForEVT(F->getContext()),
1495                               llvm::ADDRESS_SPACE_PARAM));
1496          SDValue srcAddr =
1497              DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1498                          DAG.getConstant(offsets[parti], getPointerTy()));
1499          unsigned partAlign =
1500              aggregateIsPacked ? 1
1501                                : TD->getABITypeAlignment(
1502                                      partVT.getTypeForEVT(F->getContext()));
1503          SDValue p;
1504          if (Ins[InsIdx].VT.getSizeInBits() > partVT.getSizeInBits()) {
1505            ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ?
1506                                     ISD::SEXTLOAD : ISD::ZEXTLOAD;
1507            p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, srcAddr,
1508                               MachinePointerInfo(srcValue), partVT, false,
1509                               false, partAlign);
1510          } else {
1511            p = DAG.getLoad(partVT, dl, Root, srcAddr,
1512                            MachinePointerInfo(srcValue), false, false, false,
1513                            partAlign);
1514          }
1515          if (p.getNode())
1516            p.getNode()->setIROrder(idx + 1);
1517          InVals.push_back(p);
1518          ++InsIdx;
1519        }
1520        if (vtparts.size() > 0)
1521          --InsIdx;
1522        continue;
1523      }
1524      if (Ty->isVectorTy()) {
1525        EVT ObjectVT = getValueType(Ty);
1526        SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1527        unsigned NumElts = ObjectVT.getVectorNumElements();
1528        assert(TLI->getNumRegisters(F->getContext(), ObjectVT) == NumElts &&
1529               "Vector was not scalarized");
1530        unsigned Ofst = 0;
1531        EVT EltVT = ObjectVT.getVectorElementType();
1532
1533        // V1 load
1534        // f32 = load ...
1535        if (NumElts == 1) {
1536          // We only have one element, so just directly load it
1537          Value *SrcValue = Constant::getNullValue(PointerType::get(
1538              EltVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1539          SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1540                                        DAG.getConstant(Ofst, getPointerTy()));
1541          SDValue P = DAG.getLoad(
1542              EltVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1543              false, true,
1544              TD->getABITypeAlignment(EltVT.getTypeForEVT(F->getContext())));
1545          if (P.getNode())
1546            P.getNode()->setIROrder(idx + 1);
1547
1548          if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1549            P = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, P);
1550          InVals.push_back(P);
1551          Ofst += TD->getTypeAllocSize(EltVT.getTypeForEVT(F->getContext()));
1552          ++InsIdx;
1553        } else if (NumElts == 2) {
1554          // V2 load
1555          // f32,f32 = load ...
1556          EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, 2);
1557          Value *SrcValue = Constant::getNullValue(PointerType::get(
1558              VecVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1559          SDValue SrcAddr = DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1560                                        DAG.getConstant(Ofst, getPointerTy()));
1561          SDValue P = DAG.getLoad(
1562              VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1563              false, true,
1564              TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1565          if (P.getNode())
1566            P.getNode()->setIROrder(idx + 1);
1567
1568          SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1569                                     DAG.getIntPtrConstant(0));
1570          SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1571                                     DAG.getIntPtrConstant(1));
1572
1573          if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits()) {
1574            Elt0 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt0);
1575            Elt1 = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt1);
1576          }
1577
1578          InVals.push_back(Elt0);
1579          InVals.push_back(Elt1);
1580          Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1581          InsIdx += 2;
1582        } else {
1583          // V4 loads
1584          // We have at least 4 elements (<3 x Ty> expands to 4 elements) and
1585          // the
1586          // vector will be expanded to a power of 2 elements, so we know we can
1587          // always round up to the next multiple of 4 when creating the vector
1588          // loads.
1589          // e.g.  4 elem => 1 ld.v4
1590          //       6 elem => 2 ld.v4
1591          //       8 elem => 2 ld.v4
1592          //      11 elem => 3 ld.v4
1593          unsigned VecSize = 4;
1594          if (EltVT.getSizeInBits() == 64) {
1595            VecSize = 2;
1596          }
1597          EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
1598          for (unsigned i = 0; i < NumElts; i += VecSize) {
1599            Value *SrcValue = Constant::getNullValue(
1600                PointerType::get(VecVT.getTypeForEVT(F->getContext()),
1601                                 llvm::ADDRESS_SPACE_PARAM));
1602            SDValue SrcAddr =
1603                DAG.getNode(ISD::ADD, dl, getPointerTy(), Arg,
1604                            DAG.getConstant(Ofst, getPointerTy()));
1605            SDValue P = DAG.getLoad(
1606                VecVT, dl, Root, SrcAddr, MachinePointerInfo(SrcValue), false,
1607                false, true,
1608                TD->getABITypeAlignment(VecVT.getTypeForEVT(F->getContext())));
1609            if (P.getNode())
1610              P.getNode()->setIROrder(idx + 1);
1611
1612            for (unsigned j = 0; j < VecSize; ++j) {
1613              if (i + j >= NumElts)
1614                break;
1615              SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, EltVT, P,
1616                                        DAG.getIntPtrConstant(j));
1617              if (Ins[InsIdx].VT.getSizeInBits() > EltVT.getSizeInBits())
1618                Elt = DAG.getNode(ISD::ANY_EXTEND, dl, Ins[InsIdx].VT, Elt);
1619              InVals.push_back(Elt);
1620            }
1621            Ofst += TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1622          }
1623          InsIdx += NumElts;
1624        }
1625
1626        if (NumElts > 0)
1627          --InsIdx;
1628        continue;
1629      }
1630      // A plain scalar.
1631      EVT ObjectVT = getValueType(Ty);
1632      // If ABI, load from the param symbol
1633      SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1634      Value *srcValue = Constant::getNullValue(PointerType::get(
1635          ObjectVT.getTypeForEVT(F->getContext()), llvm::ADDRESS_SPACE_PARAM));
1636      SDValue p;
1637       if (ObjectVT.getSizeInBits() < Ins[InsIdx].VT.getSizeInBits()) {
1638        ISD::LoadExtType ExtOp = Ins[InsIdx].Flags.isSExt() ?
1639                                       ISD::SEXTLOAD : ISD::ZEXTLOAD;
1640        p = DAG.getExtLoad(ExtOp, dl, Ins[InsIdx].VT, Root, Arg,
1641                           MachinePointerInfo(srcValue), ObjectVT, false, false,
1642        TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1643      } else {
1644        p = DAG.getLoad(Ins[InsIdx].VT, dl, Root, Arg,
1645                        MachinePointerInfo(srcValue), false, false, false,
1646        TD->getABITypeAlignment(ObjectVT.getTypeForEVT(F->getContext())));
1647      }
1648      if (p.getNode())
1649        p.getNode()->setIROrder(idx + 1);
1650      InVals.push_back(p);
1651      continue;
1652    }
1653
1654    // Param has ByVal attribute
1655    // Return MoveParam(param symbol).
1656    // Ideally, the param symbol can be returned directly,
1657    // but when SDNode builder decides to use it in a CopyToReg(),
1658    // machine instruction fails because TargetExternalSymbol
1659    // (not lowered) is target dependent, and CopyToReg assumes
1660    // the source is lowered.
1661    EVT ObjectVT = getValueType(Ty);
1662    assert(ObjectVT == Ins[InsIdx].VT &&
1663           "Ins type did not match function type");
1664    SDValue Arg = getParamSymbol(DAG, idx, getPointerTy());
1665    SDValue p = DAG.getNode(NVPTXISD::MoveParam, dl, ObjectVT, Arg);
1666    if (p.getNode())
1667      p.getNode()->setIROrder(idx + 1);
1668    if (isKernel)
1669      InVals.push_back(p);
1670    else {
1671      SDValue p2 = DAG.getNode(
1672          ISD::INTRINSIC_WO_CHAIN, dl, ObjectVT,
1673          DAG.getConstant(Intrinsic::nvvm_ptr_local_to_gen, MVT::i32), p);
1674      InVals.push_back(p2);
1675    }
1676  }
1677
1678  // Clang will check explicit VarArg and issue error if any. However, Clang
1679  // will let code with
1680  // implicit var arg like f() pass. See bug 617733.
1681  // We treat this case as if the arg list is empty.
1682  // if (F.isVarArg()) {
1683  // assert(0 && "VarArg not supported yet!");
1684  //}
1685
1686  if (!OutChains.empty())
1687    DAG.setRoot(DAG.getNode(ISD::TokenFactor, dl, MVT::Other, &OutChains[0],
1688                            OutChains.size()));
1689
1690  return Chain;
1691}
1692
1693
1694SDValue
1695NVPTXTargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
1696                                 bool isVarArg,
1697                                 const SmallVectorImpl<ISD::OutputArg> &Outs,
1698                                 const SmallVectorImpl<SDValue> &OutVals,
1699                                 SDLoc dl, SelectionDAG &DAG) const {
1700  MachineFunction &MF = DAG.getMachineFunction();
1701  const Function *F = MF.getFunction();
1702  Type *RetTy = F->getReturnType();
1703  const DataLayout *TD = getDataLayout();
1704
1705  bool isABI = (nvptxSubtarget.getSmVersion() >= 20);
1706  assert(isABI && "Non-ABI compilation is not supported");
1707  if (!isABI)
1708    return Chain;
1709
1710  if (VectorType *VTy = dyn_cast<VectorType>(RetTy)) {
1711    // If we have a vector type, the OutVals array will be the scalarized
1712    // components and we have combine them into 1 or more vector stores.
1713    unsigned NumElts = VTy->getNumElements();
1714    assert(NumElts == Outs.size() && "Bad scalarization of return value");
1715
1716    // const_cast can be removed in later LLVM versions
1717    EVT EltVT = getValueType(RetTy).getVectorElementType();
1718    bool NeedExtend = false;
1719    if (EltVT.getSizeInBits() < 16)
1720      NeedExtend = true;
1721
1722    // V1 store
1723    if (NumElts == 1) {
1724      SDValue StoreVal = OutVals[0];
1725      // We only have one element, so just directly store it
1726      if (NeedExtend)
1727        StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal);
1728      SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal };
1729      Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1730                                      DAG.getVTList(MVT::Other), &Ops[0], 3,
1731                                      EltVT, MachinePointerInfo());
1732
1733    } else if (NumElts == 2) {
1734      // V2 store
1735      SDValue StoreVal0 = OutVals[0];
1736      SDValue StoreVal1 = OutVals[1];
1737
1738      if (NeedExtend) {
1739        StoreVal0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal0);
1740        StoreVal1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, StoreVal1);
1741      }
1742
1743      SDValue Ops[] = { Chain, DAG.getConstant(0, MVT::i32), StoreVal0,
1744                        StoreVal1 };
1745      Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetvalV2, dl,
1746                                      DAG.getVTList(MVT::Other), &Ops[0], 4,
1747                                      EltVT, MachinePointerInfo());
1748    } else {
1749      // V4 stores
1750      // We have at least 4 elements (<3 x Ty> expands to 4 elements) and the
1751      // vector will be expanded to a power of 2 elements, so we know we can
1752      // always round up to the next multiple of 4 when creating the vector
1753      // stores.
1754      // e.g.  4 elem => 1 st.v4
1755      //       6 elem => 2 st.v4
1756      //       8 elem => 2 st.v4
1757      //      11 elem => 3 st.v4
1758
1759      unsigned VecSize = 4;
1760      if (OutVals[0].getValueType().getSizeInBits() == 64)
1761        VecSize = 2;
1762
1763      unsigned Offset = 0;
1764
1765      EVT VecVT =
1766          EVT::getVectorVT(F->getContext(), OutVals[0].getValueType(), VecSize);
1767      unsigned PerStoreOffset =
1768          TD->getTypeAllocSize(VecVT.getTypeForEVT(F->getContext()));
1769
1770      for (unsigned i = 0; i < NumElts; i += VecSize) {
1771        // Get values
1772        SDValue StoreVal;
1773        SmallVector<SDValue, 8> Ops;
1774        Ops.push_back(Chain);
1775        Ops.push_back(DAG.getConstant(Offset, MVT::i32));
1776        unsigned Opc = NVPTXISD::StoreRetvalV2;
1777        EVT ExtendedVT = (NeedExtend) ? MVT::i16 : OutVals[0].getValueType();
1778
1779        StoreVal = OutVals[i];
1780        if (NeedExtend)
1781          StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1782        Ops.push_back(StoreVal);
1783
1784        if (i + 1 < NumElts) {
1785          StoreVal = OutVals[i + 1];
1786          if (NeedExtend)
1787            StoreVal = DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1788        } else {
1789          StoreVal = DAG.getUNDEF(ExtendedVT);
1790        }
1791        Ops.push_back(StoreVal);
1792
1793        if (VecSize == 4) {
1794          Opc = NVPTXISD::StoreRetvalV4;
1795          if (i + 2 < NumElts) {
1796            StoreVal = OutVals[i + 2];
1797            if (NeedExtend)
1798              StoreVal =
1799                  DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1800          } else {
1801            StoreVal = DAG.getUNDEF(ExtendedVT);
1802          }
1803          Ops.push_back(StoreVal);
1804
1805          if (i + 3 < NumElts) {
1806            StoreVal = OutVals[i + 3];
1807            if (NeedExtend)
1808              StoreVal =
1809                  DAG.getNode(ISD::ZERO_EXTEND, dl, ExtendedVT, StoreVal);
1810          } else {
1811            StoreVal = DAG.getUNDEF(ExtendedVT);
1812          }
1813          Ops.push_back(StoreVal);
1814        }
1815
1816        // Chain = DAG.getNode(Opc, dl, MVT::Other, &Ops[0], Ops.size());
1817        Chain =
1818            DAG.getMemIntrinsicNode(Opc, dl, DAG.getVTList(MVT::Other), &Ops[0],
1819                                    Ops.size(), EltVT, MachinePointerInfo());
1820        Offset += PerStoreOffset;
1821      }
1822    }
1823  } else {
1824    SmallVector<EVT, 16> ValVTs;
1825    // const_cast is necessary since we are still using an LLVM version from
1826    // before the type system re-write.
1827    ComputePTXValueVTs(*this, RetTy, ValVTs);
1828    assert(ValVTs.size() == OutVals.size() && "Bad return value decomposition");
1829
1830    unsigned SizeSoFar = 0;
1831    for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
1832      SDValue theVal = OutVals[i];
1833      EVT TheValType = theVal.getValueType();
1834      unsigned numElems = 1;
1835      if (TheValType.isVector())
1836        numElems = TheValType.getVectorNumElements();
1837      for (unsigned j = 0, je = numElems; j != je; ++j) {
1838        SDValue TmpVal = theVal;
1839        if (TheValType.isVector())
1840          TmpVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
1841                               TheValType.getVectorElementType(), TmpVal,
1842                               DAG.getIntPtrConstant(j));
1843        EVT TheStoreType = ValVTs[i];
1844        if (RetTy->isIntegerTy() &&
1845            TD->getTypeAllocSizeInBits(RetTy) < 32) {
1846          // The following zero-extension is for integer types only, and
1847          // specifically not for aggregates.
1848          TmpVal = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i32, TmpVal);
1849          TheStoreType = MVT::i32;
1850        }
1851        else if (TmpVal.getValueType().getSizeInBits() < 16)
1852          TmpVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, TmpVal);
1853
1854        SDValue Ops[] = { Chain, DAG.getConstant(SizeSoFar, MVT::i32), TmpVal };
1855        Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreRetval, dl,
1856                                        DAG.getVTList(MVT::Other), &Ops[0],
1857                                        3, TheStoreType,
1858                                        MachinePointerInfo());
1859        if(TheValType.isVector())
1860          SizeSoFar +=
1861            TheStoreType.getVectorElementType().getStoreSizeInBits() / 8;
1862        else
1863          SizeSoFar += TheStoreType.getStoreSizeInBits()/8;
1864      }
1865    }
1866  }
1867
1868  return DAG.getNode(NVPTXISD::RET_FLAG, dl, MVT::Other, Chain);
1869}
1870
1871
1872void NVPTXTargetLowering::LowerAsmOperandForConstraint(
1873    SDValue Op, std::string &Constraint, std::vector<SDValue> &Ops,
1874    SelectionDAG &DAG) const {
1875  if (Constraint.length() > 1)
1876    return;
1877  else
1878    TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
1879}
1880
1881// NVPTX suuport vector of legal types of any length in Intrinsics because the
1882// NVPTX specific type legalizer
1883// will legalize them to the PTX supported length.
1884bool NVPTXTargetLowering::isTypeSupportedInIntrinsic(MVT VT) const {
1885  if (isTypeLegal(VT))
1886    return true;
1887  if (VT.isVector()) {
1888    MVT eVT = VT.getVectorElementType();
1889    if (isTypeLegal(eVT))
1890      return true;
1891  }
1892  return false;
1893}
1894
1895// llvm.ptx.memcpy.const and llvm.ptx.memmove.const need to be modeled as
1896// TgtMemIntrinsic
1897// because we need the information that is only available in the "Value" type
1898// of destination
1899// pointer. In particular, the address space information.
1900bool NVPTXTargetLowering::getTgtMemIntrinsic(
1901    IntrinsicInfo &Info, const CallInst &I, unsigned Intrinsic) const {
1902  switch (Intrinsic) {
1903  default:
1904    return false;
1905
1906  case Intrinsic::nvvm_atomic_load_add_f32:
1907    Info.opc = ISD::INTRINSIC_W_CHAIN;
1908    Info.memVT = MVT::f32;
1909    Info.ptrVal = I.getArgOperand(0);
1910    Info.offset = 0;
1911    Info.vol = 0;
1912    Info.readMem = true;
1913    Info.writeMem = true;
1914    Info.align = 0;
1915    return true;
1916
1917  case Intrinsic::nvvm_atomic_load_inc_32:
1918  case Intrinsic::nvvm_atomic_load_dec_32:
1919    Info.opc = ISD::INTRINSIC_W_CHAIN;
1920    Info.memVT = MVT::i32;
1921    Info.ptrVal = I.getArgOperand(0);
1922    Info.offset = 0;
1923    Info.vol = 0;
1924    Info.readMem = true;
1925    Info.writeMem = true;
1926    Info.align = 0;
1927    return true;
1928
1929  case Intrinsic::nvvm_ldu_global_i:
1930  case Intrinsic::nvvm_ldu_global_f:
1931  case Intrinsic::nvvm_ldu_global_p:
1932
1933    Info.opc = ISD::INTRINSIC_W_CHAIN;
1934    if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
1935      Info.memVT = getValueType(I.getType());
1936    else if (Intrinsic == Intrinsic::nvvm_ldu_global_p)
1937      Info.memVT = getValueType(I.getType());
1938    else
1939      Info.memVT = MVT::f32;
1940    Info.ptrVal = I.getArgOperand(0);
1941    Info.offset = 0;
1942    Info.vol = 0;
1943    Info.readMem = true;
1944    Info.writeMem = false;
1945    Info.align = 0;
1946    return true;
1947
1948  }
1949  return false;
1950}
1951
1952/// isLegalAddressingMode - Return true if the addressing mode represented
1953/// by AM is legal for this target, for a load/store of the specified type.
1954/// Used to guide target specific optimizations, like loop strength reduction
1955/// (LoopStrengthReduce.cpp) and memory optimization for address mode
1956/// (CodeGenPrepare.cpp)
1957bool NVPTXTargetLowering::isLegalAddressingMode(const AddrMode &AM,
1958                                                Type *Ty) const {
1959
1960  // AddrMode - This represents an addressing mode of:
1961  //    BaseGV + BaseOffs + BaseReg + Scale*ScaleReg
1962  //
1963  // The legal address modes are
1964  // - [avar]
1965  // - [areg]
1966  // - [areg+immoff]
1967  // - [immAddr]
1968
1969  if (AM.BaseGV) {
1970    if (AM.BaseOffs || AM.HasBaseReg || AM.Scale)
1971      return false;
1972    return true;
1973  }
1974
1975  switch (AM.Scale) {
1976  case 0: // "r", "r+i" or "i" is allowed
1977    break;
1978  case 1:
1979    if (AM.HasBaseReg) // "r+r+i" or "r+r" is not allowed.
1980      return false;
1981    // Otherwise we have r+i.
1982    break;
1983  default:
1984    // No scale > 1 is allowed
1985    return false;
1986  }
1987  return true;
1988}
1989
1990//===----------------------------------------------------------------------===//
1991//                         NVPTX Inline Assembly Support
1992//===----------------------------------------------------------------------===//
1993
1994/// getConstraintType - Given a constraint letter, return the type of
1995/// constraint it is for this target.
1996NVPTXTargetLowering::ConstraintType
1997NVPTXTargetLowering::getConstraintType(const std::string &Constraint) const {
1998  if (Constraint.size() == 1) {
1999    switch (Constraint[0]) {
2000    default:
2001      break;
2002    case 'r':
2003    case 'h':
2004    case 'c':
2005    case 'l':
2006    case 'f':
2007    case 'd':
2008    case '0':
2009    case 'N':
2010      return C_RegisterClass;
2011    }
2012  }
2013  return TargetLowering::getConstraintType(Constraint);
2014}
2015
2016std::pair<unsigned, const TargetRegisterClass *>
2017NVPTXTargetLowering::getRegForInlineAsmConstraint(const std::string &Constraint,
2018                                                  MVT VT) const {
2019  if (Constraint.size() == 1) {
2020    switch (Constraint[0]) {
2021    case 'c':
2022      return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2023    case 'h':
2024      return std::make_pair(0U, &NVPTX::Int16RegsRegClass);
2025    case 'r':
2026      return std::make_pair(0U, &NVPTX::Int32RegsRegClass);
2027    case 'l':
2028    case 'N':
2029      return std::make_pair(0U, &NVPTX::Int64RegsRegClass);
2030    case 'f':
2031      return std::make_pair(0U, &NVPTX::Float32RegsRegClass);
2032    case 'd':
2033      return std::make_pair(0U, &NVPTX::Float64RegsRegClass);
2034    }
2035  }
2036  return TargetLowering::getRegForInlineAsmConstraint(Constraint, VT);
2037}
2038
2039/// getFunctionAlignment - Return the Log2 alignment of this function.
2040unsigned NVPTXTargetLowering::getFunctionAlignment(const Function *) const {
2041  return 4;
2042}
2043
2044/// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
2045static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
2046                              SmallVectorImpl<SDValue> &Results) {
2047  EVT ResVT = N->getValueType(0);
2048  SDLoc DL(N);
2049
2050  assert(ResVT.isVector() && "Vector load must have vector type");
2051
2052  // We only handle "native" vector sizes for now, e.g. <4 x double> is not
2053  // legal.  We can (and should) split that into 2 loads of <2 x double> here
2054  // but I'm leaving that as a TODO for now.
2055  assert(ResVT.isSimple() && "Can only handle simple types");
2056  switch (ResVT.getSimpleVT().SimpleTy) {
2057  default:
2058    return;
2059  case MVT::v2i8:
2060  case MVT::v2i16:
2061  case MVT::v2i32:
2062  case MVT::v2i64:
2063  case MVT::v2f32:
2064  case MVT::v2f64:
2065  case MVT::v4i8:
2066  case MVT::v4i16:
2067  case MVT::v4i32:
2068  case MVT::v4f32:
2069    // This is a "native" vector type
2070    break;
2071  }
2072
2073  EVT EltVT = ResVT.getVectorElementType();
2074  unsigned NumElts = ResVT.getVectorNumElements();
2075
2076  // Since LoadV2 is a target node, we cannot rely on DAG type legalization.
2077  // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2078  // loaded type to i16 and propogate the "real" type as the memory type.
2079  bool NeedTrunc = false;
2080  if (EltVT.getSizeInBits() < 16) {
2081    EltVT = MVT::i16;
2082    NeedTrunc = true;
2083  }
2084
2085  unsigned Opcode = 0;
2086  SDVTList LdResVTs;
2087
2088  switch (NumElts) {
2089  default:
2090    return;
2091  case 2:
2092    Opcode = NVPTXISD::LoadV2;
2093    LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2094    break;
2095  case 4: {
2096    Opcode = NVPTXISD::LoadV4;
2097    EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2098    LdResVTs = DAG.getVTList(ListVTs, 5);
2099    break;
2100  }
2101  }
2102
2103  SmallVector<SDValue, 8> OtherOps;
2104
2105  // Copy regular operands
2106  for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2107    OtherOps.push_back(N->getOperand(i));
2108
2109  LoadSDNode *LD = cast<LoadSDNode>(N);
2110
2111  // The select routine does not have access to the LoadSDNode instance, so
2112  // pass along the extension information
2113  OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType()));
2114
2115  SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, &OtherOps[0],
2116                                          OtherOps.size(), LD->getMemoryVT(),
2117                                          LD->getMemOperand());
2118
2119  SmallVector<SDValue, 4> ScalarRes;
2120
2121  for (unsigned i = 0; i < NumElts; ++i) {
2122    SDValue Res = NewLD.getValue(i);
2123    if (NeedTrunc)
2124      Res = DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2125    ScalarRes.push_back(Res);
2126  }
2127
2128  SDValue LoadChain = NewLD.getValue(NumElts);
2129
2130  SDValue BuildVec =
2131      DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2132
2133  Results.push_back(BuildVec);
2134  Results.push_back(LoadChain);
2135}
2136
2137static void ReplaceINTRINSIC_W_CHAIN(SDNode *N, SelectionDAG &DAG,
2138                                     SmallVectorImpl<SDValue> &Results) {
2139  SDValue Chain = N->getOperand(0);
2140  SDValue Intrin = N->getOperand(1);
2141  SDLoc DL(N);
2142
2143  // Get the intrinsic ID
2144  unsigned IntrinNo = cast<ConstantSDNode>(Intrin.getNode())->getZExtValue();
2145  switch (IntrinNo) {
2146  default:
2147    return;
2148  case Intrinsic::nvvm_ldg_global_i:
2149  case Intrinsic::nvvm_ldg_global_f:
2150  case Intrinsic::nvvm_ldg_global_p:
2151  case Intrinsic::nvvm_ldu_global_i:
2152  case Intrinsic::nvvm_ldu_global_f:
2153  case Intrinsic::nvvm_ldu_global_p: {
2154    EVT ResVT = N->getValueType(0);
2155
2156    if (ResVT.isVector()) {
2157      // Vector LDG/LDU
2158
2159      unsigned NumElts = ResVT.getVectorNumElements();
2160      EVT EltVT = ResVT.getVectorElementType();
2161
2162      // Since LDU/LDG are target nodes, we cannot rely on DAG type
2163      // legalization.
2164      // Therefore, we must ensure the type is legal.  For i1 and i8, we set the
2165      // loaded type to i16 and propogate the "real" type as the memory type.
2166      bool NeedTrunc = false;
2167      if (EltVT.getSizeInBits() < 16) {
2168        EltVT = MVT::i16;
2169        NeedTrunc = true;
2170      }
2171
2172      unsigned Opcode = 0;
2173      SDVTList LdResVTs;
2174
2175      switch (NumElts) {
2176      default:
2177        return;
2178      case 2:
2179        switch (IntrinNo) {
2180        default:
2181          return;
2182        case Intrinsic::nvvm_ldg_global_i:
2183        case Intrinsic::nvvm_ldg_global_f:
2184        case Intrinsic::nvvm_ldg_global_p:
2185          Opcode = NVPTXISD::LDGV2;
2186          break;
2187        case Intrinsic::nvvm_ldu_global_i:
2188        case Intrinsic::nvvm_ldu_global_f:
2189        case Intrinsic::nvvm_ldu_global_p:
2190          Opcode = NVPTXISD::LDUV2;
2191          break;
2192        }
2193        LdResVTs = DAG.getVTList(EltVT, EltVT, MVT::Other);
2194        break;
2195      case 4: {
2196        switch (IntrinNo) {
2197        default:
2198          return;
2199        case Intrinsic::nvvm_ldg_global_i:
2200        case Intrinsic::nvvm_ldg_global_f:
2201        case Intrinsic::nvvm_ldg_global_p:
2202          Opcode = NVPTXISD::LDGV4;
2203          break;
2204        case Intrinsic::nvvm_ldu_global_i:
2205        case Intrinsic::nvvm_ldu_global_f:
2206        case Intrinsic::nvvm_ldu_global_p:
2207          Opcode = NVPTXISD::LDUV4;
2208          break;
2209        }
2210        EVT ListVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other };
2211        LdResVTs = DAG.getVTList(ListVTs, 5);
2212        break;
2213      }
2214      }
2215
2216      SmallVector<SDValue, 8> OtherOps;
2217
2218      // Copy regular operands
2219
2220      OtherOps.push_back(Chain); // Chain
2221                                 // Skip operand 1 (intrinsic ID)
2222      // Others
2223      for (unsigned i = 2, e = N->getNumOperands(); i != e; ++i)
2224        OtherOps.push_back(N->getOperand(i));
2225
2226      MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2227
2228      SDValue NewLD = DAG.getMemIntrinsicNode(
2229          Opcode, DL, LdResVTs, &OtherOps[0], OtherOps.size(),
2230          MemSD->getMemoryVT(), MemSD->getMemOperand());
2231
2232      SmallVector<SDValue, 4> ScalarRes;
2233
2234      for (unsigned i = 0; i < NumElts; ++i) {
2235        SDValue Res = NewLD.getValue(i);
2236        if (NeedTrunc)
2237          Res =
2238              DAG.getNode(ISD::TRUNCATE, DL, ResVT.getVectorElementType(), Res);
2239        ScalarRes.push_back(Res);
2240      }
2241
2242      SDValue LoadChain = NewLD.getValue(NumElts);
2243
2244      SDValue BuildVec =
2245          DAG.getNode(ISD::BUILD_VECTOR, DL, ResVT, &ScalarRes[0], NumElts);
2246
2247      Results.push_back(BuildVec);
2248      Results.push_back(LoadChain);
2249    } else {
2250      // i8 LDG/LDU
2251      assert(ResVT.isSimple() && ResVT.getSimpleVT().SimpleTy == MVT::i8 &&
2252             "Custom handling of non-i8 ldu/ldg?");
2253
2254      // Just copy all operands as-is
2255      SmallVector<SDValue, 4> Ops;
2256      for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i)
2257        Ops.push_back(N->getOperand(i));
2258
2259      // Force output to i16
2260      SDVTList LdResVTs = DAG.getVTList(MVT::i16, MVT::Other);
2261
2262      MemIntrinsicSDNode *MemSD = cast<MemIntrinsicSDNode>(N);
2263
2264      // We make sure the memory type is i8, which will be used during isel
2265      // to select the proper instruction.
2266      SDValue NewLD =
2267          DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, LdResVTs, &Ops[0],
2268                                  Ops.size(), MVT::i8, MemSD->getMemOperand());
2269
2270      Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
2271                                    NewLD.getValue(0)));
2272      Results.push_back(NewLD.getValue(1));
2273    }
2274  }
2275  }
2276}
2277
2278void NVPTXTargetLowering::ReplaceNodeResults(
2279    SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
2280  switch (N->getOpcode()) {
2281  default:
2282    report_fatal_error("Unhandled custom legalization");
2283  case ISD::LOAD:
2284    ReplaceLoadVector(N, DAG, Results);
2285    return;
2286  case ISD::INTRINSIC_W_CHAIN:
2287    ReplaceINTRINSIC_W_CHAIN(N, DAG, Results);
2288    return;
2289  }
2290}
2291
2292// Pin NVPTXSection's and NVPTXTargetObjectFile's vtables to this file.
2293void NVPTXSection::anchor() {}
2294
2295NVPTXTargetObjectFile::~NVPTXTargetObjectFile() {
2296  delete TextSection;
2297  delete DataSection;
2298  delete BSSSection;
2299  delete ReadOnlySection;
2300
2301  delete StaticCtorSection;
2302  delete StaticDtorSection;
2303  delete LSDASection;
2304  delete EHFrameSection;
2305  delete DwarfAbbrevSection;
2306  delete DwarfInfoSection;
2307  delete DwarfLineSection;
2308  delete DwarfFrameSection;
2309  delete DwarfPubTypesSection;
2310  delete DwarfDebugInlineSection;
2311  delete DwarfStrSection;
2312  delete DwarfLocSection;
2313  delete DwarfARangesSection;
2314  delete DwarfRangesSection;
2315  delete DwarfMacroInfoSection;
2316}
2317