slang_rs_backend.cpp revision dd6206bb61bf8df2ed6b643abe8a29c48a315685
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_rs.h"
35#include "slang_rs_context.h"
36#include "slang_rs_export_func.h"
37#include "slang_rs_export_type.h"
38#include "slang_rs_export_var.h"
39#include "slang_rs_metadata.h"
40
41namespace slang {
42
43RSBackend::RSBackend(RSContext *Context,
44                     clang::Diagnostic *Diags,
45                     const clang::CodeGenOptions &CodeGenOpts,
46                     const clang::TargetOptions &TargetOpts,
47                     const PragmaList &Pragmas,
48                     llvm::raw_ostream *OS,
49                     Slang::OutputType OT,
50                     clang::SourceManager &SourceMgr,
51                     bool AllowRSPrefix)
52    : Backend(Diags,
53              CodeGenOpts,
54              TargetOpts,
55              Pragmas,
56              OS,
57              OT),
58      mContext(Context),
59      mSourceMgr(SourceMgr),
60      mAllowRSPrefix(AllowRSPrefix),
61      mExportVarMetadata(NULL),
62      mExportFuncMetadata(NULL),
63      mExportTypeMetadata(NULL) {
64  return;
65}
66
67// 1) Add zero initialization of local RS object types
68void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
69  if (FD &&
70      FD->hasBody() &&
71      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
72    mRefCount.Init(mContext->getASTContext());
73    mRefCount.Visit(FD->getBody());
74  }
75  return;
76}
77
78void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
79  // Disallow user-defined functions with prefix "rs"
80  if (!mAllowRSPrefix) {
81    // Iterate all function declarations in the program.
82    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
83         I != E; I++) {
84      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
85      if (FD == NULL)
86        continue;
87      if (!FD->getName().startswith("rs"))  // Check prefix
88        continue;
89      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
90        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
91                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
92                                             "invalid function name prefix, "
93                                             "\"rs\" is reserved: '%0'"))
94            << FD->getName();
95    }
96  }
97
98  // Process any non-static function declarations
99  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
100    AnnotateFunction(dyn_cast<clang::FunctionDecl>(*I));
101  }
102
103  Backend::HandleTopLevelDecl(D);
104  return;
105}
106
107namespace {
108
109bool ValidateVar(clang::VarDecl *VD, clang::Diagnostic *Diags,
110    clang::SourceManager *SM) {
111  llvm::StringRef TypeName;
112  const clang::Type *T = VD->getType().getTypePtr();
113  if (!RSExportType::NormalizeType(T, TypeName, Diags, SM, VD)) {
114    return false;
115  }
116  return true;
117}
118
119bool ValidateASTContext(clang::ASTContext &C, clang::Diagnostic &Diags) {
120  bool valid = true;
121  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
122  for (clang::DeclContext::decl_iterator DI = TUDecl->decls_begin(),
123          DE = TUDecl->decls_end();
124       DI != DE;
125       DI++) {
126    if (DI->getKind() == clang::Decl::Var) {
127      clang::VarDecl *VD = (clang::VarDecl*) (*DI);
128      if (VD->getLinkage() == clang::ExternalLinkage) {
129        if (!ValidateVar(VD, &Diags, &C.getSourceManager())) {
130          valid = false;
131        }
132      }
133    }
134  }
135
136  return valid;
137}
138
139}  // namespace
140
141void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
142  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
143
144  if (!ValidateASTContext(C, mDiags)) {
145    return;
146  }
147
148  // Process any static function declarations
149  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
150          E = TUDecl->decls_end(); I != E; I++) {
151    if ((I->getKind() >= clang::Decl::firstFunction) &&
152        (I->getKind() <= clang::Decl::lastFunction)) {
153      AnnotateFunction(static_cast<clang::FunctionDecl*>(*I));
154    }
155  }
156
157  return;
158}
159
160///////////////////////////////////////////////////////////////////////////////
161void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
162  if (!mContext->processExport()) {
163    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
164                                         "elements cannot be exported"));
165    return;
166  }
167
168  // Dump export variable info
169  if (mContext->hasExportVar()) {
170    if (mExportVarMetadata == NULL)
171      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
172
173    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
174
175    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
176            E = mContext->export_vars_end();
177         I != E;
178         I++) {
179      const RSExportVar *EV = *I;
180      const RSExportType *ET = EV->getType();
181
182      // Variable name
183      ExportVarInfo.push_back(
184          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
185
186      // Type name
187      switch (ET->getClass()) {
188        case RSExportType::ExportClassPrimitive: {
189          ExportVarInfo.push_back(
190              llvm::MDString::get(
191                mLLVMContext, llvm::utostr_32(
192                  static_cast<const RSExportPrimitiveType*>(ET)->getType())));
193          break;
194        }
195        case RSExportType::ExportClassPointer: {
196          ExportVarInfo.push_back(
197              llvm::MDString::get(
198                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
199                  ->getPointeeType()->getName()).c_str()));
200          break;
201        }
202        case RSExportType::ExportClassMatrix: {
203          ExportVarInfo.push_back(
204              llvm::MDString::get(
205                mLLVMContext, llvm::utostr_32(
206                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
207                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
208          break;
209        }
210        case RSExportType::ExportClassVector:
211        case RSExportType::ExportClassConstantArray:
212        case RSExportType::ExportClassRecord: {
213          ExportVarInfo.push_back(
214              llvm::MDString::get(mLLVMContext,
215                EV->getType()->getName().c_str()));
216          break;
217        }
218      }
219
220      mExportVarMetadata->addOperand(
221          llvm::MDNode::get(mLLVMContext,
222                            ExportVarInfo.data(),
223                            ExportVarInfo.size()) );
224
225      ExportVarInfo.clear();
226    }
227  }
228
229  // Dump export function info
230  if (mContext->hasExportFunc()) {
231    if (mExportFuncMetadata == NULL)
232      mExportFuncMetadata =
233          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
234
235    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
236
237    for (RSContext::const_export_func_iterator
238            I = mContext->export_funcs_begin(),
239            E = mContext->export_funcs_end();
240         I != E;
241         I++) {
242      const RSExportFunc *EF = *I;
243
244      // Function name
245      if (!EF->hasParam()) {
246        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
247                                                     EF->getName().c_str()));
248      } else {
249        llvm::Function *F = M->getFunction(EF->getName());
250        llvm::Function *HelperFunction;
251        const std::string HelperFunctionName(".helper_" + EF->getName());
252
253        assert(F && "Function marked as exported disappeared in Bitcode");
254
255        // Create helper function
256        {
257          llvm::StructType *HelperFunctionParameterTy = NULL;
258
259          if (!F->getArgumentList().empty()) {
260            std::vector<const llvm::Type*> HelperFunctionParameterTys;
261            for (llvm::Function::arg_iterator AI = F->arg_begin(),
262                 AE = F->arg_end(); AI != AE; AI++)
263              HelperFunctionParameterTys.push_back(AI->getType());
264
265            HelperFunctionParameterTy =
266                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
267          }
268
269          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
270            fprintf(stderr, "Failed to export function %s: parameter type "
271                            "mismatch during creation of helper function.\n",
272                    EF->getName().c_str());
273
274            const RSExportRecordType *Expected = EF->getParamPacketType();
275            if (Expected) {
276              fprintf(stderr, "Expected:\n");
277              Expected->getLLVMType()->dump();
278            }
279            if (HelperFunctionParameterTy) {
280              fprintf(stderr, "Got:\n");
281              HelperFunctionParameterTy->dump();
282            }
283          }
284
285          std::vector<const llvm::Type*> Params;
286          if (HelperFunctionParameterTy) {
287            llvm::PointerType *HelperFunctionParameterTyP =
288                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
289            Params.push_back(HelperFunctionParameterTyP);
290          }
291
292          llvm::FunctionType * HelperFunctionType =
293              llvm::FunctionType::get(F->getReturnType(),
294                                      Params,
295                                      /* IsVarArgs = */false);
296
297          HelperFunction =
298              llvm::Function::Create(HelperFunctionType,
299                                     llvm::GlobalValue::ExternalLinkage,
300                                     HelperFunctionName,
301                                     M);
302
303          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
304          HelperFunction->setCallingConv(F->getCallingConv());
305
306          // Create helper function body
307          {
308            llvm::Argument *HelperFunctionParameter =
309                &(*HelperFunction->arg_begin());
310            llvm::BasicBlock *BB =
311                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
312            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
313            llvm::SmallVector<llvm::Value*, 6> Params;
314            llvm::Value *Idx[2];
315
316            Idx[0] =
317                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
318
319            // getelementptr and load instruction for all elements in
320            // parameter .p
321            for (size_t i = 0; i < EF->getNumParameters(); i++) {
322              // getelementptr
323              Idx[1] =
324                  llvm::ConstantInt::get(
325                      llvm::Type::getInt32Ty(mLLVMContext), i);
326              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
327                                                       Idx,
328                                                       Idx + 2);
329
330              // load
331              llvm::Value *V = IB->CreateLoad(Ptr);
332              Params.push_back(V);
333            }
334
335            // Call and pass the all elements as paramter to F
336            llvm::CallInst *CI = IB->CreateCall(F,
337                                                Params.data(),
338                                                Params.data() + Params.size());
339
340            CI->setCallingConv(F->getCallingConv());
341
342            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
343              IB->CreateRetVoid();
344            else
345              IB->CreateRet(CI);
346
347            delete IB;
348          }
349        }
350
351        ExportFuncInfo.push_back(
352            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
353      }
354
355      mExportFuncMetadata->addOperand(
356          llvm::MDNode::get(mLLVMContext,
357                            ExportFuncInfo.data(),
358                            ExportFuncInfo.size()));
359
360      ExportFuncInfo.clear();
361    }
362  }
363
364  // Dump export type info
365  if (mContext->hasExportType()) {
366    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
367
368    for (RSContext::const_export_type_iterator
369            I = mContext->export_types_begin(),
370            E = mContext->export_types_end();
371         I != E;
372         I++) {
373      // First, dump type name list to export
374      const RSExportType *ET = I->getValue();
375
376      ExportTypeInfo.clear();
377      // Type name
378      ExportTypeInfo.push_back(
379          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
380
381      if (ET->getClass() == RSExportType::ExportClassRecord) {
382        const RSExportRecordType *ERT =
383            static_cast<const RSExportRecordType*>(ET);
384
385        if (mExportTypeMetadata == NULL)
386          mExportTypeMetadata =
387              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
388
389        mExportTypeMetadata->addOperand(
390            llvm::MDNode::get(mLLVMContext,
391                              ExportTypeInfo.data(),
392                              ExportTypeInfo.size()));
393
394        // Now, export struct field information to %[struct name]
395        std::string StructInfoMetadataName("%");
396        StructInfoMetadataName.append(ET->getName());
397        llvm::NamedMDNode *StructInfoMetadata =
398            M->getOrInsertNamedMetadata(StructInfoMetadataName);
399        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
400
401        assert(StructInfoMetadata->getNumOperands() == 0 &&
402               "Metadata with same name was created before");
403        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
404                FE = ERT->fields_end();
405             FI != FE;
406             FI++) {
407          const RSExportRecordType::Field *F = *FI;
408
409          // 1. field name
410          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
411                                                  F->getName().c_str()));
412
413          // 2. field type name
414          FieldInfo.push_back(
415              llvm::MDString::get(mLLVMContext,
416                                  F->getType()->getName().c_str()));
417
418          // 3. field kind
419          switch (F->getType()->getClass()) {
420            case RSExportType::ExportClassPrimitive:
421            case RSExportType::ExportClassVector: {
422              const RSExportPrimitiveType *EPT =
423                  static_cast<const RSExportPrimitiveType*>(F->getType());
424              FieldInfo.push_back(
425                  llvm::MDString::get(mLLVMContext,
426                                      llvm::itostr(EPT->getKind())));
427              break;
428            }
429
430            default: {
431              FieldInfo.push_back(
432                  llvm::MDString::get(mLLVMContext,
433                                      llvm::itostr(
434                                        RSExportPrimitiveType::DataKindUser)));
435              break;
436            }
437          }
438
439          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
440                                                           FieldInfo.data(),
441                                                           FieldInfo.size()));
442
443          FieldInfo.clear();
444        }
445      }   // ET->getClass() == RSExportType::ExportClassRecord
446    }
447  }
448
449  return;
450}
451
452RSBackend::~RSBackend() {
453  return;
454}
455
456}  // namespace slang
457