slang_rs_backend.cpp revision 7aff4a0a124209fdf93ecbcd7aed701d39ba094b
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_assert.h"
35#include "slang_rs.h"
36#include "slang_rs_context.h"
37#include "slang_rs_export_foreach.h"
38#include "slang_rs_export_func.h"
39#include "slang_rs_export_type.h"
40#include "slang_rs_export_var.h"
41#include "slang_rs_metadata.h"
42
43namespace slang {
44
45RSBackend::RSBackend(RSContext *Context,
46                     clang::DiagnosticsEngine *DiagEngine,
47                     const clang::CodeGenOptions &CodeGenOpts,
48                     const clang::TargetOptions &TargetOpts,
49                     PragmaList *Pragmas,
50                     llvm::raw_ostream *OS,
51                     Slang::OutputType OT,
52                     clang::SourceManager &SourceMgr,
53                     bool AllowRSPrefix)
54  : Backend(DiagEngine, CodeGenOpts, TargetOpts, Pragmas, OS, OT),
55    mContext(Context),
56    mSourceMgr(SourceMgr),
57    mAllowRSPrefix(AllowRSPrefix),
58    mExportVarMetadata(NULL),
59    mExportFuncMetadata(NULL),
60    mExportForEachMetadata(NULL),
61    mExportTypeMetadata(NULL),
62    mRSObjectSlotsMetadata(NULL),
63    mRefCount(mContext->getASTContext()) {
64}
65
66// 1) Add zero initialization of local RS object types
67void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
68  if (FD &&
69      FD->hasBody() &&
70      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
71    mRefCount.Init();
72    mRefCount.Visit(FD->getBody());
73  }
74  return;
75}
76
77bool RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
78  // Disallow user-defined functions with prefix "rs"
79  if (!mAllowRSPrefix) {
80    // Iterate all function declarations in the program.
81    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
82         I != E; I++) {
83      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
84      if (FD == NULL)
85        continue;
86      if (!FD->getName().startswith("rs"))  // Check prefix
87        continue;
88      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
89        mDiagEngine.Report(
90          clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
91          mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
92                                      "invalid function name prefix, "
93                                      "\"rs\" is reserved: '%0'"))
94          << FD->getName();
95    }
96  }
97
98  // Process any non-static function declarations
99  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
100    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
101    if (FD && FD->isGlobal()) {
102      AnnotateFunction(FD);
103    }
104  }
105
106  return Backend::HandleTopLevelDecl(D);
107}
108
109namespace {
110
111static bool ValidateVarDecl(clang::VarDecl *VD) {
112  if (!VD) {
113    return true;
114  }
115
116  clang::ASTContext &C = VD->getASTContext();
117  const clang::Type *T = VD->getType().getTypePtr();
118  bool valid = true;
119
120  if (VD->getLinkage() == clang::ExternalLinkage) {
121    llvm::StringRef TypeName;
122    if (!RSExportType::NormalizeType(T, TypeName, &C.getDiagnostics(), VD)) {
123      valid = false;
124    }
125  }
126  valid &= RSExportType::ValidateVarDecl(VD);
127
128  return valid;
129}
130
131static bool ValidateASTContext(clang::ASTContext &C) {
132  bool valid = true;
133  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
134  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
135          DE = TUDecl->decls_end();
136       DI != DE;
137       DI++) {
138    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI);
139    if (VD && !ValidateVarDecl(VD)) {
140      valid = false;
141    }
142  }
143
144  return valid;
145}
146
147}  // namespace
148
149void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
150  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
151
152  if (!ValidateASTContext(C)) {
153    return;
154  }
155
156  int version = mContext->getVersion();
157  if (version == 0) {
158    // Not setting a version is an error
159    mDiagEngine.Report(
160        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
161        mDiagEngine.getCustomDiagID(
162            clang::DiagnosticsEngine::Error,
163            "missing pragma for version in source file"));
164  } else {
165    slangAssert(version == 1);
166  }
167
168  // Create a static global destructor if necessary (to handle RS object
169  // runtime cleanup).
170  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
171  if (FD) {
172    HandleTopLevelDecl(clang::DeclGroupRef(FD));
173  }
174
175  // Process any static function declarations
176  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
177          E = TUDecl->decls_end(); I != E; I++) {
178    if ((I->getKind() >= clang::Decl::firstFunction) &&
179        (I->getKind() <= clang::Decl::lastFunction)) {
180      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
181      if (FD && !FD->isGlobal()) {
182        AnnotateFunction(FD);
183      }
184    }
185  }
186
187  return;
188}
189
190///////////////////////////////////////////////////////////////////////////////
191void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
192  if (!mContext->processExport()) {
193    return;
194  }
195
196  // Dump export variable info
197  if (mContext->hasExportVar()) {
198    int slotCount = 0;
199    if (mExportVarMetadata == NULL)
200      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
201
202    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
203
204    // We emit slot information (#rs_object_slots) for any reference counted
205    // RS type or pointer (which can also be bound).
206
207    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
208            E = mContext->export_vars_end();
209         I != E;
210         I++) {
211      const RSExportVar *EV = *I;
212      const RSExportType *ET = EV->getType();
213      bool countsAsRSObject = false;
214
215      // Variable name
216      ExportVarInfo.push_back(
217          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
218
219      // Type name
220      switch (ET->getClass()) {
221        case RSExportType::ExportClassPrimitive: {
222          const RSExportPrimitiveType *PT =
223              static_cast<const RSExportPrimitiveType*>(ET);
224          ExportVarInfo.push_back(
225              llvm::MDString::get(
226                mLLVMContext, llvm::utostr_32(PT->getType())));
227          if (PT->isRSObjectType()) {
228            countsAsRSObject = true;
229          }
230          break;
231        }
232        case RSExportType::ExportClassPointer: {
233          ExportVarInfo.push_back(
234              llvm::MDString::get(
235                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
236                  ->getPointeeType()->getName()).c_str()));
237          break;
238        }
239        case RSExportType::ExportClassMatrix: {
240          ExportVarInfo.push_back(
241              llvm::MDString::get(
242                mLLVMContext, llvm::utostr_32(
243                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
244                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
245          break;
246        }
247        case RSExportType::ExportClassVector:
248        case RSExportType::ExportClassConstantArray:
249        case RSExportType::ExportClassRecord: {
250          ExportVarInfo.push_back(
251              llvm::MDString::get(mLLVMContext,
252                EV->getType()->getName().c_str()));
253          break;
254        }
255      }
256
257      mExportVarMetadata->addOperand(
258          llvm::MDNode::get(mLLVMContext, ExportVarInfo));
259      ExportVarInfo.clear();
260
261      if (mRSObjectSlotsMetadata == NULL) {
262        mRSObjectSlotsMetadata =
263            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
264      }
265
266      if (countsAsRSObject) {
267        mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
268            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
269      }
270
271      slotCount++;
272    }
273  }
274
275  // Dump export function info
276  if (mContext->hasExportFunc()) {
277    if (mExportFuncMetadata == NULL)
278      mExportFuncMetadata =
279          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
280
281    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
282
283    for (RSContext::const_export_func_iterator
284            I = mContext->export_funcs_begin(),
285            E = mContext->export_funcs_end();
286         I != E;
287         I++) {
288      const RSExportFunc *EF = *I;
289
290      // Function name
291      if (!EF->hasParam()) {
292        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
293                                                     EF->getName().c_str()));
294      } else {
295        llvm::Function *F = M->getFunction(EF->getName());
296        llvm::Function *HelperFunction;
297        const std::string HelperFunctionName(".helper_" + EF->getName());
298
299        slangAssert(F && "Function marked as exported disappeared in Bitcode");
300
301        // Create helper function
302        {
303          llvm::StructType *HelperFunctionParameterTy = NULL;
304
305          if (!F->getArgumentList().empty()) {
306            std::vector<llvm::Type*> HelperFunctionParameterTys;
307            for (llvm::Function::arg_iterator AI = F->arg_begin(),
308                 AE = F->arg_end(); AI != AE; AI++)
309              HelperFunctionParameterTys.push_back(AI->getType());
310
311            HelperFunctionParameterTy =
312                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
313          }
314
315          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
316            fprintf(stderr, "Failed to export function %s: parameter type "
317                            "mismatch during creation of helper function.\n",
318                    EF->getName().c_str());
319
320            const RSExportRecordType *Expected = EF->getParamPacketType();
321            if (Expected) {
322              fprintf(stderr, "Expected:\n");
323              Expected->getLLVMType()->dump();
324            }
325            if (HelperFunctionParameterTy) {
326              fprintf(stderr, "Got:\n");
327              HelperFunctionParameterTy->dump();
328            }
329          }
330
331          std::vector<llvm::Type*> Params;
332          if (HelperFunctionParameterTy) {
333            llvm::PointerType *HelperFunctionParameterTyP =
334                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
335            Params.push_back(HelperFunctionParameterTyP);
336          }
337
338          llvm::FunctionType * HelperFunctionType =
339              llvm::FunctionType::get(F->getReturnType(),
340                                      Params,
341                                      /* IsVarArgs = */false);
342
343          HelperFunction =
344              llvm::Function::Create(HelperFunctionType,
345                                     llvm::GlobalValue::ExternalLinkage,
346                                     HelperFunctionName,
347                                     M);
348
349          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
350          HelperFunction->setCallingConv(F->getCallingConv());
351
352          // Create helper function body
353          {
354            llvm::Argument *HelperFunctionParameter =
355                &(*HelperFunction->arg_begin());
356            llvm::BasicBlock *BB =
357                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
358            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
359            llvm::SmallVector<llvm::Value*, 6> Params;
360            llvm::Value *Idx[2];
361
362            Idx[0] =
363                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
364
365            // getelementptr and load instruction for all elements in
366            // parameter .p
367            for (size_t i = 0; i < EF->getNumParameters(); i++) {
368              // getelementptr
369              Idx[1] = llvm::ConstantInt::get(
370                llvm::Type::getInt32Ty(mLLVMContext), i);
371
372              llvm::Value *Ptr =
373                IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
374
375              // load
376              llvm::Value *V = IB->CreateLoad(Ptr);
377              Params.push_back(V);
378            }
379
380            // Call and pass the all elements as parameter to F
381            llvm::CallInst *CI = IB->CreateCall(F, Params);
382
383            CI->setCallingConv(F->getCallingConv());
384
385            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
386              IB->CreateRetVoid();
387            else
388              IB->CreateRet(CI);
389
390            delete IB;
391          }
392        }
393
394        ExportFuncInfo.push_back(
395            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
396      }
397
398      mExportFuncMetadata->addOperand(
399          llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
400      ExportFuncInfo.clear();
401    }
402  }
403
404  // Dump export function info
405  if (mContext->hasExportForEach()) {
406    if (mExportForEachMetadata == NULL)
407      mExportForEachMetadata =
408          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
409
410    llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
411
412    for (RSContext::const_export_foreach_iterator
413            I = mContext->export_foreach_begin(),
414            E = mContext->export_foreach_end();
415         I != E;
416         I++) {
417      const RSExportForEach *EFE = *I;
418
419      ExportForEachInfo.push_back(
420          llvm::MDString::get(mLLVMContext,
421                              llvm::utostr_32(EFE->getMetadataEncoding())));
422
423      mExportForEachMetadata->addOperand(
424          llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
425      ExportForEachInfo.clear();
426    }
427  }
428
429  // Dump export type info
430  if (mContext->hasExportType()) {
431    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
432
433    for (RSContext::const_export_type_iterator
434            I = mContext->export_types_begin(),
435            E = mContext->export_types_end();
436         I != E;
437         I++) {
438      // First, dump type name list to export
439      const RSExportType *ET = I->getValue();
440
441      ExportTypeInfo.clear();
442      // Type name
443      ExportTypeInfo.push_back(
444          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
445
446      if (ET->getClass() == RSExportType::ExportClassRecord) {
447        const RSExportRecordType *ERT =
448            static_cast<const RSExportRecordType*>(ET);
449
450        if (mExportTypeMetadata == NULL)
451          mExportTypeMetadata =
452              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
453
454        mExportTypeMetadata->addOperand(
455            llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
456
457        // Now, export struct field information to %[struct name]
458        std::string StructInfoMetadataName("%");
459        StructInfoMetadataName.append(ET->getName());
460        llvm::NamedMDNode *StructInfoMetadata =
461            M->getOrInsertNamedMetadata(StructInfoMetadataName);
462        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
463
464        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
465                    "Metadata with same name was created before");
466        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
467                FE = ERT->fields_end();
468             FI != FE;
469             FI++) {
470          const RSExportRecordType::Field *F = *FI;
471
472          // 1. field name
473          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
474                                                  F->getName().c_str()));
475
476          // 2. field type name
477          FieldInfo.push_back(
478              llvm::MDString::get(mLLVMContext,
479                                  F->getType()->getName().c_str()));
480
481          // 3. field kind
482          switch (F->getType()->getClass()) {
483            case RSExportType::ExportClassPrimitive:
484            case RSExportType::ExportClassVector: {
485              const RSExportPrimitiveType *EPT =
486                  static_cast<const RSExportPrimitiveType*>(F->getType());
487              FieldInfo.push_back(
488                  llvm::MDString::get(mLLVMContext,
489                                      llvm::itostr(EPT->getKind())));
490              break;
491            }
492
493            default: {
494              FieldInfo.push_back(
495                  llvm::MDString::get(mLLVMContext,
496                                      llvm::itostr(
497                                        RSExportPrimitiveType::DataKindUser)));
498              break;
499            }
500          }
501
502          StructInfoMetadata->addOperand(
503              llvm::MDNode::get(mLLVMContext, FieldInfo));
504          FieldInfo.clear();
505        }
506      }   // ET->getClass() == RSExportType::ExportClassRecord
507    }
508  }
509
510  return;
511}
512
513RSBackend::~RSBackend() {
514  return;
515}
516
517}  // namespace slang
518