slang_rs_backend.cpp revision 5baf6324a97430016026419deaef246ad75430fc
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_assert.h"
35#include "slang_rs.h"
36#include "slang_rs_context.h"
37#include "slang_rs_export_func.h"
38#include "slang_rs_export_type.h"
39#include "slang_rs_export_var.h"
40#include "slang_rs_metadata.h"
41
42namespace slang {
43
44RSBackend::RSBackend(RSContext *Context,
45                     clang::Diagnostic *Diags,
46                     const clang::CodeGenOptions &CodeGenOpts,
47                     const clang::TargetOptions &TargetOpts,
48                     PragmaList *Pragmas,
49                     llvm::raw_ostream *OS,
50                     Slang::OutputType OT,
51                     clang::SourceManager &SourceMgr,
52                     bool AllowRSPrefix)
53    : Backend(Diags,
54              CodeGenOpts,
55              TargetOpts,
56              Pragmas,
57              OS,
58              OT),
59      mContext(Context),
60      mSourceMgr(SourceMgr),
61      mAllowRSPrefix(AllowRSPrefix),
62      mExportVarMetadata(NULL),
63      mExportFuncMetadata(NULL),
64      mExportTypeMetadata(NULL),
65      mRSObjectSlotsMetadata(NULL),
66      mRefCount(mContext->getASTContext()) {
67  return;
68}
69
70// 1) Add zero initialization of local RS object types
71void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
72  if (FD &&
73      FD->hasBody() &&
74      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
75    mRefCount.Init();
76    mRefCount.Visit(FD->getBody());
77  }
78  return;
79}
80
81void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
82  // Disallow user-defined functions with prefix "rs"
83  if (!mAllowRSPrefix) {
84    // Iterate all function declarations in the program.
85    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
86         I != E; I++) {
87      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
88      if (FD == NULL)
89        continue;
90      if (!FD->getName().startswith("rs"))  // Check prefix
91        continue;
92      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
93        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
94                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
95                                             "invalid function name prefix, "
96                                             "\"rs\" is reserved: '%0'"))
97            << FD->getName();
98    }
99  }
100
101  // Process any non-static function declarations
102  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
103    clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
104    if (FD && FD->isGlobal()) {
105      AnnotateFunction(FD);
106    }
107  }
108
109  Backend::HandleTopLevelDecl(D);
110  return;
111}
112
113namespace {
114
115static bool ValidateVarDecl(clang::VarDecl *VD) {
116  if (!VD) {
117    return true;
118  }
119
120  clang::ASTContext &C = VD->getASTContext();
121  const clang::Type *T = VD->getType().getTypePtr();
122  bool valid = true;
123
124  if (VD->getLinkage() == clang::ExternalLinkage) {
125    llvm::StringRef TypeName;
126    if (!RSExportType::NormalizeType(T, TypeName, &C.getDiagnostics(), VD)) {
127      valid = false;
128    }
129  }
130  valid &= RSExportType::ValidateVarDecl(VD);
131
132  return valid;
133}
134
135static bool ValidateASTContext(clang::ASTContext &C) {
136  bool valid = true;
137  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
138  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
139          DE = TUDecl->decls_end();
140       DI != DE;
141       DI++) {
142    clang::VarDecl *VD = dyn_cast<clang::VarDecl>(*DI);
143    if (VD && !ValidateVarDecl(VD)) {
144      valid = false;
145    }
146  }
147
148  return valid;
149}
150
151}  // namespace
152
153void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
154  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
155
156  if (!ValidateASTContext(C)) {
157    return;
158  }
159
160  int version = mContext->getVersion();
161  if (version == 0) {
162    // Not setting a version is an error
163    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
164                      "Missing pragma for version in source file"));
165  } else if (version > 1) {
166    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
167                      "Pragma for version in source file must be set to 1"));
168  }
169
170  // Process any static function declarations
171  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
172          E = TUDecl->decls_end(); I != E; I++) {
173    if ((I->getKind() >= clang::Decl::firstFunction) &&
174        (I->getKind() <= clang::Decl::lastFunction)) {
175      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
176      if (FD && !FD->isGlobal()) {
177        AnnotateFunction(FD);
178      }
179    }
180  }
181
182  return;
183}
184
185///////////////////////////////////////////////////////////////////////////////
186void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
187  if (!mContext->processExport()) {
188    return;
189  }
190
191  // Dump export variable info
192  if (mContext->hasExportVar()) {
193    int slotCount = 0;
194    if (mExportVarMetadata == NULL)
195      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
196
197    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
198    llvm::SmallVector<llvm::Value*, 1> SlotVarInfo;
199
200    // We emit slot information (#rs_object_slots) for any reference counted
201    // RS type or pointer (which can also be bound).
202
203    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
204            E = mContext->export_vars_end();
205         I != E;
206         I++) {
207      const RSExportVar *EV = *I;
208      const RSExportType *ET = EV->getType();
209      bool countsAsRSObject = false;
210
211      // Variable name
212      ExportVarInfo.push_back(
213          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
214
215      // Type name
216      switch (ET->getClass()) {
217        case RSExportType::ExportClassPrimitive: {
218          const RSExportPrimitiveType *PT =
219              static_cast<const RSExportPrimitiveType*>(ET);
220          ExportVarInfo.push_back(
221              llvm::MDString::get(
222                mLLVMContext, llvm::utostr_32(PT->getType())));
223          if (PT->isRSObjectType()) {
224            countsAsRSObject = true;
225          }
226          break;
227        }
228        case RSExportType::ExportClassPointer: {
229          ExportVarInfo.push_back(
230              llvm::MDString::get(
231                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
232                  ->getPointeeType()->getName()).c_str()));
233          break;
234        }
235        case RSExportType::ExportClassMatrix: {
236          ExportVarInfo.push_back(
237              llvm::MDString::get(
238                mLLVMContext, llvm::utostr_32(
239                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
240                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
241          break;
242        }
243        case RSExportType::ExportClassVector:
244        case RSExportType::ExportClassConstantArray:
245        case RSExportType::ExportClassRecord: {
246          ExportVarInfo.push_back(
247              llvm::MDString::get(mLLVMContext,
248                EV->getType()->getName().c_str()));
249          break;
250        }
251      }
252
253      mExportVarMetadata->addOperand(
254          llvm::MDNode::get(mLLVMContext,
255                            ExportVarInfo.data(),
256                            ExportVarInfo.size()) );
257
258      ExportVarInfo.clear();
259
260      if (mRSObjectSlotsMetadata == NULL) {
261        mRSObjectSlotsMetadata =
262            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
263      }
264
265      if (countsAsRSObject) {
266        SlotVarInfo.push_back(
267            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount)));
268
269        mRSObjectSlotsMetadata->addOperand(
270            llvm::MDNode::get(mLLVMContext,
271                              SlotVarInfo.data(),
272                              SlotVarInfo.size()));
273      }
274
275      slotCount++;
276      SlotVarInfo.clear();
277    }
278  }
279
280  // Dump export function info
281  if (mContext->hasExportFunc()) {
282    if (mExportFuncMetadata == NULL)
283      mExportFuncMetadata =
284          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
285
286    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
287
288    for (RSContext::const_export_func_iterator
289            I = mContext->export_funcs_begin(),
290            E = mContext->export_funcs_end();
291         I != E;
292         I++) {
293      const RSExportFunc *EF = *I;
294
295      // Function name
296      if (!EF->hasParam()) {
297        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
298                                                     EF->getName().c_str()));
299      } else {
300        llvm::Function *F = M->getFunction(EF->getName());
301        llvm::Function *HelperFunction;
302        const std::string HelperFunctionName(".helper_" + EF->getName());
303
304        slangAssert(F && "Function marked as exported disappeared in Bitcode");
305
306        // Create helper function
307        {
308          llvm::StructType *HelperFunctionParameterTy = NULL;
309
310          if (!F->getArgumentList().empty()) {
311            std::vector<const llvm::Type*> HelperFunctionParameterTys;
312            for (llvm::Function::arg_iterator AI = F->arg_begin(),
313                 AE = F->arg_end(); AI != AE; AI++)
314              HelperFunctionParameterTys.push_back(AI->getType());
315
316            HelperFunctionParameterTy =
317                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
318          }
319
320          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
321            fprintf(stderr, "Failed to export function %s: parameter type "
322                            "mismatch during creation of helper function.\n",
323                    EF->getName().c_str());
324
325            const RSExportRecordType *Expected = EF->getParamPacketType();
326            if (Expected) {
327              fprintf(stderr, "Expected:\n");
328              Expected->getLLVMType()->dump();
329            }
330            if (HelperFunctionParameterTy) {
331              fprintf(stderr, "Got:\n");
332              HelperFunctionParameterTy->dump();
333            }
334          }
335
336          std::vector<const llvm::Type*> Params;
337          if (HelperFunctionParameterTy) {
338            llvm::PointerType *HelperFunctionParameterTyP =
339                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
340            Params.push_back(HelperFunctionParameterTyP);
341          }
342
343          llvm::FunctionType * HelperFunctionType =
344              llvm::FunctionType::get(F->getReturnType(),
345                                      Params,
346                                      /* IsVarArgs = */false);
347
348          HelperFunction =
349              llvm::Function::Create(HelperFunctionType,
350                                     llvm::GlobalValue::ExternalLinkage,
351                                     HelperFunctionName,
352                                     M);
353
354          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
355          HelperFunction->setCallingConv(F->getCallingConv());
356
357          // Create helper function body
358          {
359            llvm::Argument *HelperFunctionParameter =
360                &(*HelperFunction->arg_begin());
361            llvm::BasicBlock *BB =
362                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
363            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
364            llvm::SmallVector<llvm::Value*, 6> Params;
365            llvm::Value *Idx[2];
366
367            Idx[0] =
368                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
369
370            // getelementptr and load instruction for all elements in
371            // parameter .p
372            for (size_t i = 0; i < EF->getNumParameters(); i++) {
373              // getelementptr
374              Idx[1] =
375                  llvm::ConstantInt::get(
376                      llvm::Type::getInt32Ty(mLLVMContext), i);
377              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
378                                                       Idx,
379                                                       Idx + 2);
380
381              // load
382              llvm::Value *V = IB->CreateLoad(Ptr);
383              Params.push_back(V);
384            }
385
386            // Call and pass the all elements as paramter to F
387            llvm::CallInst *CI = IB->CreateCall(F,
388                                                Params.data(),
389                                                Params.data() + Params.size());
390
391            CI->setCallingConv(F->getCallingConv());
392
393            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
394              IB->CreateRetVoid();
395            else
396              IB->CreateRet(CI);
397
398            delete IB;
399          }
400        }
401
402        ExportFuncInfo.push_back(
403            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
404      }
405
406      mExportFuncMetadata->addOperand(
407          llvm::MDNode::get(mLLVMContext,
408                            ExportFuncInfo.data(),
409                            ExportFuncInfo.size()));
410
411      ExportFuncInfo.clear();
412    }
413  }
414
415  // Dump export type info
416  if (mContext->hasExportType()) {
417    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
418
419    for (RSContext::const_export_type_iterator
420            I = mContext->export_types_begin(),
421            E = mContext->export_types_end();
422         I != E;
423         I++) {
424      // First, dump type name list to export
425      const RSExportType *ET = I->getValue();
426
427      ExportTypeInfo.clear();
428      // Type name
429      ExportTypeInfo.push_back(
430          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
431
432      if (ET->getClass() == RSExportType::ExportClassRecord) {
433        const RSExportRecordType *ERT =
434            static_cast<const RSExportRecordType*>(ET);
435
436        if (mExportTypeMetadata == NULL)
437          mExportTypeMetadata =
438              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
439
440        mExportTypeMetadata->addOperand(
441            llvm::MDNode::get(mLLVMContext,
442                              ExportTypeInfo.data(),
443                              ExportTypeInfo.size()));
444
445        // Now, export struct field information to %[struct name]
446        std::string StructInfoMetadataName("%");
447        StructInfoMetadataName.append(ET->getName());
448        llvm::NamedMDNode *StructInfoMetadata =
449            M->getOrInsertNamedMetadata(StructInfoMetadataName);
450        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
451
452        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
453                    "Metadata with same name was created before");
454        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
455                FE = ERT->fields_end();
456             FI != FE;
457             FI++) {
458          const RSExportRecordType::Field *F = *FI;
459
460          // 1. field name
461          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
462                                                  F->getName().c_str()));
463
464          // 2. field type name
465          FieldInfo.push_back(
466              llvm::MDString::get(mLLVMContext,
467                                  F->getType()->getName().c_str()));
468
469          // 3. field kind
470          switch (F->getType()->getClass()) {
471            case RSExportType::ExportClassPrimitive:
472            case RSExportType::ExportClassVector: {
473              const RSExportPrimitiveType *EPT =
474                  static_cast<const RSExportPrimitiveType*>(F->getType());
475              FieldInfo.push_back(
476                  llvm::MDString::get(mLLVMContext,
477                                      llvm::itostr(EPT->getKind())));
478              break;
479            }
480
481            default: {
482              FieldInfo.push_back(
483                  llvm::MDString::get(mLLVMContext,
484                                      llvm::itostr(
485                                        RSExportPrimitiveType::DataKindUser)));
486              break;
487            }
488          }
489
490          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
491                                                           FieldInfo.data(),
492                                                           FieldInfo.size()));
493
494          FieldInfo.clear();
495        }
496      }   // ET->getClass() == RSExportType::ExportClassRecord
497    }
498  }
499
500  return;
501}
502
503RSBackend::~RSBackend() {
504  return;
505}
506
507}  // namespace slang
508