slang_rs_backend.cpp revision 23c4358f12bd9d0ba7166eceebd683db95a41b3f
1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/Frontend/CodeGenOptions.h"
24
25#include "llvm/ADT/Twine.h"
26#include "llvm/ADT/StringExtras.h"
27
28#include "llvm/IR/Constant.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DerivedTypes.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/Metadata.h"
34#include "llvm/IR/Module.h"
35
36#include "llvm/Support/DebugLoc.h"
37
38#include "slang_assert.h"
39#include "slang_rs.h"
40#include "slang_rs_context.h"
41#include "slang_rs_export_foreach.h"
42#include "slang_rs_export_func.h"
43#include "slang_rs_export_type.h"
44#include "slang_rs_export_var.h"
45#include "slang_rs_metadata.h"
46
47namespace slang {
48
49RSBackend::RSBackend(RSContext *Context,
50                     clang::DiagnosticsEngine *DiagEngine,
51                     const clang::CodeGenOptions &CodeGenOpts,
52                     const clang::TargetOptions &TargetOpts,
53                     PragmaList *Pragmas,
54                     llvm::raw_ostream *OS,
55                     Slang::OutputType OT,
56                     clang::SourceManager &SourceMgr,
57                     bool AllowRSPrefix,
58                     bool IsFilterscript)
59  : Backend(DiagEngine, CodeGenOpts, TargetOpts, Pragmas, OS, OT),
60    mContext(Context),
61    mSourceMgr(SourceMgr),
62    mAllowRSPrefix(AllowRSPrefix),
63    mIsFilterscript(IsFilterscript),
64    mExportVarMetadata(NULL),
65    mExportFuncMetadata(NULL),
66    mExportForEachNameMetadata(NULL),
67    mExportForEachSignatureMetadata(NULL),
68    mExportTypeMetadata(NULL),
69    mRSObjectSlotsMetadata(NULL),
70    mRefCount(mContext->getASTContext()),
71    mASTChecker(mContext->getASTContext(), mContext->getTargetAPI(),
72                IsFilterscript) {
73}
74
75// 1) Add zero initialization of local RS object types
76void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
77  if (FD &&
78      FD->hasBody() &&
79      !SlangRS::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
80    mRefCount.Init();
81    mRefCount.Visit(FD->getBody());
82  }
83  return;
84}
85
86bool RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
87  // Disallow user-defined functions with prefix "rs"
88  if (!mAllowRSPrefix) {
89    // Iterate all function declarations in the program.
90    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
91         I != E; I++) {
92      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
93      if (FD == NULL)
94        continue;
95      if (!FD->getName().startswith("rs"))  // Check prefix
96        continue;
97      if (!SlangRS::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr))
98        mDiagEngine.Report(
99          clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
100          mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
101                                      "invalid function name prefix, "
102                                      "\"rs\" is reserved: '%0'"))
103          << FD->getName();
104    }
105  }
106
107  // Process any non-static function declarations
108  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
109    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
110    if (FD && FD->isGlobal()) {
111      // Check that we don't have any array parameters being misintrepeted as
112      // kernel pointers due to the C type system's array to pointer decay.
113      size_t numParams = FD->getNumParams();
114      for (size_t i = 0; i < numParams; i++) {
115        const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
116        clang::QualType QT = PVD->getOriginalType();
117        if (QT->isArrayType()) {
118          mDiagEngine.Report(
119            clang::FullSourceLoc(PVD->getTypeSpecStartLoc(), mSourceMgr),
120            mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
121                                        "exported function parameters may "
122                                        "not have array type: %0")) << QT;
123        }
124      }
125      AnnotateFunction(FD);
126    }
127  }
128
129  return Backend::HandleTopLevelDecl(D);
130}
131
132
133void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
134  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
135
136  // If we have an invalid RS/FS AST, don't check further.
137  if (!mASTChecker.Validate()) {
138    return;
139  }
140
141  if (mIsFilterscript) {
142    mContext->addPragma("rs_fp_relaxed", "");
143  }
144
145  int version = mContext->getVersion();
146  if (version == 0) {
147    // Not setting a version is an error
148    mDiagEngine.Report(
149        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
150        mDiagEngine.getCustomDiagID(
151            clang::DiagnosticsEngine::Error,
152            "missing pragma for version in source file"));
153  } else {
154    slangAssert(version == 1);
155  }
156
157  if (mContext->getReflectJavaPackageName().empty()) {
158    mDiagEngine.Report(
159        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
160        mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
161                                    "missing \"#pragma rs "
162                                    "java_package_name(com.foo.bar)\" "
163                                    "in source file"));
164    return;
165  }
166
167  // Create a static global destructor if necessary (to handle RS object
168  // runtime cleanup).
169  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
170  if (FD) {
171    HandleTopLevelDecl(clang::DeclGroupRef(FD));
172  }
173
174  // Process any static function declarations
175  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
176          E = TUDecl->decls_end(); I != E; I++) {
177    if ((I->getKind() >= clang::Decl::firstFunction) &&
178        (I->getKind() <= clang::Decl::lastFunction)) {
179      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
180      if (FD && !FD->isGlobal()) {
181        AnnotateFunction(FD);
182      }
183    }
184  }
185
186  return;
187}
188
189///////////////////////////////////////////////////////////////////////////////
190void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
191  if (!mContext->processExport()) {
192    return;
193  }
194
195  // Write optimization level
196  llvm::SmallVector<llvm::Value*, 1> OptimizationOption;
197  OptimizationOption.push_back(llvm::ConstantInt::get(
198    mLLVMContext, llvm::APInt(32, mCodeGenOpts.OptimizationLevel)));
199
200  // Dump export variable info
201  if (mContext->hasExportVar()) {
202    int slotCount = 0;
203    if (mExportVarMetadata == NULL)
204      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
205
206    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
207
208    // We emit slot information (#rs_object_slots) for any reference counted
209    // RS type or pointer (which can also be bound).
210
211    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
212            E = mContext->export_vars_end();
213         I != E;
214         I++) {
215      const RSExportVar *EV = *I;
216      const RSExportType *ET = EV->getType();
217      bool countsAsRSObject = false;
218
219      // Variable name
220      ExportVarInfo.push_back(
221          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
222
223      // Type name
224      switch (ET->getClass()) {
225        case RSExportType::ExportClassPrimitive: {
226          const RSExportPrimitiveType *PT =
227              static_cast<const RSExportPrimitiveType*>(ET);
228          ExportVarInfo.push_back(
229              llvm::MDString::get(
230                mLLVMContext, llvm::utostr_32(PT->getType())));
231          if (PT->isRSObjectType()) {
232            countsAsRSObject = true;
233          }
234          break;
235        }
236        case RSExportType::ExportClassPointer: {
237          ExportVarInfo.push_back(
238              llvm::MDString::get(
239                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
240                  ->getPointeeType()->getName()).c_str()));
241          break;
242        }
243        case RSExportType::ExportClassMatrix: {
244          ExportVarInfo.push_back(
245              llvm::MDString::get(
246                mLLVMContext, llvm::utostr_32(
247                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
248                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
249          break;
250        }
251        case RSExportType::ExportClassVector:
252        case RSExportType::ExportClassConstantArray:
253        case RSExportType::ExportClassRecord: {
254          ExportVarInfo.push_back(
255              llvm::MDString::get(mLLVMContext,
256                EV->getType()->getName().c_str()));
257          break;
258        }
259      }
260
261      mExportVarMetadata->addOperand(
262          llvm::MDNode::get(mLLVMContext, ExportVarInfo));
263      ExportVarInfo.clear();
264
265      if (mRSObjectSlotsMetadata == NULL) {
266        mRSObjectSlotsMetadata =
267            M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
268      }
269
270      if (countsAsRSObject) {
271        mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
272            llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
273      }
274
275      slotCount++;
276    }
277  }
278
279  // Dump export function info
280  if (mContext->hasExportFunc()) {
281    if (mExportFuncMetadata == NULL)
282      mExportFuncMetadata =
283          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
284
285    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
286
287    for (RSContext::const_export_func_iterator
288            I = mContext->export_funcs_begin(),
289            E = mContext->export_funcs_end();
290         I != E;
291         I++) {
292      const RSExportFunc *EF = *I;
293
294      // Function name
295      if (!EF->hasParam()) {
296        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
297                                                     EF->getName().c_str()));
298      } else {
299        llvm::Function *F = M->getFunction(EF->getName());
300        llvm::Function *HelperFunction;
301        const std::string HelperFunctionName(".helper_" + EF->getName());
302
303        slangAssert(F && "Function marked as exported disappeared in Bitcode");
304
305        // Create helper function
306        {
307          llvm::StructType *HelperFunctionParameterTy = NULL;
308
309          if (!F->getArgumentList().empty()) {
310            std::vector<llvm::Type*> HelperFunctionParameterTys;
311            for (llvm::Function::arg_iterator AI = F->arg_begin(),
312                 AE = F->arg_end(); AI != AE; AI++)
313              HelperFunctionParameterTys.push_back(AI->getType());
314
315            HelperFunctionParameterTy =
316                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
317          }
318
319          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
320            fprintf(stderr, "Failed to export function %s: parameter type "
321                            "mismatch during creation of helper function.\n",
322                    EF->getName().c_str());
323
324            const RSExportRecordType *Expected = EF->getParamPacketType();
325            if (Expected) {
326              fprintf(stderr, "Expected:\n");
327              Expected->getLLVMType()->dump();
328            }
329            if (HelperFunctionParameterTy) {
330              fprintf(stderr, "Got:\n");
331              HelperFunctionParameterTy->dump();
332            }
333          }
334
335          std::vector<llvm::Type*> Params;
336          if (HelperFunctionParameterTy) {
337            llvm::PointerType *HelperFunctionParameterTyP =
338                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
339            Params.push_back(HelperFunctionParameterTyP);
340          }
341
342          llvm::FunctionType * HelperFunctionType =
343              llvm::FunctionType::get(F->getReturnType(),
344                                      Params,
345                                      /* IsVarArgs = */false);
346
347          HelperFunction =
348              llvm::Function::Create(HelperFunctionType,
349                                     llvm::GlobalValue::ExternalLinkage,
350                                     HelperFunctionName,
351                                     M);
352
353          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
354          HelperFunction->setCallingConv(F->getCallingConv());
355
356          // Create helper function body
357          {
358            llvm::Argument *HelperFunctionParameter =
359                &(*HelperFunction->arg_begin());
360            llvm::BasicBlock *BB =
361                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
362            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
363            llvm::SmallVector<llvm::Value*, 6> Params;
364            llvm::Value *Idx[2];
365
366            Idx[0] =
367                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
368
369            // getelementptr and load instruction for all elements in
370            // parameter .p
371            for (size_t i = 0; i < EF->getNumParameters(); i++) {
372              // getelementptr
373              Idx[1] = llvm::ConstantInt::get(
374                llvm::Type::getInt32Ty(mLLVMContext), i);
375
376              llvm::Value *Ptr =
377                IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
378
379              // load
380              llvm::Value *V = IB->CreateLoad(Ptr);
381              Params.push_back(V);
382            }
383
384            // Call and pass the all elements as parameter to F
385            llvm::CallInst *CI = IB->CreateCall(F, Params);
386
387            CI->setCallingConv(F->getCallingConv());
388
389            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
390              IB->CreateRetVoid();
391            else
392              IB->CreateRet(CI);
393
394            delete IB;
395          }
396        }
397
398        ExportFuncInfo.push_back(
399            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
400      }
401
402      mExportFuncMetadata->addOperand(
403          llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
404      ExportFuncInfo.clear();
405    }
406  }
407
408  // Dump export function info
409  if (mContext->hasExportForEach()) {
410    if (mExportForEachNameMetadata == NULL) {
411      mExportForEachNameMetadata =
412          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_NAME_MN);
413    }
414    if (mExportForEachSignatureMetadata == NULL) {
415      mExportForEachSignatureMetadata =
416          M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
417    }
418
419    llvm::SmallVector<llvm::Value*, 1> ExportForEachName;
420    llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
421
422    for (RSContext::const_export_foreach_iterator
423            I = mContext->export_foreach_begin(),
424            E = mContext->export_foreach_end();
425         I != E;
426         I++) {
427      const RSExportForEach *EFE = *I;
428
429      ExportForEachName.push_back(
430          llvm::MDString::get(mLLVMContext, EFE->getName().c_str()));
431
432      mExportForEachNameMetadata->addOperand(
433          llvm::MDNode::get(mLLVMContext, ExportForEachName));
434      ExportForEachName.clear();
435
436      ExportForEachInfo.push_back(
437          llvm::MDString::get(mLLVMContext,
438                              llvm::utostr_32(EFE->getSignatureMetadata())));
439
440      mExportForEachSignatureMetadata->addOperand(
441          llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
442      ExportForEachInfo.clear();
443    }
444  }
445
446  // Dump export type info
447  if (mContext->hasExportType()) {
448    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
449
450    for (RSContext::const_export_type_iterator
451            I = mContext->export_types_begin(),
452            E = mContext->export_types_end();
453         I != E;
454         I++) {
455      // First, dump type name list to export
456      const RSExportType *ET = I->getValue();
457
458      ExportTypeInfo.clear();
459      // Type name
460      ExportTypeInfo.push_back(
461          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
462
463      if (ET->getClass() == RSExportType::ExportClassRecord) {
464        const RSExportRecordType *ERT =
465            static_cast<const RSExportRecordType*>(ET);
466
467        if (mExportTypeMetadata == NULL)
468          mExportTypeMetadata =
469              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
470
471        mExportTypeMetadata->addOperand(
472            llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
473
474        // Now, export struct field information to %[struct name]
475        std::string StructInfoMetadataName("%");
476        StructInfoMetadataName.append(ET->getName());
477        llvm::NamedMDNode *StructInfoMetadata =
478            M->getOrInsertNamedMetadata(StructInfoMetadataName);
479        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
480
481        slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
482                    "Metadata with same name was created before");
483        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
484                FE = ERT->fields_end();
485             FI != FE;
486             FI++) {
487          const RSExportRecordType::Field *F = *FI;
488
489          // 1. field name
490          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
491                                                  F->getName().c_str()));
492
493          // 2. field type name
494          FieldInfo.push_back(
495              llvm::MDString::get(mLLVMContext,
496                                  F->getType()->getName().c_str()));
497
498          StructInfoMetadata->addOperand(
499              llvm::MDNode::get(mLLVMContext, FieldInfo));
500          FieldInfo.clear();
501        }
502      }   // ET->getClass() == RSExportType::ExportClassRecord
503    }
504  }
505
506  return;
507}
508
509RSBackend::~RSBackend() {
510  return;
511}
512
513}  // namespace slang
514