1//===-- llvm/CodeGen/GlobalISel/LegalizerCombiner.h --===========//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9// This file contains some helper functions which try to cleanup artifacts
10// such as G_TRUNCs/G_[ZSA]EXTENDS that were created during legalization to make
11// the types match. This file also contains some combines of merges that happens
12// at the end of the legalization.
13//===----------------------------------------------------------------------===//
14
15#include "llvm/CodeGen/GlobalISel/Legalizer.h"
16#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
17#include "llvm/CodeGen/GlobalISel/Utils.h"
18#include "llvm/CodeGen/MachineRegisterInfo.h"
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "legalizer"
22
23namespace llvm {
24class LegalizerCombiner {
25  MachineIRBuilder &Builder;
26  MachineRegisterInfo &MRI;
27
28public:
29  LegalizerCombiner(MachineIRBuilder &B, MachineRegisterInfo &MRI)
30      : Builder(B), MRI(MRI) {}
31
32  bool tryCombineAnyExt(MachineInstr &MI,
33                        SmallVectorImpl<MachineInstr *> &DeadInsts) {
34    if (MI.getOpcode() != TargetOpcode::G_ANYEXT)
35      return false;
36    MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(1).getReg());
37    if (DefMI->getOpcode() == TargetOpcode::G_TRUNC) {
38      DEBUG(dbgs() << ".. Combine MI: " << MI;);
39      unsigned DstReg = MI.getOperand(0).getReg();
40      unsigned SrcReg = DefMI->getOperand(1).getReg();
41      Builder.setInstr(MI);
42      // We get a copy/trunc/extend depending on the sizes
43      Builder.buildAnyExtOrTrunc(DstReg, SrcReg);
44      markInstAndDefDead(MI, *DefMI, DeadInsts);
45      return true;
46    }
47    return false;
48  }
49
50  bool tryCombineZExt(MachineInstr &MI,
51                      SmallVectorImpl<MachineInstr *> &DeadInsts) {
52
53    if (MI.getOpcode() != TargetOpcode::G_ZEXT)
54      return false;
55    MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(1).getReg());
56    if (DefMI->getOpcode() == TargetOpcode::G_TRUNC) {
57      DEBUG(dbgs() << ".. Combine MI: " << MI;);
58      Builder.setInstr(MI);
59      unsigned DstReg = MI.getOperand(0).getReg();
60      unsigned ZExtSrc = MI.getOperand(1).getReg();
61      LLT ZExtSrcTy = MRI.getType(ZExtSrc);
62      LLT DstTy = MRI.getType(DstReg);
63      APInt Mask = APInt::getAllOnesValue(ZExtSrcTy.getSizeInBits());
64      auto MaskCstMIB = Builder.buildConstant(DstTy, Mask.getZExtValue());
65      unsigned TruncSrc = DefMI->getOperand(1).getReg();
66      // We get a copy/trunc/extend depending on the sizes
67      auto SrcCopyOrTrunc = Builder.buildAnyExtOrTrunc(DstTy, TruncSrc);
68      Builder.buildAnd(DstReg, SrcCopyOrTrunc, MaskCstMIB);
69      markInstAndDefDead(MI, *DefMI, DeadInsts);
70      return true;
71    }
72    return false;
73  }
74
75  bool tryCombineSExt(MachineInstr &MI,
76                      SmallVectorImpl<MachineInstr *> &DeadInsts) {
77
78    if (MI.getOpcode() != TargetOpcode::G_SEXT)
79      return false;
80    MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(1).getReg());
81    if (DefMI->getOpcode() == TargetOpcode::G_TRUNC) {
82      DEBUG(dbgs() << ".. Combine MI: " << MI;);
83      Builder.setInstr(MI);
84      unsigned DstReg = MI.getOperand(0).getReg();
85      LLT DstTy = MRI.getType(DstReg);
86      unsigned SExtSrc = MI.getOperand(1).getReg();
87      LLT SExtSrcTy = MRI.getType(SExtSrc);
88      unsigned SizeDiff = DstTy.getSizeInBits() - SExtSrcTy.getSizeInBits();
89      auto SizeDiffMIB = Builder.buildConstant(DstTy, SizeDiff);
90      unsigned TruncSrcReg = DefMI->getOperand(1).getReg();
91      // We get a copy/trunc/extend depending on the sizes
92      auto SrcCopyExtOrTrunc = Builder.buildAnyExtOrTrunc(DstTy, TruncSrcReg);
93      auto ShlMIB = Builder.buildInstr(TargetOpcode::G_SHL, DstTy,
94                                       SrcCopyExtOrTrunc, SizeDiffMIB);
95      Builder.buildInstr(TargetOpcode::G_ASHR, DstReg, ShlMIB, SizeDiffMIB);
96      markInstAndDefDead(MI, *DefMI, DeadInsts);
97      return true;
98    }
99    return false;
100  }
101
102  bool tryCombineMerges(MachineInstr &MI,
103                        SmallVectorImpl<MachineInstr *> &DeadInsts) {
104
105    if (MI.getOpcode() != TargetOpcode::G_UNMERGE_VALUES)
106      return false;
107
108    unsigned NumDefs = MI.getNumOperands() - 1;
109    unsigned SrcReg = MI.getOperand(NumDefs).getReg();
110    MachineInstr *MergeI = MRI.getVRegDef(SrcReg);
111    if (!MergeI || (MergeI->getOpcode() != TargetOpcode::G_MERGE_VALUES))
112      return false;
113
114    const unsigned NumMergeRegs = MergeI->getNumOperands() - 1;
115
116    if (NumMergeRegs < NumDefs) {
117      if (NumDefs % NumMergeRegs != 0)
118        return false;
119
120      Builder.setInstr(MI);
121      // Transform to UNMERGEs, for example
122      //   %1 = G_MERGE_VALUES %4, %5
123      //   %9, %10, %11, %12 = G_UNMERGE_VALUES %1
124      // to
125      //   %9, %10 = G_UNMERGE_VALUES %4
126      //   %11, %12 = G_UNMERGE_VALUES %5
127
128      const unsigned NewNumDefs = NumDefs / NumMergeRegs;
129      for (unsigned Idx = 0; Idx < NumMergeRegs; ++Idx) {
130        SmallVector<unsigned, 2> DstRegs;
131        for (unsigned j = 0, DefIdx = Idx * NewNumDefs; j < NewNumDefs;
132             ++j, ++DefIdx)
133          DstRegs.push_back(MI.getOperand(DefIdx).getReg());
134
135        Builder.buildUnmerge(DstRegs, MergeI->getOperand(Idx + 1).getReg());
136      }
137
138    } else if (NumMergeRegs > NumDefs) {
139      if (NumMergeRegs % NumDefs != 0)
140        return false;
141
142      Builder.setInstr(MI);
143      // Transform to MERGEs
144      //   %6 = G_MERGE_VALUES %17, %18, %19, %20
145      //   %7, %8 = G_UNMERGE_VALUES %6
146      // to
147      //   %7 = G_MERGE_VALUES %17, %18
148      //   %8 = G_MERGE_VALUES %19, %20
149
150      const unsigned NumRegs = NumMergeRegs / NumDefs;
151      for (unsigned DefIdx = 0; DefIdx < NumDefs; ++DefIdx) {
152        SmallVector<unsigned, 2> Regs;
153        for (unsigned j = 0, Idx = NumRegs * DefIdx + 1; j < NumRegs;
154             ++j, ++Idx)
155          Regs.push_back(MergeI->getOperand(Idx).getReg());
156
157        Builder.buildMerge(MI.getOperand(DefIdx).getReg(), Regs);
158      }
159
160    } else {
161      // FIXME: is a COPY appropriate if the types mismatch? We know both
162      // registers are allocatable by now.
163      if (MRI.getType(MI.getOperand(0).getReg()) !=
164          MRI.getType(MergeI->getOperand(1).getReg()))
165        return false;
166
167      for (unsigned Idx = 0; Idx < NumDefs; ++Idx)
168        MRI.replaceRegWith(MI.getOperand(Idx).getReg(),
169                           MergeI->getOperand(Idx + 1).getReg());
170    }
171
172    markInstAndDefDead(MI, *MergeI, DeadInsts);
173    return true;
174  }
175
176  /// Try to combine away MI.
177  /// Returns true if it combined away the MI.
178  /// Adds instructions that are dead as a result of the combine
179  /// into DeadInsts, which can include MI.
180  bool tryCombineInstruction(MachineInstr &MI,
181                             SmallVectorImpl<MachineInstr *> &DeadInsts) {
182    switch (MI.getOpcode()) {
183    default:
184      return false;
185    case TargetOpcode::G_ANYEXT:
186      return tryCombineAnyExt(MI, DeadInsts);
187    case TargetOpcode::G_ZEXT:
188      return tryCombineZExt(MI, DeadInsts);
189    case TargetOpcode::G_SEXT:
190      return tryCombineSExt(MI, DeadInsts);
191    case TargetOpcode::G_UNMERGE_VALUES:
192      return tryCombineMerges(MI, DeadInsts);
193    }
194  }
195
196private:
197  /// Mark MI as dead. If a def of one of MI's operands, DefMI, would also be
198  /// dead due to MI being killed, then mark DefMI as dead too.
199  void markInstAndDefDead(MachineInstr &MI, MachineInstr &DefMI,
200                          SmallVectorImpl<MachineInstr *> &DeadInsts) {
201    DeadInsts.push_back(&MI);
202    if (MRI.hasOneUse(DefMI.getOperand(0).getReg()))
203      DeadInsts.push_back(&DefMI);
204  }
205};
206
207} // namespace llvm
208