slang_rs_backend.cpp revision d0b5edd02be5f09c1d8d211f4a06b031a7b66510
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_assert.h"
35#include "slang_rs.h"
36#include "slang_rs_context.h"
37#include "slang_rs_export_func.h"
38#include "slang_rs_export_type.h"
39#include "slang_rs_export_var.h"
40#include "slang_rs_metadata.h"
41
42namespace slang {
43
44RSBackend::RSBackend(RSContext *Context,
45                     clang::Diagnostic *Diags,
46                     const clang::CodeGenOptions &CodeGenOpts,
47                     const clang::TargetOptions &TargetOpts,
48                     PragmaList *Pragmas,
49                     llvm::raw_ostream *OS,
50                     Slang::OutputType OT,
51                     clang::SourceManager &SourceMgr,
52                     bool AllowRSPrefix)
53    : Backend(Diags,
54              CodeGenOpts,
55              TargetOpts,
56              Pragmas,
57              OS,
58              OT),
59      mContext(Context),
60      mSourceMgr(SourceMgr),
61      mAllowRSPrefix(AllowRSPrefix),
62      mExportVarMetadata(NULL),
63      mExportFuncMetadata(NULL),
64      mExportTypeMetadata(NULL),
65      mRSObjectSlotsMetadata(NULL),
66      mRefCount(mContext->getASTContext()) {
67  return;
68}
69
70// 1) Add zero initialization of local RS object types
71void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
72  if (FD &&
73      FD->hasBody() &&
74      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
75    mRefCount.Init();
76    mRefCount.Visit(FD->getBody());
77  }
78  return;
79}
80
81void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
82  // Disallow user-defined functions with prefix "rs"
83  if (!mAllowRSPrefix) {
84    // Iterate all function declarations in the program.
85    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
86         I != E; I++) {
87      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
88      if (FD == NULL)
89        continue;
90      if (!FD->getName().startswith("rs"))  // Check prefix
91        continue;
92      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
93        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
94                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
95                                             "invalid function name prefix, "
96                                             "\"rs\" is reserved: '%0'"))
97            << FD->getName();
98    }
99  }
100
101  // Process any non-static function declarations
102  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
103    clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
104    if (FD && FD->isGlobal()) {
105      AnnotateFunction(FD);
106    }
107  }
108
109  Backend::HandleTopLevelDecl(D);
110  return;
111}
112
113namespace {
114
115bool ValidateVar(clang::VarDecl *VD, clang::Diagnostic *Diags,
116    clang::SourceManager *SM) {
117  llvm::StringRef TypeName;
118  const clang::Type *T = VD->getType().getTypePtr();
119  if (!RSExportType::NormalizeType(T, TypeName, Diags, SM, VD)) {
120    return false;
121  }
122  return true;
123}
124
125bool ValidateASTContext(clang::ASTContext &C, clang::Diagnostic &Diags) {
126  bool valid = true;
127  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
128  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
129          DE = TUDecl->decls_end();
130       DI != DE;
131       DI++) {
132    if (DI->getKind() == clang::Decl::Var) {
133      clang::VarDecl *VD = (clang::VarDecl*) (*DI);
134      if (VD->getLinkage() == clang::ExternalLinkage) {
135        if (!ValidateVar(VD, &Diags, &C.getSourceManager())) {
136          valid = false;
137        }
138      }
139    }
140  }
141
142  return valid;
143}
144
145}  // namespace
146
147void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
148  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
149
150  if (!ValidateASTContext(C, mDiags)) {
151    return;
152  }
153
154  int version = mContext->getVersion();
155  if (version == 0) {
156    // Not setting a version is an error
157    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
158                      "Missing pragma for version in source file"));
159  } else if (version > 1) {
160    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
161                      "Pragma for version in source file must be set to 1"));
162  }
163
164  // Process any static function declarations
165  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
166          E = TUDecl->decls_end(); I != E; I++) {
167    if ((I->getKind() >= clang::Decl::firstFunction) &&
168        (I->getKind() <= clang::Decl::lastFunction)) {
169      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
170      if (FD && !FD->isGlobal()) {
171        AnnotateFunction(FD);
172      }
173    }
174  }
175
176  return;
177}
178
179///////////////////////////////////////////////////////////////////////////////
180void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
181  if (!mContext->processExport()) {
182    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
183                                         "elements cannot be exported"));
184    return;
185  }
186
187  // Dump export variable info
188  if (mContext->hasExportVar()) {
189    int slotCount = 0;
190    if (mExportVarMetadata == NULL)
191      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
192
193    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
194    llvm::SmallVector<llvm::Value*, 1> SlotVarInfo;
195
196    // We emit slot information (#rs_object_slots) for any reference counted
197    // RS type or pointer (which can also be bound).
198
199    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
200            E = mContext->export_vars_end();
201         I != E;
202         I++) {
203      const RSExportVar *EV = *I;
204      const RSExportType *ET = EV->getType();
205      bool countsAsRSObject = false;
206
207      // Variable name
208      ExportVarInfo.push_back(
209          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
210
211      // Type name
212      switch (ET->getClass()) {
213        case RSExportType::ExportClassPrimitive: {
214          const RSExportPrimitiveType *PT =
215              static_cast<const RSExportPrimitiveType*>(ET);
216          ExportVarInfo.push_back(
217              llvm::MDString::get(
218                mLLVMContext, llvm::utostr_32(PT->getType())));
219          if (PT->isRSObjectType()) {
220            countsAsRSObject = true;
221          }
222          break;
223        }
224        case RSExportType::ExportClassPointer: {
225          ExportVarInfo.push_back(
226              llvm::MDString::get(
227                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
228                  ->getPointeeType()->getName()).c_str()));
229          break;
230        }
231        case RSExportType::ExportClassMatrix: {
232          ExportVarInfo.push_back(
233              llvm::MDString::get(
234                mLLVMContext, llvm::utostr_32(
235                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
236                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
237          break;
238        }
239        case RSExportType::ExportClassVector:
240        case RSExportType::ExportClassConstantArray:
241        case RSExportType::ExportClassRecord: {
242          ExportVarInfo.push_back(
243              llvm::MDString::get(mLLVMContext,
244                EV->getType()->getName().c_str()));
245          break;
246        }
247      }
248
249      mExportVarMetadata->addOperand(
250          llvm::MDNode::get(mLLVMContext,
251                            ExportVarInfo.data(),
252                            ExportVarInfo.size()) );
253
254      ExportVarInfo.clear();
255
256      if (mRSObjectSlotsMetadata == NULL) {
257        mRSObjectSlotsMetadata =
258            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
259      }
260
261      if (countsAsRSObject) {
262        SlotVarInfo.push_back(
263            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount)));
264
265        mRSObjectSlotsMetadata->addOperand(
266            llvm::MDNode::get(mLLVMContext,
267                              SlotVarInfo.data(),
268                              SlotVarInfo.size()));
269      }
270
271      slotCount++;
272      SlotVarInfo.clear();
273    }
274  }
275
276  // Dump export function info
277  if (mContext->hasExportFunc()) {
278    if (mExportFuncMetadata == NULL)
279      mExportFuncMetadata =
280          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
281
282    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
283
284    for (RSContext::const_export_func_iterator
285            I = mContext->export_funcs_begin(),
286            E = mContext->export_funcs_end();
287         I != E;
288         I++) {
289      const RSExportFunc *EF = *I;
290
291      // Function name
292      if (!EF->hasParam()) {
293        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
294                                                     EF->getName().c_str()));
295      } else {
296        llvm::Function *F = M->getFunction(EF->getName());
297        llvm::Function *HelperFunction;
298        const std::string HelperFunctionName(".helper_" + EF->getName());
299
300        slangAssert(F && "Function marked as exported disappeared in Bitcode");
301
302        // Create helper function
303        {
304          llvm::StructType *HelperFunctionParameterTy = NULL;
305
306          if (!F->getArgumentList().empty()) {
307            std::vector<const llvm::Type*> HelperFunctionParameterTys;
308            for (llvm::Function::arg_iterator AI = F->arg_begin(),
309                 AE = F->arg_end(); AI != AE; AI++)
310              HelperFunctionParameterTys.push_back(AI->getType());
311
312            HelperFunctionParameterTy =
313                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
314          }
315
316          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
317            fprintf(stderr, "Failed to export function %s: parameter type "
318                            "mismatch during creation of helper function.\n",
319                    EF->getName().c_str());
320
321            const RSExportRecordType *Expected = EF->getParamPacketType();
322            if (Expected) {
323              fprintf(stderr, "Expected:\n");
324              Expected->getLLVMType()->dump();
325            }
326            if (HelperFunctionParameterTy) {
327              fprintf(stderr, "Got:\n");
328              HelperFunctionParameterTy->dump();
329            }
330          }
331
332          std::vector<const llvm::Type*> Params;
333          if (HelperFunctionParameterTy) {
334            llvm::PointerType *HelperFunctionParameterTyP =
335                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
336            Params.push_back(HelperFunctionParameterTyP);
337          }
338
339          llvm::FunctionType * HelperFunctionType =
340              llvm::FunctionType::get(F->getReturnType(),
341                                      Params,
342                                      /* IsVarArgs = */false);
343
344          HelperFunction =
345              llvm::Function::Create(HelperFunctionType,
346                                     llvm::GlobalValue::ExternalLinkage,
347                                     HelperFunctionName,
348                                     M);
349
350          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
351          HelperFunction->setCallingConv(F->getCallingConv());
352
353          // Create helper function body
354          {
355            llvm::Argument *HelperFunctionParameter =
356                &(*HelperFunction->arg_begin());
357            llvm::BasicBlock *BB =
358                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
359            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
360            llvm::SmallVector<llvm::Value*, 6> Params;
361            llvm::Value *Idx[2];
362
363            Idx[0] =
364                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
365
366            // getelementptr and load instruction for all elements in
367            // parameter .p
368            for (size_t i = 0; i < EF->getNumParameters(); i++) {
369              // getelementptr
370              Idx[1] =
371                  llvm::ConstantInt::get(
372                      llvm::Type::getInt32Ty(mLLVMContext), i);
373              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
374                                                       Idx,
375                                                       Idx + 2);
376
377              // load
378              llvm::Value *V = IB->CreateLoad(Ptr);
379              Params.push_back(V);
380            }
381
382            // Call and pass the all elements as paramter to F
383            llvm::CallInst *CI = IB->CreateCall(F,
384                                                Params.data(),
385                                                Params.data() + Params.size());
386
387            CI->setCallingConv(F->getCallingConv());
388
389            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
390              IB->CreateRetVoid();
391            else
392              IB->CreateRet(CI);
393
394            delete IB;
395          }
396        }
397
398        ExportFuncInfo.push_back(
399            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
400      }
401
402      mExportFuncMetadata->addOperand(
403          llvm::MDNode::get(mLLVMContext,
404                            ExportFuncInfo.data(),
405                            ExportFuncInfo.size()));
406
407      ExportFuncInfo.clear();
408    }
409  }
410
411  // Dump export type info
412  if (mContext->hasExportType()) {
413    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
414
415    for (RSContext::const_export_type_iterator
416            I = mContext->export_types_begin(),
417            E = mContext->export_types_end();
418         I != E;
419         I++) {
420      // First, dump type name list to export
421      const RSExportType *ET = I->getValue();
422
423      ExportTypeInfo.clear();
424      // Type name
425      ExportTypeInfo.push_back(
426          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
427
428      if (ET->getClass() == RSExportType::ExportClassRecord) {
429        const RSExportRecordType *ERT =
430            static_cast<const RSExportRecordType*>(ET);
431
432        if (mExportTypeMetadata == NULL)
433          mExportTypeMetadata =
434              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
435
436        mExportTypeMetadata->addOperand(
437            llvm::MDNode::get(mLLVMContext,
438                              ExportTypeInfo.data(),
439                              ExportTypeInfo.size()));
440
441        // Now, export struct field information to %[struct name]
442        std::string StructInfoMetadataName("%");
443        StructInfoMetadataName.append(ET->getName());
444        llvm::NamedMDNode *StructInfoMetadata =
445            M->getOrInsertNamedMetadata(StructInfoMetadataName);
446        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
447
448        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
449                    "Metadata with same name was created before");
450        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
451                FE = ERT->fields_end();
452             FI != FE;
453             FI++) {
454          const RSExportRecordType::Field *F = *FI;
455
456          // 1. field name
457          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
458                                                  F->getName().c_str()));
459
460          // 2. field type name
461          FieldInfo.push_back(
462              llvm::MDString::get(mLLVMContext,
463                                  F->getType()->getName().c_str()));
464
465          // 3. field kind
466          switch (F->getType()->getClass()) {
467            case RSExportType::ExportClassPrimitive:
468            case RSExportType::ExportClassVector: {
469              const RSExportPrimitiveType *EPT =
470                  static_cast<const RSExportPrimitiveType*>(F->getType());
471              FieldInfo.push_back(
472                  llvm::MDString::get(mLLVMContext,
473                                      llvm::itostr(EPT->getKind())));
474              break;
475            }
476
477            default: {
478              FieldInfo.push_back(
479                  llvm::MDString::get(mLLVMContext,
480                                      llvm::itostr(
481                                        RSExportPrimitiveType::DataKindUser)));
482              break;
483            }
484          }
485
486          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
487                                                           FieldInfo.data(),
488                                                           FieldInfo.size()));
489
490          FieldInfo.clear();
491        }
492      }   // ET->getClass() == RSExportType::ExportClassRecord
493    }
494  }
495
496  return;
497}
498
499RSBackend::~RSBackend() {
500  return;
501}
502
503}  // namespace slang
504