slang_rs_backend.cpp revision 0da0a7dc51c25943fe31d0bfccbdfee326a3199c
1#include "slang_rs_backend.h"
2
3#include <vector>
4#include <string>
5
6#include "llvm/Metadata.h"
7#include "llvm/Constant.h"
8#include "llvm/Constants.h"
9#include "llvm/Module.h"
10#include "llvm/Function.h"
11#include "llvm/DerivedTypes.h"
12
13#include "llvm/Support/IRBuilder.h"
14
15#include "llvm/ADT/Twine.h"
16#include "llvm/ADT/StringExtras.h"
17
18#include "clang/AST/DeclGroup.h"
19
20#include "slang_rs_context.h"
21#include "slang_rs_export_var.h"
22#include "slang_rs_export_func.h"
23#include "slang_rs_export_type.h"
24
25using namespace slang;
26
27RSBackend::RSBackend(RSContext *Context,
28                     clang::Diagnostic &Diags,
29                     const clang::CodeGenOptions &CodeGenOpts,
30                     const clang::TargetOptions &TargetOpts,
31                     const PragmaList &Pragmas,
32                     llvm::raw_ostream *OS,
33                     SlangCompilerOutputTy OutputType,
34                     clang::SourceManager &SourceMgr,
35                     bool AllowRSPrefix)
36    : Backend(Diags,
37              CodeGenOpts,
38              TargetOpts,
39              Pragmas,
40              OS,
41              OutputType,
42              SourceMgr,
43              AllowRSPrefix),
44      mContext(Context),
45      mExportVarMetadata(NULL),
46      mExportFuncMetadata(NULL),
47      mExportTypeMetadata(NULL) {
48  return;
49}
50
51void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
52  Backend::HandleTopLevelDecl(D);
53  return;
54}
55
56void RSBackend::HandleTranslationUnitEx(clang::ASTContext &Ctx) {
57  assert((&Ctx == mContext->getASTContext()) && "Unexpected AST context change"
58                                                " during LLVM IR generation");
59  mContext->processExport();
60
61  // Dump export variable info
62  if (mContext->hasExportVar()) {
63    if (mExportVarMetadata == NULL)
64      mExportVarMetadata = mpModule->getOrInsertNamedMetadata("#rs_export_var");
65
66    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
67
68    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
69            E = mContext->export_vars_end();
70         I != E;
71         I++) {
72      const RSExportVar *EV = *I;
73      const RSExportType *ET = EV->getType();
74
75      // Variable name
76      ExportVarInfo.push_back(
77          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
78
79      // Type name
80      if (ET->getClass() == RSExportType::ExportClassPrimitive)
81        ExportVarInfo.push_back(
82            llvm::MDString::get(
83                mLLVMContext, llvm::utostr_32(
84                    static_cast<const RSExportPrimitiveType*>(ET)->getType())));
85      else if (ET->getClass() == RSExportType::ExportClassPointer)
86        ExportVarInfo.push_back(
87            llvm::MDString::get(
88                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
89                               ->getPointeeType()->getName()).c_str()));
90      else
91        ExportVarInfo.push_back(
92            llvm::MDString::get(mLLVMContext,
93                                EV->getType()->getName().c_str()));
94
95      mExportVarMetadata->addOperand(
96          llvm::MDNode::get(mLLVMContext,
97                            ExportVarInfo.data(),
98                            ExportVarInfo.size()) );
99
100      ExportVarInfo.clear();
101    }
102  }
103
104  // Dump export function info
105  if (mContext->hasExportFunc()) {
106    if (mExportFuncMetadata == NULL)
107      mExportFuncMetadata =
108          mpModule->getOrInsertNamedMetadata("#rs_export_func");
109
110    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
111
112    for (RSContext::const_export_func_iterator
113            I = mContext->export_funcs_begin(),
114            E = mContext->export_funcs_end();
115         I != E;
116         I++) {
117      const RSExportFunc *EF = *I;
118
119      // Function name
120      if (!EF->hasParam()) {
121        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
122                                                     EF->getName().c_str()));
123      } else {
124        llvm::Function *F = mpModule->getFunction(EF->getName());
125        llvm::Function *HelperFunction;
126        const std::string HelperFunctionName(".helper_" + EF->getName());
127
128        assert(F && "Function marked as exported disappeared in Bitcode");
129
130        // Create helper function
131        {
132          llvm::StructType *HelperFunctionParameterTy = NULL;
133
134          if (!F->getArgumentList().empty()) {
135            std::vector<const llvm::Type*> HelperFunctionParameterTys;
136            for (llvm::Function::arg_iterator AI = F->arg_begin(),
137                 AE = F->arg_end(); AI != AE; AI++)
138              HelperFunctionParameterTys.push_back(AI->getType());
139
140            HelperFunctionParameterTy =
141                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
142          }
143
144          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
145            fprintf(stderr, "Failed to export function %s: parameter type "
146                            "mismatch during creation of helper function.\n",
147                    EF->getName().c_str());
148
149            const RSExportRecordType *Expected = EF->getParamPacketType();
150            if (Expected) {
151              fprintf(stderr, "Expected:\n");
152              Expected->getLLVMType()->dump();
153            }
154            if (HelperFunctionParameterTy) {
155              fprintf(stderr, "Got:\n");
156              HelperFunctionParameterTy->dump();
157            }
158          }
159
160          std::vector<const llvm::Type*> Params;
161          if (HelperFunctionParameterTy) {
162            llvm::PointerType *HelperFunctionParameterTyP =
163                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
164            Params.push_back(HelperFunctionParameterTyP);
165          }
166
167          llvm::FunctionType * HelperFunctionType =
168              llvm::FunctionType::get(F->getReturnType(),
169                                      Params,
170                                      /* IsVarArgs = */false);
171
172          HelperFunction =
173              llvm::Function::Create(HelperFunctionType,
174                                     llvm::GlobalValue::ExternalLinkage,
175                                     HelperFunctionName,
176                                     mpModule);
177
178          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
179          HelperFunction->setCallingConv(F->getCallingConv());
180
181          // Create helper function body
182          {
183            llvm::Argument *HelperFunctionParameter =
184                &(*HelperFunction->arg_begin());
185            llvm::BasicBlock *BB =
186                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
187            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
188            llvm::SmallVector<llvm::Value*, 6> Params;
189            llvm::Value *Idx[2];
190
191            Idx[0] =
192                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
193
194            // getelementptr and load instruction for all elements in
195            // parameter .p
196            for (size_t i = 0; i < EF->getNumParameters(); i++) {
197              // getelementptr
198              Idx[1] =
199                  llvm::ConstantInt::get(
200                      llvm::Type::getInt32Ty(mLLVMContext), i);
201              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
202                                                       Idx,
203                                                       Idx + 2);
204
205              // load
206              llvm::Value *V = IB->CreateLoad(Ptr);
207              Params.push_back(V);
208            }
209
210            // Call and pass the all elements as paramter to F
211            llvm::CallInst *CI = IB->CreateCall(F,
212                                                Params.data(),
213                                                Params.data() + Params.size());
214
215            CI->setCallingConv(F->getCallingConv());
216
217            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
218              IB->CreateRetVoid();
219            else
220              IB->CreateRet(CI);
221
222            delete IB;
223          }
224        }
225
226        ExportFuncInfo.push_back(
227            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
228      }
229
230      mExportFuncMetadata->addOperand(
231          llvm::MDNode::get(mLLVMContext,
232                            ExportFuncInfo.data(),
233                            ExportFuncInfo.size()));
234
235      ExportFuncInfo.clear();
236    }
237  }
238
239  // Dump export type info
240  if (mContext->hasExportType()) {
241    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
242
243    for (RSContext::const_export_type_iterator
244            I = mContext->export_types_begin(),
245            E = mContext->export_types_end();
246         I != E;
247         I++) {
248      // First, dump type name list to export
249      const RSExportType *ET = I->getValue();
250
251      ExportTypeInfo.clear();
252      // Type name
253      ExportTypeInfo.push_back(
254          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
255
256      if (ET->getClass() == RSExportType::ExportClassRecord) {
257        const RSExportRecordType *ERT =
258            static_cast<const RSExportRecordType*>(ET);
259
260        if (mExportTypeMetadata == NULL)
261          mExportTypeMetadata =
262              mpModule->getOrInsertNamedMetadata("#rs_export_type");
263
264        mExportTypeMetadata->addOperand(
265            llvm::MDNode::get(mLLVMContext,
266                              ExportTypeInfo.data(),
267                              ExportTypeInfo.size()));
268
269        // Now, export struct field information to %[struct name]
270        std::string StructInfoMetadataName("%");
271        StructInfoMetadataName.append(ET->getName());
272        llvm::NamedMDNode *StructInfoMetadata =
273            mpModule->getOrInsertNamedMetadata(StructInfoMetadataName);
274        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
275
276        assert(StructInfoMetadata->getNumOperands() == 0 &&
277               "Metadata with same name was created before");
278        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
279                FE = ERT->fields_end();
280             FI != FE;
281             FI++) {
282          const RSExportRecordType::Field *F = *FI;
283
284          // 1. field name
285          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
286                                                  F->getName().c_str()));
287
288          // 2. field type name
289          FieldInfo.push_back(
290              llvm::MDString::get(mLLVMContext,
291                                  F->getType()->getName().c_str()));
292
293          // 3. field kind
294          switch (F->getType()->getClass()) {
295            case RSExportType::ExportClassPrimitive:
296            case RSExportType::ExportClassVector: {
297              const RSExportPrimitiveType *EPT =
298                  static_cast<const RSExportPrimitiveType*>(F->getType());
299              FieldInfo.push_back(
300                  llvm::MDString::get(mLLVMContext,
301                                      llvm::itostr(EPT->getKind())));
302              break;
303            }
304
305            default: {
306              FieldInfo.push_back(
307                  llvm::MDString::get(mLLVMContext,
308                                      llvm::itostr(
309                                        RSExportPrimitiveType::DataKindUser)));
310              break;
311            }
312          }
313
314          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
315                                                           FieldInfo.data(),
316                                                           FieldInfo.size()));
317
318          FieldInfo.clear();
319        }
320      }   // ET->getClass() == RSExportType::ExportClassRecord
321    }
322  }
323
324  return;
325}
326
327RSBackend::~RSBackend() {
328  return;
329}
330