slang_rs_backend.cpp revision 3a9ca1f0d6bd8f12c2bb2adea51f95c255996180
1#include "slang_rs_backend.h"
2
3#include <vector>
4#include <string>
5
6#include "llvm/Metadata.h"
7#include "llvm/Constant.h"
8#include "llvm/Constants.h"
9#include "llvm/Module.h"
10#include "llvm/Function.h"
11#include "llvm/DerivedTypes.h"
12
13#include "llvm/Support/IRBuilder.h"
14
15#include "llvm/ADT/Twine.h"
16#include "llvm/ADT/StringExtras.h"
17
18#include "clang/AST/DeclGroup.h"
19
20#include "slang_rs_context.h"
21#include "slang_rs_export_var.h"
22#include "slang_rs_export_func.h"
23#include "slang_rs_export_type.h"
24
25using namespace slang;
26
27RSBackend::RSBackend(RSContext *Context,
28                     clang::Diagnostic &Diags,
29                     const clang::CodeGenOptions &CodeGenOpts,
30                     const clang::TargetOptions &TargetOpts,
31                     const PragmaList &Pragmas,
32                     llvm::raw_ostream *OS,
33                     Slang::OutputType OT,
34                     clang::SourceManager &SourceMgr,
35                     bool AllowRSPrefix)
36    : Backend(Diags,
37              CodeGenOpts,
38              TargetOpts,
39              Pragmas,
40              OS,
41              OT),
42      mContext(Context),
43      mSourceMgr(SourceMgr),
44      mAllowRSPrefix(AllowRSPrefix),
45      mExportVarMetadata(NULL),
46      mExportFuncMetadata(NULL),
47      mExportTypeMetadata(NULL) {
48  return;
49}
50
51void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
52  // Disallow user-defined functions with prefix "rs"
53  if (!mAllowRSPrefix) {
54    clang::DeclGroupRef::iterator I;
55    for (I = D.begin(); I != D.end(); I++) {
56      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
57      if (!FD || !FD->isThisDeclarationADefinition()) continue;
58      if (FD->getName().startswith("rs")) {
59        mDiags.Report(clang::FullSourceLoc(FD->getLocStart(), mSourceMgr),
60                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
61                                             "invalid function name prefix,"
62                                             " \"rs\" is reserved: '%0'")
63                      )
64            << FD->getNameAsString();
65      }
66    }
67  }
68
69  Backend::HandleTopLevelDecl(D);
70  return;
71}
72
73void RSBackend::HandleTranslationUnitEx(clang::ASTContext &Ctx) {
74  assert((&Ctx == mContext->getASTContext()) && "Unexpected AST context change"
75                                                " during LLVM IR generation");
76  mContext->processExport();
77
78  // Dump export variable info
79  if (mContext->hasExportVar()) {
80    if (mExportVarMetadata == NULL)
81      mExportVarMetadata = mpModule->getOrInsertNamedMetadata("#rs_export_var");
82
83    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
84
85    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
86            E = mContext->export_vars_end();
87         I != E;
88         I++) {
89      const RSExportVar *EV = *I;
90      const RSExportType *ET = EV->getType();
91
92      // Variable name
93      ExportVarInfo.push_back(
94          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
95
96      // Type name
97      if (ET->getClass() == RSExportType::ExportClassPrimitive)
98        ExportVarInfo.push_back(
99            llvm::MDString::get(
100                mLLVMContext, llvm::utostr_32(
101                    static_cast<const RSExportPrimitiveType*>(ET)->getType())));
102      else if (ET->getClass() == RSExportType::ExportClassPointer)
103        ExportVarInfo.push_back(
104            llvm::MDString::get(
105                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
106                               ->getPointeeType()->getName()).c_str()));
107      else
108        ExportVarInfo.push_back(
109            llvm::MDString::get(mLLVMContext,
110                                EV->getType()->getName().c_str()));
111
112      mExportVarMetadata->addOperand(
113          llvm::MDNode::get(mLLVMContext,
114                            ExportVarInfo.data(),
115                            ExportVarInfo.size()) );
116
117      ExportVarInfo.clear();
118    }
119  }
120
121  // Dump export function info
122  if (mContext->hasExportFunc()) {
123    if (mExportFuncMetadata == NULL)
124      mExportFuncMetadata =
125          mpModule->getOrInsertNamedMetadata("#rs_export_func");
126
127    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
128
129    for (RSContext::const_export_func_iterator
130            I = mContext->export_funcs_begin(),
131            E = mContext->export_funcs_end();
132         I != E;
133         I++) {
134      const RSExportFunc *EF = *I;
135
136      // Function name
137      if (!EF->hasParam()) {
138        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
139                                                     EF->getName().c_str()));
140      } else {
141        llvm::Function *F = mpModule->getFunction(EF->getName());
142        llvm::Function *HelperFunction;
143        const std::string HelperFunctionName(".helper_" + EF->getName());
144
145        assert(F && "Function marked as exported disappeared in Bitcode");
146
147        // Create helper function
148        {
149          llvm::StructType *HelperFunctionParameterTy = NULL;
150
151          if (!F->getArgumentList().empty()) {
152            std::vector<const llvm::Type*> HelperFunctionParameterTys;
153            for (llvm::Function::arg_iterator AI = F->arg_begin(),
154                 AE = F->arg_end(); AI != AE; AI++)
155              HelperFunctionParameterTys.push_back(AI->getType());
156
157            HelperFunctionParameterTy =
158                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
159          }
160
161          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
162            fprintf(stderr, "Failed to export function %s: parameter type "
163                            "mismatch during creation of helper function.\n",
164                    EF->getName().c_str());
165
166            const RSExportRecordType *Expected = EF->getParamPacketType();
167            if (Expected) {
168              fprintf(stderr, "Expected:\n");
169              Expected->getLLVMType()->dump();
170            }
171            if (HelperFunctionParameterTy) {
172              fprintf(stderr, "Got:\n");
173              HelperFunctionParameterTy->dump();
174            }
175          }
176
177          std::vector<const llvm::Type*> Params;
178          if (HelperFunctionParameterTy) {
179            llvm::PointerType *HelperFunctionParameterTyP =
180                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
181            Params.push_back(HelperFunctionParameterTyP);
182          }
183
184          llvm::FunctionType * HelperFunctionType =
185              llvm::FunctionType::get(F->getReturnType(),
186                                      Params,
187                                      /* IsVarArgs = */false);
188
189          HelperFunction =
190              llvm::Function::Create(HelperFunctionType,
191                                     llvm::GlobalValue::ExternalLinkage,
192                                     HelperFunctionName,
193                                     mpModule);
194
195          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
196          HelperFunction->setCallingConv(F->getCallingConv());
197
198          // Create helper function body
199          {
200            llvm::Argument *HelperFunctionParameter =
201                &(*HelperFunction->arg_begin());
202            llvm::BasicBlock *BB =
203                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
204            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
205            llvm::SmallVector<llvm::Value*, 6> Params;
206            llvm::Value *Idx[2];
207
208            Idx[0] =
209                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
210
211            // getelementptr and load instruction for all elements in
212            // parameter .p
213            for (size_t i = 0; i < EF->getNumParameters(); i++) {
214              // getelementptr
215              Idx[1] =
216                  llvm::ConstantInt::get(
217                      llvm::Type::getInt32Ty(mLLVMContext), i);
218              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
219                                                       Idx,
220                                                       Idx + 2);
221
222              // load
223              llvm::Value *V = IB->CreateLoad(Ptr);
224              Params.push_back(V);
225            }
226
227            // Call and pass the all elements as paramter to F
228            llvm::CallInst *CI = IB->CreateCall(F,
229                                                Params.data(),
230                                                Params.data() + Params.size());
231
232            CI->setCallingConv(F->getCallingConv());
233
234            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
235              IB->CreateRetVoid();
236            else
237              IB->CreateRet(CI);
238
239            delete IB;
240          }
241        }
242
243        ExportFuncInfo.push_back(
244            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
245      }
246
247      mExportFuncMetadata->addOperand(
248          llvm::MDNode::get(mLLVMContext,
249                            ExportFuncInfo.data(),
250                            ExportFuncInfo.size()));
251
252      ExportFuncInfo.clear();
253    }
254  }
255
256  // Dump export type info
257  if (mContext->hasExportType()) {
258    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
259
260    for (RSContext::const_export_type_iterator
261            I = mContext->export_types_begin(),
262            E = mContext->export_types_end();
263         I != E;
264         I++) {
265      // First, dump type name list to export
266      const RSExportType *ET = I->getValue();
267
268      ExportTypeInfo.clear();
269      // Type name
270      ExportTypeInfo.push_back(
271          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
272
273      if (ET->getClass() == RSExportType::ExportClassRecord) {
274        const RSExportRecordType *ERT =
275            static_cast<const RSExportRecordType*>(ET);
276
277        if (mExportTypeMetadata == NULL)
278          mExportTypeMetadata =
279              mpModule->getOrInsertNamedMetadata("#rs_export_type");
280
281        mExportTypeMetadata->addOperand(
282            llvm::MDNode::get(mLLVMContext,
283                              ExportTypeInfo.data(),
284                              ExportTypeInfo.size()));
285
286        // Now, export struct field information to %[struct name]
287        std::string StructInfoMetadataName("%");
288        StructInfoMetadataName.append(ET->getName());
289        llvm::NamedMDNode *StructInfoMetadata =
290            mpModule->getOrInsertNamedMetadata(StructInfoMetadataName);
291        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
292
293        assert(StructInfoMetadata->getNumOperands() == 0 &&
294               "Metadata with same name was created before");
295        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
296                FE = ERT->fields_end();
297             FI != FE;
298             FI++) {
299          const RSExportRecordType::Field *F = *FI;
300
301          // 1. field name
302          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
303                                                  F->getName().c_str()));
304
305          // 2. field type name
306          FieldInfo.push_back(
307              llvm::MDString::get(mLLVMContext,
308                                  F->getType()->getName().c_str()));
309
310          // 3. field kind
311          switch (F->getType()->getClass()) {
312            case RSExportType::ExportClassPrimitive:
313            case RSExportType::ExportClassVector: {
314              const RSExportPrimitiveType *EPT =
315                  static_cast<const RSExportPrimitiveType*>(F->getType());
316              FieldInfo.push_back(
317                  llvm::MDString::get(mLLVMContext,
318                                      llvm::itostr(EPT->getKind())));
319              break;
320            }
321
322            default: {
323              FieldInfo.push_back(
324                  llvm::MDString::get(mLLVMContext,
325                                      llvm::itostr(
326                                        RSExportPrimitiveType::DataKindUser)));
327              break;
328            }
329          }
330
331          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
332                                                           FieldInfo.data(),
333                                                           FieldInfo.size()));
334
335          FieldInfo.clear();
336        }
337      }   // ET->getClass() == RSExportType::ExportClassRecord
338    }
339  }
340
341  return;
342}
343
344RSBackend::~RSBackend() {
345  return;
346}
347