1//===- SymbolRewriter.cpp - Symbol Rewriter ---------------------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// SymbolRewriter is a LLVM pass which can rewrite symbols transparently within
11// existing code.  It is implemented as a compiler pass and is configured via a
12// YAML configuration file.
13//
14// The YAML configuration file format is as follows:
15//
16// RewriteMapFile := RewriteDescriptors
17// RewriteDescriptors := RewriteDescriptor | RewriteDescriptors
18// RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}'
19// RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields
20// RewriteDescriptorField := FieldIdentifier ':' FieldValue ','
21// RewriteDescriptorType := Identifier
22// FieldIdentifier := Identifier
23// FieldValue := Identifier
24// Identifier := [0-9a-zA-Z]+
25//
26// Currently, the following descriptor types are supported:
27//
28// - function:          (function rewriting)
29//      + Source        (original name of the function)
30//      + Target        (explicit transformation)
31//      + Transform     (pattern transformation)
32//      + Naked         (boolean, whether the function is undecorated)
33// - global variable:   (external linkage global variable rewriting)
34//      + Source        (original name of externally visible variable)
35//      + Target        (explicit transformation)
36//      + Transform     (pattern transformation)
37// - global alias:      (global alias rewriting)
38//      + Source        (original name of the aliased name)
39//      + Target        (explicit transformation)
40//      + Transform     (pattern transformation)
41//
42// Note that source and exactly one of [Target, Transform] must be provided
43//
44// New rewrite descriptors can be created.  Addding a new rewrite descriptor
45// involves:
46//
47//  a) extended the rewrite descriptor kind enumeration
48//     (<anonymous>::RewriteDescriptor::RewriteDescriptorType)
49//  b) implementing the new descriptor
50//     (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor)
51//  c) extending the rewrite map parser
52//     (<anonymous>::RewriteMapParser::parseEntry)
53//
54//  Specify to rewrite the symbols using the `-rewrite-symbols` option, and
55//  specify the map file to use for the rewriting via the `-rewrite-map-file`
56//  option.
57//
58//===----------------------------------------------------------------------===//
59
60#define DEBUG_TYPE "symbol-rewriter"
61#include "llvm/Pass.h"
62#include "llvm/ADT/SmallString.h"
63#include "llvm/IR/LegacyPassManager.h"
64#include "llvm/Support/CommandLine.h"
65#include "llvm/Support/Debug.h"
66#include "llvm/Support/MemoryBuffer.h"
67#include "llvm/Support/Regex.h"
68#include "llvm/Support/SourceMgr.h"
69#include "llvm/Support/YAMLParser.h"
70#include "llvm/Support/raw_ostream.h"
71#include "llvm/Transforms/Utils/SymbolRewriter.h"
72
73using namespace llvm;
74using namespace SymbolRewriter;
75
76static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
77                                             cl::desc("Symbol Rewrite Map"),
78                                             cl::value_desc("filename"));
79
80static void rewriteComdat(Module &M, GlobalObject *GO,
81                          const std::string &Source,
82                          const std::string &Target) {
83  if (Comdat *CD = GO->getComdat()) {
84    auto &Comdats = M.getComdatSymbolTable();
85
86    Comdat *C = M.getOrInsertComdat(Target);
87    C->setSelectionKind(CD->getSelectionKind());
88    GO->setComdat(C);
89
90    Comdats.erase(Comdats.find(Source));
91  }
92}
93
94namespace {
95template <RewriteDescriptor::Type DT, typename ValueType,
96          ValueType *(llvm::Module::*Get)(StringRef) const>
97class ExplicitRewriteDescriptor : public RewriteDescriptor {
98public:
99  const std::string Source;
100  const std::string Target;
101
102  ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked)
103      : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S),
104        Target(T) {}
105
106  bool performOnModule(Module &M) override;
107
108  static bool classof(const RewriteDescriptor *RD) {
109    return RD->getType() == DT;
110  }
111};
112
113template <RewriteDescriptor::Type DT, typename ValueType,
114          ValueType *(llvm::Module::*Get)(StringRef) const>
115bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
116  bool Changed = false;
117  if (ValueType *S = (M.*Get)(Source)) {
118    if (GlobalObject *GO = dyn_cast<GlobalObject>(S))
119      rewriteComdat(M, GO, Source, Target);
120
121    if (Value *T = (M.*Get)(Target))
122      S->setValueName(T->getValueName());
123    else
124      S->setName(Target);
125
126    Changed = true;
127  }
128  return Changed;
129}
130
131template <RewriteDescriptor::Type DT, typename ValueType,
132          ValueType *(llvm::Module::*Get)(StringRef) const,
133          iterator_range<typename iplist<ValueType>::iterator>
134          (llvm::Module::*Iterator)()>
135class PatternRewriteDescriptor : public RewriteDescriptor {
136public:
137  const std::string Pattern;
138  const std::string Transform;
139
140  PatternRewriteDescriptor(StringRef P, StringRef T)
141    : RewriteDescriptor(DT), Pattern(P), Transform(T) { }
142
143  bool performOnModule(Module &M) override;
144
145  static bool classof(const RewriteDescriptor *RD) {
146    return RD->getType() == DT;
147  }
148};
149
150template <RewriteDescriptor::Type DT, typename ValueType,
151          ValueType *(llvm::Module::*Get)(StringRef) const,
152          iterator_range<typename iplist<ValueType>::iterator>
153          (llvm::Module::*Iterator)()>
154bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>::
155performOnModule(Module &M) {
156  bool Changed = false;
157  for (auto &C : (M.*Iterator)()) {
158    std::string Error;
159
160    std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error);
161    if (!Error.empty())
162      report_fatal_error("unable to transforn " + C.getName() + " in " +
163                         M.getModuleIdentifier() + ": " + Error);
164
165    if (C.getName() == Name)
166      continue;
167
168    if (GlobalObject *GO = dyn_cast<GlobalObject>(&C))
169      rewriteComdat(M, GO, C.getName(), Name);
170
171    if (Value *V = (M.*Get)(Name))
172      C.setValueName(V->getValueName());
173    else
174      C.setName(Name);
175
176    Changed = true;
177  }
178  return Changed;
179}
180
181/// Represents a rewrite for an explicitly named (function) symbol.  Both the
182/// source function name and target function name of the transformation are
183/// explicitly spelt out.
184typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function,
185                                  llvm::Function, &llvm::Module::getFunction>
186    ExplicitRewriteFunctionDescriptor;
187
188/// Represents a rewrite for an explicitly named (global variable) symbol.  Both
189/// the source variable name and target variable name are spelt out.  This
190/// applies only to module level variables.
191typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
192                                  llvm::GlobalVariable,
193                                  &llvm::Module::getGlobalVariable>
194    ExplicitRewriteGlobalVariableDescriptor;
195
196/// Represents a rewrite for an explicitly named global alias.  Both the source
197/// and target name are explicitly spelt out.
198typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
199                                  llvm::GlobalAlias,
200                                  &llvm::Module::getNamedAlias>
201    ExplicitRewriteNamedAliasDescriptor;
202
203/// Represents a rewrite for a regular expression based pattern for functions.
204/// A pattern for the function name is provided and a transformation for that
205/// pattern to determine the target function name create the rewrite rule.
206typedef PatternRewriteDescriptor<RewriteDescriptor::Type::Function,
207                                 llvm::Function, &llvm::Module::getFunction,
208                                 &llvm::Module::functions>
209    PatternRewriteFunctionDescriptor;
210
211/// Represents a rewrite for a global variable based upon a matching pattern.
212/// Each global variable matching the provided pattern will be transformed as
213/// described in the transformation pattern for the target.  Applies only to
214/// module level variables.
215typedef PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
216                                 llvm::GlobalVariable,
217                                 &llvm::Module::getGlobalVariable,
218                                 &llvm::Module::globals>
219    PatternRewriteGlobalVariableDescriptor;
220
221/// PatternRewriteNamedAliasDescriptor - represents a rewrite for global
222/// aliases which match a given pattern.  The provided transformation will be
223/// applied to each of the matching names.
224typedef PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
225                                 llvm::GlobalAlias,
226                                 &llvm::Module::getNamedAlias,
227                                 &llvm::Module::aliases>
228    PatternRewriteNamedAliasDescriptor;
229} // namespace
230
231bool RewriteMapParser::parse(const std::string &MapFile,
232                             RewriteDescriptorList *DL) {
233  ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping =
234      MemoryBuffer::getFile(MapFile);
235
236  if (!Mapping)
237    report_fatal_error("unable to read rewrite map '" + MapFile + "': " +
238                       Mapping.getError().message());
239
240  if (!parse(*Mapping, DL))
241    report_fatal_error("unable to parse rewrite map '" + MapFile + "'");
242
243  return true;
244}
245
246bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile,
247                             RewriteDescriptorList *DL) {
248  SourceMgr SM;
249  yaml::Stream YS(MapFile->getBuffer(), SM);
250
251  for (auto &Document : YS) {
252    yaml::MappingNode *DescriptorList;
253
254    // ignore empty documents
255    if (isa<yaml::NullNode>(Document.getRoot()))
256      continue;
257
258    DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot());
259    if (!DescriptorList) {
260      YS.printError(Document.getRoot(), "DescriptorList node must be a map");
261      return false;
262    }
263
264    for (auto &Descriptor : *DescriptorList)
265      if (!parseEntry(YS, Descriptor, DL))
266        return false;
267  }
268
269  return true;
270}
271
272bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry,
273                                  RewriteDescriptorList *DL) {
274  yaml::ScalarNode *Key;
275  yaml::MappingNode *Value;
276  SmallString<32> KeyStorage;
277  StringRef RewriteType;
278
279  Key = dyn_cast<yaml::ScalarNode>(Entry.getKey());
280  if (!Key) {
281    YS.printError(Entry.getKey(), "rewrite type must be a scalar");
282    return false;
283  }
284
285  Value = dyn_cast<yaml::MappingNode>(Entry.getValue());
286  if (!Value) {
287    YS.printError(Entry.getValue(), "rewrite descriptor must be a map");
288    return false;
289  }
290
291  RewriteType = Key->getValue(KeyStorage);
292  if (RewriteType.equals("function"))
293    return parseRewriteFunctionDescriptor(YS, Key, Value, DL);
294  else if (RewriteType.equals("global variable"))
295    return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL);
296  else if (RewriteType.equals("global alias"))
297    return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL);
298
299  YS.printError(Entry.getKey(), "unknown rewrite type");
300  return false;
301}
302
303bool RewriteMapParser::
304parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
305                               yaml::MappingNode *Descriptor,
306                               RewriteDescriptorList *DL) {
307  bool Naked = false;
308  std::string Source;
309  std::string Target;
310  std::string Transform;
311
312  for (auto &Field : *Descriptor) {
313    yaml::ScalarNode *Key;
314    yaml::ScalarNode *Value;
315    SmallString<32> KeyStorage;
316    SmallString<32> ValueStorage;
317    StringRef KeyValue;
318
319    Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
320    if (!Key) {
321      YS.printError(Field.getKey(), "descriptor key must be a scalar");
322      return false;
323    }
324
325    Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
326    if (!Value) {
327      YS.printError(Field.getValue(), "descriptor value must be a scalar");
328      return false;
329    }
330
331    KeyValue = Key->getValue(KeyStorage);
332    if (KeyValue.equals("source")) {
333      std::string Error;
334
335      Source = Value->getValue(ValueStorage);
336      if (!Regex(Source).isValid(Error)) {
337        YS.printError(Field.getKey(), "invalid regex: " + Error);
338        return false;
339      }
340    } else if (KeyValue.equals("target")) {
341      Target = Value->getValue(ValueStorage);
342    } else if (KeyValue.equals("transform")) {
343      Transform = Value->getValue(ValueStorage);
344    } else if (KeyValue.equals("naked")) {
345      std::string Undecorated;
346
347      Undecorated = Value->getValue(ValueStorage);
348      Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1";
349    } else {
350      YS.printError(Field.getKey(), "unknown key for function");
351      return false;
352    }
353  }
354
355  if (Transform.empty() == Target.empty()) {
356    YS.printError(Descriptor,
357                  "exactly one of transform or target must be specified");
358    return false;
359  }
360
361  // TODO see if there is a more elegant solution to selecting the rewrite
362  // descriptor type
363  if (!Target.empty())
364    DL->push_back(new ExplicitRewriteFunctionDescriptor(Source, Target, Naked));
365  else
366    DL->push_back(new PatternRewriteFunctionDescriptor(Source, Transform));
367
368  return true;
369}
370
371bool RewriteMapParser::
372parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
373                                     yaml::MappingNode *Descriptor,
374                                     RewriteDescriptorList *DL) {
375  std::string Source;
376  std::string Target;
377  std::string Transform;
378
379  for (auto &Field : *Descriptor) {
380    yaml::ScalarNode *Key;
381    yaml::ScalarNode *Value;
382    SmallString<32> KeyStorage;
383    SmallString<32> ValueStorage;
384    StringRef KeyValue;
385
386    Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
387    if (!Key) {
388      YS.printError(Field.getKey(), "descriptor Key must be a scalar");
389      return false;
390    }
391
392    Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
393    if (!Value) {
394      YS.printError(Field.getValue(), "descriptor value must be a scalar");
395      return false;
396    }
397
398    KeyValue = Key->getValue(KeyStorage);
399    if (KeyValue.equals("source")) {
400      std::string Error;
401
402      Source = Value->getValue(ValueStorage);
403      if (!Regex(Source).isValid(Error)) {
404        YS.printError(Field.getKey(), "invalid regex: " + Error);
405        return false;
406      }
407    } else if (KeyValue.equals("target")) {
408      Target = Value->getValue(ValueStorage);
409    } else if (KeyValue.equals("transform")) {
410      Transform = Value->getValue(ValueStorage);
411    } else {
412      YS.printError(Field.getKey(), "unknown Key for Global Variable");
413      return false;
414    }
415  }
416
417  if (Transform.empty() == Target.empty()) {
418    YS.printError(Descriptor,
419                  "exactly one of transform or target must be specified");
420    return false;
421  }
422
423  if (!Target.empty())
424    DL->push_back(new ExplicitRewriteGlobalVariableDescriptor(Source, Target,
425                                                              /*Naked*/false));
426  else
427    DL->push_back(new PatternRewriteGlobalVariableDescriptor(Source,
428                                                             Transform));
429
430  return true;
431}
432
433bool RewriteMapParser::
434parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
435                                  yaml::MappingNode *Descriptor,
436                                  RewriteDescriptorList *DL) {
437  std::string Source;
438  std::string Target;
439  std::string Transform;
440
441  for (auto &Field : *Descriptor) {
442    yaml::ScalarNode *Key;
443    yaml::ScalarNode *Value;
444    SmallString<32> KeyStorage;
445    SmallString<32> ValueStorage;
446    StringRef KeyValue;
447
448    Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
449    if (!Key) {
450      YS.printError(Field.getKey(), "descriptor key must be a scalar");
451      return false;
452    }
453
454    Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
455    if (!Value) {
456      YS.printError(Field.getValue(), "descriptor value must be a scalar");
457      return false;
458    }
459
460    KeyValue = Key->getValue(KeyStorage);
461    if (KeyValue.equals("source")) {
462      std::string Error;
463
464      Source = Value->getValue(ValueStorage);
465      if (!Regex(Source).isValid(Error)) {
466        YS.printError(Field.getKey(), "invalid regex: " + Error);
467        return false;
468      }
469    } else if (KeyValue.equals("target")) {
470      Target = Value->getValue(ValueStorage);
471    } else if (KeyValue.equals("transform")) {
472      Transform = Value->getValue(ValueStorage);
473    } else {
474      YS.printError(Field.getKey(), "unknown key for Global Alias");
475      return false;
476    }
477  }
478
479  if (Transform.empty() == Target.empty()) {
480    YS.printError(Descriptor,
481                  "exactly one of transform or target must be specified");
482    return false;
483  }
484
485  if (!Target.empty())
486    DL->push_back(new ExplicitRewriteNamedAliasDescriptor(Source, Target,
487                                                          /*Naked*/false));
488  else
489    DL->push_back(new PatternRewriteNamedAliasDescriptor(Source, Transform));
490
491  return true;
492}
493
494namespace {
495class RewriteSymbols : public ModulePass {
496public:
497  static char ID; // Pass identification, replacement for typeid
498
499  RewriteSymbols();
500  RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL);
501
502  bool runOnModule(Module &M) override;
503
504private:
505  void loadAndParseMapFiles();
506
507  SymbolRewriter::RewriteDescriptorList Descriptors;
508};
509
510char RewriteSymbols::ID = 0;
511
512RewriteSymbols::RewriteSymbols() : ModulePass(ID) {
513  initializeRewriteSymbolsPass(*PassRegistry::getPassRegistry());
514  loadAndParseMapFiles();
515}
516
517RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL)
518    : ModulePass(ID) {
519  Descriptors.splice(Descriptors.begin(), DL);
520}
521
522bool RewriteSymbols::runOnModule(Module &M) {
523  bool Changed;
524
525  Changed = false;
526  for (auto &Descriptor : Descriptors)
527    Changed |= Descriptor.performOnModule(M);
528
529  return Changed;
530}
531
532void RewriteSymbols::loadAndParseMapFiles() {
533  const std::vector<std::string> MapFiles(RewriteMapFiles);
534  SymbolRewriter::RewriteMapParser parser;
535
536  for (const auto &MapFile : MapFiles)
537    parser.parse(MapFile, &Descriptors);
538}
539}
540
541INITIALIZE_PASS(RewriteSymbols, "rewrite-symbols", "Rewrite Symbols", false,
542                false)
543
544ModulePass *llvm::createRewriteSymbolsPass() { return new RewriteSymbols(); }
545
546ModulePass *
547llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) {
548  return new RewriteSymbols(DL);
549}
550