slang_rs_backend.cpp revision 9207a2e495c8363606861e4f034504ec5c153dab
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_assert.h"
35#include "slang_rs.h"
36#include "slang_rs_context.h"
37#include "slang_rs_export_foreach.h"
38#include "slang_rs_export_func.h"
39#include "slang_rs_export_type.h"
40#include "slang_rs_export_var.h"
41#include "slang_rs_metadata.h"
42
43namespace slang {
44
45RSBackend::RSBackend(RSContext *Context,
46                     clang::DiagnosticsEngine *DiagEngine,
47                     const clang::CodeGenOptions &CodeGenOpts,
48                     const clang::TargetOptions &TargetOpts,
49                     PragmaList *Pragmas,
50                     llvm::raw_ostream *OS,
51                     Slang::OutputType OT,
52                     clang::SourceManager &SourceMgr,
53                     bool AllowRSPrefix)
54  : Backend(DiagEngine, CodeGenOpts, TargetOpts, Pragmas, OS, OT),
55    mContext(Context),
56    mSourceMgr(SourceMgr),
57    mAllowRSPrefix(AllowRSPrefix),
58    mExportVarMetadata(NULL),
59    mExportFuncMetadata(NULL),
60    mExportForEachMetadata(NULL),
61    mExportTypeMetadata(NULL),
62    mRSObjectSlotsMetadata(NULL),
63    mRefCount(mContext->getASTContext()) {
64}
65
66// 1) Add zero initialization of local RS object types
67void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
68  if (FD &&
69      FD->hasBody() &&
70      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
71    mRefCount.Init();
72    mRefCount.Visit(FD->getBody());
73  }
74  return;
75}
76
77void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
78  // Disallow user-defined functions with prefix "rs"
79  if (!mAllowRSPrefix) {
80    // Iterate all function declarations in the program.
81    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
82         I != E; I++) {
83      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
84      if (FD == NULL)
85        continue;
86      if (!FD->getName().startswith("rs"))  // Check prefix
87        continue;
88      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
89        mDiagEngine.Report(
90          clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
91          mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
92                                      "invalid function name prefix, "
93                                      "\"rs\" is reserved: '%0'"))
94          << FD->getName();
95    }
96  }
97
98  // Process any non-static function declarations
99  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
100    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
101    if (FD && FD->isGlobal()) {
102      AnnotateFunction(FD);
103    }
104  }
105
106  Backend::HandleTopLevelDecl(D);
107  return;
108}
109
110namespace {
111
112static bool ValidateVarDecl(clang::VarDecl *VD) {
113  if (!VD) {
114    return true;
115  }
116
117  clang::ASTContext &C = VD->getASTContext();
118  const clang::Type *T = VD->getType().getTypePtr();
119  bool valid = true;
120
121  if (VD->getLinkage() == clang::ExternalLinkage) {
122    llvm::StringRef TypeName;
123    if (!RSExportType::NormalizeType(T, TypeName, &C.getDiagnostics(), VD)) {
124      valid = false;
125    }
126  }
127  valid &= RSExportType::ValidateVarDecl(VD);
128
129  return valid;
130}
131
132static bool ValidateASTContext(clang::ASTContext &C) {
133  bool valid = true;
134  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
135  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
136          DE = TUDecl->decls_end();
137       DI != DE;
138       DI++) {
139    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI);
140    if (VD && !ValidateVarDecl(VD)) {
141      valid = false;
142    }
143  }
144
145  return valid;
146}
147
148}  // namespace
149
150void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
151  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
152
153  if (!ValidateASTContext(C)) {
154    return;
155  }
156
157  int version = mContext->getVersion();
158  if (version == 0) {
159    // Not setting a version is an error
160    mDiagEngine.Report(mDiagEngine.getCustomDiagID(
161      clang::DiagnosticsEngine::Error,
162      "Missing pragma for version in source file"));
163  } else if (version > 1) {
164    mDiagEngine.Report(mDiagEngine.getCustomDiagID(
165      clang::DiagnosticsEngine::Error,
166      "Pragma for version in source file must be set to 1"));
167  }
168
169  // Create a static global destructor if necessary (to handle RS object
170  // runtime cleanup).
171  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
172  if (FD) {
173    HandleTopLevelDecl(clang::DeclGroupRef(FD));
174  }
175
176  // Process any static function declarations
177  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
178          E = TUDecl->decls_end(); I != E; I++) {
179    if ((I->getKind() >= clang::Decl::firstFunction) &&
180        (I->getKind() <= clang::Decl::lastFunction)) {
181      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
182      if (FD && !FD->isGlobal()) {
183        AnnotateFunction(FD);
184      }
185    }
186  }
187
188  return;
189}
190
191///////////////////////////////////////////////////////////////////////////////
192void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
193  if (!mContext->processExport()) {
194    return;
195  }
196
197  // Dump export variable info
198  if (mContext->hasExportVar()) {
199    int slotCount = 0;
200    if (mExportVarMetadata == NULL)
201      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
202
203    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
204
205    // We emit slot information (#rs_object_slots) for any reference counted
206    // RS type or pointer (which can also be bound).
207
208    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
209            E = mContext->export_vars_end();
210         I != E;
211         I++) {
212      const RSExportVar *EV = *I;
213      const RSExportType *ET = EV->getType();
214      bool countsAsRSObject = false;
215
216      // Variable name
217      ExportVarInfo.push_back(
218          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
219
220      // Type name
221      switch (ET->getClass()) {
222        case RSExportType::ExportClassPrimitive: {
223          const RSExportPrimitiveType *PT =
224              static_cast<const RSExportPrimitiveType*>(ET);
225          ExportVarInfo.push_back(
226              llvm::MDString::get(
227                mLLVMContext, llvm::utostr_32(PT->getType())));
228          if (PT->isRSObjectType()) {
229            countsAsRSObject = true;
230          }
231          break;
232        }
233        case RSExportType::ExportClassPointer: {
234          ExportVarInfo.push_back(
235              llvm::MDString::get(
236                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
237                  ->getPointeeType()->getName()).c_str()));
238          break;
239        }
240        case RSExportType::ExportClassMatrix: {
241          ExportVarInfo.push_back(
242              llvm::MDString::get(
243                mLLVMContext, llvm::utostr_32(
244                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
245                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
246          break;
247        }
248        case RSExportType::ExportClassVector:
249        case RSExportType::ExportClassConstantArray:
250        case RSExportType::ExportClassRecord: {
251          ExportVarInfo.push_back(
252              llvm::MDString::get(mLLVMContext,
253                EV->getType()->getName().c_str()));
254          break;
255        }
256      }
257
258      mExportVarMetadata->addOperand(
259          llvm::MDNode::get(mLLVMContext, ExportVarInfo));
260      ExportVarInfo.clear();
261
262      if (mRSObjectSlotsMetadata == NULL) {
263        mRSObjectSlotsMetadata =
264            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
265      }
266
267      if (countsAsRSObject) {
268        mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
269            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
270      }
271
272      slotCount++;
273    }
274  }
275
276  // Dump export function info
277  if (mContext->hasExportFunc()) {
278    if (mExportFuncMetadata == NULL)
279      mExportFuncMetadata =
280          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
281
282    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
283
284    for (RSContext::const_export_func_iterator
285            I = mContext->export_funcs_begin(),
286            E = mContext->export_funcs_end();
287         I != E;
288         I++) {
289      const RSExportFunc *EF = *I;
290
291      // Function name
292      if (!EF->hasParam()) {
293        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
294                                                     EF->getName().c_str()));
295      } else {
296        llvm::Function *F = M->getFunction(EF->getName());
297        llvm::Function *HelperFunction;
298        const std::string HelperFunctionName(".helper_" + EF->getName());
299
300        slangAssert(F && "Function marked as exported disappeared in Bitcode");
301
302        // Create helper function
303        {
304          llvm::StructType *HelperFunctionParameterTy = NULL;
305
306          if (!F->getArgumentList().empty()) {
307            std::vector<llvm::Type*> HelperFunctionParameterTys;
308            for (llvm::Function::arg_iterator AI = F->arg_begin(),
309                 AE = F->arg_end(); AI != AE; AI++)
310              HelperFunctionParameterTys.push_back(AI->getType());
311
312            HelperFunctionParameterTy =
313                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
314          }
315
316          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
317            fprintf(stderr, "Failed to export function %s: parameter type "
318                            "mismatch during creation of helper function.\n",
319                    EF->getName().c_str());
320
321            const RSExportRecordType *Expected = EF->getParamPacketType();
322            if (Expected) {
323              fprintf(stderr, "Expected:\n");
324              Expected->getLLVMType()->dump();
325            }
326            if (HelperFunctionParameterTy) {
327              fprintf(stderr, "Got:\n");
328              HelperFunctionParameterTy->dump();
329            }
330          }
331
332          std::vector<llvm::Type*> Params;
333          if (HelperFunctionParameterTy) {
334            llvm::PointerType *HelperFunctionParameterTyP =
335                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
336            Params.push_back(HelperFunctionParameterTyP);
337          }
338
339          llvm::FunctionType * HelperFunctionType =
340              llvm::FunctionType::get(F->getReturnType(),
341                                      Params,
342                                      /* IsVarArgs = */false);
343
344          HelperFunction =
345              llvm::Function::Create(HelperFunctionType,
346                                     llvm::GlobalValue::ExternalLinkage,
347                                     HelperFunctionName,
348                                     M);
349
350          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
351          HelperFunction->setCallingConv(F->getCallingConv());
352
353          // Create helper function body
354          {
355            llvm::Argument *HelperFunctionParameter =
356                &(*HelperFunction->arg_begin());
357            llvm::BasicBlock *BB =
358                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
359            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
360            llvm::SmallVector<llvm::Value*, 6> Params;
361            llvm::Value *Idx[2];
362
363            Idx[0] =
364                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
365
366            // getelementptr and load instruction for all elements in
367            // parameter .p
368            for (size_t i = 0; i < EF->getNumParameters(); i++) {
369              // getelementptr
370              Idx[1] = llvm::ConstantInt::get(
371                llvm::Type::getInt32Ty(mLLVMContext), i);
372
373              llvm::Value *Ptr =
374                IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
375
376              // load
377              llvm::Value *V = IB->CreateLoad(Ptr);
378              Params.push_back(V);
379            }
380
381            // Call and pass the all elements as parameter to F
382            llvm::CallInst *CI = IB->CreateCall(F, Params);
383
384            CI->setCallingConv(F->getCallingConv());
385
386            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
387              IB->CreateRetVoid();
388            else
389              IB->CreateRet(CI);
390
391            delete IB;
392          }
393        }
394
395        ExportFuncInfo.push_back(
396            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
397      }
398
399      mExportFuncMetadata->addOperand(
400          llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
401      ExportFuncInfo.clear();
402    }
403  }
404
405  // Dump export function info
406  if (mContext->hasExportForEach()) {
407    if (mExportForEachMetadata == NULL)
408      mExportForEachMetadata =
409          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
410
411    llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
412
413    for (RSContext::const_export_foreach_iterator
414            I = mContext->export_foreach_begin(),
415            E = mContext->export_foreach_end();
416         I != E;
417         I++) {
418      const RSExportForEach *EFE = *I;
419
420      ExportForEachInfo.push_back(
421          llvm::MDString::get(mLLVMContext,
422                              llvm::utostr_32(EFE->getMetadataEncoding())));
423
424      mExportForEachMetadata->addOperand(
425          llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
426      ExportForEachInfo.clear();
427    }
428  }
429
430  // Dump export type info
431  if (mContext->hasExportType()) {
432    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
433
434    for (RSContext::const_export_type_iterator
435            I = mContext->export_types_begin(),
436            E = mContext->export_types_end();
437         I != E;
438         I++) {
439      // First, dump type name list to export
440      const RSExportType *ET = I->getValue();
441
442      ExportTypeInfo.clear();
443      // Type name
444      ExportTypeInfo.push_back(
445          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
446
447      if (ET->getClass() == RSExportType::ExportClassRecord) {
448        const RSExportRecordType *ERT =
449            static_cast<const RSExportRecordType*>(ET);
450
451        if (mExportTypeMetadata == NULL)
452          mExportTypeMetadata =
453              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
454
455        mExportTypeMetadata->addOperand(
456            llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
457
458        // Now, export struct field information to %[struct name]
459        std::string StructInfoMetadataName("%");
460        StructInfoMetadataName.append(ET->getName());
461        llvm::NamedMDNode *StructInfoMetadata =
462            M->getOrInsertNamedMetadata(StructInfoMetadataName);
463        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
464
465        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
466                    "Metadata with same name was created before");
467        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
468                FE = ERT->fields_end();
469             FI != FE;
470             FI++) {
471          const RSExportRecordType::Field *F = *FI;
472
473          // 1. field name
474          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
475                                                  F->getName().c_str()));
476
477          // 2. field type name
478          FieldInfo.push_back(
479              llvm::MDString::get(mLLVMContext,
480                                  F->getType()->getName().c_str()));
481
482          // 3. field kind
483          switch (F->getType()->getClass()) {
484            case RSExportType::ExportClassPrimitive:
485            case RSExportType::ExportClassVector: {
486              const RSExportPrimitiveType *EPT =
487                  static_cast<const RSExportPrimitiveType*>(F->getType());
488              FieldInfo.push_back(
489                  llvm::MDString::get(mLLVMContext,
490                                      llvm::itostr(EPT->getKind())));
491              break;
492            }
493
494            default: {
495              FieldInfo.push_back(
496                  llvm::MDString::get(mLLVMContext,
497                                      llvm::itostr(
498                                        RSExportPrimitiveType::DataKindUser)));
499              break;
500            }
501          }
502
503          StructInfoMetadata->addOperand(
504              llvm::MDNode::get(mLLVMContext, FieldInfo));
505          FieldInfo.clear();
506        }
507      }   // ET->getClass() == RSExportType::ExportClassRecord
508    }
509  }
510
511  return;
512}
513
514RSBackend::~RSBackend() {
515  return;
516}
517
518}  // namespace slang
519