slang_rs_backend.cpp revision 6e6578a360497f78a181e63d7783422a9c9bfb15
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_assert.h"
35#include "slang_rs.h"
36#include "slang_rs_context.h"
37#include "slang_rs_export_func.h"
38#include "slang_rs_export_type.h"
39#include "slang_rs_export_var.h"
40#include "slang_rs_metadata.h"
41
42namespace slang {
43
44RSBackend::RSBackend(RSContext *Context,
45                     clang::Diagnostic *Diags,
46                     const clang::CodeGenOptions &CodeGenOpts,
47                     const clang::TargetOptions &TargetOpts,
48                     PragmaList *Pragmas,
49                     llvm::raw_ostream *OS,
50                     Slang::OutputType OT,
51                     clang::SourceManager &SourceMgr,
52                     bool AllowRSPrefix)
53    : Backend(Diags,
54              CodeGenOpts,
55              TargetOpts,
56              Pragmas,
57              OS,
58              OT),
59      mContext(Context),
60      mSourceMgr(SourceMgr),
61      mAllowRSPrefix(AllowRSPrefix),
62      mExportVarMetadata(NULL),
63      mExportFuncMetadata(NULL),
64      mExportTypeMetadata(NULL),
65      mRSObjectSlotsMetadata(NULL) {
66  return;
67}
68
69// 1) Add zero initialization of local RS object types
70void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
71  if (FD &&
72      FD->hasBody() &&
73      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
74    mRefCount.Init(mContext->getASTContext());
75    mRefCount.Visit(FD->getBody());
76  }
77  return;
78}
79
80void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
81  // Disallow user-defined functions with prefix "rs"
82  if (!mAllowRSPrefix) {
83    // Iterate all function declarations in the program.
84    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
85         I != E; I++) {
86      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
87      if (FD == NULL)
88        continue;
89      if (!FD->getName().startswith("rs"))  // Check prefix
90        continue;
91      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
92        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
93                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
94                                             "invalid function name prefix, "
95                                             "\"rs\" is reserved: '%0'"))
96            << FD->getName();
97    }
98  }
99
100  // Process any non-static function declarations
101  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
102    clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
103    if (FD && FD->isGlobal()) {
104      AnnotateFunction(FD);
105    }
106  }
107
108  Backend::HandleTopLevelDecl(D);
109  return;
110}
111
112namespace {
113
114bool ValidateVar(clang::VarDecl *VD, clang::Diagnostic *Diags,
115    clang::SourceManager *SM) {
116  llvm::StringRef TypeName;
117  const clang::Type *T = VD->getType().getTypePtr();
118  if (!RSExportType::NormalizeType(T, TypeName, Diags, SM, VD)) {
119    return false;
120  }
121  return true;
122}
123
124bool ValidateASTContext(clang::ASTContext &C, clang::Diagnostic &Diags) {
125  bool valid = true;
126  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
127  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
128          DE = TUDecl->decls_end();
129       DI != DE;
130       DI++) {
131    if (DI->getKind() == clang::Decl::Var) {
132      clang::VarDecl *VD = (clang::VarDecl*) (*DI);
133      if (VD->getLinkage() == clang::ExternalLinkage) {
134        if (!ValidateVar(VD, &Diags, &C.getSourceManager())) {
135          valid = false;
136        }
137      }
138    }
139  }
140
141  return valid;
142}
143
144}  // namespace
145
146void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
147  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
148
149  if (!ValidateASTContext(C, mDiags)) {
150    return;
151  }
152
153  int version = mContext->getVersion();
154  if (version == 0) {
155    // Not setting a version is an error
156    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
157                      "Missing pragma for version in source file"));
158  } else if (version > 1) {
159    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
160                      "Pragma for version in source file must be set to 1"));
161  }
162
163  // Process any static function declarations
164  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
165          E = TUDecl->decls_end(); I != E; I++) {
166    if ((I->getKind() >= clang::Decl::firstFunction) &&
167        (I->getKind() <= clang::Decl::lastFunction)) {
168      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
169      if (FD && !FD->isGlobal()) {
170        AnnotateFunction(FD);
171      }
172    }
173  }
174
175  return;
176}
177
178///////////////////////////////////////////////////////////////////////////////
179void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
180  if (!mContext->processExport()) {
181    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
182                                         "elements cannot be exported"));
183    return;
184  }
185
186  // Dump export variable info
187  if (mContext->hasExportVar()) {
188    int slotCount = 0;
189    if (mExportVarMetadata == NULL)
190      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
191
192    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
193    llvm::SmallVector<llvm::Value*, 1> SlotVarInfo;
194
195    // We emit slot information (#rs_object_slots) for any reference counted
196    // RS type or pointer (which can also be bound).
197
198    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
199            E = mContext->export_vars_end();
200         I != E;
201         I++) {
202      const RSExportVar *EV = *I;
203      const RSExportType *ET = EV->getType();
204      bool countsAsRSObject = false;
205
206      // Variable name
207      ExportVarInfo.push_back(
208          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
209
210      // Type name
211      switch (ET->getClass()) {
212        case RSExportType::ExportClassPrimitive: {
213          const RSExportPrimitiveType *PT =
214              static_cast<const RSExportPrimitiveType*>(ET);
215          ExportVarInfo.push_back(
216              llvm::MDString::get(
217                mLLVMContext, llvm::utostr_32(PT->getType())));
218          if (PT->isRSObjectType()) {
219            countsAsRSObject = true;
220          }
221          break;
222        }
223        case RSExportType::ExportClassPointer: {
224          ExportVarInfo.push_back(
225              llvm::MDString::get(
226                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
227                  ->getPointeeType()->getName()).c_str()));
228          break;
229        }
230        case RSExportType::ExportClassMatrix: {
231          ExportVarInfo.push_back(
232              llvm::MDString::get(
233                mLLVMContext, llvm::utostr_32(
234                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
235                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
236          break;
237        }
238        case RSExportType::ExportClassVector:
239        case RSExportType::ExportClassConstantArray:
240        case RSExportType::ExportClassRecord: {
241          ExportVarInfo.push_back(
242              llvm::MDString::get(mLLVMContext,
243                EV->getType()->getName().c_str()));
244          break;
245        }
246      }
247
248      mExportVarMetadata->addOperand(
249          llvm::MDNode::get(mLLVMContext,
250                            ExportVarInfo.data(),
251                            ExportVarInfo.size()) );
252
253      ExportVarInfo.clear();
254
255      if (mRSObjectSlotsMetadata == NULL) {
256        mRSObjectSlotsMetadata =
257            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
258      }
259
260      if (countsAsRSObject) {
261        SlotVarInfo.push_back(
262            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount)));
263
264        mRSObjectSlotsMetadata->addOperand(
265            llvm::MDNode::get(mLLVMContext,
266                              SlotVarInfo.data(),
267                              SlotVarInfo.size()));
268      }
269
270      slotCount++;
271      SlotVarInfo.clear();
272    }
273  }
274
275  // Dump export function info
276  if (mContext->hasExportFunc()) {
277    if (mExportFuncMetadata == NULL)
278      mExportFuncMetadata =
279          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
280
281    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
282
283    for (RSContext::const_export_func_iterator
284            I = mContext->export_funcs_begin(),
285            E = mContext->export_funcs_end();
286         I != E;
287         I++) {
288      const RSExportFunc *EF = *I;
289
290      // Function name
291      if (!EF->hasParam()) {
292        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
293                                                     EF->getName().c_str()));
294      } else {
295        llvm::Function *F = M->getFunction(EF->getName());
296        llvm::Function *HelperFunction;
297        const std::string HelperFunctionName(".helper_" + EF->getName());
298
299        slangAssert(F && "Function marked as exported disappeared in Bitcode");
300
301        // Create helper function
302        {
303          llvm::StructType *HelperFunctionParameterTy = NULL;
304
305          if (!F->getArgumentList().empty()) {
306            std::vector<const llvm::Type*> HelperFunctionParameterTys;
307            for (llvm::Function::arg_iterator AI = F->arg_begin(),
308                 AE = F->arg_end(); AI != AE; AI++)
309              HelperFunctionParameterTys.push_back(AI->getType());
310
311            HelperFunctionParameterTy =
312                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
313          }
314
315          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
316            fprintf(stderr, "Failed to export function %s: parameter type "
317                            "mismatch during creation of helper function.\n",
318                    EF->getName().c_str());
319
320            const RSExportRecordType *Expected = EF->getParamPacketType();
321            if (Expected) {
322              fprintf(stderr, "Expected:\n");
323              Expected->getLLVMType()->dump();
324            }
325            if (HelperFunctionParameterTy) {
326              fprintf(stderr, "Got:\n");
327              HelperFunctionParameterTy->dump();
328            }
329          }
330
331          std::vector<const llvm::Type*> Params;
332          if (HelperFunctionParameterTy) {
333            llvm::PointerType *HelperFunctionParameterTyP =
334                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
335            Params.push_back(HelperFunctionParameterTyP);
336          }
337
338          llvm::FunctionType * HelperFunctionType =
339              llvm::FunctionType::get(F->getReturnType(),
340                                      Params,
341                                      /* IsVarArgs = */false);
342
343          HelperFunction =
344              llvm::Function::Create(HelperFunctionType,
345                                     llvm::GlobalValue::ExternalLinkage,
346                                     HelperFunctionName,
347                                     M);
348
349          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
350          HelperFunction->setCallingConv(F->getCallingConv());
351
352          // Create helper function body
353          {
354            llvm::Argument *HelperFunctionParameter =
355                &(*HelperFunction->arg_begin());
356            llvm::BasicBlock *BB =
357                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
358            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
359            llvm::SmallVector<llvm::Value*, 6> Params;
360            llvm::Value *Idx[2];
361
362            Idx[0] =
363                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
364
365            // getelementptr and load instruction for all elements in
366            // parameter .p
367            for (size_t i = 0; i < EF->getNumParameters(); i++) {
368              // getelementptr
369              Idx[1] =
370                  llvm::ConstantInt::get(
371                      llvm::Type::getInt32Ty(mLLVMContext), i);
372              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
373                                                       Idx,
374                                                       Idx + 2);
375
376              // load
377              llvm::Value *V = IB->CreateLoad(Ptr);
378              Params.push_back(V);
379            }
380
381            // Call and pass the all elements as paramter to F
382            llvm::CallInst *CI = IB->CreateCall(F,
383                                                Params.data(),
384                                                Params.data() + Params.size());
385
386            CI->setCallingConv(F->getCallingConv());
387
388            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
389              IB->CreateRetVoid();
390            else
391              IB->CreateRet(CI);
392
393            delete IB;
394          }
395        }
396
397        ExportFuncInfo.push_back(
398            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
399      }
400
401      mExportFuncMetadata->addOperand(
402          llvm::MDNode::get(mLLVMContext,
403                            ExportFuncInfo.data(),
404                            ExportFuncInfo.size()));
405
406      ExportFuncInfo.clear();
407    }
408  }
409
410  // Dump export type info
411  if (mContext->hasExportType()) {
412    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
413
414    for (RSContext::const_export_type_iterator
415            I = mContext->export_types_begin(),
416            E = mContext->export_types_end();
417         I != E;
418         I++) {
419      // First, dump type name list to export
420      const RSExportType *ET = I->getValue();
421
422      ExportTypeInfo.clear();
423      // Type name
424      ExportTypeInfo.push_back(
425          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
426
427      if (ET->getClass() == RSExportType::ExportClassRecord) {
428        const RSExportRecordType *ERT =
429            static_cast<const RSExportRecordType*>(ET);
430
431        if (mExportTypeMetadata == NULL)
432          mExportTypeMetadata =
433              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
434
435        mExportTypeMetadata->addOperand(
436            llvm::MDNode::get(mLLVMContext,
437                              ExportTypeInfo.data(),
438                              ExportTypeInfo.size()));
439
440        // Now, export struct field information to %[struct name]
441        std::string StructInfoMetadataName("%");
442        StructInfoMetadataName.append(ET->getName());
443        llvm::NamedMDNode *StructInfoMetadata =
444            M->getOrInsertNamedMetadata(StructInfoMetadataName);
445        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
446
447        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
448                    "Metadata with same name was created before");
449        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
450                FE = ERT->fields_end();
451             FI != FE;
452             FI++) {
453          const RSExportRecordType::Field *F = *FI;
454
455          // 1. field name
456          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
457                                                  F->getName().c_str()));
458
459          // 2. field type name
460          FieldInfo.push_back(
461              llvm::MDString::get(mLLVMContext,
462                                  F->getType()->getName().c_str()));
463
464          // 3. field kind
465          switch (F->getType()->getClass()) {
466            case RSExportType::ExportClassPrimitive:
467            case RSExportType::ExportClassVector: {
468              const RSExportPrimitiveType *EPT =
469                  static_cast<const RSExportPrimitiveType*>(F->getType());
470              FieldInfo.push_back(
471                  llvm::MDString::get(mLLVMContext,
472                                      llvm::itostr(EPT->getKind())));
473              break;
474            }
475
476            default: {
477              FieldInfo.push_back(
478                  llvm::MDString::get(mLLVMContext,
479                                      llvm::itostr(
480                                        RSExportPrimitiveType::DataKindUser)));
481              break;
482            }
483          }
484
485          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
486                                                           FieldInfo.data(),
487                                                           FieldInfo.size()));
488
489          FieldInfo.clear();
490        }
491      }   // ET->getClass() == RSExportType::ExportClassRecord
492    }
493  }
494
495  return;
496}
497
498RSBackend::~RSBackend() {
499  return;
500}
501
502}  // namespace slang
503