slang_rs_metadata_spec_encoder.cpp revision 18c8829f2bd3cbe0d02471588c6643c0a8c6ca3c
1/*
2 * Copyright 2010, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "slang_rs_metadata_spec.h"
18
19#include <cstdlib>
20#include <list>
21#include <map>
22#include <string>
23
24#include "llvm/ADT/SmallVector.h"
25#include "llvm/ADT/StringRef.h"
26
27#include "llvm/Metadata.h"
28#include "llvm/Module.h"
29
30#include "slang_assert.h"
31#include "slang_rs_type_spec.h"
32
33#define RS_METADATA_STRTAB_MN   "#rs_metadata_strtab"
34#define RS_TYPE_INFO_MN         "#rs_type_info"
35#define RS_EXPORT_VAR_MN        "#rs_export_var"
36#define RS_EXPORT_FUNC_MN       "#rs_export_func"
37#define RS_EXPORT_RECORD_TYPE_NAME_MN_PREFIX  "%"
38
39///////////////////////////////////////////////////////////////////////////////
40// Useful utility functions
41///////////////////////////////////////////////////////////////////////////////
42static bool EncodeInteger(llvm::LLVMContext &C,
43                          unsigned I,
44                          llvm::SmallVectorImpl<llvm::Value*> &Op) {
45  llvm::StringRef S(reinterpret_cast<const char*>(&I), sizeof(I));
46  llvm::MDString *MDS = llvm::MDString::get(C, S);
47
48  if (MDS == NULL)
49    return false;
50  Op.push_back(MDS);
51  return true;
52}
53
54///////////////////////////////////////////////////////////////////////////////
55// class RSMetadataEncoderInternal
56///////////////////////////////////////////////////////////////////////////////
57namespace {
58
59class RSMetadataEncoderInternal {
60 private:
61  llvm::Module *mModule;
62
63  typedef std::map</* key */unsigned, unsigned/* index */> TypesMapTy;
64  TypesMapTy mTypes;
65  std::list<unsigned> mEncodedRSTypeInfo;  // simply a sequece of integers
66  unsigned mCurTypeIndex;
67
68  // A special type for lookup created record type. It uses record name as key.
69  typedef std::map</* name */std::string, unsigned/* index */> RecordTypesMapTy;
70  RecordTypesMapTy mRecordTypes;
71
72  typedef std::map<std::string, unsigned/* index */> StringsMapTy;
73  StringsMapTy mStrings;
74  std::list<const char*> mEncodedStrings;
75  unsigned mCurStringIndex;
76
77  llvm::NamedMDNode *mVarInfoMetadata;
78  llvm::NamedMDNode *mFuncInfoMetadata;
79
80  // This function check the return value of function:
81  //   joinString, encodeTypeBase, encode*Type(), encodeRSType, encodeRSVar,
82  //   and encodeRSFunc. Return false if the value of Index indicates failure.
83  inline bool checkReturnIndex(unsigned *Index) {
84    if (*Index == 0)
85      return false;
86    else
87      (*Index)--;
88    return true;
89  }
90
91  unsigned joinString(const std::string &S);
92
93  unsigned encodeTypeBase(const struct RSTypeBase *Base);
94  unsigned encodeTypeBaseAsKey(const struct RSTypeBase *Base);
95#define ENUM_RS_DATA_TYPE_CLASS(x)  \
96  unsigned encode ## x ## Type(const union RSType *T);
97RS_DATA_TYPE_CLASS_ENUMS
98#undef ENUM_RS_DATA_TYPE_CLASS
99
100  unsigned encodeRSType(const union RSType *T);
101
102  int flushStringTable();
103  int flushTypeInfo();
104
105 public:
106  explicit RSMetadataEncoderInternal(llvm::Module *M);
107
108  int encodeRSVar(const RSVar *V);
109  int encodeRSFunc(const RSFunction *F);
110
111  int finalize();
112};
113}
114
115RSMetadataEncoderInternal::RSMetadataEncoderInternal(llvm::Module *M)
116    : mModule(M),
117      mCurTypeIndex(0),
118      mCurStringIndex(0),
119      mVarInfoMetadata(NULL),
120      mFuncInfoMetadata(NULL) {
121  mTypes.clear();
122  mEncodedRSTypeInfo.clear();
123  mRecordTypes.clear();
124  mStrings.clear();
125
126  return;
127}
128
129// Return (StringIndex + 1) when successfully join the string and 0 if there's
130// any error.
131unsigned RSMetadataEncoderInternal::joinString(const std::string &S) {
132  StringsMapTy::const_iterator I = mStrings.find(S);
133
134  if (I != mStrings.end())
135    return (I->second + 1);
136
137  // Add S into mStrings
138  std::pair<StringsMapTy::iterator, bool> Res =
139      mStrings.insert(std::make_pair(S, mCurStringIndex));
140  // Insertion failed
141  if (!Res.second)
142    return 0;
143
144  // Add S into mEncodedStrings
145  mEncodedStrings.push_back(Res.first->first.c_str());
146  mCurStringIndex++;
147
148  // Return (StringIndex + 1)
149  return (Res.first->second + 1);
150}
151
152unsigned
153RSMetadataEncoderInternal::encodeTypeBase(const struct RSTypeBase *Base) {
154  mEncodedRSTypeInfo.push_back(Base->bits);
155  return ++mCurTypeIndex;
156}
157
158unsigned RSMetadataEncoderInternal::encodeTypeBaseAsKey(
159    const struct RSTypeBase *Base) {
160  TypesMapTy::const_iterator I = mTypes.find(Base->bits);
161  if (I != mTypes.end())
162    return (I->second + 1);
163
164  // Add Base into mTypes
165  std::pair<TypesMapTy::iterator, bool> Res =
166      mTypes.insert(std::make_pair(Base->bits, mCurTypeIndex));
167  // Insertion failed
168  if (!Res.second)
169    return 0;
170
171  // Push to mEncodedRSTypeInfo. This will also update mCurTypeIndex.
172  return encodeTypeBase(Base);
173}
174
175unsigned RSMetadataEncoderInternal::encodePrimitiveType(const union RSType *T) {
176  return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
177}
178
179unsigned RSMetadataEncoderInternal::encodePointerType(const union RSType *T) {
180  // Encode pointee type first
181  unsigned PointeeType = encodeRSType(RS_POINTER_TYPE_GET_POINTEE_TYPE(T));
182  if (!checkReturnIndex(&PointeeType))
183    return 0;
184
185  unsigned Res = encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
186  // Push PointeeType after the base type
187  mEncodedRSTypeInfo.push_back(PointeeType);
188  return Res;
189}
190
191unsigned RSMetadataEncoderInternal::encodeVectorType(const union RSType *T) {
192  return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
193}
194
195unsigned RSMetadataEncoderInternal::encodeMatrixType(const union RSType *T) {
196  return encodeTypeBaseAsKey(RS_GET_TYPE_BASE(T));
197}
198
199unsigned
200RSMetadataEncoderInternal::encodeConstantArrayType(const union RSType *T) {
201  // Encode element type
202  unsigned ElementType =
203      encodeRSType(RS_CONSTANT_ARRAY_TYPE_GET_ELEMENT_TYPE(T));
204  if (!checkReturnIndex(&ElementType))
205    return 0;
206
207  unsigned Res = encodeTypeBase(RS_GET_TYPE_BASE(T));
208  // Push the ElementType after the type base
209  mEncodedRSTypeInfo.push_back(ElementType);
210  return Res;
211}
212
213unsigned RSMetadataEncoderInternal::encodeRecordType(const union RSType *T) {
214  // Construct record name
215  std::string RecordInfoMetadataName(RS_EXPORT_RECORD_TYPE_NAME_MN_PREFIX);
216  RecordInfoMetadataName.append(RS_RECORD_TYPE_GET_NAME(T));
217
218  // Try to find it in mRecordTypes
219  RecordTypesMapTy::const_iterator I =
220      mRecordTypes.find(RecordInfoMetadataName);
221
222  // This record type has been encoded before. Fast return its index here.
223  if (I != mRecordTypes.end())
224    return (I->second + 1);
225
226  // Encode this record type into mTypes. Encode record name string first.
227  unsigned RecordName = joinString(RecordInfoMetadataName);
228  if (!checkReturnIndex(&RecordName))
229    return 0;
230
231  unsigned Base = encodeTypeBase(RS_GET_TYPE_BASE(T));
232  if (!checkReturnIndex(&Base))
233    return 0;
234
235  // Push record name after encoding the type base
236  mEncodedRSTypeInfo.push_back(RecordName);
237
238  // Add this record type into the map
239  std::pair<StringsMapTy::iterator, bool> Res =
240      mRecordTypes.insert(std::make_pair(RecordInfoMetadataName, Base));
241  // Insertion failed
242  if (!Res.second)
243    return 0;
244
245  // Create a named MDNode for this record type. We cannot create this before
246  // encoding type base into Types and updating mRecordTypes. This is because
247  // we may have structure like:
248  //
249  //            struct foo {
250  //              ...
251  //              struct foo *bar;  // self type reference
252  //              ...
253  //            }
254  llvm::NamedMDNode *RecordInfoMetadata =
255      mModule->getOrInsertNamedMetadata(RecordInfoMetadataName);
256
257  slangAssert((RecordInfoMetadata->getNumOperands() == 0) &&
258              "Record created before!");
259
260  // Encode field info into this named MDNode
261  llvm::SmallVector<llvm::Value*, 3> FieldInfo;
262
263  for (unsigned i = 0; i < RS_RECORD_TYPE_GET_NUM_FIELDS(T); i++) {
264    // 1. field name
265    unsigned FieldName = joinString(RS_RECORD_TYPE_GET_FIELD_NAME(T, i));
266    if (!checkReturnIndex(&FieldName))
267      return 0;
268    if (!EncodeInteger(mModule->getContext(),
269                       FieldName,
270                       FieldInfo)) {
271      return 0;
272    }
273
274    // 2. field type
275    unsigned FieldType = encodeRSType(RS_RECORD_TYPE_GET_FIELD_TYPE(T, i));
276    if (!checkReturnIndex(&FieldType))
277      return 0;
278    if (!EncodeInteger(mModule->getContext(),
279                       FieldType,
280                       FieldInfo)) {
281      return 0;
282    }
283
284    // 3. field data kind
285    if (!EncodeInteger(mModule->getContext(),
286                       RS_RECORD_TYPE_GET_FIELD_DATA_KIND(T, i),
287                       FieldInfo)) {
288      return 0;
289    }
290
291    RecordInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
292                                                     FieldInfo));
293    FieldInfo.clear();
294  }
295
296  return (Res.first->second + 1);
297}
298
299unsigned RSMetadataEncoderInternal::encodeRSType(const union RSType *T) {
300  switch (static_cast<enum RSTypeClass>(RS_TYPE_GET_CLASS(T))) {
301#define ENUM_RS_DATA_TYPE_CLASS(x)  \
302    case RS_TC_ ## x: return encode ## x ## Type(T);
303    RS_DATA_TYPE_CLASS_ENUMS
304#undef ENUM_RS_DATA_TYPE_CLASS
305    default: return 0;
306  }
307  return 0;
308}
309
310int RSMetadataEncoderInternal::encodeRSVar(const RSVar *V) {
311  // check parameter
312  if ((V == NULL) || (V->name == NULL) || (V->type == NULL))
313    return -1;
314
315  // 1. var name
316  unsigned VarName = joinString(V->name);
317  if (!checkReturnIndex(&VarName)) {
318    return -2;
319  }
320
321  // 2. type
322  unsigned Type = encodeRSType(V->type);
323
324  llvm::SmallVector<llvm::Value*, 1> VarInfo;
325
326  if (!EncodeInteger(mModule->getContext(), VarName, VarInfo)) {
327    return -3;
328  }
329  if (!EncodeInteger(mModule->getContext(), Type, VarInfo)) {
330    return -4;
331  }
332
333  if (mVarInfoMetadata == NULL)
334    mVarInfoMetadata = mModule->getOrInsertNamedMetadata(RS_EXPORT_VAR_MN);
335
336  mVarInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
337                                                 VarInfo));
338
339  return 0;
340}
341
342int RSMetadataEncoderInternal::encodeRSFunc(const RSFunction *F) {
343  // check parameter
344  if ((F == NULL) || (F->name == NULL)) {
345    return -1;
346  }
347
348  // 1. var name
349  unsigned FuncName = joinString(F->name);
350  if (!checkReturnIndex(&FuncName)) {
351    return -2;
352  }
353
354  llvm::SmallVector<llvm::Value*, 1> FuncInfo;
355  if (!EncodeInteger(mModule->getContext(), FuncName, FuncInfo)) {
356    return -3;
357  }
358
359  if (mFuncInfoMetadata == NULL)
360    mFuncInfoMetadata = mModule->getOrInsertNamedMetadata(RS_EXPORT_FUNC_MN);
361
362  mFuncInfoMetadata->addOperand(llvm::MDNode::get(mModule->getContext(),
363                                                  FuncInfo));
364
365  return 0;
366}
367
368// Write string table and string index table
369int RSMetadataEncoderInternal::flushStringTable() {
370  slangAssert((mCurStringIndex == mEncodedStrings.size()));
371  slangAssert((mCurStringIndex == mStrings.size()));
372
373  if (mCurStringIndex == 0)
374    return 0;
375
376  // Prepare named MDNode for string table and string index table.
377  llvm::NamedMDNode *RSMetadataStrTab =
378      mModule->getOrInsertNamedMetadata(RS_METADATA_STRTAB_MN);
379  RSMetadataStrTab->dropAllReferences();
380
381  unsigned StrTabSize = 0;
382  unsigned *StrIdx = reinterpret_cast<unsigned*>(
383                        ::malloc((mStrings.size() + 1) * sizeof(unsigned)));
384
385  if (StrIdx == NULL)
386    return -1;
387
388  unsigned StrIdxI = 0;  // iterator for array StrIdx
389
390  // count StrTabSize and fill StrIdx by the way
391  for (std::list<const char*>::const_iterator I = mEncodedStrings.begin(),
392          E = mEncodedStrings.end();
393       I != E;
394       I++) {
395    StrIdx[StrIdxI++] = StrTabSize;
396    StrTabSize += ::strlen(*I) + 1 /* for '\0' */;
397  }
398  StrIdx[StrIdxI] = StrTabSize;
399
400  // Allocate
401  char *StrTab = reinterpret_cast<char*>(::malloc(StrTabSize));
402  if (StrTab == NULL) {
403    free(StrIdx);
404    return -1;
405  }
406
407  llvm::StringRef StrTabData(StrTab, StrTabSize);
408  llvm::StringRef StrIdxData(reinterpret_cast<const char*>(StrIdx),
409                             mStrings.size() * sizeof(unsigned));
410
411  // Copy
412  StrIdxI = 1;
413  for (std::list<const char*>::const_iterator I = mEncodedStrings.begin(),
414          E = mEncodedStrings.end();
415       I != E;
416       I++) {
417    // Get string length from StrIdx (O(1)) instead of call strlen again (O(n)).
418    unsigned CurStrLength = StrIdx[StrIdxI] - StrIdx[StrIdxI - 1];
419    ::memcpy(StrTab, *I, CurStrLength);
420    // Move forward the pointer
421    StrTab += CurStrLength;
422    StrIdxI++;
423  }
424
425  // Flush to metadata
426  llvm::Value *StrTabMDS =
427      llvm::MDString::get(mModule->getContext(), StrTabData);
428  llvm::Value *StrIdxMDS =
429      llvm::MDString::get(mModule->getContext(), StrIdxData);
430
431  if ((StrTabMDS == NULL) || (StrIdxMDS == NULL)) {
432    free(StrIdx);
433    free(StrTab);
434    return -1;
435  }
436
437  llvm::SmallVector<llvm::Value*, 2> StrTabVal;
438  StrTabVal.push_back(StrTabMDS);
439  StrTabVal.push_back(StrIdxMDS);
440  RSMetadataStrTab->addOperand(llvm::MDNode::get(mModule->getContext(),
441                                                 StrTabVal));
442
443  return 0;
444}
445
446// Write RS type stream
447int RSMetadataEncoderInternal::flushTypeInfo() {
448  unsigned TypeInfoCount = mEncodedRSTypeInfo.size();
449  if (TypeInfoCount <= 0) {
450    return 0;
451  }
452
453  llvm::NamedMDNode *RSTypeInfo =
454      mModule->getOrInsertNamedMetadata(RS_TYPE_INFO_MN);
455  RSTypeInfo->dropAllReferences();
456
457  unsigned *TypeInfos =
458      reinterpret_cast<unsigned*>(::malloc(TypeInfoCount * sizeof(unsigned)));
459  unsigned TypeInfosIdx = 0;  // iterator for array TypeInfos
460
461  if (TypeInfos == NULL)
462    return -1;
463
464  for (std::list<unsigned>::const_iterator I = mEncodedRSTypeInfo.begin(),
465          E = mEncodedRSTypeInfo.end();
466       I != E;
467       I++)
468    TypeInfos[TypeInfosIdx++] = *I;
469
470  llvm::StringRef TypeInfoData(reinterpret_cast<const char*>(TypeInfos),
471                               TypeInfoCount * sizeof(unsigned));
472  llvm::Value *TypeInfoMDS =
473      llvm::MDString::get(mModule->getContext(), TypeInfoData);
474  if (TypeInfoMDS == NULL) {
475    free(TypeInfos);
476    return -1;
477  }
478
479  llvm::SmallVector<llvm::Value*, 1> TypeInfo;
480  TypeInfo.push_back(TypeInfoMDS);
481
482  RSTypeInfo->addOperand(llvm::MDNode::get(mModule->getContext(),
483                                           TypeInfo));
484  free(TypeInfos);
485
486  return 0;
487}
488
489int RSMetadataEncoderInternal::finalize() {
490  int Res = flushStringTable();
491  if (Res != 0)
492    return Res;
493
494  Res = flushTypeInfo();
495  if (Res != 0)
496    return Res;
497
498  return 0;
499}
500
501///////////////////////////////////////////////////////////////////////////////
502// APIs
503///////////////////////////////////////////////////////////////////////////////
504RSMetadataEncoder *CreateRSMetadataEncoder(llvm::Module *M) {
505  return reinterpret_cast<RSMetadataEncoder*>(new RSMetadataEncoderInternal(M));
506}
507
508int RSEncodeVarMetadata(RSMetadataEncoder *E, const RSVar *V) {
509  return reinterpret_cast<RSMetadataEncoderInternal*>(E)->encodeRSVar(V);
510}
511
512int RSEncodeFunctionMetadata(RSMetadataEncoder *E, const RSFunction *F) {
513  return reinterpret_cast<RSMetadataEncoderInternal*>(E)->encodeRSFunc(F);
514}
515
516void DestroyRSMetadataEncoder(RSMetadataEncoder *E) {
517  RSMetadataEncoderInternal *C =
518      reinterpret_cast<RSMetadataEncoderInternal*>(E);
519  delete C;
520  return;
521}
522
523int FinalizeRSMetadataEncoder(RSMetadataEncoder *E) {
524  RSMetadataEncoderInternal *C =
525      reinterpret_cast<RSMetadataEncoderInternal*>(E);
526  int Res = C->finalize();
527  DestroyRSMetadataEncoder(E);
528  return Res;
529}
530