slang_rs_backend.cpp revision c460b37ffb50819a32c2a8967754b6f784b28263
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "clang/Frontend/CodeGenOptions.h"
23
24#include "llvm/ADT/Twine.h"
25#include "llvm/ADT/StringExtras.h"
26
27#include "llvm/Constant.h"
28#include "llvm/Constants.h"
29#include "llvm/DerivedTypes.h"
30#include "llvm/Function.h"
31#include "llvm/Metadata.h"
32#include "llvm/Module.h"
33
34#include "llvm/Support/DebugLoc.h"
35#include "llvm/Support/IRBuilder.h"
36
37#include "slang_assert.h"
38#include "slang_rs.h"
39#include "slang_rs_context.h"
40#include "slang_rs_export_foreach.h"
41#include "slang_rs_export_func.h"
42#include "slang_rs_export_type.h"
43#include "slang_rs_export_var.h"
44#include "slang_rs_metadata.h"
45
46namespace slang {
47
48RSBackend::RSBackend(RSContext *Context,
49                     clang::DiagnosticsEngine *DiagEngine,
50                     const clang::CodeGenOptions &CodeGenOpts,
51                     const clang::TargetOptions &TargetOpts,
52                     PragmaList *Pragmas,
53                     llvm::raw_ostream *OS,
54                     Slang::OutputType OT,
55                     clang::SourceManager &SourceMgr,
56                     bool AllowRSPrefix)
57  : Backend(DiagEngine, CodeGenOpts, TargetOpts, Pragmas, OS, OT),
58    mContext(Context),
59    mSourceMgr(SourceMgr),
60    mAllowRSPrefix(AllowRSPrefix),
61    mExportVarMetadata(NULL),
62    mExportFuncMetadata(NULL),
63    mExportForEachMetadata(NULL),
64    mExportTypeMetadata(NULL),
65    mRSObjectSlotsMetadata(NULL),
66    mRSOptimizationMetadata(NULL),
67    mRefCount(mContext->getASTContext()) {
68}
69
70// 1) Add zero initialization of local RS object types
71void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
72  if (FD &&
73      FD->hasBody() &&
74      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
75    mRefCount.Init();
76    mRefCount.Visit(FD->getBody());
77  }
78  return;
79}
80
81void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
82  // Disallow user-defined functions with prefix "rs"
83  if (!mAllowRSPrefix) {
84    // Iterate all function declarations in the program.
85    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
86         I != E; I++) {
87      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
88      if (FD == NULL)
89        continue;
90      if (!FD->getName().startswith("rs"))  // Check prefix
91        continue;
92      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
93        mDiagEngine.Report(
94          clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
95          mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
96                                      "invalid function name prefix, "
97                                      "\"rs\" is reserved: '%0'"))
98          << FD->getName();
99    }
100  }
101
102  // Process any non-static function declarations
103  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
104    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
105    if (FD && FD->isGlobal()) {
106      AnnotateFunction(FD);
107    }
108  }
109
110  Backend::HandleTopLevelDecl(D);
111  return;
112}
113
114namespace {
115
116static bool ValidateVarDecl(clang::VarDecl *VD) {
117  if (!VD) {
118    return true;
119  }
120
121  clang::ASTContext &C = VD->getASTContext();
122  const clang::Type *T = VD->getType().getTypePtr();
123  bool valid = true;
124
125  if (VD->getLinkage() == clang::ExternalLinkage) {
126    llvm::StringRef TypeName;
127    if (!RSExportType::NormalizeType(T, TypeName, &C.getDiagnostics(), VD)) {
128      valid = false;
129    }
130  }
131  valid &= RSExportType::ValidateVarDecl(VD);
132
133  return valid;
134}
135
136static bool ValidateASTContext(clang::ASTContext &C) {
137  bool valid = true;
138  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
139  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
140          DE = TUDecl->decls_end();
141       DI != DE;
142       DI++) {
143    clang::VarDecl *VD = llvm::dyn_cast<clang::VarDecl>(*DI);
144    if (VD && !ValidateVarDecl(VD)) {
145      valid = false;
146    }
147  }
148
149  return valid;
150}
151
152}  // namespace
153
154void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
155  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
156
157  if (!ValidateASTContext(C)) {
158    return;
159  }
160
161  int version = mContext->getVersion();
162  if (version == 0) {
163    // Not setting a version is an error
164    mDiagEngine.Report(mDiagEngine.getCustomDiagID(
165      clang::DiagnosticsEngine::Error,
166      "Missing pragma for version in source file"));
167  } else if (version > 1) {
168    mDiagEngine.Report(mDiagEngine.getCustomDiagID(
169      clang::DiagnosticsEngine::Error,
170      "Pragma for version in source file must be set to 1"));
171  }
172
173  // Create a static global destructor if necessary (to handle RS object
174  // runtime cleanup).
175  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
176  if (FD) {
177    HandleTopLevelDecl(clang::DeclGroupRef(FD));
178  }
179
180  // Process any static function declarations
181  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
182          E = TUDecl->decls_end(); I != E; I++) {
183    if ((I->getKind() >= clang::Decl::firstFunction) &&
184        (I->getKind() <= clang::Decl::lastFunction)) {
185      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
186      if (FD && !FD->isGlobal()) {
187        AnnotateFunction(FD);
188      }
189    }
190  }
191
192  return;
193}
194
195///////////////////////////////////////////////////////////////////////////////
196void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
197  if (!mContext->processExport()) {
198    return;
199  }
200
201  // Write optimization level
202  llvm::SmallVector<llvm::Value*, 1> OptimizationOption;
203  OptimizationOption.push_back(llvm::ConstantInt::get(
204    mLLVMContext, llvm::APInt(32, mCodeGenOpts.OptimizationLevel)));
205
206  if (mRSOptimizationMetadata == NULL)
207    mRSOptimizationMetadata = M->getOrInsertNamedMetadata(OPTIMIZATION_LEVEL_MN);
208  mRSOptimizationMetadata->addOperand(
209    llvm::MDNode::get(mLLVMContext, OptimizationOption));
210
211  // Dump export variable info
212  if (mContext->hasExportVar()) {
213    int slotCount = 0;
214    if (mExportVarMetadata == NULL)
215      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
216
217    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
218
219    // We emit slot information (#rs_object_slots) for any reference counted
220    // RS type or pointer (which can also be bound).
221
222    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
223            E = mContext->export_vars_end();
224         I != E;
225         I++) {
226      const RSExportVar *EV = *I;
227      const RSExportType *ET = EV->getType();
228      bool countsAsRSObject = false;
229
230      // Variable name
231      ExportVarInfo.push_back(
232          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
233
234      // Type name
235      switch (ET->getClass()) {
236        case RSExportType::ExportClassPrimitive: {
237          const RSExportPrimitiveType *PT =
238              static_cast<const RSExportPrimitiveType*>(ET);
239          ExportVarInfo.push_back(
240              llvm::MDString::get(
241                mLLVMContext, llvm::utostr_32(PT->getType())));
242          if (PT->isRSObjectType()) {
243            countsAsRSObject = true;
244          }
245          break;
246        }
247        case RSExportType::ExportClassPointer: {
248          ExportVarInfo.push_back(
249              llvm::MDString::get(
250                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
251                  ->getPointeeType()->getName()).c_str()));
252          break;
253        }
254        case RSExportType::ExportClassMatrix: {
255          ExportVarInfo.push_back(
256              llvm::MDString::get(
257                mLLVMContext, llvm::utostr_32(
258                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
259                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
260          break;
261        }
262        case RSExportType::ExportClassVector:
263        case RSExportType::ExportClassConstantArray:
264        case RSExportType::ExportClassRecord: {
265          ExportVarInfo.push_back(
266              llvm::MDString::get(mLLVMContext,
267                EV->getType()->getName().c_str()));
268          break;
269        }
270      }
271
272      mExportVarMetadata->addOperand(
273          llvm::MDNode::get(mLLVMContext, ExportVarInfo));
274      ExportVarInfo.clear();
275
276      if (mRSObjectSlotsMetadata == NULL) {
277        mRSObjectSlotsMetadata =
278            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
279      }
280
281      if (countsAsRSObject) {
282        mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
283            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
284      }
285
286      slotCount++;
287    }
288  }
289
290  // Dump export function info
291  if (mContext->hasExportFunc()) {
292    if (mExportFuncMetadata == NULL)
293      mExportFuncMetadata =
294          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
295
296    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
297
298    for (RSContext::const_export_func_iterator
299            I = mContext->export_funcs_begin(),
300            E = mContext->export_funcs_end();
301         I != E;
302         I++) {
303      const RSExportFunc *EF = *I;
304
305      // Function name
306      if (!EF->hasParam()) {
307        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
308                                                     EF->getName().c_str()));
309      } else {
310        llvm::Function *F = M->getFunction(EF->getName());
311        llvm::Function *HelperFunction;
312        const std::string HelperFunctionName(".helper_" + EF->getName());
313
314        slangAssert(F && "Function marked as exported disappeared in Bitcode");
315
316        // Create helper function
317        {
318          llvm::StructType *HelperFunctionParameterTy = NULL;
319
320          if (!F->getArgumentList().empty()) {
321            std::vector<llvm::Type*> HelperFunctionParameterTys;
322            for (llvm::Function::arg_iterator AI = F->arg_begin(),
323                 AE = F->arg_end(); AI != AE; AI++)
324              HelperFunctionParameterTys.push_back(AI->getType());
325
326            HelperFunctionParameterTy =
327                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
328          }
329
330          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
331            fprintf(stderr, "Failed to export function %s: parameter type "
332                            "mismatch during creation of helper function.\n",
333                    EF->getName().c_str());
334
335            const RSExportRecordType *Expected = EF->getParamPacketType();
336            if (Expected) {
337              fprintf(stderr, "Expected:\n");
338              Expected->getLLVMType()->dump();
339            }
340            if (HelperFunctionParameterTy) {
341              fprintf(stderr, "Got:\n");
342              HelperFunctionParameterTy->dump();
343            }
344          }
345
346          std::vector<llvm::Type*> Params;
347          if (HelperFunctionParameterTy) {
348            llvm::PointerType *HelperFunctionParameterTyP =
349                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
350            Params.push_back(HelperFunctionParameterTyP);
351          }
352
353          llvm::FunctionType * HelperFunctionType =
354              llvm::FunctionType::get(F->getReturnType(),
355                                      Params,
356                                      /* IsVarArgs = */false);
357
358          HelperFunction =
359              llvm::Function::Create(HelperFunctionType,
360                                     llvm::GlobalValue::ExternalLinkage,
361                                     HelperFunctionName,
362                                     M);
363
364          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
365          HelperFunction->setCallingConv(F->getCallingConv());
366
367          // Create helper function body
368          {
369            llvm::Argument *HelperFunctionParameter =
370                &(*HelperFunction->arg_begin());
371            llvm::BasicBlock *BB =
372                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
373            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
374            llvm::SmallVector<llvm::Value*, 6> Params;
375            llvm::Value *Idx[2];
376
377            Idx[0] =
378                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
379
380            // getelementptr and load instruction for all elements in
381            // parameter .p
382            for (size_t i = 0; i < EF->getNumParameters(); i++) {
383              // getelementptr
384              Idx[1] = llvm::ConstantInt::get(
385                llvm::Type::getInt32Ty(mLLVMContext), i);
386
387              llvm::Value *Ptr =
388                IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
389
390              // load
391              llvm::Value *V = IB->CreateLoad(Ptr);
392              Params.push_back(V);
393            }
394
395            // Call and pass the all elements as parameter to F
396            llvm::CallInst *CI = IB->CreateCall(F, Params);
397
398            CI->setCallingConv(F->getCallingConv());
399
400            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
401              IB->CreateRetVoid();
402            else
403              IB->CreateRet(CI);
404
405            delete IB;
406          }
407        }
408
409        ExportFuncInfo.push_back(
410            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
411      }
412
413      mExportFuncMetadata->addOperand(
414          llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
415      ExportFuncInfo.clear();
416    }
417  }
418
419  // Dump export function info
420  if (mContext->hasExportForEach()) {
421    if (mExportForEachMetadata == NULL)
422      mExportForEachMetadata =
423          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
424
425    llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
426
427    for (RSContext::const_export_foreach_iterator
428            I = mContext->export_foreach_begin(),
429            E = mContext->export_foreach_end();
430         I != E;
431         I++) {
432      const RSExportForEach *EFE = *I;
433
434      ExportForEachInfo.push_back(
435          llvm::MDString::get(mLLVMContext,
436                              llvm::utostr_32(EFE->getMetadataEncoding())));
437
438      mExportForEachMetadata->addOperand(
439          llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
440      ExportForEachInfo.clear();
441    }
442  }
443
444  // Dump export type info
445  if (mContext->hasExportType()) {
446    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
447
448    for (RSContext::const_export_type_iterator
449            I = mContext->export_types_begin(),
450            E = mContext->export_types_end();
451         I != E;
452         I++) {
453      // First, dump type name list to export
454      const RSExportType *ET = I->getValue();
455
456      ExportTypeInfo.clear();
457      // Type name
458      ExportTypeInfo.push_back(
459          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
460
461      if (ET->getClass() == RSExportType::ExportClassRecord) {
462        const RSExportRecordType *ERT =
463            static_cast<const RSExportRecordType*>(ET);
464
465        if (mExportTypeMetadata == NULL)
466          mExportTypeMetadata =
467              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
468
469        mExportTypeMetadata->addOperand(
470            llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
471
472        // Now, export struct field information to %[struct name]
473        std::string StructInfoMetadataName("%");
474        StructInfoMetadataName.append(ET->getName());
475        llvm::NamedMDNode *StructInfoMetadata =
476            M->getOrInsertNamedMetadata(StructInfoMetadataName);
477        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
478
479        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
480                    "Metadata with same name was created before");
481        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
482                FE = ERT->fields_end();
483             FI != FE;
484             FI++) {
485          const RSExportRecordType::Field *F = *FI;
486
487          // 1. field name
488          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
489                                                  F->getName().c_str()));
490
491          // 2. field type name
492          FieldInfo.push_back(
493              llvm::MDString::get(mLLVMContext,
494                                  F->getType()->getName().c_str()));
495
496          // 3. field kind
497          switch (F->getType()->getClass()) {
498            case RSExportType::ExportClassPrimitive:
499            case RSExportType::ExportClassVector: {
500              const RSExportPrimitiveType *EPT =
501                  static_cast<const RSExportPrimitiveType*>(F->getType());
502              FieldInfo.push_back(
503                  llvm::MDString::get(mLLVMContext,
504                                      llvm::itostr(EPT->getKind())));
505              break;
506            }
507
508            default: {
509              FieldInfo.push_back(
510                  llvm::MDString::get(mLLVMContext,
511                                      llvm::itostr(
512                                        RSExportPrimitiveType::DataKindUser)));
513              break;
514            }
515          }
516
517          StructInfoMetadata->addOperand(
518              llvm::MDNode::get(mLLVMContext, FieldInfo));
519          FieldInfo.clear();
520        }
521      }   // ET->getClass() == RSExportType::ExportClassRecord
522    }
523  }
524
525  return;
526}
527
528RSBackend::~RSBackend() {
529  return;
530}
531
532}  // namespace slang
533