1/*
2 * Copyright 2010-2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <string>
20#include <vector>
21
22#include "clang/AST/ASTContext.h"
23#include "clang/Frontend/CodeGenOptions.h"
24
25#include "llvm/ADT/Twine.h"
26#include "llvm/ADT/StringExtras.h"
27
28#include "llvm/IR/Constant.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/DerivedTypes.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/Metadata.h"
34#include "llvm/IR/Module.h"
35
36#include "llvm/IR/DebugLoc.h"
37
38#include "slang_assert.h"
39#include "slang_rs.h"
40#include "slang_rs_context.h"
41#include "slang_rs_export_foreach.h"
42#include "slang_rs_export_func.h"
43#include "slang_rs_export_type.h"
44#include "slang_rs_export_var.h"
45#include "slang_rs_metadata.h"
46
47namespace slang {
48
49RSBackend::RSBackend(RSContext *Context,
50                     clang::DiagnosticsEngine *DiagEngine,
51                     const clang::CodeGenOptions &CodeGenOpts,
52                     const clang::TargetOptions &TargetOpts,
53                     PragmaList *Pragmas,
54                     llvm::raw_ostream *OS,
55                     Slang::OutputType OT,
56                     clang::SourceManager &SourceMgr,
57                     bool AllowRSPrefix,
58                     bool IsFilterscript)
59  : Backend(DiagEngine, CodeGenOpts, TargetOpts, Pragmas, OS, OT),
60    mContext(Context),
61    mSourceMgr(SourceMgr),
62    mAllowRSPrefix(AllowRSPrefix),
63    mIsFilterscript(IsFilterscript),
64    mExportVarMetadata(NULL),
65    mExportFuncMetadata(NULL),
66    mExportForEachNameMetadata(NULL),
67    mExportForEachSignatureMetadata(NULL),
68    mExportTypeMetadata(NULL),
69    mRSObjectSlotsMetadata(NULL),
70    mRefCount(mContext->getASTContext()),
71    mASTChecker(Context, Context->getTargetAPI(), IsFilterscript) {
72}
73
74// 1) Add zero initialization of local RS object types
75void RSBackend::AnnotateFunction(clang::FunctionDecl *FD) {
76  if (FD &&
77      FD->hasBody() &&
78      !SlangRS::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr)) {
79    mRefCount.Init();
80    mRefCount.Visit(FD->getBody());
81  }
82}
83
84bool RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
85  // Disallow user-defined functions with prefix "rs"
86  if (!mAllowRSPrefix) {
87    // Iterate all function declarations in the program.
88    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
89         I != E; I++) {
90      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
91      if (FD == NULL)
92        continue;
93      if (!FD->getName().startswith("rs"))  // Check prefix
94        continue;
95      if (!SlangRS::IsLocInRSHeaderFile(FD->getLocation(), mSourceMgr))
96        mContext->ReportError(FD->getLocation(),
97                              "invalid function name prefix, "
98                              "\"rs\" is reserved: '%0'")
99            << FD->getName();
100    }
101  }
102
103  // Process any non-static function declarations
104  for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end(); I != E; I++) {
105    clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
106    if (FD && FD->isGlobal()) {
107      // Check that we don't have any array parameters being misintrepeted as
108      // kernel pointers due to the C type system's array to pointer decay.
109      size_t numParams = FD->getNumParams();
110      for (size_t i = 0; i < numParams; i++) {
111        const clang::ParmVarDecl *PVD = FD->getParamDecl(i);
112        clang::QualType QT = PVD->getOriginalType();
113        if (QT->isArrayType()) {
114          mContext->ReportError(
115              PVD->getTypeSpecStartLoc(),
116              "exported function parameters may not have array type: %0")
117              << QT;
118        }
119      }
120      AnnotateFunction(FD);
121    }
122  }
123
124  return Backend::HandleTopLevelDecl(D);
125}
126
127
128void RSBackend::HandleTranslationUnitPre(clang::ASTContext &C) {
129  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
130
131  // If we have an invalid RS/FS AST, don't check further.
132  if (!mASTChecker.Validate()) {
133    return;
134  }
135
136  if (mIsFilterscript) {
137    mContext->addPragma("rs_fp_relaxed", "");
138  }
139
140  int version = mContext->getVersion();
141  if (version == 0) {
142    // Not setting a version is an error
143    mDiagEngine.Report(
144        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
145        mDiagEngine.getCustomDiagID(
146            clang::DiagnosticsEngine::Error,
147            "missing pragma for version in source file"));
148  } else {
149    slangAssert(version == 1);
150  }
151
152  if (mContext->getReflectJavaPackageName().empty()) {
153    mDiagEngine.Report(
154        mSourceMgr.getLocForEndOfFile(mSourceMgr.getMainFileID()),
155        mDiagEngine.getCustomDiagID(clang::DiagnosticsEngine::Error,
156                                    "missing \"#pragma rs "
157                                    "java_package_name(com.foo.bar)\" "
158                                    "in source file"));
159    return;
160  }
161
162  // Create a static global destructor if necessary (to handle RS object
163  // runtime cleanup).
164  clang::FunctionDecl *FD = mRefCount.CreateStaticGlobalDtor();
165  if (FD) {
166    HandleTopLevelDecl(clang::DeclGroupRef(FD));
167  }
168
169  // Process any static function declarations
170  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
171          E = TUDecl->decls_end(); I != E; I++) {
172    if ((I->getKind() >= clang::Decl::firstFunction) &&
173        (I->getKind() <= clang::Decl::lastFunction)) {
174      clang::FunctionDecl *FD = llvm::dyn_cast<clang::FunctionDecl>(*I);
175      if (FD && !FD->isGlobal()) {
176        AnnotateFunction(FD);
177      }
178    }
179  }
180}
181
182///////////////////////////////////////////////////////////////////////////////
183void RSBackend::dumpExportVarInfo(llvm::Module *M) {
184  int slotCount = 0;
185  if (mExportVarMetadata == NULL)
186    mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
187
188  llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
189
190  // We emit slot information (#rs_object_slots) for any reference counted
191  // RS type or pointer (which can also be bound).
192
193  for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
194          E = mContext->export_vars_end();
195       I != E;
196       I++) {
197    const RSExportVar *EV = *I;
198    const RSExportType *ET = EV->getType();
199    bool countsAsRSObject = false;
200
201    // Variable name
202    ExportVarInfo.push_back(
203        llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
204
205    // Type name
206    switch (ET->getClass()) {
207      case RSExportType::ExportClassPrimitive: {
208        const RSExportPrimitiveType *PT =
209            static_cast<const RSExportPrimitiveType*>(ET);
210        ExportVarInfo.push_back(
211            llvm::MDString::get(
212              mLLVMContext, llvm::utostr_32(PT->getType())));
213        if (PT->isRSObjectType()) {
214          countsAsRSObject = true;
215        }
216        break;
217      }
218      case RSExportType::ExportClassPointer: {
219        ExportVarInfo.push_back(
220            llvm::MDString::get(
221              mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
222                ->getPointeeType()->getName()).c_str()));
223        break;
224      }
225      case RSExportType::ExportClassMatrix: {
226        ExportVarInfo.push_back(
227            llvm::MDString::get(
228              mLLVMContext, llvm::utostr_32(
229                  /* TODO Strange value.  This pushes just a number, quite
230                   * different than the other cases.  What is this used for?
231                   * These are the metadata values that some partner drivers
232                   * want to reference (for TBAA, etc.). We may want to look
233                   * at whether these provide any reasonable value (or have
234                   * distinct enough values to actually depend on).
235                   */
236                DataTypeRSMatrix2x2 +
237                static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
238        break;
239      }
240      case RSExportType::ExportClassVector:
241      case RSExportType::ExportClassConstantArray:
242      case RSExportType::ExportClassRecord: {
243        ExportVarInfo.push_back(
244            llvm::MDString::get(mLLVMContext,
245              EV->getType()->getName().c_str()));
246        break;
247      }
248    }
249
250    mExportVarMetadata->addOperand(
251        llvm::MDNode::get(mLLVMContext, ExportVarInfo));
252    ExportVarInfo.clear();
253
254    if (mRSObjectSlotsMetadata == NULL) {
255      mRSObjectSlotsMetadata =
256          M->getOrInsertNamedMetadata(RS_OBJECT_SLOTS_MN);
257    }
258
259    if (countsAsRSObject) {
260      mRSObjectSlotsMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
261          llvm::MDString::get(mLLVMContext, llvm::utostr_32(slotCount))));
262    }
263
264    slotCount++;
265  }
266}
267
268void RSBackend::dumpExportFunctionInfo(llvm::Module *M) {
269  if (mExportFuncMetadata == NULL)
270    mExportFuncMetadata =
271        M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
272
273  llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
274
275  for (RSContext::const_export_func_iterator
276          I = mContext->export_funcs_begin(),
277          E = mContext->export_funcs_end();
278       I != E;
279       I++) {
280    const RSExportFunc *EF = *I;
281
282    // Function name
283    if (!EF->hasParam()) {
284      ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
285                                                   EF->getName().c_str()));
286    } else {
287      llvm::Function *F = M->getFunction(EF->getName());
288      llvm::Function *HelperFunction;
289      const std::string HelperFunctionName(".helper_" + EF->getName());
290
291      slangAssert(F && "Function marked as exported disappeared in Bitcode");
292
293      // Create helper function
294      {
295        llvm::StructType *HelperFunctionParameterTy = NULL;
296        std::vector<bool> isStructInput;
297        if (!F->getArgumentList().empty()) {
298          std::vector<llvm::Type*> HelperFunctionParameterTys;
299          for (llvm::Function::arg_iterator AI = F->arg_begin(),
300                   AE = F->arg_end(); AI != AE; AI++) {
301              if (AI->getType()->isPointerTy() && AI->getType()->getPointerElementType()->isStructTy()) {
302                  HelperFunctionParameterTys.push_back(AI->getType()->getPointerElementType());
303                  isStructInput.push_back(true);
304              } else {
305                  HelperFunctionParameterTys.push_back(AI->getType());
306                  isStructInput.push_back(false);
307              }
308          }
309          HelperFunctionParameterTy =
310              llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
311        }
312
313        if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
314          fprintf(stderr, "Failed to export function %s: parameter type "
315                          "mismatch during creation of helper function.\n",
316                  EF->getName().c_str());
317
318          const RSExportRecordType *Expected = EF->getParamPacketType();
319          if (Expected) {
320            fprintf(stderr, "Expected:\n");
321            Expected->getLLVMType()->dump();
322          }
323          if (HelperFunctionParameterTy) {
324            fprintf(stderr, "Got:\n");
325            HelperFunctionParameterTy->dump();
326          }
327        }
328
329        std::vector<llvm::Type*> Params;
330        if (HelperFunctionParameterTy) {
331          llvm::PointerType *HelperFunctionParameterTyP =
332              llvm::PointerType::getUnqual(HelperFunctionParameterTy);
333          Params.push_back(HelperFunctionParameterTyP);
334        }
335
336        llvm::FunctionType * HelperFunctionType =
337            llvm::FunctionType::get(F->getReturnType(),
338                                    Params,
339                                    /* IsVarArgs = */false);
340
341        HelperFunction =
342            llvm::Function::Create(HelperFunctionType,
343                                   llvm::GlobalValue::ExternalLinkage,
344                                   HelperFunctionName,
345                                   M);
346
347        HelperFunction->addFnAttr(llvm::Attribute::NoInline);
348        HelperFunction->setCallingConv(F->getCallingConv());
349
350        // Create helper function body
351        {
352          llvm::Argument *HelperFunctionParameter =
353              &(*HelperFunction->arg_begin());
354          llvm::BasicBlock *BB =
355              llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
356          llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
357          llvm::SmallVector<llvm::Value*, 6> Params;
358          llvm::Value *Idx[2];
359
360          Idx[0] =
361              llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
362
363          // getelementptr and load instruction for all elements in
364          // parameter .p
365          for (size_t i = 0; i < EF->getNumParameters(); i++) {
366            // getelementptr
367            Idx[1] = llvm::ConstantInt::get(
368              llvm::Type::getInt32Ty(mLLVMContext), i);
369
370            llvm::Value *Ptr = NULL;
371
372            Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter, Idx);
373
374            // Load is only required for non-struct ptrs
375            if (isStructInput[i]) {
376                Params.push_back(Ptr);
377            } else {
378                llvm::Value *V = IB->CreateLoad(Ptr);
379                Params.push_back(V);
380            }
381          }
382
383          // Call and pass the all elements as parameter to F
384          llvm::CallInst *CI = IB->CreateCall(F, Params);
385
386          CI->setCallingConv(F->getCallingConv());
387
388          if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
389            IB->CreateRetVoid();
390          else
391            IB->CreateRet(CI);
392
393          delete IB;
394        }
395      }
396
397      ExportFuncInfo.push_back(
398          llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
399    }
400
401    mExportFuncMetadata->addOperand(
402        llvm::MDNode::get(mLLVMContext, ExportFuncInfo));
403    ExportFuncInfo.clear();
404  }
405}
406
407void RSBackend::dumpExportForEachInfo(llvm::Module *M) {
408  if (mExportForEachNameMetadata == NULL) {
409    mExportForEachNameMetadata =
410        M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_NAME_MN);
411  }
412  if (mExportForEachSignatureMetadata == NULL) {
413    mExportForEachSignatureMetadata =
414        M->getOrInsertNamedMetadata(RS_EXPORT_FOREACH_MN);
415  }
416
417  llvm::SmallVector<llvm::Value*, 1> ExportForEachName;
418  llvm::SmallVector<llvm::Value*, 1> ExportForEachInfo;
419
420  for (RSContext::const_export_foreach_iterator
421          I = mContext->export_foreach_begin(),
422          E = mContext->export_foreach_end();
423       I != E;
424       I++) {
425    const RSExportForEach *EFE = *I;
426
427    ExportForEachName.push_back(
428        llvm::MDString::get(mLLVMContext, EFE->getName().c_str()));
429
430    mExportForEachNameMetadata->addOperand(
431        llvm::MDNode::get(mLLVMContext, ExportForEachName));
432    ExportForEachName.clear();
433
434    ExportForEachInfo.push_back(
435        llvm::MDString::get(mLLVMContext,
436                            llvm::utostr_32(EFE->getSignatureMetadata())));
437
438    mExportForEachSignatureMetadata->addOperand(
439        llvm::MDNode::get(mLLVMContext, ExportForEachInfo));
440    ExportForEachInfo.clear();
441  }
442}
443
444void RSBackend::dumpExportTypeInfo(llvm::Module *M) {
445  llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
446
447  for (RSContext::const_export_type_iterator
448          I = mContext->export_types_begin(),
449          E = mContext->export_types_end();
450       I != E;
451       I++) {
452    // First, dump type name list to export
453    const RSExportType *ET = I->getValue();
454
455    ExportTypeInfo.clear();
456    // Type name
457    ExportTypeInfo.push_back(
458        llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
459
460    if (ET->getClass() == RSExportType::ExportClassRecord) {
461      const RSExportRecordType *ERT =
462          static_cast<const RSExportRecordType*>(ET);
463
464      if (mExportTypeMetadata == NULL)
465        mExportTypeMetadata =
466            M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
467
468      mExportTypeMetadata->addOperand(
469          llvm::MDNode::get(mLLVMContext, ExportTypeInfo));
470
471      // Now, export struct field information to %[struct name]
472      std::string StructInfoMetadataName("%");
473      StructInfoMetadataName.append(ET->getName());
474      llvm::NamedMDNode *StructInfoMetadata =
475          M->getOrInsertNamedMetadata(StructInfoMetadataName);
476      llvm::SmallVector<llvm::Value*, 3> FieldInfo;
477
478      slangAssert(StructInfoMetadata->getNumOperands() == 0 &&
479                  "Metadata with same name was created before");
480      for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
481              FE = ERT->fields_end();
482           FI != FE;
483           FI++) {
484        const RSExportRecordType::Field *F = *FI;
485
486        // 1. field name
487        FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
488                                                F->getName().c_str()));
489
490        // 2. field type name
491        FieldInfo.push_back(
492            llvm::MDString::get(mLLVMContext,
493                                F->getType()->getName().c_str()));
494
495        StructInfoMetadata->addOperand(
496            llvm::MDNode::get(mLLVMContext, FieldInfo));
497        FieldInfo.clear();
498      }
499    }   // ET->getClass() == RSExportType::ExportClassRecord
500  }
501}
502
503void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
504  if (!mContext->processExport()) {
505    return;
506  }
507
508  if (mContext->hasExportVar())
509    dumpExportVarInfo(M);
510
511  if (mContext->hasExportFunc())
512    dumpExportFunctionInfo(M);
513
514  if (mContext->hasExportForEach())
515    dumpExportForEachInfo(M);
516
517  if (mContext->hasExportType())
518    dumpExportTypeInfo(M);
519}
520
521RSBackend::~RSBackend() {
522}
523
524}  // namespace slang
525