slang_rs_backend.cpp revision 3fd0a94a5cf1656569b1aea07043cc63939dcb46
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_rs.h"
35#include "slang_rs_context.h"
36#include "slang_rs_export_func.h"
37#include "slang_rs_export_type.h"
38#include "slang_rs_export_var.h"
39#include "slang_rs_metadata.h"
40
41namespace slang {
42
43RSBackend::RSBackend(RSContext *Context,
44                     clang::Diagnostic *Diags,
45                     const clang::CodeGenOptions &CodeGenOpts,
46                     const clang::TargetOptions &TargetOpts,
47                     PragmaList *Pragmas,
48                     llvm::raw_ostream *OS,
49                     Slang::OutputType OT,
50                     clang::SourceManager &SourceMgr,
51                     bool AllowRSPrefix)
52    : Backend(Diags,
53              CodeGenOpts,
54              TargetOpts,
55              Pragmas,
56              OS,
57              OT),
58      mContext(Context),
59      mSourceMgr(SourceMgr),
60      mAllowRSPrefix(AllowRSPrefix),
61      mExportVarMetadata(NULL),
62      mExportFuncMetadata(NULL),
63      mExportTypeMetadata(NULL) {
64  return;
65}
66
67// 1) Add zero initialization of local RS object types
68void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
69  if (FD &&
70      FD->hasBody() &&
71      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
72    mRefCount.Init(mContext->getASTContext());
73    mRefCount.Visit(FD->getBody());
74  }
75  return;
76}
77
78void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
79  // Disallow user-defined functions with prefix "rs"
80  if (!mAllowRSPrefix) {
81    // Iterate all function declarations in the program.
82    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
83         I != E; I++) {
84      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
85      if (FD == NULL)
86        continue;
87      if (!FD->getName().startswith("rs"))  // Check prefix
88        continue;
89      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
90        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
91                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
92                                             "invalid function name prefix, "
93                                             "\"rs\" is reserved: '%0'"))
94            << FD->getName();
95    }
96  }
97
98  // Process any non-static function declarations
99  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
100    clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
101    if (FD && FD->isGlobal()) {
102      AnnotateFunction(FD);
103    }
104  }
105
106  Backend::HandleTopLevelDecl(D);
107  return;
108}
109
110namespace {
111
112bool ValidateVar(clang::VarDecl *VD, clang::Diagnostic *Diags,
113    clang::SourceManager *SM) {
114  llvm::StringRef TypeName;
115  const clang::Type *T = VD->getType().getTypePtr();
116  if (!RSExportType::NormalizeType(T, TypeName, Diags, SM, VD)) {
117    return false;
118  }
119  return true;
120}
121
122bool ValidateASTContext(clang::ASTContext &C, clang::Diagnostic &Diags) {
123  bool valid = true;
124  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
125  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
126          DE = TUDecl->decls_end();
127       DI != DE;
128       DI++) {
129    if (DI->getKind() == clang::Decl::Var) {
130      clang::VarDecl *VD = (clang::VarDecl*) (*DI);
131      if (VD->getLinkage() == clang::ExternalLinkage) {
132        if (!ValidateVar(VD, &Diags, &C.getSourceManager())) {
133          valid = false;
134        }
135      }
136    }
137  }
138
139  return valid;
140}
141
142}  // namespace
143
144void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
145  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
146
147  if (!ValidateASTContext(C, mDiags)) {
148    return;
149  }
150
151  int version = mContext->getVersion();
152  if (version == 0) {
153    // Not setting a version is an error
154    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
155                      "Missing pragma for version in source file"));
156  } else if (version > 1) {
157    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
158                      "Pragma for version in source file must be set to 1"));
159  }
160
161  // Process any static function declarations
162  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
163          E = TUDecl->decls_end(); I != E; I++) {
164    if ((I->getKind() >= clang::Decl::firstFunction) &&
165        (I->getKind() <= clang::Decl::lastFunction)) {
166      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
167      if (FD && !FD->isGlobal()) {
168        AnnotateFunction(FD);
169      }
170    }
171  }
172
173  return;
174}
175
176///////////////////////////////////////////////////////////////////////////////
177void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
178  if (!mContext->processExport()) {
179    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
180                                         "elements cannot be exported"));
181    return;
182  }
183
184  // Dump export variable info
185  if (mContext->hasExportVar()) {
186    if (mExportVarMetadata == NULL)
187      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
188
189    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
190
191    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
192            E = mContext->export_vars_end();
193         I != E;
194         I++) {
195      const RSExportVar *EV = *I;
196      const RSExportType *ET = EV->getType();
197
198      // Variable name
199      ExportVarInfo.push_back(
200          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
201
202      // Type name
203      switch (ET->getClass()) {
204        case RSExportType::ExportClassPrimitive: {
205          ExportVarInfo.push_back(
206              llvm::MDString::get(
207                mLLVMContext, llvm::utostr_32(
208                  static_cast<const RSExportPrimitiveType*>(ET)->getType())));
209          break;
210        }
211        case RSExportType::ExportClassPointer: {
212          ExportVarInfo.push_back(
213              llvm::MDString::get(
214                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
215                  ->getPointeeType()->getName()).c_str()));
216          break;
217        }
218        case RSExportType::ExportClassMatrix: {
219          ExportVarInfo.push_back(
220              llvm::MDString::get(
221                mLLVMContext, llvm::utostr_32(
222                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
223                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
224          break;
225        }
226        case RSExportType::ExportClassVector:
227        case RSExportType::ExportClassConstantArray:
228        case RSExportType::ExportClassRecord: {
229          ExportVarInfo.push_back(
230              llvm::MDString::get(mLLVMContext,
231                EV->getType()->getName().c_str()));
232          break;
233        }
234      }
235
236      mExportVarMetadata->addOperand(
237          llvm::MDNode::get(mLLVMContext,
238                            ExportVarInfo.data(),
239                            ExportVarInfo.size()) );
240
241      ExportVarInfo.clear();
242    }
243  }
244
245  // Dump export function info
246  if (mContext->hasExportFunc()) {
247    if (mExportFuncMetadata == NULL)
248      mExportFuncMetadata =
249          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
250
251    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
252
253    for (RSContext::const_export_func_iterator
254            I = mContext->export_funcs_begin(),
255            E = mContext->export_funcs_end();
256         I != E;
257         I++) {
258      const RSExportFunc *EF = *I;
259
260      // Function name
261      if (!EF->hasParam()) {
262        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
263                                                     EF->getName().c_str()));
264      } else {
265        llvm::Function *F = M->getFunction(EF->getName());
266        llvm::Function *HelperFunction;
267        const std::string HelperFunctionName(".helper_" + EF->getName());
268
269        assert(F && "Function marked as exported disappeared in Bitcode");
270
271        // Create helper function
272        {
273          llvm::StructType *HelperFunctionParameterTy = NULL;
274
275          if (!F->getArgumentList().empty()) {
276            std::vector<const llvm::Type*> HelperFunctionParameterTys;
277            for (llvm::Function::arg_iterator AI = F->arg_begin(),
278                 AE = F->arg_end(); AI != AE; AI++)
279              HelperFunctionParameterTys.push_back(AI->getType());
280
281            HelperFunctionParameterTy =
282                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
283          }
284
285          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
286            fprintf(stderr, "Failed to export function %s: parameter type "
287                            "mismatch during creation of helper function.\n",
288                    EF->getName().c_str());
289
290            const RSExportRecordType *Expected = EF->getParamPacketType();
291            if (Expected) {
292              fprintf(stderr, "Expected:\n");
293              Expected->getLLVMType()->dump();
294            }
295            if (HelperFunctionParameterTy) {
296              fprintf(stderr, "Got:\n");
297              HelperFunctionParameterTy->dump();
298            }
299          }
300
301          std::vector<const llvm::Type*> Params;
302          if (HelperFunctionParameterTy) {
303            llvm::PointerType *HelperFunctionParameterTyP =
304                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
305            Params.push_back(HelperFunctionParameterTyP);
306          }
307
308          llvm::FunctionType * HelperFunctionType =
309              llvm::FunctionType::get(F->getReturnType(),
310                                      Params,
311                                      /* IsVarArgs = */false);
312
313          HelperFunction =
314              llvm::Function::Create(HelperFunctionType,
315                                     llvm::GlobalValue::ExternalLinkage,
316                                     HelperFunctionName,
317                                     M);
318
319          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
320          HelperFunction->setCallingConv(F->getCallingConv());
321
322          // Create helper function body
323          {
324            llvm::Argument *HelperFunctionParameter =
325                &(*HelperFunction->arg_begin());
326            llvm::BasicBlock *BB =
327                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
328            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
329            llvm::SmallVector<llvm::Value*, 6> Params;
330            llvm::Value *Idx[2];
331
332            Idx[0] =
333                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
334
335            // getelementptr and load instruction for all elements in
336            // parameter .p
337            for (size_t i = 0; i < EF->getNumParameters(); i++) {
338              // getelementptr
339              Idx[1] =
340                  llvm::ConstantInt::get(
341                      llvm::Type::getInt32Ty(mLLVMContext), i);
342              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
343                                                       Idx,
344                                                       Idx + 2);
345
346              // load
347              llvm::Value *V = IB->CreateLoad(Ptr);
348              Params.push_back(V);
349            }
350
351            // Call and pass the all elements as paramter to F
352            llvm::CallInst *CI = IB->CreateCall(F,
353                                                Params.data(),
354                                                Params.data() + Params.size());
355
356            CI->setCallingConv(F->getCallingConv());
357
358            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
359              IB->CreateRetVoid();
360            else
361              IB->CreateRet(CI);
362
363            delete IB;
364          }
365        }
366
367        ExportFuncInfo.push_back(
368            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
369      }
370
371      mExportFuncMetadata->addOperand(
372          llvm::MDNode::get(mLLVMContext,
373                            ExportFuncInfo.data(),
374                            ExportFuncInfo.size()));
375
376      ExportFuncInfo.clear();
377    }
378  }
379
380  // Dump export type info
381  if (mContext->hasExportType()) {
382    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
383
384    for (RSContext::const_export_type_iterator
385            I = mContext->export_types_begin(),
386            E = mContext->export_types_end();
387         I != E;
388         I++) {
389      // First, dump type name list to export
390      const RSExportType *ET = I->getValue();
391
392      ExportTypeInfo.clear();
393      // Type name
394      ExportTypeInfo.push_back(
395          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
396
397      if (ET->getClass() == RSExportType::ExportClassRecord) {
398        const RSExportRecordType *ERT =
399            static_cast<const RSExportRecordType*>(ET);
400
401        if (mExportTypeMetadata == NULL)
402          mExportTypeMetadata =
403              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
404
405        mExportTypeMetadata->addOperand(
406            llvm::MDNode::get(mLLVMContext,
407                              ExportTypeInfo.data(),
408                              ExportTypeInfo.size()));
409
410        // Now, export struct field information to %[struct name]
411        std::string StructInfoMetadataName("%");
412        StructInfoMetadataName.append(ET->getName());
413        llvm::NamedMDNode *StructInfoMetadata =
414            M->getOrInsertNamedMetadata(StructInfoMetadataName);
415        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
416
417        assert(StructInfoMetadata->getNumOperands() == 0 &&
418               "Metadata with same name was created before");
419        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
420                FE = ERT->fields_end();
421             FI != FE;
422             FI++) {
423          const RSExportRecordType::Field *F = *FI;
424
425          // 1. field name
426          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
427                                                  F->getName().c_str()));
428
429          // 2. field type name
430          FieldInfo.push_back(
431              llvm::MDString::get(mLLVMContext,
432                                  F->getType()->getName().c_str()));
433
434          // 3. field kind
435          switch (F->getType()->getClass()) {
436            case RSExportType::ExportClassPrimitive:
437            case RSExportType::ExportClassVector: {
438              const RSExportPrimitiveType *EPT =
439                  static_cast<const RSExportPrimitiveType*>(F->getType());
440              FieldInfo.push_back(
441                  llvm::MDString::get(mLLVMContext,
442                                      llvm::itostr(EPT->getKind())));
443              break;
444            }
445
446            default: {
447              FieldInfo.push_back(
448                  llvm::MDString::get(mLLVMContext,
449                                      llvm::itostr(
450                                        RSExportPrimitiveType::DataKindUser)));
451              break;
452            }
453          }
454
455          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
456                                                           FieldInfo.data(),
457                                                           FieldInfo.size()));
458
459          FieldInfo.clear();
460        }
461      }   // ET->getClass() == RSExportType::ExportClassRecord
462    }
463  }
464
465  return;
466}
467
468RSBackend::~RSBackend() {
469  return;
470}
471
472}  // namespace slang
473