slang_rs_backend.cpp revision a65ec168e41e3ee9c6e8ac04cde694bbbfc2590a
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_backend.h"
18
19#include <vector>
20#include <string>
21
22#include "llvm/Metadata.h"
23#include "llvm/Constant.h"
24#include "llvm/Constants.h"
25#include "llvm/Module.h"
26#include "llvm/Function.h"
27#include "llvm/DerivedTypes.h"
28
29#include "llvm/System/Path.h"
30
31#include "llvm/Support/IRBuilder.h"
32
33#include "llvm/ADT/Twine.h"
34#include "llvm/ADT/StringExtras.h"
35
36#include "clang/AST/DeclGroup.h"
37
38#include "slang_rs.h"
39#include "slang_rs_context.h"
40#include "slang_rs_metadata.h"
41#include "slang_rs_export_var.h"
42#include "slang_rs_export_func.h"
43#include "slang_rs_export_type.h"
44
45using namespace slang;
46
47RSBackend::RSBackend(RSContext *Context,
48                     clang::Diagnostic &Diags,
49                     const clang::CodeGenOptions &CodeGenOpts,
50                     const clang::TargetOptions &TargetOpts,
51                     const PragmaList &Pragmas,
52                     llvm::raw_ostream *OS,
53                     Slang::OutputType OT,
54                     clang::SourceManager &SourceMgr,
55                     bool AllowRSPrefix)
56    : Backend(Diags,
57              CodeGenOpts,
58              TargetOpts,
59              Pragmas,
60              OS,
61              OT),
62      mContext(Context),
63      mSourceMgr(SourceMgr),
64      mAllowRSPrefix(AllowRSPrefix),
65      mExportVarMetadata(NULL),
66      mExportFuncMetadata(NULL),
67      mExportTypeMetadata(NULL) {
68  return;
69}
70
71void RSBackend::HandleTopLevelDecl(clang::DeclGroupRef D) {
72  // Disallow user-defined functions with prefix "rs"
73  if (!mAllowRSPrefix) {
74    // Iterate all function declarations in the program.
75    for (clang::DeclGroupRef::iterator I = D.begin(), E = D.end();
76         I != E; I++) {
77      clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
78      if (FD == NULL)
79        continue;
80      if (FD->getName().startswith("rs")) {  // Check prefix
81        clang::FullSourceLoc FSL(FD->getLocStart(), mSourceMgr);
82        clang::PresumedLoc PLoc = mSourceMgr.getPresumedLoc(FSL);
83        llvm::sys::Path HeaderFilename(PLoc.getFilename());
84
85        // Skip if that function declared in the RS default header.
86        if (SlangRS::IsRSHeaderFile(HeaderFilename.getLast().data()))
87          continue;
88        mDiags.Report(FSL, mDiags.getCustomDiagID(clang::Diagnostic::Error,
89                      "invalid function name prefix, \"rs\" is reserved: '%0'"))
90            << FD->getName();
91      }
92    }
93  }
94
95  Backend::HandleTopLevelDecl(D);
96  return;
97}
98
99void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
100  mContext->processExport();
101
102  // Dump export variable info
103  if (mContext->hasExportVar()) {
104    if (mExportVarMetadata == NULL)
105      mExportVarMetadata = M->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
106
107    llvm::SmallVector<llvm::Value*, 2> ExportVarInfo;
108
109    for (RSContext::const_export_var_iterator I = mContext->export_vars_begin(),
110            E = mContext->export_vars_end();
111         I != E;
112         I++) {
113      const RSExportVar *EV = *I;
114      const RSExportType *ET = EV->getType();
115
116      // Variable name
117      ExportVarInfo.push_back(
118          llvm::MDString::get(mLLVMContext, EV->getName().c_str()));
119
120      // Type name
121      switch (ET->getClass()) {
122        case RSExportType::ExportClassPrimitive: {
123          ExportVarInfo.push_back(
124              llvm::MDString::get(
125                mLLVMContext, llvm::utostr_32(
126                  static_cast<const RSExportPrimitiveType*>(ET)->getType())));
127          break;
128        }
129        case RSExportType::ExportClassPointer: {
130          ExportVarInfo.push_back(
131              llvm::MDString::get(
132                mLLVMContext, ("*" + static_cast<const RSExportPointerType*>(ET)
133                  ->getPointeeType()->getName()).c_str()));
134          break;
135        }
136        case RSExportType::ExportClassMatrix: {
137          ExportVarInfo.push_back(
138              llvm::MDString::get(
139                mLLVMContext, llvm::utostr_32(
140                  RSExportPrimitiveType::DataTypeRSMatrix2x2 +
141                  static_cast<const RSExportMatrixType*>(ET)->getDim() - 2)));
142          break;
143        }
144        case RSExportType::ExportClassVector:
145        case RSExportType::ExportClassConstantArray:
146        case RSExportType::ExportClassRecord: {
147          ExportVarInfo.push_back(
148              llvm::MDString::get(mLLVMContext,
149                EV->getType()->getName().c_str()));
150          break;
151        }
152      }
153
154      mExportVarMetadata->addOperand(
155          llvm::MDNode::get(mLLVMContext,
156                            ExportVarInfo.data(),
157                            ExportVarInfo.size()) );
158
159      ExportVarInfo.clear();
160    }
161  }
162
163  // Dump export function info
164  if (mContext->hasExportFunc()) {
165    if (mExportFuncMetadata == NULL)
166      mExportFuncMetadata =
167          M->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
168
169    llvm::SmallVector<llvm::Value*, 1> ExportFuncInfo;
170
171    for (RSContext::const_export_func_iterator
172            I = mContext->export_funcs_begin(),
173            E = mContext->export_funcs_end();
174         I != E;
175         I++) {
176      const RSExportFunc *EF = *I;
177
178      // Function name
179      if (!EF->hasParam()) {
180        ExportFuncInfo.push_back(llvm::MDString::get(mLLVMContext,
181                                                     EF->getName().c_str()));
182      } else {
183        llvm::Function *F = M->getFunction(EF->getName());
184        llvm::Function *HelperFunction;
185        const std::string HelperFunctionName(".helper_" + EF->getName());
186
187        assert(F && "Function marked as exported disappeared in Bitcode");
188
189        // Create helper function
190        {
191          llvm::StructType *HelperFunctionParameterTy = NULL;
192
193          if (!F->getArgumentList().empty()) {
194            std::vector<const llvm::Type*> HelperFunctionParameterTys;
195            for (llvm::Function::arg_iterator AI = F->arg_begin(),
196                 AE = F->arg_end(); AI != AE; AI++)
197              HelperFunctionParameterTys.push_back(AI->getType());
198
199            HelperFunctionParameterTy =
200                llvm::StructType::get(mLLVMContext, HelperFunctionParameterTys);
201          }
202
203          if (!EF->checkParameterPacketType(HelperFunctionParameterTy)) {
204            fprintf(stderr, "Failed to export function %s: parameter type "
205                            "mismatch during creation of helper function.\n",
206                    EF->getName().c_str());
207
208            const RSExportRecordType *Expected = EF->getParamPacketType();
209            if (Expected) {
210              fprintf(stderr, "Expected:\n");
211              Expected->getLLVMType()->dump();
212            }
213            if (HelperFunctionParameterTy) {
214              fprintf(stderr, "Got:\n");
215              HelperFunctionParameterTy->dump();
216            }
217          }
218
219          std::vector<const llvm::Type*> Params;
220          if (HelperFunctionParameterTy) {
221            llvm::PointerType *HelperFunctionParameterTyP =
222                llvm::PointerType::getUnqual(HelperFunctionParameterTy);
223            Params.push_back(HelperFunctionParameterTyP);
224          }
225
226          llvm::FunctionType * HelperFunctionType =
227              llvm::FunctionType::get(F->getReturnType(),
228                                      Params,
229                                      /* IsVarArgs = */false);
230
231          HelperFunction =
232              llvm::Function::Create(HelperFunctionType,
233                                     llvm::GlobalValue::ExternalLinkage,
234                                     HelperFunctionName,
235                                     M);
236
237          HelperFunction->addFnAttr(llvm::Attribute::NoInline);
238          HelperFunction->setCallingConv(F->getCallingConv());
239
240          // Create helper function body
241          {
242            llvm::Argument *HelperFunctionParameter =
243                &(*HelperFunction->arg_begin());
244            llvm::BasicBlock *BB =
245                llvm::BasicBlock::Create(mLLVMContext, "entry", HelperFunction);
246            llvm::IRBuilder<> *IB = new llvm::IRBuilder<>(BB);
247            llvm::SmallVector<llvm::Value*, 6> Params;
248            llvm::Value *Idx[2];
249
250            Idx[0] =
251                llvm::ConstantInt::get(llvm::Type::getInt32Ty(mLLVMContext), 0);
252
253            // getelementptr and load instruction for all elements in
254            // parameter .p
255            for (size_t i = 0; i < EF->getNumParameters(); i++) {
256              // getelementptr
257              Idx[1] =
258                  llvm::ConstantInt::get(
259                      llvm::Type::getInt32Ty(mLLVMContext), i);
260              llvm::Value *Ptr = IB->CreateInBoundsGEP(HelperFunctionParameter,
261                                                       Idx,
262                                                       Idx + 2);
263
264              // load
265              llvm::Value *V = IB->CreateLoad(Ptr);
266              Params.push_back(V);
267            }
268
269            // Call and pass the all elements as paramter to F
270            llvm::CallInst *CI = IB->CreateCall(F,
271                                                Params.data(),
272                                                Params.data() + Params.size());
273
274            CI->setCallingConv(F->getCallingConv());
275
276            if (F->getReturnType() == llvm::Type::getVoidTy(mLLVMContext))
277              IB->CreateRetVoid();
278            else
279              IB->CreateRet(CI);
280
281            delete IB;
282          }
283        }
284
285        ExportFuncInfo.push_back(
286            llvm::MDString::get(mLLVMContext, HelperFunctionName.c_str()));
287      }
288
289      mExportFuncMetadata->addOperand(
290          llvm::MDNode::get(mLLVMContext,
291                            ExportFuncInfo.data(),
292                            ExportFuncInfo.size()));
293
294      ExportFuncInfo.clear();
295    }
296  }
297
298  // Dump export type info
299  if (mContext->hasExportType()) {
300    llvm::SmallVector<llvm::Value*, 1> ExportTypeInfo;
301
302    for (RSContext::const_export_type_iterator
303            I = mContext->export_types_begin(),
304            E = mContext->export_types_end();
305         I != E;
306         I++) {
307      // First, dump type name list to export
308      const RSExportType *ET = I->getValue();
309
310      ExportTypeInfo.clear();
311      // Type name
312      ExportTypeInfo.push_back(
313          llvm::MDString::get(mLLVMContext, ET->getName().c_str()));
314
315      if (ET->getClass() == RSExportType::ExportClassRecord) {
316        const RSExportRecordType *ERT =
317            static_cast<const RSExportRecordType*>(ET);
318
319        if (mExportTypeMetadata == NULL)
320          mExportTypeMetadata =
321              M->getOrInsertNamedMetadata(RS_EXPORT_TYPE_MN);
322
323        mExportTypeMetadata->addOperand(
324            llvm::MDNode::get(mLLVMContext,
325                              ExportTypeInfo.data(),
326                              ExportTypeInfo.size()));
327
328        // Now, export struct field information to %[struct name]
329        std::string StructInfoMetadataName("%");
330        StructInfoMetadataName.append(ET->getName());
331        llvm::NamedMDNode *StructInfoMetadata =
332            M->getOrInsertNamedMetadata(StructInfoMetadataName);
333        llvm::SmallVector<llvm::Value*, 3> FieldInfo;
334
335        assert(StructInfoMetadata->getNumOperands() == 0 &&
336               "Metadata with same name was created before");
337        for (RSExportRecordType::const_field_iterator FI = ERT->fields_begin(),
338                FE = ERT->fields_end();
339             FI != FE;
340             FI++) {
341          const RSExportRecordType::Field *F = *FI;
342
343          // 1. field name
344          FieldInfo.push_back(llvm::MDString::get(mLLVMContext,
345                                                  F->getName().c_str()));
346
347          // 2. field type name
348          FieldInfo.push_back(
349              llvm::MDString::get(mLLVMContext,
350                                  F->getType()->getName().c_str()));
351
352          // 3. field kind
353          switch (F->getType()->getClass()) {
354            case RSExportType::ExportClassPrimitive:
355            case RSExportType::ExportClassVector: {
356              const RSExportPrimitiveType *EPT =
357                  static_cast<const RSExportPrimitiveType*>(F->getType());
358              FieldInfo.push_back(
359                  llvm::MDString::get(mLLVMContext,
360                                      llvm::itostr(EPT->getKind())));
361              break;
362            }
363
364            default: {
365              FieldInfo.push_back(
366                  llvm::MDString::get(mLLVMContext,
367                                      llvm::itostr(
368                                        RSExportPrimitiveType::DataKindUser)));
369              break;
370            }
371          }
372
373          StructInfoMetadata->addOperand(llvm::MDNode::get(mLLVMContext,
374                                                           FieldInfo.data(),
375                                                           FieldInfo.size()));
376
377          FieldInfo.clear();
378        }
379      }   // ET->getClass() == RSExportType::ExportClassRecord
380    }
381  }
382
383  return;
384}
385
386RSBackend::~RSBackend() {
387  return;
388}
389