1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_metadata_spec.h"
18
19#include <cstdlib>
20#include <list>
21#include <map>
22#include <string>
23
24#include "llvm/ADT/SmallVector.h"
25#include "llvm/ADT/StringRef.h"
26
27#include "llvm/Metadata.h"
28#include "llvm/Module.h"
29
30#include "slang_assert.h"
31#include "slang_rs_type_spec.h"
32
33#define RS_METADATA_STRTAB_MN   "#rs_metadata_strtab"
34#define RS_TYPE_INFO_MN         "#rs_type_info"
35#define RS_EXPORT_VAR_MN        "#rs_export_var"
36#define RS_EXPORT_FUNC_MN       "#rs_export_func"
37#define RS_EXPORT_RECORD_TYPE_NAME_MN_PREFIX  "%"
38
39///////////////////////////////////////////////////////////////////////////////
40// Useful utility functions
41///////////////////////////////////////////////////////////////////////////////
42static bool EncodeInteger(llvm::LLVMContext &C,
43                          unsigned I,
44                          llvm::SmallVectorImpl<llvm::Value*> &Op) {
45  llvm::StringRef S(reinterpret_cast<const char*>(&I), sizeof(I));
46  llvm::MDString *MDS = llvm::MDString::get(C, S);
47
48  if (MDS == NULL)
49    return false;
50  Op.push_back(MDS);
51  return true;
52}
53
54///////////////////////////////////////////////////////////////////////////////
55// class RSMetadataEncoderInternal
56///////////////////////////////////////////////////////////////////////////////
57namespace {
58
59class RSMetadataEncoderInternal {
60 private:
61  llvm::Module *mModule;
62
63  typedef std::map</* key */unsigned, unsigned/* index */> TypesMapTy;
64  TypesMapTy mTypes;
65  std::list<unsigned> mEncodedRSTypeInfo;  // simply a sequece of integers
66  unsigned mCurTypeIndex;
67
68  // A special type for lookup created record type. It uses record name as key.
69  typedef std::map</* name */std::string, unsigned/* index */> RecordTypesMapTy;
70  RecordTypesMapTy mRecordTypes;
71
72  typedef std::map<std::string, unsigned/* index */> StringsMapTy;
73  StringsMapTy mStrings;
74  std::list<const char*> mEncodedStrings;
75  unsigned mCurStringIndex;
76
77  llvm::NamedMDNode *mVarInfoMetadata;
78  llvm::NamedMDNode *mFuncInfoMetadata;
79
80  // This function check the return value of function:
81  //   joinString, encodeTypeBase, encode*Type(), encodeRSType, encodeRSVar,
82  //   and encodeRSFunc. Return false if the value of Index indicates failure.
83  inline bool checkReturnIndex(unsigned *Index) {
84    if (*Index == 0)
85      return false;
86    else
87      (*Index)--;
88    return true;
89  }
90
91  unsigned joinString(const std::string &S);
92
93  unsigned encodeTypeBase(const struct RSTypeBase *Base);
94  unsigned encodeTypeBaseAsKey(const struct RSTypeBase *Base);
95#define ENUM_RS_DATA_TYPE_CLASS(x)  \
96  unsigned encode ## x ## Type(const union RSType *T);
97RS_DATA_TYPE_CLASS_ENUMS
98#undef ENUM_RS_DATA_TYPE_CLASS
99
100  unsigned encodeRSType(const union RSType *T);
101
102  int flushStringTable();
103  int flushTypeInfo();
104
105 public:
106  explicit RSMetadataEncoderInternal(llvm::Module *M);
107
108  int encodeRSVar(const RSVar *V);
109  int encodeRSFunc(const RSFunction *F);
110
111  int finalize();
112};
113
114}  // namespace
115
116RSMetadataEncoderInternal::RSMetadataEncoderInternal(llvm::Module *M)
117    : mModule(M),
118      mCurTypeIndex(0),
119      mCurStringIndex(0),
120      mVarInfoMetadata(NULL),
121      mFuncInfoMetadata(NULL) {
122  mTypes.clear();
123  mEncodedRSTypeInfo.clear();
124  mRecordTypes.clear();
125  mStrings.clear();
126
127  return;
128}
129
130// Return (StringIndex + 1) when successfully join the string and 0 if there's
131// any error.
132unsigned RSMetadataEncoderInternal::joinString(const std::string &S) {
133  StringsMapTy::const_iterator I = mStrings.find(S);
134
135  if (I != mStrings.end())
136    return (I->second + 1);
137
138  // Add S into mStrings
139  std::pair<StringsMapTy::iterator, bool> Res =
140      mStrings.insert(std::make_pair(S, mCurStringIndex));
141  // Insertion failed
142  if (!Res.second)
143    return 0;
144
145  // Add S into mEncodedStrings
146  mEncodedStrings.push_back(Res.first->first.c_str());
147  mCurStringIndex++;
148
149  // Return (StringIndex + 1)
150  return (Res.first->second + 1);
151}
152
153unsigned
154RSMetadataEncoderInternal::encodeTypeBase(const struct RSTypeBase *Base) {
155  mEncodedRSTypeInfo.push_back(Base->bits);
156  return ++mCurTypeIndex;
157}
158
159unsigned RSMetadataEncoderInternal::encodeTypeBaseAsKey(
160    const struct RSTypeBase *Base) {
161  TypesMapTy::const_iterator I = mTypes.find(Base->bits);
162  if (I != mTypes.end())
163    return (I->second + 1);
164
165  // Add Base into mTypes
166  std::pair<TypesMapTy::iterator, bool> Res =
167      mTypes.insert(std::make_pair(Base->bits, mCurTypeIndex));
168  // Insertion failed
169  if (!Res.second)
170    return 0;
171
172  // Push to mEncodedRSTypeInfo. This will also update mCurTypeIndex.
173  return encodeTypeBase(Base);
174}
175
176unsigned RSMetadataEncoderInternal::encodePrimitiveType(const union RSType *T) {
177  return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
178}
179
180unsigned RSMetadataEncoderInternal::encodePointerType(const union RSType *T) {
181  // Encode pointee type first
182  unsigned PointeeType = encodeRSType(RS_POINTER_TYPE_GET_POINTEE_TYPE(T));
183  if (!checkReturnIndex(&PointeeType))
184    return 0;
185
186  unsigned Res = encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
187  // Push PointeeType after the base type
188  mEncodedRSTypeInfo.push_back(PointeeType);
189  return Res;
190}
191
192unsigned RSMetadataEncoderInternal::encodeVectorType(const union RSType *T) {
193  return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
194}
195
196unsigned RSMetadataEncoderInternal::encodeMatrixType(const union RSType *T) {
197  return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
198}
199
200unsigned
201RSMetadataEncoderInternal::encodeConstantArrayType(const union RSType *T) {
202  // Encode element type
203  unsigned ElementType =
204      encodeRSType(RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(T));
205  if (!checkReturnIndex(&ElementType))
206    return 0;
207
208  unsigned Res = encodeTypeBase(RS_GET_TYPE_BASE(T));
209  // Push the ElementType after the type base
210  mEncodedRSTypeInfo.push_back(ElementType);
211  return Res;
212}
213
214unsigned RSMetadataEncoderInternal::encodeRecordType(const union RSType *T) {
215  // Construct record name
216  std::string RecordInfoMetadataName(RS_EXPORT_RECORD_TYPE_NAME_MN_PREFIX);
217  RecordInfoMetadataName.append(RS_RECORD_TYPE_GET_NAME(T));
218
219  // Try to find it in mRecordTypes
220  RecordTypesMapTy::const_iterator I =
221      mRecordTypes.find(RecordInfoMetadataName);
222
223  // This record type has been encoded before. Fast return its index here.
224  if (I != mRecordTypes.end())
225    return (I->second + 1);
226
227  // Encode this record type into mTypes. Encode record name string first.
228  unsigned RecordName = joinString(RecordInfoMetadataName);
229  if (!checkReturnIndex(&RecordName))
230    return 0;
231
232  unsigned Base = encodeTypeBase(RS_GET_TYPE_BASE(T));
233  if (!checkReturnIndex(&Base))
234    return 0;
235
236  // Push record name after encoding the type base
237  mEncodedRSTypeInfo.push_back(RecordName);
238
239  // Add this record type into the map
240  std::pair<StringsMapTy::iterator, bool> Res =
241      mRecordTypes.insert(std::make_pair(RecordInfoMetadataName, Base));
242  // Insertion failed
243  if (!Res.second)
244    return 0;
245
246  // Create a named MDNode for this record type. We cannot create this before
247  // encoding type base into Types and updating mRecordTypes. This is because
248  // we may have structure like:
249  //
250  //            struct foo {
251  //              ...
252  //              struct foo *bar;  // self type reference
253  //              ...
254  //            }
255  llvm::NamedMDNode *RecordInfoMetadata =
256      mModule->getOrInsertNamedMetadata(RecordInfoMetadataName);
257
258  slangAssert((RecordInfoMetadata->getNumOperands() == 0) &&
259              "Record created before!");
260
261  // Encode field info into this named MDNode
262  llvm::SmallVector<llvm::Value*, 3> FieldInfo;
263
264  for (unsigned i = 0; i < RS_RECORD_TYPE_GET_NUM_FIELDS(T); i++) {
265    // 1. field name
266    unsigned FieldName = joinString(RS_RECORD_TYPE_GET_FIELD_NAME(T, i));
267    if (!checkReturnIndex(&FieldName))
268      return 0;
269    if (!EncodeInteger(mModule->getContext(),
270                       FieldName,
271                       FieldInfo)) {
272      return 0;
273    }
274
275    // 2. field type
276    unsigned FieldType = encodeRSType(RS_RECORD_TYPE_GET_FIELD_TYPE(T, i));
277    if (!checkReturnIndex(&FieldType))
278      return 0;
279    if (!EncodeInteger(mModule->getContext(),
280                       FieldType,
281                       FieldInfo)) {
282      return 0;
283    }
284
285    RecordInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
286                                                     FieldInfo));
287    FieldInfo.clear();
288  }
289
290  return (Res.first->second + 1);
291}
292
293unsigned RSMetadataEncoderInternal::encodeRSType(const union RSType *T) {
294  switch (static_cast<enum RSTypeClass>(RS_TYPE_GET_CLASS(T))) {
295#define ENUM_RS_DATA_TYPE_CLASS(x)  \
296    case RS_TC_ ## x: return encode ## x ## Type(T);
297    RS_DATA_TYPE_CLASS_ENUMS
298#undef ENUM_RS_DATA_TYPE_CLASS
299    default: return 0;
300  }
301  return 0;
302}
303
304int RSMetadataEncoderInternal::encodeRSVar(const RSVar *V) {
305  // check parameter
306  if ((V == NULL) || (V->name == NULL) || (V->type == NULL))
307    return -1;
308
309  // 1. var name
310  unsigned VarName = joinString(V->name);
311  if (!checkReturnIndex(&VarName)) {
312    return -2;
313  }
314
315  // 2. type
316  unsigned Type = encodeRSType(V->type);
317
318  llvm::SmallVector<llvm::Value*, 1> VarInfo;
319
320  if (!EncodeInteger(mModule->getContext(), VarName, VarInfo)) {
321    return -3;
322  }
323  if (!EncodeInteger(mModule->getContext(), Type, VarInfo)) {
324    return -4;
325  }
326
327  if (mVarInfoMetadata == NULL)
328    mVarInfoMetadata = mModule->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
329
330  mVarInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
331                                                 VarInfo));
332
333  return 0;
334}
335
336int RSMetadataEncoderInternal::encodeRSFunc(const RSFunction *F) {
337  // check parameter
338  if ((F == NULL) || (F->name == NULL)) {
339    return -1;
340  }
341
342  // 1. var name
343  unsigned FuncName = joinString(F->name);
344  if (!checkReturnIndex(&FuncName)) {
345    return -2;
346  }
347
348  llvm::SmallVector<llvm::Value*, 1> FuncInfo;
349  if (!EncodeInteger(mModule->getContext(), FuncName, FuncInfo)) {
350    return -3;
351  }
352
353  if (mFuncInfoMetadata == NULL)
354    mFuncInfoMetadata = mModule->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
355
356  mFuncInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
357                                                  FuncInfo));
358
359  return 0;
360}
361
362// Write string table and string index table
363int RSMetadataEncoderInternal::flushStringTable() {
364  slangAssert((mCurStringIndex == mEncodedStrings.size()));
365  slangAssert((mCurStringIndex == mStrings.size()));
366
367  if (mCurStringIndex == 0)
368    return 0;
369
370  // Prepare named MDNode for string table and string index table.
371  llvm::NamedMDNode *RSMetadataStrTab =
372      mModule->getOrInsertNamedMetadata(RS_METADATA_STRTAB_MN);
373  RSMetadataStrTab->dropAllReferences();
374
375  unsigned StrTabSize = 0;
376  unsigned *StrIdx = reinterpret_cast<unsigned*>(
377                        ::malloc((mStrings.size() + 1) * sizeof(unsigned)));
378
379  if (StrIdx == NULL)
380    return -1;
381
382  unsigned StrIdxI = 0;  // iterator for array StrIdx
383
384  // count StrTabSize and fill StrIdx by the way
385  for (std::list<const char*>::const_iterator I = mEncodedStrings.begin(),
386          E = mEncodedStrings.end();
387       I != E;
388       I++) {
389    StrIdx[StrIdxI++] = StrTabSize;
390    StrTabSize += ::strlen(*I) + 1 /* for '\0' */;
391  }
392  StrIdx[StrIdxI] = StrTabSize;
393
394  // Allocate
395  char *StrTab = reinterpret_cast<char*>(::malloc(StrTabSize));
396  if (StrTab == NULL) {
397    free(StrIdx);
398    return -1;
399  }
400
401  llvm::StringRef StrTabData(StrTab, StrTabSize);
402  llvm::StringRef StrIdxData(reinterpret_cast<const char*>(StrIdx),
403                             mStrings.size() * sizeof(unsigned));
404
405  // Copy
406  StrIdxI = 1;
407  for (std::list<const char*>::const_iterator I = mEncodedStrings.begin(),
408          E = mEncodedStrings.end();
409       I != E;
410       I++) {
411    // Get string length from StrIdx (O(1)) instead of call strlen again (O(n)).
412    unsigned CurStrLength = StrIdx[StrIdxI] - StrIdx[StrIdxI - 1];
413    ::memcpy(StrTab, *I, CurStrLength);
414    // Move forward the pointer
415    StrTab += CurStrLength;
416    StrIdxI++;
417  }
418
419  // Flush to metadata
420  llvm::Value *StrTabMDS =
421      llvm::MDString::get(mModule->getContext(), StrTabData);
422  llvm::Value *StrIdxMDS =
423      llvm::MDString::get(mModule->getContext(), StrIdxData);
424
425  if ((StrTabMDS == NULL) || (StrIdxMDS == NULL)) {
426    free(StrIdx);
427    free(StrTab);
428    return -1;
429  }
430
431  llvm::SmallVector<llvm::Value*, 2> StrTabVal;
432  StrTabVal.push_back(StrTabMDS);
433  StrTabVal.push_back(StrIdxMDS);
434  RSMetadataStrTab->addOperand(llvm::MDNode::get(mModule->getContext(),
435                                                 StrTabVal));
436
437  return 0;
438}
439
440// Write RS type stream
441int RSMetadataEncoderInternal::flushTypeInfo() {
442  unsigned TypeInfoCount = mEncodedRSTypeInfo.size();
443  if (TypeInfoCount <= 0) {
444    return 0;
445  }
446
447  llvm::NamedMDNode *RSTypeInfo =
448      mModule->getOrInsertNamedMetadata(RS_TYPE_INFO_MN);
449  RSTypeInfo->dropAllReferences();
450
451  unsigned *TypeInfos =
452      reinterpret_cast<unsigned*>(::malloc(TypeInfoCount * sizeof(unsigned)));
453  unsigned TypeInfosIdx = 0;  // iterator for array TypeInfos
454
455  if (TypeInfos == NULL)
456    return -1;
457
458  for (std::list<unsigned>::const_iterator I = mEncodedRSTypeInfo.begin(),
459          E = mEncodedRSTypeInfo.end();
460       I != E;
461       I++)
462    TypeInfos[TypeInfosIdx++] = *I;
463
464  llvm::StringRef TypeInfoData(reinterpret_cast<const char*>(TypeInfos),
465                               TypeInfoCount * sizeof(unsigned));
466  llvm::Value *TypeInfoMDS =
467      llvm::MDString::get(mModule->getContext(), TypeInfoData);
468  if (TypeInfoMDS == NULL) {
469    free(TypeInfos);
470    return -1;
471  }
472
473  llvm::SmallVector<llvm::Value*, 1> TypeInfo;
474  TypeInfo.push_back(TypeInfoMDS);
475
476  RSTypeInfo->addOperand(llvm::MDNode::get(mModule->getContext(),
477                                           TypeInfo));
478  free(TypeInfos);
479
480  return 0;
481}
482
483int RSMetadataEncoderInternal::finalize() {
484  int Res = flushStringTable();
485  if (Res != 0)
486    return Res;
487
488  Res = flushTypeInfo();
489  if (Res != 0)
490    return Res;
491
492  return 0;
493}
494
495///////////////////////////////////////////////////////////////////////////////
496// APIs
497///////////////////////////////////////////////////////////////////////////////
498RSMetadataEncoder *CreateRSMetadataEncoder(llvm::Module *M) {
499  return reinterpret_cast<RSMetadataEncoder*>(new RSMetadataEncoderInternal(M));
500}
501
502int RSEncodeVarMetadata(RSMetadataEncoder *E, const RSVar *V) {
503  return reinterpret_cast<RSMetadataEncoderInternal*>(E)->encodeRSVar(V);
504}
505
506int RSEncodeFunctionMetadata(RSMetadataEncoder *E, const RSFunction *F) {
507  return reinterpret_cast<RSMetadataEncoderInternal*>(E)->encodeRSFunc(F);
508}
509
510void DestroyRSMetadataEncoder(RSMetadataEncoder *E) {
511  RSMetadataEncoderInternal *C =
512      reinterpret_cast<RSMetadataEncoderInternal*>(E);
513  delete C;
514  return;
515}
516
517int FinalizeRSMetadataEncoder(RSMetadataEncoder *E) {
518  RSMetadataEncoderInternal *C =
519      reinterpret_cast<RSMetadataEncoderInternal*>(E);
520  int Res = C->finalize();
521  DestroyRSMetadataEncoder(E);
522  return Res;
523}
524