slang_rs_backend.cpp revision c97a333bc84ce8c28c96d07734cbded75c914639
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_rs.h"
35#include "slang_rs_context.h"
36#include "slang_rs_export_func.h"
37#include "slang_rs_export_type.h"
38#include "slang_rs_export_var.h"
39#include "slang_rs_metadata.h"
40
41namespace slang {
42
43RSBackend::RSBackend(RSContext *Context,
44                     clang::Diagnostic *Diags,
45                     const clang::CodeGenOptions &CodeGenOpts,
46                     const clang::TargetOptions &TargetOpts,
47                     const PragmaList &Pragmas,
48                     llvm::raw_ostream *OS,
49                     Slang::OutputType OT,
50                     clang::SourceManager &SourceMgr,
51                     bool AllowRSPrefix)
52    : Backend(Diags,
53              CodeGenOpts,
54              TargetOpts,
55              Pragmas,
56              OS,
57              OT),
58      mContext(Context),
59      mSourceMgr(SourceMgr),
60      mAllowRSPrefix(AllowRSPrefix),
61      mExportVarMetadata(NULL),
62      mExportFuncMetadata(NULL),
63      mExportTypeMetadata(NULL) {
64  return;
65}
66
67// 1) Add zero initialization of local RS object types
68void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
69  if (FD &&
70      FD->hasBody() &&
71      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
72    mRefCount.Init(mContext->getASTContext());
73    mRefCount.Visit(FD->getBody());
74  }
75  return;
76}
77
78void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
79  // Disallow user-defined functions with prefix "rs"
80  if (!mAllowRSPrefix) {
81    // Iterate all function declarations in the program.
82    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
83         I != E; I++) {
84      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
85      if (FD == NULL)
86        continue;
87      if (!FD->getName().startswith("rs"))  // Check prefix
88        continue;
89      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
90        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
91                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
92                                             "invalid function name prefix, "
93                                             "\"rs\" is reserved: '%0'"))
94            << FD->getName();
95    }
96  }
97
98  // Process any non-static function declarations
99  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
100    AnnotateFunction(dyn_cast<clang::FunctionDecl>(*I));
101  }
102
103  Backend::HandleTopLevelDecl(D);
104  return;
105}
106
107namespace {
108
109bool ValidateVar(clang::VarDecl *VD) {
110  llvm::StringRef TypeName;
111  const clang::Type *T = VD->getType().getTypePtr();
112  if (!RSExportType::NormalizeType(T, TypeName)) {
113    return false;
114  }
115  return true;
116}
117
118bool ValidateASTContext(clang::ASTContext &C, clang::Diagnostic &Diags) {
119  bool valid = true;
120  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
121  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
122          DE = TUDecl->decls_end();
123       DI != DE;
124       DI++) {
125    if (DI->getKind() == clang::Decl::Var) {
126      clang::VarDecl *VD = (clang::VarDecl*) (*DI);
127      if (VD->getLinkage() == clang::ExternalLinkage) {
128        if (!ValidateVar(VD)) {
129          valid = false;
130          VD = VD->getCanonicalDecl();
131          Diags.Report(clang::FullSourceLoc(VD->getLocation(),
132                                            C.getSourceManager()),
133                       Diags.getCustomDiagID(clang::Diagnostic::Error,
134                                             "variable cannot be "
135                                             "exported: '%0'"))
136              << VD->getName();
137        }
138      }
139    }
140  }
141
142  return valid;
143}
144
145}  // end namespace
146
147void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
148  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
149
150  if (!ValidateASTContext(C, mDiags)) {
151    return;
152  }
153
154  // Process any static function declarations
155  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
156          E = TUDecl->decls_end(); I != E; I++) {
157    if ((I->getKind() >= clang::Decl::firstFunction) &&
158        (I->getKind() <= clang::Decl::lastFunction)) {
159      AnnotateFunction(static_cast<clang::FunctionDecl*>(*I));
160    }
161  }
162
163  return;
164}
165
166///////////////////////////////////////////////////////////////////////////////
167void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
168  if (!mContext->processExport()) {
169    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
170                                         "elements cannot be exported"));
171    return;
172  }
173
174  // Dump export variable info
175  if (mContext->hasExportVar()) {
176    if (mExportVarMetadata == NULL)
177      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
178
179    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
180
181    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
182            E = mContext->export_vars_end();
183         I != E;
184         I++) {
185      const RSExportVar *EV = *I;
186      const RSExportType *ET = EV->getType();
187
188      // Variable name
189      ExportVarInfo.push_back(
190          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
191
192      // Type name
193      switch (ET->getClass()) {
194        case RSExportType::ExportClassPrimitive: {
195          ExportVarInfo.push_back(
196              llvm::MDString::get(
197                mLLVMContext, llvm::utostr_32(
198                  static_cast<const RSExportPrimitiveType*>(ET)->getType())));
199          break;
200        }
201        case RSExportType::ExportClassPointer: {
202          ExportVarInfo.push_back(
203              llvm::MDString::get(
204                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
205                  ->getPointeeType()->getName()).c_str()));
206          break;
207        }
208        case RSExportType::ExportClassMatrix: {
209          ExportVarInfo.push_back(
210              llvm::MDString::get(
211                mLLVMContext, llvm::utostr_32(
212                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
213                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
214          break;
215        }
216        case RSExportType::ExportClassVector:
217        case RSExportType::ExportClassConstantArray:
218        case RSExportType::ExportClassRecord: {
219          ExportVarInfo.push_back(
220              llvm::MDString::get(mLLVMContext,
221                EV->getType()->getName().c_str()));
222          break;
223        }
224      }
225
226      mExportVarMetadata->addOperand(
227          llvm::MDNode::get(mLLVMContext,
228                            ExportVarInfo.data(),
229                            ExportVarInfo.size()) );
230
231      ExportVarInfo.clear();
232    }
233  }
234
235  // Dump export function info
236  if (mContext->hasExportFunc()) {
237    if (mExportFuncMetadata == NULL)
238      mExportFuncMetadata =
239          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
240
241    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
242
243    for (RSContext::const_export_func_iterator
244            I = mContext->export_funcs_begin(),
245            E = mContext->export_funcs_end();
246         I != E;
247         I++) {
248      const RSExportFunc *EF = *I;
249
250      // Function name
251      if (!EF->hasParam()) {
252        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
253                                                     EF->getName().c_str()));
254      } else {
255        llvm::Function *F = M->getFunction(EF->getName());
256        llvm::Function *HelperFunction;
257        const std::string HelperFunctionName(".helper_" + EF->getName());
258
259        assert(F && "Function marked as exported disappeared in Bitcode");
260
261        // Create helper function
262        {
263          llvm::StructType *HelperFunctionParameterTy = NULL;
264
265          if (!F->getArgumentList().empty()) {
266            std::vector<const llvm::Type*> HelperFunctionParameterTys;
267            for (llvm::Function::arg_iterator AI = F->arg_begin(),
268                 AE = F->arg_end(); AI != AE; AI++)
269              HelperFunctionParameterTys.push_back(AI->getType());
270
271            HelperFunctionParameterTy =
272                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
273          }
274
275          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
276            fprintf(stderr, "Failed to export function %s: parameter type "
277                            "mismatch during creation of helper function.\n",
278                    EF->getName().c_str());
279
280            const RSExportRecordType *Expected = EF->getParamPacketType();
281            if (Expected) {
282              fprintf(stderr, "Expected:\n");
283              Expected->getLLVMType()->dump();
284            }
285            if (HelperFunctionParameterTy) {
286              fprintf(stderr, "Got:\n");
287              HelperFunctionParameterTy->dump();
288            }
289          }
290
291          std::vector<const llvm::Type*> Params;
292          if (HelperFunctionParameterTy) {
293            llvm::PointerType *HelperFunctionParameterTyP =
294                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
295            Params.push_back(HelperFunctionParameterTyP);
296          }
297
298          llvm::FunctionType * HelperFunctionType =
299              llvm::FunctionType::get(F->getReturnType(),
300                                      Params,
301                                      /* IsVarArgs = */false);
302
303          HelperFunction =
304              llvm::Function::Create(HelperFunctionType,
305                                     llvm::GlobalValue::ExternalLinkage,
306                                     HelperFunctionName,
307                                     M);
308
309          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
310          HelperFunction->setCallingConv(F->getCallingConv());
311
312          // Create helper function body
313          {
314            llvm::Argument *HelperFunctionParameter =
315                &(*HelperFunction->arg_begin());
316            llvm::BasicBlock *BB =
317                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
318            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
319            llvm::SmallVector<llvm::Value*, 6> Params;
320            llvm::Value *Idx[2];
321
322            Idx[0] =
323                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
324
325            // getelementptr and load instruction for all elements in
326            // parameter .p
327            for (size_t i = 0; i < EF->getNumParameters(); i++) {
328              // getelementptr
329              Idx[1] =
330                  llvm::ConstantInt::get(
331                      llvm::Type::getInt32Ty(mLLVMContext), i);
332              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
333                                                       Idx,
334                                                       Idx + 2);
335
336              // load
337              llvm::Value *V = IB->CreateLoad(Ptr);
338              Params.push_back(V);
339            }
340
341            // Call and pass the all elements as paramter to F
342            llvm::CallInst *CI = IB->CreateCall(F,
343                                                Params.data(),
344                                                Params.data() + Params.size());
345
346            CI->setCallingConv(F->getCallingConv());
347
348            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
349              IB->CreateRetVoid();
350            else
351              IB->CreateRet(CI);
352
353            delete IB;
354          }
355        }
356
357        ExportFuncInfo.push_back(
358            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
359      }
360
361      mExportFuncMetadata->addOperand(
362          llvm::MDNode::get(mLLVMContext,
363                            ExportFuncInfo.data(),
364                            ExportFuncInfo.size()));
365
366      ExportFuncInfo.clear();
367    }
368  }
369
370  // Dump export type info
371  if (mContext->hasExportType()) {
372    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
373
374    for (RSContext::const_export_type_iterator
375            I = mContext->export_types_begin(),
376            E = mContext->export_types_end();
377         I != E;
378         I++) {
379      // First, dump type name list to export
380      const RSExportType *ET = I->getValue();
381
382      ExportTypeInfo.clear();
383      // Type name
384      ExportTypeInfo.push_back(
385          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
386
387      if (ET->getClass() == RSExportType::ExportClassRecord) {
388        const RSExportRecordType *ERT =
389            static_cast<const RSExportRecordType*>(ET);
390
391        if (mExportTypeMetadata == NULL)
392          mExportTypeMetadata =
393              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
394
395        mExportTypeMetadata->addOperand(
396            llvm::MDNode::get(mLLVMContext,
397                              ExportTypeInfo.data(),
398                              ExportTypeInfo.size()));
399
400        // Now, export struct field information to %[struct name]
401        std::string StructInfoMetadataName("%");
402        StructInfoMetadataName.append(ET->getName());
403        llvm::NamedMDNode *StructInfoMetadata =
404            M->getOrInsertNamedMetadata(StructInfoMetadataName);
405        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
406
407        assert(StructInfoMetadata->getNumOperands() == 0 &&
408               "Metadata with same name was created before");
409        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
410                FE = ERT->fields_end();
411             FI != FE;
412             FI++) {
413          const RSExportRecordType::Field *F = *FI;
414
415          // 1. field name
416          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
417                                                  F->getName().c_str()));
418
419          // 2. field type name
420          FieldInfo.push_back(
421              llvm::MDString::get(mLLVMContext,
422                                  F->getType()->getName().c_str()));
423
424          // 3. field kind
425          switch (F->getType()->getClass()) {
426            case RSExportType::ExportClassPrimitive:
427            case RSExportType::ExportClassVector: {
428              const RSExportPrimitiveType *EPT =
429                  static_cast<const RSExportPrimitiveType*>(F->getType());
430              FieldInfo.push_back(
431                  llvm::MDString::get(mLLVMContext,
432                                      llvm::itostr(EPT->getKind())));
433              break;
434            }
435
436            default: {
437              FieldInfo.push_back(
438                  llvm::MDString::get(mLLVMContext,
439                                      llvm::itostr(
440                                        RSExportPrimitiveType::DataKindUser)));
441              break;
442            }
443          }
444
445          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
446                                                           FieldInfo.data(),
447                                                           FieldInfo.size()));
448
449          FieldInfo.clear();
450        }
451      }   // ET->getClass() == RSExportType::ExportClassRecord
452    }
453  }
454
455  return;
456}
457
458RSBackend::~RSBackend() {
459  return;
460}
461
462}  // namespace slang
463