1//===- GVNExpression.h - GVN Expression classes -----------------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10/// \file
11///
12/// The header file for the GVN pass that contains expression handling
13/// classes
14//
15//===----------------------------------------------------------------------===//
16
17#ifndef LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
18#define LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
19
20#include "llvm/ADT/Hashing.h"
21#include "llvm/ADT/iterator_range.h"
22#include "llvm/Analysis/MemorySSA.h"
23#include "llvm/IR/Constant.h"
24#include "llvm/IR/Instructions.h"
25#include "llvm/IR/Value.h"
26#include "llvm/Support/Allocator.h"
27#include "llvm/Support/ArrayRecycler.h"
28#include "llvm/Support/Casting.h"
29#include "llvm/Support/Compiler.h"
30#include "llvm/Support/raw_ostream.h"
31#include <algorithm>
32#include <cassert>
33#include <iterator>
34#include <utility>
35
36namespace llvm {
37
38class BasicBlock;
39class Type;
40
41namespace GVNExpression {
42
43enum ExpressionType {
44  ET_Base,
45  ET_Constant,
46  ET_Variable,
47  ET_Dead,
48  ET_Unknown,
49  ET_BasicStart,
50  ET_Basic,
51  ET_AggregateValue,
52  ET_Phi,
53  ET_MemoryStart,
54  ET_Call,
55  ET_Load,
56  ET_Store,
57  ET_MemoryEnd,
58  ET_BasicEnd
59};
60
61class Expression {
62private:
63  ExpressionType EType;
64  unsigned Opcode;
65  mutable hash_code HashVal = 0;
66
67public:
68  Expression(ExpressionType ET = ET_Base, unsigned O = ~2U)
69      : EType(ET), Opcode(O) {}
70  Expression(const Expression &) = delete;
71  Expression &operator=(const Expression &) = delete;
72  virtual ~Expression();
73
74  static unsigned getEmptyKey() { return ~0U; }
75  static unsigned getTombstoneKey() { return ~1U; }
76
77  bool operator!=(const Expression &Other) const { return !(*this == Other); }
78  bool operator==(const Expression &Other) const {
79    if (getOpcode() != Other.getOpcode())
80      return false;
81    if (getOpcode() == getEmptyKey() || getOpcode() == getTombstoneKey())
82      return true;
83    // Compare the expression type for anything but load and store.
84    // For load and store we set the opcode to zero to make them equal.
85    if (getExpressionType() != ET_Load && getExpressionType() != ET_Store &&
86        getExpressionType() != Other.getExpressionType())
87      return false;
88
89    return equals(Other);
90  }
91
92  hash_code getComputedHash() const {
93    // It's theoretically possible for a thing to hash to zero.  In that case,
94    // we will just compute the hash a few extra times, which is no worse that
95    // we did before, which was to compute it always.
96    if (static_cast<unsigned>(HashVal) == 0)
97      HashVal = getHashValue();
98    return HashVal;
99  }
100
101  virtual bool equals(const Expression &Other) const { return true; }
102
103  // Return true if the two expressions are exactly the same, including the
104  // normally ignored fields.
105  virtual bool exactlyEquals(const Expression &Other) const {
106    return getExpressionType() == Other.getExpressionType() && equals(Other);
107  }
108
109  unsigned getOpcode() const { return Opcode; }
110  void setOpcode(unsigned opcode) { Opcode = opcode; }
111  ExpressionType getExpressionType() const { return EType; }
112
113  // We deliberately leave the expression type out of the hash value.
114  virtual hash_code getHashValue() const { return getOpcode(); }
115
116  // Debugging support
117  virtual void printInternal(raw_ostream &OS, bool PrintEType) const {
118    if (PrintEType)
119      OS << "etype = " << getExpressionType() << ",";
120    OS << "opcode = " << getOpcode() << ", ";
121  }
122
123  void print(raw_ostream &OS) const {
124    OS << "{ ";
125    printInternal(OS, true);
126    OS << "}";
127  }
128
129  LLVM_DUMP_METHOD void dump() const;
130};
131
132inline raw_ostream &operator<<(raw_ostream &OS, const Expression &E) {
133  E.print(OS);
134  return OS;
135}
136
137class BasicExpression : public Expression {
138private:
139  using RecyclerType = ArrayRecycler<Value *>;
140  using RecyclerCapacity = RecyclerType::Capacity;
141
142  Value **Operands = nullptr;
143  unsigned MaxOperands;
144  unsigned NumOperands = 0;
145  Type *ValueType = nullptr;
146
147public:
148  BasicExpression(unsigned NumOperands)
149      : BasicExpression(NumOperands, ET_Basic) {}
150  BasicExpression(unsigned NumOperands, ExpressionType ET)
151      : Expression(ET), MaxOperands(NumOperands) {}
152  BasicExpression() = delete;
153  BasicExpression(const BasicExpression &) = delete;
154  BasicExpression &operator=(const BasicExpression &) = delete;
155  ~BasicExpression() override;
156
157  static bool classof(const Expression *EB) {
158    ExpressionType ET = EB->getExpressionType();
159    return ET > ET_BasicStart && ET < ET_BasicEnd;
160  }
161
162  /// \brief Swap two operands. Used during GVN to put commutative operands in
163  /// order.
164  void swapOperands(unsigned First, unsigned Second) {
165    std::swap(Operands[First], Operands[Second]);
166  }
167
168  Value *getOperand(unsigned N) const {
169    assert(Operands && "Operands not allocated");
170    assert(N < NumOperands && "Operand out of range");
171    return Operands[N];
172  }
173
174  void setOperand(unsigned N, Value *V) {
175    assert(Operands && "Operands not allocated before setting");
176    assert(N < NumOperands && "Operand out of range");
177    Operands[N] = V;
178  }
179
180  unsigned getNumOperands() const { return NumOperands; }
181
182  using op_iterator = Value **;
183  using const_op_iterator = Value *const *;
184
185  op_iterator op_begin() { return Operands; }
186  op_iterator op_end() { return Operands + NumOperands; }
187  const_op_iterator op_begin() const { return Operands; }
188  const_op_iterator op_end() const { return Operands + NumOperands; }
189  iterator_range<op_iterator> operands() {
190    return iterator_range<op_iterator>(op_begin(), op_end());
191  }
192  iterator_range<const_op_iterator> operands() const {
193    return iterator_range<const_op_iterator>(op_begin(), op_end());
194  }
195
196  void op_push_back(Value *Arg) {
197    assert(NumOperands < MaxOperands && "Tried to add too many operands");
198    assert(Operands && "Operandss not allocated before pushing");
199    Operands[NumOperands++] = Arg;
200  }
201  bool op_empty() const { return getNumOperands() == 0; }
202
203  void allocateOperands(RecyclerType &Recycler, BumpPtrAllocator &Allocator) {
204    assert(!Operands && "Operands already allocated");
205    Operands = Recycler.allocate(RecyclerCapacity::get(MaxOperands), Allocator);
206  }
207  void deallocateOperands(RecyclerType &Recycler) {
208    Recycler.deallocate(RecyclerCapacity::get(MaxOperands), Operands);
209  }
210
211  void setType(Type *T) { ValueType = T; }
212  Type *getType() const { return ValueType; }
213
214  bool equals(const Expression &Other) const override {
215    if (getOpcode() != Other.getOpcode())
216      return false;
217
218    const auto &OE = cast<BasicExpression>(Other);
219    return getType() == OE.getType() && NumOperands == OE.NumOperands &&
220           std::equal(op_begin(), op_end(), OE.op_begin());
221  }
222
223  hash_code getHashValue() const override {
224    return hash_combine(this->Expression::getHashValue(), ValueType,
225                        hash_combine_range(op_begin(), op_end()));
226  }
227
228  // Debugging support
229  void printInternal(raw_ostream &OS, bool PrintEType) const override {
230    if (PrintEType)
231      OS << "ExpressionTypeBasic, ";
232
233    this->Expression::printInternal(OS, false);
234    OS << "operands = {";
235    for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
236      OS << "[" << i << "] = ";
237      Operands[i]->printAsOperand(OS);
238      OS << "  ";
239    }
240    OS << "} ";
241  }
242};
243
244class op_inserter
245    : public std::iterator<std::output_iterator_tag, void, void, void, void> {
246private:
247  using Container = BasicExpression;
248
249  Container *BE;
250
251public:
252  explicit op_inserter(BasicExpression &E) : BE(&E) {}
253  explicit op_inserter(BasicExpression *E) : BE(E) {}
254
255  op_inserter &operator=(Value *val) {
256    BE->op_push_back(val);
257    return *this;
258  }
259  op_inserter &operator*() { return *this; }
260  op_inserter &operator++() { return *this; }
261  op_inserter &operator++(int) { return *this; }
262};
263
264class MemoryExpression : public BasicExpression {
265private:
266  const MemoryAccess *MemoryLeader;
267
268public:
269  MemoryExpression(unsigned NumOperands, enum ExpressionType EType,
270                   const MemoryAccess *MemoryLeader)
271      : BasicExpression(NumOperands, EType), MemoryLeader(MemoryLeader) {}
272  MemoryExpression() = delete;
273  MemoryExpression(const MemoryExpression &) = delete;
274  MemoryExpression &operator=(const MemoryExpression &) = delete;
275
276  static bool classof(const Expression *EB) {
277    return EB->getExpressionType() > ET_MemoryStart &&
278           EB->getExpressionType() < ET_MemoryEnd;
279  }
280
281  hash_code getHashValue() const override {
282    return hash_combine(this->BasicExpression::getHashValue(), MemoryLeader);
283  }
284
285  bool equals(const Expression &Other) const override {
286    if (!this->BasicExpression::equals(Other))
287      return false;
288    const MemoryExpression &OtherMCE = cast<MemoryExpression>(Other);
289
290    return MemoryLeader == OtherMCE.MemoryLeader;
291  }
292
293  const MemoryAccess *getMemoryLeader() const { return MemoryLeader; }
294  void setMemoryLeader(const MemoryAccess *ML) { MemoryLeader = ML; }
295};
296
297class CallExpression final : public MemoryExpression {
298private:
299  CallInst *Call;
300
301public:
302  CallExpression(unsigned NumOperands, CallInst *C,
303                 const MemoryAccess *MemoryLeader)
304      : MemoryExpression(NumOperands, ET_Call, MemoryLeader), Call(C) {}
305  CallExpression() = delete;
306  CallExpression(const CallExpression &) = delete;
307  CallExpression &operator=(const CallExpression &) = delete;
308  ~CallExpression() override;
309
310  static bool classof(const Expression *EB) {
311    return EB->getExpressionType() == ET_Call;
312  }
313
314  // Debugging support
315  void printInternal(raw_ostream &OS, bool PrintEType) const override {
316    if (PrintEType)
317      OS << "ExpressionTypeCall, ";
318    this->BasicExpression::printInternal(OS, false);
319    OS << " represents call at ";
320    Call->printAsOperand(OS);
321  }
322};
323
324class LoadExpression final : public MemoryExpression {
325private:
326  LoadInst *Load;
327  unsigned Alignment;
328
329public:
330  LoadExpression(unsigned NumOperands, LoadInst *L,
331                 const MemoryAccess *MemoryLeader)
332      : LoadExpression(ET_Load, NumOperands, L, MemoryLeader) {}
333
334  LoadExpression(enum ExpressionType EType, unsigned NumOperands, LoadInst *L,
335                 const MemoryAccess *MemoryLeader)
336      : MemoryExpression(NumOperands, EType, MemoryLeader), Load(L) {
337    Alignment = L ? L->getAlignment() : 0;
338  }
339
340  LoadExpression() = delete;
341  LoadExpression(const LoadExpression &) = delete;
342  LoadExpression &operator=(const LoadExpression &) = delete;
343  ~LoadExpression() override;
344
345  static bool classof(const Expression *EB) {
346    return EB->getExpressionType() == ET_Load;
347  }
348
349  LoadInst *getLoadInst() const { return Load; }
350  void setLoadInst(LoadInst *L) { Load = L; }
351
352  unsigned getAlignment() const { return Alignment; }
353  void setAlignment(unsigned Align) { Alignment = Align; }
354
355  bool equals(const Expression &Other) const override;
356  bool exactlyEquals(const Expression &Other) const override {
357    return Expression::exactlyEquals(Other) &&
358           cast<LoadExpression>(Other).getLoadInst() == getLoadInst();
359  }
360
361  // Debugging support
362  void printInternal(raw_ostream &OS, bool PrintEType) const override {
363    if (PrintEType)
364      OS << "ExpressionTypeLoad, ";
365    this->BasicExpression::printInternal(OS, false);
366    OS << " represents Load at ";
367    Load->printAsOperand(OS);
368    OS << " with MemoryLeader " << *getMemoryLeader();
369  }
370};
371
372class StoreExpression final : public MemoryExpression {
373private:
374  StoreInst *Store;
375  Value *StoredValue;
376
377public:
378  StoreExpression(unsigned NumOperands, StoreInst *S, Value *StoredValue,
379                  const MemoryAccess *MemoryLeader)
380      : MemoryExpression(NumOperands, ET_Store, MemoryLeader), Store(S),
381        StoredValue(StoredValue) {}
382  StoreExpression() = delete;
383  StoreExpression(const StoreExpression &) = delete;
384  StoreExpression &operator=(const StoreExpression &) = delete;
385  ~StoreExpression() override;
386
387  static bool classof(const Expression *EB) {
388    return EB->getExpressionType() == ET_Store;
389  }
390
391  StoreInst *getStoreInst() const { return Store; }
392  Value *getStoredValue() const { return StoredValue; }
393
394  bool equals(const Expression &Other) const override;
395
396  bool exactlyEquals(const Expression &Other) const override {
397    return Expression::exactlyEquals(Other) &&
398           cast<StoreExpression>(Other).getStoreInst() == getStoreInst();
399  }
400
401  // Debugging support
402  void printInternal(raw_ostream &OS, bool PrintEType) const override {
403    if (PrintEType)
404      OS << "ExpressionTypeStore, ";
405    this->BasicExpression::printInternal(OS, false);
406    OS << " represents Store  " << *Store;
407    OS << " with StoredValue ";
408    StoredValue->printAsOperand(OS);
409    OS << " and MemoryLeader " << *getMemoryLeader();
410  }
411};
412
413class AggregateValueExpression final : public BasicExpression {
414private:
415  unsigned MaxIntOperands;
416  unsigned NumIntOperands = 0;
417  unsigned *IntOperands = nullptr;
418
419public:
420  AggregateValueExpression(unsigned NumOperands, unsigned NumIntOperands)
421      : BasicExpression(NumOperands, ET_AggregateValue),
422        MaxIntOperands(NumIntOperands) {}
423  AggregateValueExpression() = delete;
424  AggregateValueExpression(const AggregateValueExpression &) = delete;
425  AggregateValueExpression &
426  operator=(const AggregateValueExpression &) = delete;
427  ~AggregateValueExpression() override;
428
429  static bool classof(const Expression *EB) {
430    return EB->getExpressionType() == ET_AggregateValue;
431  }
432
433  using int_arg_iterator = unsigned *;
434  using const_int_arg_iterator = const unsigned *;
435
436  int_arg_iterator int_op_begin() { return IntOperands; }
437  int_arg_iterator int_op_end() { return IntOperands + NumIntOperands; }
438  const_int_arg_iterator int_op_begin() const { return IntOperands; }
439  const_int_arg_iterator int_op_end() const {
440    return IntOperands + NumIntOperands;
441  }
442  unsigned int_op_size() const { return NumIntOperands; }
443  bool int_op_empty() const { return NumIntOperands == 0; }
444  void int_op_push_back(unsigned IntOperand) {
445    assert(NumIntOperands < MaxIntOperands &&
446           "Tried to add too many int operands");
447    assert(IntOperands && "Operands not allocated before pushing");
448    IntOperands[NumIntOperands++] = IntOperand;
449  }
450
451  virtual void allocateIntOperands(BumpPtrAllocator &Allocator) {
452    assert(!IntOperands && "Operands already allocated");
453    IntOperands = Allocator.Allocate<unsigned>(MaxIntOperands);
454  }
455
456  bool equals(const Expression &Other) const override {
457    if (!this->BasicExpression::equals(Other))
458      return false;
459    const AggregateValueExpression &OE = cast<AggregateValueExpression>(Other);
460    return NumIntOperands == OE.NumIntOperands &&
461           std::equal(int_op_begin(), int_op_end(), OE.int_op_begin());
462  }
463
464  hash_code getHashValue() const override {
465    return hash_combine(this->BasicExpression::getHashValue(),
466                        hash_combine_range(int_op_begin(), int_op_end()));
467  }
468
469  // Debugging support
470  void printInternal(raw_ostream &OS, bool PrintEType) const override {
471    if (PrintEType)
472      OS << "ExpressionTypeAggregateValue, ";
473    this->BasicExpression::printInternal(OS, false);
474    OS << ", intoperands = {";
475    for (unsigned i = 0, e = int_op_size(); i != e; ++i) {
476      OS << "[" << i << "] = " << IntOperands[i] << "  ";
477    }
478    OS << "}";
479  }
480};
481
482class int_op_inserter
483    : public std::iterator<std::output_iterator_tag, void, void, void, void> {
484private:
485  using Container = AggregateValueExpression;
486
487  Container *AVE;
488
489public:
490  explicit int_op_inserter(AggregateValueExpression &E) : AVE(&E) {}
491  explicit int_op_inserter(AggregateValueExpression *E) : AVE(E) {}
492
493  int_op_inserter &operator=(unsigned int val) {
494    AVE->int_op_push_back(val);
495    return *this;
496  }
497  int_op_inserter &operator*() { return *this; }
498  int_op_inserter &operator++() { return *this; }
499  int_op_inserter &operator++(int) { return *this; }
500};
501
502class PHIExpression final : public BasicExpression {
503private:
504  BasicBlock *BB;
505
506public:
507  PHIExpression(unsigned NumOperands, BasicBlock *B)
508      : BasicExpression(NumOperands, ET_Phi), BB(B) {}
509  PHIExpression() = delete;
510  PHIExpression(const PHIExpression &) = delete;
511  PHIExpression &operator=(const PHIExpression &) = delete;
512  ~PHIExpression() override;
513
514  static bool classof(const Expression *EB) {
515    return EB->getExpressionType() == ET_Phi;
516  }
517
518  bool equals(const Expression &Other) const override {
519    if (!this->BasicExpression::equals(Other))
520      return false;
521    const PHIExpression &OE = cast<PHIExpression>(Other);
522    return BB == OE.BB;
523  }
524
525  hash_code getHashValue() const override {
526    return hash_combine(this->BasicExpression::getHashValue(), BB);
527  }
528
529  // Debugging support
530  void printInternal(raw_ostream &OS, bool PrintEType) const override {
531    if (PrintEType)
532      OS << "ExpressionTypePhi, ";
533    this->BasicExpression::printInternal(OS, false);
534    OS << "bb = " << BB;
535  }
536};
537
538class DeadExpression final : public Expression {
539public:
540  DeadExpression() : Expression(ET_Dead) {}
541  DeadExpression(const DeadExpression &) = delete;
542  DeadExpression &operator=(const DeadExpression &) = delete;
543
544  static bool classof(const Expression *E) {
545    return E->getExpressionType() == ET_Dead;
546  }
547};
548
549class VariableExpression final : public Expression {
550private:
551  Value *VariableValue;
552
553public:
554  VariableExpression(Value *V) : Expression(ET_Variable), VariableValue(V) {}
555  VariableExpression() = delete;
556  VariableExpression(const VariableExpression &) = delete;
557  VariableExpression &operator=(const VariableExpression &) = delete;
558
559  static bool classof(const Expression *EB) {
560    return EB->getExpressionType() == ET_Variable;
561  }
562
563  Value *getVariableValue() const { return VariableValue; }
564  void setVariableValue(Value *V) { VariableValue = V; }
565
566  bool equals(const Expression &Other) const override {
567    const VariableExpression &OC = cast<VariableExpression>(Other);
568    return VariableValue == OC.VariableValue;
569  }
570
571  hash_code getHashValue() const override {
572    return hash_combine(this->Expression::getHashValue(),
573                        VariableValue->getType(), VariableValue);
574  }
575
576  // Debugging support
577  void printInternal(raw_ostream &OS, bool PrintEType) const override {
578    if (PrintEType)
579      OS << "ExpressionTypeVariable, ";
580    this->Expression::printInternal(OS, false);
581    OS << " variable = " << *VariableValue;
582  }
583};
584
585class ConstantExpression final : public Expression {
586private:
587  Constant *ConstantValue = nullptr;
588
589public:
590  ConstantExpression() : Expression(ET_Constant) {}
591  ConstantExpression(Constant *constantValue)
592      : Expression(ET_Constant), ConstantValue(constantValue) {}
593  ConstantExpression(const ConstantExpression &) = delete;
594  ConstantExpression &operator=(const ConstantExpression &) = delete;
595
596  static bool classof(const Expression *EB) {
597    return EB->getExpressionType() == ET_Constant;
598  }
599
600  Constant *getConstantValue() const { return ConstantValue; }
601  void setConstantValue(Constant *V) { ConstantValue = V; }
602
603  bool equals(const Expression &Other) const override {
604    const ConstantExpression &OC = cast<ConstantExpression>(Other);
605    return ConstantValue == OC.ConstantValue;
606  }
607
608  hash_code getHashValue() const override {
609    return hash_combine(this->Expression::getHashValue(),
610                        ConstantValue->getType(), ConstantValue);
611  }
612
613  // Debugging support
614  void printInternal(raw_ostream &OS, bool PrintEType) const override {
615    if (PrintEType)
616      OS << "ExpressionTypeConstant, ";
617    this->Expression::printInternal(OS, false);
618    OS << " constant = " << *ConstantValue;
619  }
620};
621
622class UnknownExpression final : public Expression {
623private:
624  Instruction *Inst;
625
626public:
627  UnknownExpression(Instruction *I) : Expression(ET_Unknown), Inst(I) {}
628  UnknownExpression() = delete;
629  UnknownExpression(const UnknownExpression &) = delete;
630  UnknownExpression &operator=(const UnknownExpression &) = delete;
631
632  static bool classof(const Expression *EB) {
633    return EB->getExpressionType() == ET_Unknown;
634  }
635
636  Instruction *getInstruction() const { return Inst; }
637  void setInstruction(Instruction *I) { Inst = I; }
638
639  bool equals(const Expression &Other) const override {
640    const auto &OU = cast<UnknownExpression>(Other);
641    return Inst == OU.Inst;
642  }
643
644  hash_code getHashValue() const override {
645    return hash_combine(this->Expression::getHashValue(), Inst);
646  }
647
648  // Debugging support
649  void printInternal(raw_ostream &OS, bool PrintEType) const override {
650    if (PrintEType)
651      OS << "ExpressionTypeUnknown, ";
652    this->Expression::printInternal(OS, false);
653    OS << " inst = " << *Inst;
654  }
655};
656
657} // end namespace GVNExpression
658
659} // end namespace llvm
660
661#endif // LLVM_TRANSFORMS_SCALAR_GVNEXPRESSION_H
662