slang_rs_backend.cpp revision 96ab06cbe40b2d73c0eb614f814cd761d8962b6b
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_rs.h"
35#include "slang_rs_context.h"
36#include "slang_rs_export_func.h"
37#include "slang_rs_export_type.h"
38#include "slang_rs_export_var.h"
39#include "slang_rs_metadata.h"
40
41namespace slang {
42
43RSBackend::RSBackend(RSContext *Context,
44                     clang::Diagnostic *Diags,
45                     const clang::CodeGenOptions &CodeGenOpts,
46                     const clang::TargetOptions &TargetOpts,
47                     const PragmaList &Pragmas,
48                     llvm::raw_ostream *OS,
49                     Slang::OutputType OT,
50                     clang::SourceManager &SourceMgr,
51                     bool AllowRSPrefix)
52    : Backend(Diags,
53              CodeGenOpts,
54              TargetOpts,
55              Pragmas,
56              OS,
57              OT),
58      mContext(Context),
59      mSourceMgr(SourceMgr),
60      mAllowRSPrefix(AllowRSPrefix),
61      mExportVarMetadata(NULL),
62      mExportFuncMetadata(NULL),
63      mExportTypeMetadata(NULL) {
64  return;
65}
66
67// 1) Add zero initialization of local RS object types
68void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
69  if (FD &&
70      FD->hasBody() &&
71      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
72    mRefCount.Init(mContext->getASTContext());
73    mRefCount.Visit(FD->getBody());
74  }
75  return;
76}
77
78void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
79  // Disallow user-defined functions with prefix "rs"
80  if (!mAllowRSPrefix) {
81    // Iterate all function declarations in the program.
82    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
83         I != E; I++) {
84      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
85      if (FD == NULL)
86        continue;
87      if (!FD->getName().startswith("rs"))  // Check prefix
88        continue;
89      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
90        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
91                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
92                                             "invalid function name prefix, "
93                                             "\"rs\" is reserved: '%0'"))
94            << FD->getName();
95    }
96  }
97
98  // Process any non-static function declarations
99  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
100    AnnotateFunction(dyn_cast<clang::FunctionDecl>(*I));
101  }
102
103  Backend::HandleTopLevelDecl(D);
104  return;
105}
106
107namespace {
108
109bool ValidateVar(clang::VarDecl *VD, clang::Diagnostic *Diags,
110    clang::SourceManager *SM) {
111  llvm::StringRef TypeName;
112  const clang::Type *T = VD->getType().getTypePtr();
113  if (!RSExportType::NormalizeType(T, TypeName, Diags, SM, VD)) {
114    return false;
115  }
116  return true;
117}
118
119bool ValidateASTContext(clang::ASTContext &C, clang::Diagnostic &Diags) {
120  bool valid = true;
121  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
122  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
123          DE = TUDecl->decls_end();
124       DI != DE;
125       DI++) {
126    if (DI->getKind() == clang::Decl::Var) {
127      clang::VarDecl *VD = (clang::VarDecl*) (*DI);
128      if (VD->getLinkage() == clang::ExternalLinkage) {
129        if (!ValidateVar(VD, &Diags, &C.getSourceManager())) {
130          valid = false;
131        }
132      }
133    }
134  }
135
136  return valid;
137}
138
139}  // namespace
140
141void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
142  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
143
144  if (!ValidateASTContext(C, mDiags)) {
145    return;
146  }
147
148  int version = mContext->getVersion();
149  if (version == 0) {
150    // Not setting a version is an error
151    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
152                      "Missing pragma for version in source file"));
153  } else if (version > 1) {
154    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
155                      "Pragma for version in source file must be set to 1"));
156  }
157
158  // Process any static function declarations
159  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
160          E = TUDecl->decls_end(); I != E; I++) {
161    if ((I->getKind() >= clang::Decl::firstFunction) &&
162        (I->getKind() <= clang::Decl::lastFunction)) {
163      AnnotateFunction(static_cast<clang::FunctionDecl*>(*I));
164    }
165  }
166
167  return;
168}
169
170///////////////////////////////////////////////////////////////////////////////
171void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
172  if (!mContext->processExport()) {
173    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
174                                         "elements cannot be exported"));
175    return;
176  }
177
178  // Dump export variable info
179  if (mContext->hasExportVar()) {
180    if (mExportVarMetadata == NULL)
181      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
182
183    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
184
185    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
186            E = mContext->export_vars_end();
187         I != E;
188         I++) {
189      const RSExportVar *EV = *I;
190      const RSExportType *ET = EV->getType();
191
192      // Variable name
193      ExportVarInfo.push_back(
194          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
195
196      // Type name
197      switch (ET->getClass()) {
198        case RSExportType::ExportClassPrimitive: {
199          ExportVarInfo.push_back(
200              llvm::MDString::get(
201                mLLVMContext, llvm::utostr_32(
202                  static_cast<const RSExportPrimitiveType*>(ET)->getType())));
203          break;
204        }
205        case RSExportType::ExportClassPointer: {
206          ExportVarInfo.push_back(
207              llvm::MDString::get(
208                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
209                  ->getPointeeType()->getName()).c_str()));
210          break;
211        }
212        case RSExportType::ExportClassMatrix: {
213          ExportVarInfo.push_back(
214              llvm::MDString::get(
215                mLLVMContext, llvm::utostr_32(
216                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
217                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
218          break;
219        }
220        case RSExportType::ExportClassVector:
221        case RSExportType::ExportClassConstantArray:
222        case RSExportType::ExportClassRecord: {
223          ExportVarInfo.push_back(
224              llvm::MDString::get(mLLVMContext,
225                EV->getType()->getName().c_str()));
226          break;
227        }
228      }
229
230      mExportVarMetadata->addOperand(
231          llvm::MDNode::get(mLLVMContext,
232                            ExportVarInfo.data(),
233                            ExportVarInfo.size()) );
234
235      ExportVarInfo.clear();
236    }
237  }
238
239  // Dump export function info
240  if (mContext->hasExportFunc()) {
241    if (mExportFuncMetadata == NULL)
242      mExportFuncMetadata =
243          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
244
245    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
246
247    for (RSContext::const_export_func_iterator
248            I = mContext->export_funcs_begin(),
249            E = mContext->export_funcs_end();
250         I != E;
251         I++) {
252      const RSExportFunc *EF = *I;
253
254      // Function name
255      if (!EF->hasParam()) {
256        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
257                                                     EF->getName().c_str()));
258      } else {
259        llvm::Function *F = M->getFunction(EF->getName());
260        llvm::Function *HelperFunction;
261        const std::string HelperFunctionName(".helper_" + EF->getName());
262
263        assert(F && "Function marked as exported disappeared in Bitcode");
264
265        // Create helper function
266        {
267          llvm::StructType *HelperFunctionParameterTy = NULL;
268
269          if (!F->getArgumentList().empty()) {
270            std::vector<const llvm::Type*> HelperFunctionParameterTys;
271            for (llvm::Function::arg_iterator AI = F->arg_begin(),
272                 AE = F->arg_end(); AI != AE; AI++)
273              HelperFunctionParameterTys.push_back(AI->getType());
274
275            HelperFunctionParameterTy =
276                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
277          }
278
279          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
280            fprintf(stderr, "Failed to export function %s: parameter type "
281                            "mismatch during creation of helper function.\n",
282                    EF->getName().c_str());
283
284            const RSExportRecordType *Expected = EF->getParamPacketType();
285            if (Expected) {
286              fprintf(stderr, "Expected:\n");
287              Expected->getLLVMType()->dump();
288            }
289            if (HelperFunctionParameterTy) {
290              fprintf(stderr, "Got:\n");
291              HelperFunctionParameterTy->dump();
292            }
293          }
294
295          std::vector<const llvm::Type*> Params;
296          if (HelperFunctionParameterTy) {
297            llvm::PointerType *HelperFunctionParameterTyP =
298                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
299            Params.push_back(HelperFunctionParameterTyP);
300          }
301
302          llvm::FunctionType * HelperFunctionType =
303              llvm::FunctionType::get(F->getReturnType(),
304                                      Params,
305                                      /* IsVarArgs = */false);
306
307          HelperFunction =
308              llvm::Function::Create(HelperFunctionType,
309                                     llvm::GlobalValue::ExternalLinkage,
310                                     HelperFunctionName,
311                                     M);
312
313          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
314          HelperFunction->setCallingConv(F->getCallingConv());
315
316          // Create helper function body
317          {
318            llvm::Argument *HelperFunctionParameter =
319                &(*HelperFunction->arg_begin());
320            llvm::BasicBlock *BB =
321                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
322            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
323            llvm::SmallVector<llvm::Value*, 6> Params;
324            llvm::Value *Idx[2];
325
326            Idx[0] =
327                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
328
329            // getelementptr and load instruction for all elements in
330            // parameter .p
331            for (size_t i = 0; i < EF->getNumParameters(); i++) {
332              // getelementptr
333              Idx[1] =
334                  llvm::ConstantInt::get(
335                      llvm::Type::getInt32Ty(mLLVMContext), i);
336              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
337                                                       Idx,
338                                                       Idx + 2);
339
340              // load
341              llvm::Value *V = IB->CreateLoad(Ptr);
342              Params.push_back(V);
343            }
344
345            // Call and pass the all elements as paramter to F
346            llvm::CallInst *CI = IB->CreateCall(F,
347                                                Params.data(),
348                                                Params.data() + Params.size());
349
350            CI->setCallingConv(F->getCallingConv());
351
352            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
353              IB->CreateRetVoid();
354            else
355              IB->CreateRet(CI);
356
357            delete IB;
358          }
359        }
360
361        ExportFuncInfo.push_back(
362            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
363      }
364
365      mExportFuncMetadata->addOperand(
366          llvm::MDNode::get(mLLVMContext,
367                            ExportFuncInfo.data(),
368                            ExportFuncInfo.size()));
369
370      ExportFuncInfo.clear();
371    }
372  }
373
374  // Dump export type info
375  if (mContext->hasExportType()) {
376    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
377
378    for (RSContext::const_export_type_iterator
379            I = mContext->export_types_begin(),
380            E = mContext->export_types_end();
381         I != E;
382         I++) {
383      // First, dump type name list to export
384      const RSExportType *ET = I->getValue();
385
386      ExportTypeInfo.clear();
387      // Type name
388      ExportTypeInfo.push_back(
389          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
390
391      if (ET->getClass() == RSExportType::ExportClassRecord) {
392        const RSExportRecordType *ERT =
393            static_cast<const RSExportRecordType*>(ET);
394
395        if (mExportTypeMetadata == NULL)
396          mExportTypeMetadata =
397              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
398
399        mExportTypeMetadata->addOperand(
400            llvm::MDNode::get(mLLVMContext,
401                              ExportTypeInfo.data(),
402                              ExportTypeInfo.size()));
403
404        // Now, export struct field information to %[struct name]
405        std::string StructInfoMetadataName("%");
406        StructInfoMetadataName.append(ET->getName());
407        llvm::NamedMDNode *StructInfoMetadata =
408            M->getOrInsertNamedMetadata(StructInfoMetadataName);
409        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
410
411        assert(StructInfoMetadata->getNumOperands() == 0 &&
412               "Metadata with same name was created before");
413        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
414                FE = ERT->fields_end();
415             FI != FE;
416             FI++) {
417          const RSExportRecordType::Field *F = *FI;
418
419          // 1. field name
420          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
421                                                  F->getName().c_str()));
422
423          // 2. field type name
424          FieldInfo.push_back(
425              llvm::MDString::get(mLLVMContext,
426                                  F->getType()->getName().c_str()));
427
428          // 3. field kind
429          switch (F->getType()->getClass()) {
430            case RSExportType::ExportClassPrimitive:
431            case RSExportType::ExportClassVector: {
432              const RSExportPrimitiveType *EPT =
433                  static_cast<const RSExportPrimitiveType*>(F->getType());
434              FieldInfo.push_back(
435                  llvm::MDString::get(mLLVMContext,
436                                      llvm::itostr(EPT->getKind())));
437              break;
438            }
439
440            default: {
441              FieldInfo.push_back(
442                  llvm::MDString::get(mLLVMContext,
443                                      llvm::itostr(
444                                        RSExportPrimitiveType::DataKindUser)));
445              break;
446            }
447          }
448
449          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
450                                                           FieldInfo.data(),
451                                                           FieldInfo.size()));
452
453          FieldInfo.clear();
454        }
455      }   // ET->getClass() == RSExportType::ExportClassRecord
456    }
457  }
458
459  return;
460}
461
462RSBackend::~RSBackend() {
463  return;
464}
465
466}  // namespace slang
467