slang_rs_backend.cpp revision 8004db04a13a3031ca21dc0b2bb8875d5bbd4473
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "clang/Frontend/CodeGenOptions.h"
23
24#include "llvm/ADT/Twine.h"
25#include "llvm/ADT/StringExtras.h"
26
27#include "llvm/Constant.h"
28#include "llvm/Constants.h"
29#include "llvm/DerivedTypes.h"
30#include "llvm/Function.h"
31#include "llvm/Metadata.h"
32#include "llvm/Module.h"
33
34#include "llvm/Support/DebugLoc.h"
35#include "llvm/Support/IRBuilder.h"
36
37#include "slang_assert.h"
38#include "slang_rs.h"
39#include "slang_rs_context.h"
40#include "slang_rs_export_foreach.h"
41#include "slang_rs_export_func.h"
42#include "slang_rs_export_type.h"
43#include "slang_rs_export_var.h"
44#include "slang_rs_metadata.h"
45
46namespace slang {
47
48RSBackend::RSBackend(RSContext *Context,
49                     clang::DiagnosticsEngine *DiagEngine,
50                     const clang::CodeGenOptions &CodeGenOpts,
51                     const clang::TargetOptions &TargetOpts,
52                     PragmaList *Pragmas,
53                     llvm::raw_ostream *OS,
54                     Slang::OutputType OT,
55                     clang::SourceManager &SourceMgr,
56                     bool AllowRSPrefix)
57  : Backend(DiagEngine, CodeGenOpts, TargetOpts, Pragmas, OS, OT),
58    mContext(Context),
59    mSourceMgr(SourceMgr),
60    mAllowRSPrefix(AllowRSPrefix),
61    mExportVarMetadata(NULL),
62    mExportFuncMetadata(NULL),
63    mExportForEachNameMetadata(NULL),
64    mExportForEachSignatureMetadata(NULL),
65    mExportTypeMetadata(NULL),
66    mRSObjectSlotsMetadata(NULL),
67    mRefCount(mContext->getASTContext()) {
68}
69
70// 1) Add zero initialization of local RS object types
71void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
72  if (FD &&
73      FD->hasBody() &&
74      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
75    mRefCount.Init();
76    mRefCount.Visit(FD->getBody());
77  }
78  return;
79}
80
81bool RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
82  // Disallow user-defined functions with prefix "rs"
83  if (!mAllowRSPrefix) {
84    // Iterate all function declarations in the program.
85    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
86         I != E; I++) {
87      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
88      if (FD == NULL)
89        continue;
90      if (!FD->getName().startswith("rs"))  // Check prefix
91        continue;
92      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
93        mDiagEngine.Report(
94          clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
95          mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
96                                      "invalid function name prefix, "
97                                      "\"rs\" is reserved: '%0'"))
98          << FD->getName();
99    }
100  }
101
102  // Process any non-static function declarations
103  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
104    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
105    if (FD && FD->isGlobal()) {
106      AnnotateFunction(FD);
107    }
108  }
109
110  return Backend::HandleTopLevelDecl(D);
111}
112
113namespace {
114
115static bool ValidateVarDecl(clang::VarDecl *VD) {
116  if (!VD) {
117    return true;
118  }
119
120  clang::ASTContext &C = VD->getASTContext();
121  const clang::Type *T = VD->getType().getTypePtr();
122  bool valid = true;
123
124  if (VD->getLinkage() == clang::ExternalLinkage) {
125    llvm::StringRef TypeName;
126    if (!RSExportType::NormalizeType(T, TypeName, &C.getDiagnostics(), VD)) {
127      valid = false;
128    }
129  }
130  valid &= RSExportType::ValidateVarDecl(VD);
131
132  return valid;
133}
134
135static bool ValidateASTContext(clang::ASTContext &C) {
136  bool valid = true;
137  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
138  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
139          DE = TUDecl->decls_end();
140       DI != DE;
141       DI++) {
142    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI);
143    if (VD && !ValidateVarDecl(VD)) {
144      valid = false;
145    }
146  }
147
148  return valid;
149}
150
151}  // namespace
152
153void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
154  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
155
156  if (!ValidateASTContext(C)) {
157    return;
158  }
159
160  int version = mContext->getVersion();
161  if (version == 0) {
162    // Not setting a version is an error
163    mDiagEngine.Report(
164        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
165        mDiagEngine.getCustomDiagID(
166            clang::DiagnosticsEngine::Error,
167            "missing pragma for version in source file"));
168  } else {
169    slangAssert(version == 1);
170  }
171
172  // Create a static global destructor if necessary (to handle RS object
173  // runtime cleanup).
174  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
175  if (FD) {
176    HandleTopLevelDecl(clang::DeclGroupRef(FD));
177  }
178
179  // Process any static function declarations
180  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
181          E = TUDecl->decls_end(); I != E; I++) {
182    if ((I->getKind() >= clang::Decl::firstFunction) &&
183        (I->getKind() <= clang::Decl::lastFunction)) {
184      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
185      if (FD && !FD->isGlobal()) {
186        AnnotateFunction(FD);
187      }
188    }
189  }
190
191  return;
192}
193
194///////////////////////////////////////////////////////////////////////////////
195void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
196  if (!mContext->processExport()) {
197    return;
198  }
199
200  // Write optimization level
201  llvm::SmallVector<llvm::Value*, 1> OptimizationOption;
202  OptimizationOption.push_back(llvm::ConstantInt::get(
203    mLLVMContext, llvm::APInt(32, mCodeGenOpts.OptimizationLevel)));
204
205  // Dump export variable info
206  if (mContext->hasExportVar()) {
207    int slotCount = 0;
208    if (mExportVarMetadata == NULL)
209      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
210
211    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
212
213    // We emit slot information (#rs_object_slots) for any reference counted
214    // RS type or pointer (which can also be bound).
215
216    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
217            E = mContext->export_vars_end();
218         I != E;
219         I++) {
220      const RSExportVar *EV = *I;
221      const RSExportType *ET = EV->getType();
222      bool countsAsRSObject = false;
223
224      // Variable name
225      ExportVarInfo.push_back(
226          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
227
228      // Type name
229      switch (ET->getClass()) {
230        case RSExportType::ExportClassPrimitive: {
231          const RSExportPrimitiveType *PT =
232              static_cast<const RSExportPrimitiveType*>(ET);
233          ExportVarInfo.push_back(
234              llvm::MDString::get(
235                mLLVMContext, llvm::utostr_32(PT->getType())));
236          if (PT->isRSObjectType()) {
237            countsAsRSObject = true;
238          }
239          break;
240        }
241        case RSExportType::ExportClassPointer: {
242          ExportVarInfo.push_back(
243              llvm::MDString::get(
244                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
245                  ->getPointeeType()->getName()).c_str()));
246          break;
247        }
248        case RSExportType::ExportClassMatrix: {
249          ExportVarInfo.push_back(
250              llvm::MDString::get(
251                mLLVMContext, llvm::utostr_32(
252                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
253                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
254          break;
255        }
256        case RSExportType::ExportClassVector:
257        case RSExportType::ExportClassConstantArray:
258        case RSExportType::ExportClassRecord: {
259          ExportVarInfo.push_back(
260              llvm::MDString::get(mLLVMContext,
261                EV->getType()->getName().c_str()));
262          break;
263        }
264      }
265
266      mExportVarMetadata->addOperand(
267          llvm::MDNode::get(mLLVMContext, ExportVarInfo));
268      ExportVarInfo.clear();
269
270      if (mRSObjectSlotsMetadata == NULL) {
271        mRSObjectSlotsMetadata =
272            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
273      }
274
275      if (countsAsRSObject) {
276        mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
277            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
278      }
279
280      slotCount++;
281    }
282  }
283
284  // Dump export function info
285  if (mContext->hasExportFunc()) {
286    if (mExportFuncMetadata == NULL)
287      mExportFuncMetadata =
288          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
289
290    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
291
292    for (RSContext::const_export_func_iterator
293            I = mContext->export_funcs_begin(),
294            E = mContext->export_funcs_end();
295         I != E;
296         I++) {
297      const RSExportFunc *EF = *I;
298
299      // Function name
300      if (!EF->hasParam()) {
301        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
302                                                     EF->getName().c_str()));
303      } else {
304        llvm::Function *F = M->getFunction(EF->getName());
305        llvm::Function *HelperFunction;
306        const std::string HelperFunctionName(".helper_" + EF->getName());
307
308        slangAssert(F && "Function marked as exported disappeared in Bitcode");
309
310        // Create helper function
311        {
312          llvm::StructType *HelperFunctionParameterTy = NULL;
313
314          if (!F->getArgumentList().empty()) {
315            std::vector<llvm::Type*> HelperFunctionParameterTys;
316            for (llvm::Function::arg_iterator AI = F->arg_begin(),
317                 AE = F->arg_end(); AI != AE; AI++)
318              HelperFunctionParameterTys.push_back(AI->getType());
319
320            HelperFunctionParameterTy =
321                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
322          }
323
324          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
325            fprintf(stderr, "Failed to export function %s: parameter type "
326                            "mismatch during creation of helper function.\n",
327                    EF->getName().c_str());
328
329            const RSExportRecordType *Expected = EF->getParamPacketType();
330            if (Expected) {
331              fprintf(stderr, "Expected:\n");
332              Expected->getLLVMType()->dump();
333            }
334            if (HelperFunctionParameterTy) {
335              fprintf(stderr, "Got:\n");
336              HelperFunctionParameterTy->dump();
337            }
338          }
339
340          std::vector<llvm::Type*> Params;
341          if (HelperFunctionParameterTy) {
342            llvm::PointerType *HelperFunctionParameterTyP =
343                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
344            Params.push_back(HelperFunctionParameterTyP);
345          }
346
347          llvm::FunctionType * HelperFunctionType =
348              llvm::FunctionType::get(F->getReturnType(),
349                                      Params,
350                                      /* IsVarArgs = */false);
351
352          HelperFunction =
353              llvm::Function::Create(HelperFunctionType,
354                                     llvm::GlobalValue::ExternalLinkage,
355                                     HelperFunctionName,
356                                     M);
357
358          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
359          HelperFunction->setCallingConv(F->getCallingConv());
360
361          // Create helper function body
362          {
363            llvm::Argument *HelperFunctionParameter =
364                &(*HelperFunction->arg_begin());
365            llvm::BasicBlock *BB =
366                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
367            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
368            llvm::SmallVector<llvm::Value*, 6> Params;
369            llvm::Value *Idx[2];
370
371            Idx[0] =
372                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
373
374            // getelementptr and load instruction for all elements in
375            // parameter .p
376            for (size_t i = 0; i < EF->getNumParameters(); i++) {
377              // getelementptr
378              Idx[1] = llvm::ConstantInt::get(
379                llvm::Type::getInt32Ty(mLLVMContext), i);
380
381              llvm::Value *Ptr =
382                IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
383
384              // load
385              llvm::Value *V = IB->CreateLoad(Ptr);
386              Params.push_back(V);
387            }
388
389            // Call and pass the all elements as parameter to F
390            llvm::CallInst *CI = IB->CreateCall(F, Params);
391
392            CI->setCallingConv(F->getCallingConv());
393
394            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
395              IB->CreateRetVoid();
396            else
397              IB->CreateRet(CI);
398
399            delete IB;
400          }
401        }
402
403        ExportFuncInfo.push_back(
404            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
405      }
406
407      mExportFuncMetadata->addOperand(
408          llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
409      ExportFuncInfo.clear();
410    }
411  }
412
413  // Dump export function info
414  if (mContext->hasExportForEach()) {
415    if (mExportForEachNameMetadata == NULL) {
416      mExportForEachNameMetadata =
417          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_NAME_MN);
418    }
419    if (mExportForEachSignatureMetadata == NULL) {
420      mExportForEachSignatureMetadata =
421          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
422    }
423
424    llvm::SmallVector<llvm::Value*, 1> ExportForEachName;
425    llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
426
427    for (RSContext::const_export_foreach_iterator
428            I = mContext->export_foreach_begin(),
429            E = mContext->export_foreach_end();
430         I != E;
431         I++) {
432      const RSExportForEach *EFE = *I;
433
434      ExportForEachName.push_back(
435          llvm::MDString::get(mLLVMContext, EFE->getName().c_str()));
436
437      mExportForEachNameMetadata->addOperand(
438          llvm::MDNode::get(mLLVMContext, ExportForEachName));
439      ExportForEachName.clear();
440
441      ExportForEachInfo.push_back(
442          llvm::MDString::get(mLLVMContext,
443                              llvm::utostr_32(EFE->getSignatureMetadata())));
444
445      mExportForEachSignatureMetadata->addOperand(
446          llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
447      ExportForEachInfo.clear();
448    }
449  }
450
451  // Dump export type info
452  if (mContext->hasExportType()) {
453    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
454
455    for (RSContext::const_export_type_iterator
456            I = mContext->export_types_begin(),
457            E = mContext->export_types_end();
458         I != E;
459         I++) {
460      // First, dump type name list to export
461      const RSExportType *ET = I->getValue();
462
463      ExportTypeInfo.clear();
464      // Type name
465      ExportTypeInfo.push_back(
466          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
467
468      if (ET->getClass() == RSExportType::ExportClassRecord) {
469        const RSExportRecordType *ERT =
470            static_cast<const RSExportRecordType*>(ET);
471
472        if (mExportTypeMetadata == NULL)
473          mExportTypeMetadata =
474              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
475
476        mExportTypeMetadata->addOperand(
477            llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
478
479        // Now, export struct field information to %[struct name]
480        std::string StructInfoMetadataName("%");
481        StructInfoMetadataName.append(ET->getName());
482        llvm::NamedMDNode *StructInfoMetadata =
483            M->getOrInsertNamedMetadata(StructInfoMetadataName);
484        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
485
486        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
487                    "Metadata with same name was created before");
488        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
489                FE = ERT->fields_end();
490             FI != FE;
491             FI++) {
492          const RSExportRecordType::Field *F = *FI;
493
494          // 1. field name
495          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
496                                                  F->getName().c_str()));
497
498          // 2. field type name
499          FieldInfo.push_back(
500              llvm::MDString::get(mLLVMContext,
501                                  F->getType()->getName().c_str()));
502
503          StructInfoMetadata->addOperand(
504              llvm::MDNode::get(mLLVMContext, FieldInfo));
505          FieldInfo.clear();
506        }
507      }   // ET->getClass() == RSExportType::ExportClassRecord
508    }
509  }
510
511  return;
512}
513
514RSBackend::~RSBackend() {
515  return;
516}
517
518}  // namespace slang
519