slang_rs_backend.cpp revision c808a99831115928b4648f4c8b86dc682594217a
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "llvm/ADT/Twine.h"
23#include "llvm/ADT/StringExtras.h"
24
25#include "llvm/Constant.h"
26#include "llvm/Constants.h"
27#include "llvm/DerivedTypes.h"
28#include "llvm/Function.h"
29#include "llvm/Metadata.h"
30#include "llvm/Module.h"
31
32#include "llvm/Support/IRBuilder.h"
33
34#include "slang_rs.h"
35#include "slang_rs_context.h"
36#include "slang_rs_export_func.h"
37#include "slang_rs_export_type.h"
38#include "slang_rs_export_var.h"
39#include "slang_rs_metadata.h"
40
41namespace slang {
42
43RSBackend::RSBackend(RSContext *Context,
44                     clang::Diagnostic *Diags,
45                     const clang::CodeGenOptions &CodeGenOpts,
46                     const clang::TargetOptions &TargetOpts,
47                     const PragmaList &Pragmas,
48                     llvm::raw_ostream *OS,
49                     Slang::OutputType OT,
50                     clang::SourceManager &SourceMgr,
51                     bool AllowRSPrefix)
52    : Backend(Diags,
53              CodeGenOpts,
54              TargetOpts,
55              Pragmas,
56              OS,
57              OT),
58      mContext(Context),
59      mSourceMgr(SourceMgr),
60      mAllowRSPrefix(AllowRSPrefix),
61      mExportVarMetadata(NULL),
62      mExportFuncMetadata(NULL),
63      mExportTypeMetadata(NULL) {
64  return;
65}
66
67// 1) Add zero initialization of local RS object types
68void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
69  if (FD &&
70      FD->hasBody() &&
71      !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr)) {
72    mRefCount.Init(mContext->getASTContext());
73    mRefCount.Visit(FD->getBody());
74  }
75  return;
76}
77
78void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
79  // Disallow user-defined functions with prefix "rs"
80  if (!mAllowRSPrefix) {
81    // Iterate all function declarations in the program.
82    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
83         I != E; I++) {
84      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
85      if (FD == NULL)
86        continue;
87      if (!FD->getName().startswith("rs"))  // Check prefix
88        continue;
89      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
90        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
91                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
92                                             "invalid function name prefix, "
93                                             "\"rs\" is reserved: '%0'"))
94            << FD->getName();
95    }
96  }
97
98  // Process any non-static function declarations
99  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
100    AnnotateFunction(dyn_cast<clang::FunctionDecl>(*I));
101  }
102
103  Backend::HandleTopLevelDecl(D);
104  return;
105}
106
107void RSBackend::HandleTranslationUnitPre(clang::ASTContext& C) {
108  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
109
110  // Process any static function declarations
111  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
112          E = TUDecl->decls_end(); I != E; I++) {
113    if ((I->getKind() >= clang::Decl::firstFunction) &&
114        (I->getKind() <= clang::Decl::lastFunction)) {
115      AnnotateFunction(static_cast<clang::FunctionDecl*>(*I));
116    }
117  }
118
119  return;
120}
121
122///////////////////////////////////////////////////////////////////////////////
123void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
124  if (!mContext->processExport()) {
125    mDiags.Report(mDiags.getCustomDiagID(clang::Diagnostic::Error,
126                                         "elements cannot be exported"));
127  }
128
129  // Dump export variable info
130  if (mContext->hasExportVar()) {
131    if (mExportVarMetadata == NULL)
132      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
133
134    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
135
136    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
137            E = mContext->export_vars_end();
138         I != E;
139         I++) {
140      const RSExportVar *EV = *I;
141      const RSExportType *ET = EV->getType();
142
143      // Variable name
144      ExportVarInfo.push_back(
145          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
146
147      // Type name
148      switch (ET->getClass()) {
149        case RSExportType::ExportClassPrimitive: {
150          ExportVarInfo.push_back(
151              llvm::MDString::get(
152                mLLVMContext, llvm::utostr_32(
153                  static_cast<const RSExportPrimitiveType*>(ET)->getType())));
154          break;
155        }
156        case RSExportType::ExportClassPointer: {
157          ExportVarInfo.push_back(
158              llvm::MDString::get(
159                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
160                  ->getPointeeType()->getName()).c_str()));
161          break;
162        }
163        case RSExportType::ExportClassMatrix: {
164          ExportVarInfo.push_back(
165              llvm::MDString::get(
166                mLLVMContext, llvm::utostr_32(
167                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
168                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
169          break;
170        }
171        case RSExportType::ExportClassVector:
172        case RSExportType::ExportClassConstantArray:
173        case RSExportType::ExportClassRecord: {
174          ExportVarInfo.push_back(
175              llvm::MDString::get(mLLVMContext,
176                EV->getType()->getName().c_str()));
177          break;
178        }
179      }
180
181      mExportVarMetadata->addOperand(
182          llvm::MDNode::get(mLLVMContext,
183                            ExportVarInfo.data(),
184                            ExportVarInfo.size()) );
185
186      ExportVarInfo.clear();
187    }
188  }
189
190  // Dump export function info
191  if (mContext->hasExportFunc()) {
192    if (mExportFuncMetadata == NULL)
193      mExportFuncMetadata =
194          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
195
196    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
197
198    for (RSContext::const_export_func_iterator
199            I = mContext->export_funcs_begin(),
200            E = mContext->export_funcs_end();
201         I != E;
202         I++) {
203      const RSExportFunc *EF = *I;
204
205      // Function name
206      if (!EF->hasParam()) {
207        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
208                                                     EF->getName().c_str()));
209      } else {
210        llvm::Function *F = M->getFunction(EF->getName());
211        llvm::Function *HelperFunction;
212        const std::string HelperFunctionName(".helper_" + EF->getName());
213
214        assert(F && "Function marked as exported disappeared in Bitcode");
215
216        // Create helper function
217        {
218          llvm::StructType *HelperFunctionParameterTy = NULL;
219
220          if (!F->getArgumentList().empty()) {
221            std::vector<const llvm::Type*> HelperFunctionParameterTys;
222            for (llvm::Function::arg_iterator AI = F->arg_begin(),
223                 AE = F->arg_end(); AI != AE; AI++)
224              HelperFunctionParameterTys.push_back(AI->getType());
225
226            HelperFunctionParameterTy =
227                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
228          }
229
230          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
231            fprintf(stderr, "Failed to export function %s: parameter type "
232                            "mismatch during creation of helper function.\n",
233                    EF->getName().c_str());
234
235            const RSExportRecordType *Expected = EF->getParamPacketType();
236            if (Expected) {
237              fprintf(stderr, "Expected:\n");
238              Expected->getLLVMType()->dump();
239            }
240            if (HelperFunctionParameterTy) {
241              fprintf(stderr, "Got:\n");
242              HelperFunctionParameterTy->dump();
243            }
244          }
245
246          std::vector<const llvm::Type*> Params;
247          if (HelperFunctionParameterTy) {
248            llvm::PointerType *HelperFunctionParameterTyP =
249                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
250            Params.push_back(HelperFunctionParameterTyP);
251          }
252
253          llvm::FunctionType * HelperFunctionType =
254              llvm::FunctionType::get(F->getReturnType(),
255                                      Params,
256                                      /* IsVarArgs = */false);
257
258          HelperFunction =
259              llvm::Function::Create(HelperFunctionType,
260                                     llvm::GlobalValue::ExternalLinkage,
261                                     HelperFunctionName,
262                                     M);
263
264          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
265          HelperFunction->setCallingConv(F->getCallingConv());
266
267          // Create helper function body
268          {
269            llvm::Argument *HelperFunctionParameter =
270                &(*HelperFunction->arg_begin());
271            llvm::BasicBlock *BB =
272                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
273            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
274            llvm::SmallVector<llvm::Value*, 6> Params;
275            llvm::Value *Idx[2];
276
277            Idx[0] =
278                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
279
280            // getelementptr and load instruction for all elements in
281            // parameter .p
282            for (size_t i = 0; i < EF->getNumParameters(); i++) {
283              // getelementptr
284              Idx[1] =
285                  llvm::ConstantInt::get(
286                      llvm::Type::getInt32Ty(mLLVMContext), i);
287              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
288                                                       Idx,
289                                                       Idx + 2);
290
291              // load
292              llvm::Value *V = IB->CreateLoad(Ptr);
293              Params.push_back(V);
294            }
295
296            // Call and pass the all elements as paramter to F
297            llvm::CallInst *CI = IB->CreateCall(F,
298                                                Params.data(),
299                                                Params.data() + Params.size());
300
301            CI->setCallingConv(F->getCallingConv());
302
303            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
304              IB->CreateRetVoid();
305            else
306              IB->CreateRet(CI);
307
308            delete IB;
309          }
310        }
311
312        ExportFuncInfo.push_back(
313            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
314      }
315
316      mExportFuncMetadata->addOperand(
317          llvm::MDNode::get(mLLVMContext,
318                            ExportFuncInfo.data(),
319                            ExportFuncInfo.size()));
320
321      ExportFuncInfo.clear();
322    }
323  }
324
325  // Dump export type info
326  if (mContext->hasExportType()) {
327    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
328
329    for (RSContext::const_export_type_iterator
330            I = mContext->export_types_begin(),
331            E = mContext->export_types_end();
332         I != E;
333         I++) {
334      // First, dump type name list to export
335      const RSExportType *ET = I->getValue();
336
337      ExportTypeInfo.clear();
338      // Type name
339      ExportTypeInfo.push_back(
340          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
341
342      if (ET->getClass() == RSExportType::ExportClassRecord) {
343        const RSExportRecordType *ERT =
344            static_cast<const RSExportRecordType*>(ET);
345
346        if (mExportTypeMetadata == NULL)
347          mExportTypeMetadata =
348              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
349
350        mExportTypeMetadata->addOperand(
351            llvm::MDNode::get(mLLVMContext,
352                              ExportTypeInfo.data(),
353                              ExportTypeInfo.size()));
354
355        // Now, export struct field information to %[struct name]
356        std::string StructInfoMetadataName("%");
357        StructInfoMetadataName.append(ET->getName());
358        llvm::NamedMDNode *StructInfoMetadata =
359            M->getOrInsertNamedMetadata(StructInfoMetadataName);
360        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
361
362        assert(StructInfoMetadata->getNumOperands() == 0 &&
363               "Metadata with same name was created before");
364        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
365                FE = ERT->fields_end();
366             FI != FE;
367             FI++) {
368          const RSExportRecordType::Field *F = *FI;
369
370          // 1. field name
371          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
372                                                  F->getName().c_str()));
373
374          // 2. field type name
375          FieldInfo.push_back(
376              llvm::MDString::get(mLLVMContext,
377                                  F->getType()->getName().c_str()));
378
379          // 3. field kind
380          switch (F->getType()->getClass()) {
381            case RSExportType::ExportClassPrimitive:
382            case RSExportType::ExportClassVector: {
383              const RSExportPrimitiveType *EPT =
384                  static_cast<const RSExportPrimitiveType*>(F->getType());
385              FieldInfo.push_back(
386                  llvm::MDString::get(mLLVMContext,
387                                      llvm::itostr(EPT->getKind())));
388              break;
389            }
390
391            default: {
392              FieldInfo.push_back(
393                  llvm::MDString::get(mLLVMContext,
394                                      llvm::itostr(
395                                        RSExportPrimitiveType::DataKindUser)));
396              break;
397            }
398          }
399
400          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
401                                                           FieldInfo.data(),
402                                                           FieldInfo.size()));
403
404          FieldInfo.clear();
405        }
406      }   // ET->getClass() == RSExportType::ExportClassRecord
407    }
408  }
409
410  return;
411}
412
413RSBackend::~RSBackend() {
414  return;
415}
416
417}  // namespace slang
418