slang_rs_backend.cpp revision 7b51b55e4467605a599e868a0dde7cb95c5ab76e
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_assert.h"
35#include "slang_rs.h"
36#include "slang_rs_context.h"
37#include "slang_rs_export_foreach.h"
38#include "slang_rs_export_func.h"
39#include "slang_rs_export_type.h"
40#include "slang_rs_export_var.h"
41#include "slang_rs_metadata.h"
42
43namespace slang {
44
45RSBackend::RSBackend(RSContext *Context,
46                     clang::DiagnosticsEngine *DiagEngine,
47                     const clang::CodeGenOptions &CodeGenOpts,
48                     const clang::TargetOptions &TargetOpts,
49                     PragmaList *Pragmas,
50                     llvm::raw_ostream *OS,
51                     Slang::OutputType OT,
52                     clang::SourceManager &SourceMgr,
53                     bool AllowRSPrefix)
54  : Backend(DiagEngine, CodeGenOpts, TargetOpts, Pragmas, OS, OT),
55    mContext(Context),
56    mSourceMgr(SourceMgr),
57    mAllowRSPrefix(AllowRSPrefix),
58    mExportVarMetadata(NULL),
59    mExportFuncMetadata(NULL),
60    mExportForEachNameMetadata(NULL),
61    mExportForEachSignatureMetadata(NULL),
62    mExportTypeMetadata(NULL),
63    mRSObjectSlotsMetadata(NULL),
64    mRefCount(mContext->getASTContext()) {
65}
66
67// 1) Add zero initialization of local RS object types
68void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
69  if (FD &&
70      FD->hasBody() &&
71      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
72    mRefCount.Init();
73    mRefCount.Visit(FD->getBody());
74  }
75  return;
76}
77
78bool RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
79  // Disallow user-defined functions with prefix "rs"
80  if (!mAllowRSPrefix) {
81    // Iterate all function declarations in the program.
82    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
83         I != E; I++) {
84      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
85      if (FD == NULL)
86        continue;
87      if (!FD->getName().startswith("rs"))  // Check prefix
88        continue;
89      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
90        mDiagEngine.Report(
91          clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
92          mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
93                                      "invalid function name prefix, "
94                                      "\"rs\" is reserved: '%0'"))
95          << FD->getName();
96    }
97  }
98
99  // Process any non-static function declarations
100  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
101    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
102    if (FD && FD->isGlobal()) {
103      AnnotateFunction(FD);
104    }
105  }
106
107  return Backend::HandleTopLevelDecl(D);
108}
109
110namespace {
111
112static bool ValidateVarDecl(clang::VarDecl *VD) {
113  if (!VD) {
114    return true;
115  }
116
117  clang::ASTContext &C = VD->getASTContext();
118  const clang::Type *T = VD->getType().getTypePtr();
119  bool valid = true;
120
121  if (VD->getLinkage() == clang::ExternalLinkage) {
122    llvm::StringRef TypeName;
123    if (!RSExportType::NormalizeType(T, TypeName, &C.getDiagnostics(), VD)) {
124      valid = false;
125    }
126  }
127  valid &= RSExportType::ValidateVarDecl(VD);
128
129  return valid;
130}
131
132static bool ValidateASTContext(clang::ASTContext &C) {
133  bool valid = true;
134  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
135  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
136          DE = TUDecl->decls_end();
137       DI != DE;
138       DI++) {
139    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI);
140    if (VD && !ValidateVarDecl(VD)) {
141      valid = false;
142    }
143  }
144
145  return valid;
146}
147
148}  // namespace
149
150void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
151  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
152
153  if (!ValidateASTContext(C)) {
154    return;
155  }
156
157  int version = mContext->getVersion();
158  if (version == 0) {
159    // Not setting a version is an error
160    mDiagEngine.Report(
161        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
162        mDiagEngine.getCustomDiagID(
163            clang::DiagnosticsEngine::Error,
164            "missing pragma for version in source file"));
165  } else {
166    slangAssert(version == 1);
167  }
168
169  // Create a static global destructor if necessary (to handle RS object
170  // runtime cleanup).
171  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
172  if (FD) {
173    HandleTopLevelDecl(clang::DeclGroupRef(FD));
174  }
175
176  // Process any static function declarations
177  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
178          E = TUDecl->decls_end(); I != E; I++) {
179    if ((I->getKind() >= clang::Decl::firstFunction) &&
180        (I->getKind() <= clang::Decl::lastFunction)) {
181      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
182      if (FD && !FD->isGlobal()) {
183        AnnotateFunction(FD);
184      }
185    }
186  }
187
188  return;
189}
190
191///////////////////////////////////////////////////////////////////////////////
192void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
193  if (!mContext->processExport()) {
194    return;
195  }
196
197  // Dump export variable info
198  if (mContext->hasExportVar()) {
199    int slotCount = 0;
200    if (mExportVarMetadata == NULL)
201      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
202
203    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
204
205    // We emit slot information (#rs_object_slots) for any reference counted
206    // RS type or pointer (which can also be bound).
207
208    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
209            E = mContext->export_vars_end();
210         I != E;
211         I++) {
212      const RSExportVar *EV = *I;
213      const RSExportType *ET = EV->getType();
214      bool countsAsRSObject = false;
215
216      // Variable name
217      ExportVarInfo.push_back(
218          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
219
220      // Type name
221      switch (ET->getClass()) {
222        case RSExportType::ExportClassPrimitive: {
223          const RSExportPrimitiveType *PT =
224              static_cast<const RSExportPrimitiveType*>(ET);
225          ExportVarInfo.push_back(
226              llvm::MDString::get(
227                mLLVMContext, llvm::utostr_32(PT->getType())));
228          if (PT->isRSObjectType()) {
229            countsAsRSObject = true;
230          }
231          break;
232        }
233        case RSExportType::ExportClassPointer: {
234          ExportVarInfo.push_back(
235              llvm::MDString::get(
236                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
237                  ->getPointeeType()->getName()).c_str()));
238          break;
239        }
240        case RSExportType::ExportClassMatrix: {
241          ExportVarInfo.push_back(
242              llvm::MDString::get(
243                mLLVMContext, llvm::utostr_32(
244                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
245                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
246          break;
247        }
248        case RSExportType::ExportClassVector:
249        case RSExportType::ExportClassConstantArray:
250        case RSExportType::ExportClassRecord: {
251          ExportVarInfo.push_back(
252              llvm::MDString::get(mLLVMContext,
253                EV->getType()->getName().c_str()));
254          break;
255        }
256      }
257
258      mExportVarMetadata->addOperand(
259          llvm::MDNode::get(mLLVMContext, ExportVarInfo));
260      ExportVarInfo.clear();
261
262      if (mRSObjectSlotsMetadata == NULL) {
263        mRSObjectSlotsMetadata =
264            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
265      }
266
267      if (countsAsRSObject) {
268        mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
269            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
270      }
271
272      slotCount++;
273    }
274  }
275
276  // Dump export function info
277  if (mContext->hasExportFunc()) {
278    if (mExportFuncMetadata == NULL)
279      mExportFuncMetadata =
280          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
281
282    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
283
284    for (RSContext::const_export_func_iterator
285            I = mContext->export_funcs_begin(),
286            E = mContext->export_funcs_end();
287         I != E;
288         I++) {
289      const RSExportFunc *EF = *I;
290
291      // Function name
292      if (!EF->hasParam()) {
293        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
294                                                     EF->getName().c_str()));
295      } else {
296        llvm::Function *F = M->getFunction(EF->getName());
297        llvm::Function *HelperFunction;
298        const std::string HelperFunctionName(".helper_" + EF->getName());
299
300        slangAssert(F && "Function marked as exported disappeared in Bitcode");
301
302        // Create helper function
303        {
304          llvm::StructType *HelperFunctionParameterTy = NULL;
305
306          if (!F->getArgumentList().empty()) {
307            std::vector<llvm::Type*> HelperFunctionParameterTys;
308            for (llvm::Function::arg_iterator AI = F->arg_begin(),
309                 AE = F->arg_end(); AI != AE; AI++)
310              HelperFunctionParameterTys.push_back(AI->getType());
311
312            HelperFunctionParameterTy =
313                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
314          }
315
316          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
317            fprintf(stderr, "Failed to export function %s: parameter type "
318                            "mismatch during creation of helper function.\n",
319                    EF->getName().c_str());
320
321            const RSExportRecordType *Expected = EF->getParamPacketType();
322            if (Expected) {
323              fprintf(stderr, "Expected:\n");
324              Expected->getLLVMType()->dump();
325            }
326            if (HelperFunctionParameterTy) {
327              fprintf(stderr, "Got:\n");
328              HelperFunctionParameterTy->dump();
329            }
330          }
331
332          std::vector<llvm::Type*> Params;
333          if (HelperFunctionParameterTy) {
334            llvm::PointerType *HelperFunctionParameterTyP =
335                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
336            Params.push_back(HelperFunctionParameterTyP);
337          }
338
339          llvm::FunctionType * HelperFunctionType =
340              llvm::FunctionType::get(F->getReturnType(),
341                                      Params,
342                                      /* IsVarArgs = */false);
343
344          HelperFunction =
345              llvm::Function::Create(HelperFunctionType,
346                                     llvm::GlobalValue::ExternalLinkage,
347                                     HelperFunctionName,
348                                     M);
349
350          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
351          HelperFunction->setCallingConv(F->getCallingConv());
352
353          // Create helper function body
354          {
355            llvm::Argument *HelperFunctionParameter =
356                &(*HelperFunction->arg_begin());
357            llvm::BasicBlock *BB =
358                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
359            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
360            llvm::SmallVector<llvm::Value*, 6> Params;
361            llvm::Value *Idx[2];
362
363            Idx[0] =
364                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
365
366            // getelementptr and load instruction for all elements in
367            // parameter .p
368            for (size_t i = 0; i < EF->getNumParameters(); i++) {
369              // getelementptr
370              Idx[1] = llvm::ConstantInt::get(
371                llvm::Type::getInt32Ty(mLLVMContext), i);
372
373              llvm::Value *Ptr =
374                IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
375
376              // load
377              llvm::Value *V = IB->CreateLoad(Ptr);
378              Params.push_back(V);
379            }
380
381            // Call and pass the all elements as parameter to F
382            llvm::CallInst *CI = IB->CreateCall(F, Params);
383
384            CI->setCallingConv(F->getCallingConv());
385
386            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
387              IB->CreateRetVoid();
388            else
389              IB->CreateRet(CI);
390
391            delete IB;
392          }
393        }
394
395        ExportFuncInfo.push_back(
396            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
397      }
398
399      mExportFuncMetadata->addOperand(
400          llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
401      ExportFuncInfo.clear();
402    }
403  }
404
405  // Dump export function info
406  if (mContext->hasExportForEach()) {
407    if (mExportForEachNameMetadata == NULL) {
408      mExportForEachNameMetadata =
409          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_NAME_MN);
410    }
411    if (mExportForEachSignatureMetadata == NULL) {
412      mExportForEachSignatureMetadata =
413          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
414    }
415
416    llvm::SmallVector<llvm::Value*, 1> ExportForEachName;
417    llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
418
419    for (RSContext::const_export_foreach_iterator
420            I = mContext->export_foreach_begin(),
421            E = mContext->export_foreach_end();
422         I != E;
423         I++) {
424      const RSExportForEach *EFE = *I;
425
426      ExportForEachName.push_back(
427          llvm::MDString::get(mLLVMContext, EFE->getName().c_str()));
428
429      mExportForEachNameMetadata->addOperand(
430          llvm::MDNode::get(mLLVMContext, ExportForEachName));
431      ExportForEachName.clear();
432
433      ExportForEachInfo.push_back(
434          llvm::MDString::get(mLLVMContext,
435                              llvm::utostr_32(EFE->getSignatureMetadata())));
436
437      mExportForEachSignatureMetadata->addOperand(
438          llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
439      ExportForEachInfo.clear();
440    }
441  }
442
443  // Dump export type info
444  if (mContext->hasExportType()) {
445    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
446
447    for (RSContext::const_export_type_iterator
448            I = mContext->export_types_begin(),
449            E = mContext->export_types_end();
450         I != E;
451         I++) {
452      // First, dump type name list to export
453      const RSExportType *ET = I->getValue();
454
455      ExportTypeInfo.clear();
456      // Type name
457      ExportTypeInfo.push_back(
458          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
459
460      if (ET->getClass() == RSExportType::ExportClassRecord) {
461        const RSExportRecordType *ERT =
462            static_cast<const RSExportRecordType*>(ET);
463
464        if (mExportTypeMetadata == NULL)
465          mExportTypeMetadata =
466              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
467
468        mExportTypeMetadata->addOperand(
469            llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
470
471        // Now, export struct field information to %[struct name]
472        std::string StructInfoMetadataName("%");
473        StructInfoMetadataName.append(ET->getName());
474        llvm::NamedMDNode *StructInfoMetadata =
475            M->getOrInsertNamedMetadata(StructInfoMetadataName);
476        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
477
478        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
479                    "Metadata with same name was created before");
480        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
481                FE = ERT->fields_end();
482             FI != FE;
483             FI++) {
484          const RSExportRecordType::Field *F = *FI;
485
486          // 1. field name
487          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
488                                                  F->getName().c_str()));
489
490          // 2. field type name
491          FieldInfo.push_back(
492              llvm::MDString::get(mLLVMContext,
493                                  F->getType()->getName().c_str()));
494
495          // 3. field kind
496          switch (F->getType()->getClass()) {
497            case RSExportType::ExportClassPrimitive:
498            case RSExportType::ExportClassVector: {
499              const RSExportPrimitiveType *EPT =
500                  static_cast<const RSExportPrimitiveType*>(F->getType());
501              FieldInfo.push_back(
502                  llvm::MDString::get(mLLVMContext,
503                                      llvm::itostr(EPT->getKind())));
504              break;
505            }
506
507            default: {
508              FieldInfo.push_back(
509                  llvm::MDString::get(mLLVMContext,
510                                      llvm::itostr(
511                                        RSExportPrimitiveType::DataKindUser)));
512              break;
513            }
514          }
515
516          StructInfoMetadata->addOperand(
517              llvm::MDNode::get(mLLVMContext, FieldInfo));
518          FieldInfo.clear();
519        }
520      }   // ET->getClass() == RSExportType::ExportClassRecord
521    }
522  }
523
524  return;
525}
526
527RSBackend::~RSBackend() {
528  return;
529}
530
531}  // namespace slang
532