1//===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9// This file contains the AArch64 / Cortex-A57 specific register allocation
10// constraints for use by the PBQP register allocator.
11//
12// It is essentially a transcription of what is contained in
13// AArch64A57FPLoadBalancing, which tries to use a balanced
14// mix of odd and even D-registers when performing a critical sequence of
15// independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
16//===----------------------------------------------------------------------===//
17
18#define DEBUG_TYPE "aarch64-pbqp"
19
20#include "AArch64.h"
21#include "AArch64PBQPRegAlloc.h"
22#include "AArch64RegisterInfo.h"
23#include "llvm/CodeGen/LiveIntervalAnalysis.h"
24#include "llvm/CodeGen/MachineBasicBlock.h"
25#include "llvm/CodeGen/MachineFunction.h"
26#include "llvm/CodeGen/MachineRegisterInfo.h"
27#include "llvm/CodeGen/RegAllocPBQP.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/ErrorHandling.h"
30#include "llvm/Support/raw_ostream.h"
31
32using namespace llvm;
33
34namespace {
35
36#ifndef NDEBUG
37bool isFPReg(unsigned reg) {
38  return AArch64::FPR32RegClass.contains(reg) ||
39         AArch64::FPR64RegClass.contains(reg) ||
40         AArch64::FPR128RegClass.contains(reg);
41}
42#endif
43
44bool isOdd(unsigned reg) {
45  switch (reg) {
46  default:
47    llvm_unreachable("Register is not from the expected class !");
48  case AArch64::S1:
49  case AArch64::S3:
50  case AArch64::S5:
51  case AArch64::S7:
52  case AArch64::S9:
53  case AArch64::S11:
54  case AArch64::S13:
55  case AArch64::S15:
56  case AArch64::S17:
57  case AArch64::S19:
58  case AArch64::S21:
59  case AArch64::S23:
60  case AArch64::S25:
61  case AArch64::S27:
62  case AArch64::S29:
63  case AArch64::S31:
64  case AArch64::D1:
65  case AArch64::D3:
66  case AArch64::D5:
67  case AArch64::D7:
68  case AArch64::D9:
69  case AArch64::D11:
70  case AArch64::D13:
71  case AArch64::D15:
72  case AArch64::D17:
73  case AArch64::D19:
74  case AArch64::D21:
75  case AArch64::D23:
76  case AArch64::D25:
77  case AArch64::D27:
78  case AArch64::D29:
79  case AArch64::D31:
80  case AArch64::Q1:
81  case AArch64::Q3:
82  case AArch64::Q5:
83  case AArch64::Q7:
84  case AArch64::Q9:
85  case AArch64::Q11:
86  case AArch64::Q13:
87  case AArch64::Q15:
88  case AArch64::Q17:
89  case AArch64::Q19:
90  case AArch64::Q21:
91  case AArch64::Q23:
92  case AArch64::Q25:
93  case AArch64::Q27:
94  case AArch64::Q29:
95  case AArch64::Q31:
96    return true;
97  case AArch64::S0:
98  case AArch64::S2:
99  case AArch64::S4:
100  case AArch64::S6:
101  case AArch64::S8:
102  case AArch64::S10:
103  case AArch64::S12:
104  case AArch64::S14:
105  case AArch64::S16:
106  case AArch64::S18:
107  case AArch64::S20:
108  case AArch64::S22:
109  case AArch64::S24:
110  case AArch64::S26:
111  case AArch64::S28:
112  case AArch64::S30:
113  case AArch64::D0:
114  case AArch64::D2:
115  case AArch64::D4:
116  case AArch64::D6:
117  case AArch64::D8:
118  case AArch64::D10:
119  case AArch64::D12:
120  case AArch64::D14:
121  case AArch64::D16:
122  case AArch64::D18:
123  case AArch64::D20:
124  case AArch64::D22:
125  case AArch64::D24:
126  case AArch64::D26:
127  case AArch64::D28:
128  case AArch64::D30:
129  case AArch64::Q0:
130  case AArch64::Q2:
131  case AArch64::Q4:
132  case AArch64::Q6:
133  case AArch64::Q8:
134  case AArch64::Q10:
135  case AArch64::Q12:
136  case AArch64::Q14:
137  case AArch64::Q16:
138  case AArch64::Q18:
139  case AArch64::Q20:
140  case AArch64::Q22:
141  case AArch64::Q24:
142  case AArch64::Q26:
143  case AArch64::Q28:
144  case AArch64::Q30:
145    return false;
146
147  }
148}
149
150bool haveSameParity(unsigned reg1, unsigned reg2) {
151  assert(isFPReg(reg1) && "Expecting an FP register for reg1");
152  assert(isFPReg(reg2) && "Expecting an FP register for reg2");
153
154  return isOdd(reg1) == isOdd(reg2);
155}
156
157}
158
159bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
160                                                 unsigned Ra) {
161  if (Rd == Ra)
162    return false;
163
164  LiveIntervals &LIs = G.getMetadata().LIS;
165
166  if (TRI->isPhysicalRegister(Rd) || TRI->isPhysicalRegister(Ra)) {
167    DEBUG(dbgs() << "Rd is a physical reg:" << TRI->isPhysicalRegister(Rd)
168          << '\n');
169    DEBUG(dbgs() << "Ra is a physical reg:" << TRI->isPhysicalRegister(Ra)
170          << '\n');
171    return false;
172  }
173
174  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
175  PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
176
177  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
178    &G.getNodeMetadata(node1).getAllowedRegs();
179  const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
180    &G.getNodeMetadata(node2).getAllowedRegs();
181
182  PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
183
184  // The edge does not exist. Create one with the appropriate interference
185  // costs.
186  if (edge == G.invalidEdgeId()) {
187    const LiveInterval &ld = LIs.getInterval(Rd);
188    const LiveInterval &la = LIs.getInterval(Ra);
189    bool livesOverlap = ld.overlaps(la);
190
191    PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
192                                 vRaAllowed->size() + 1, 0);
193    for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
194      unsigned pRd = (*vRdAllowed)[i];
195      for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
196        unsigned pRa = (*vRaAllowed)[j];
197        if (livesOverlap && TRI->regsOverlap(pRd, pRa))
198          costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
199        else
200          costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
201      }
202    }
203    G.addEdge(node1, node2, std::move(costs));
204    return true;
205  }
206
207  if (G.getEdgeNode1Id(edge) == node2) {
208    std::swap(node1, node2);
209    std::swap(vRdAllowed, vRaAllowed);
210  }
211
212  // Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
213  PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
214  for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
215    unsigned pRd = (*vRdAllowed)[i];
216
217    // Get the maximum cost (excluding unallocatable reg) for same parity
218    // registers
219    PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
220    for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
221      unsigned pRa = (*vRaAllowed)[j];
222      if (haveSameParity(pRd, pRa))
223        if (costs[i + 1][j + 1] !=
224                std::numeric_limits<PBQP::PBQPNum>::infinity() &&
225            costs[i + 1][j + 1] > sameParityMax)
226          sameParityMax = costs[i + 1][j + 1];
227    }
228
229    // Ensure all registers with a different parity have a higher cost
230    // than sameParityMax
231    for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
232      unsigned pRa = (*vRaAllowed)[j];
233      if (!haveSameParity(pRd, pRa))
234        if (sameParityMax > costs[i + 1][j + 1])
235          costs[i + 1][j + 1] = sameParityMax + 1.0;
236    }
237  }
238  G.updateEdgeCosts(edge, std::move(costs));
239
240  return true;
241}
242
243void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
244                                                 unsigned Ra) {
245  LiveIntervals &LIs = G.getMetadata().LIS;
246
247  // Do some Chain management
248  if (Chains.count(Ra)) {
249    if (Rd != Ra) {
250      DEBUG(dbgs() << "Moving acc chain from " << PrintReg(Ra, TRI) << " to "
251                   << PrintReg(Rd, TRI) << '\n';);
252      Chains.remove(Ra);
253      Chains.insert(Rd);
254    }
255  } else {
256    DEBUG(dbgs() << "Creating new acc chain for " << PrintReg(Rd, TRI)
257                 << '\n';);
258    Chains.insert(Rd);
259  }
260
261  PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
262
263  const LiveInterval &ld = LIs.getInterval(Rd);
264  for (auto r : Chains) {
265    // Skip self
266    if (r == Rd)
267      continue;
268
269    const LiveInterval &lr = LIs.getInterval(r);
270    if (ld.overlaps(lr)) {
271      const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
272        &G.getNodeMetadata(node1).getAllowedRegs();
273
274      PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
275      const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
276        &G.getNodeMetadata(node2).getAllowedRegs();
277
278      PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
279      assert(edge != G.invalidEdgeId() &&
280             "PBQP error ! The edge should exist !");
281
282      DEBUG(dbgs() << "Refining constraint !\n";);
283
284      if (G.getEdgeNode1Id(edge) == node2) {
285        std::swap(node1, node2);
286        std::swap(vRdAllowed, vRrAllowed);
287      }
288
289      // Enforce that cost is higher with all other Chains of the same parity
290      PBQP::Matrix costs(G.getEdgeCosts(edge));
291      for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
292        unsigned pRd = (*vRdAllowed)[i];
293
294        // Get the maximum cost (excluding unallocatable reg) for all other
295        // parity registers
296        PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
297        for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
298          unsigned pRa = (*vRrAllowed)[j];
299          if (!haveSameParity(pRd, pRa))
300            if (costs[i + 1][j + 1] !=
301                    std::numeric_limits<PBQP::PBQPNum>::infinity() &&
302                costs[i + 1][j + 1] > sameParityMax)
303              sameParityMax = costs[i + 1][j + 1];
304        }
305
306        // Ensure all registers with same parity have a higher cost
307        // than sameParityMax
308        for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
309          unsigned pRa = (*vRrAllowed)[j];
310          if (haveSameParity(pRd, pRa))
311            if (sameParityMax > costs[i + 1][j + 1])
312              costs[i + 1][j + 1] = sameParityMax + 1.0;
313        }
314      }
315      G.updateEdgeCosts(edge, std::move(costs));
316    }
317  }
318}
319
320static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
321                                const MachineInstr &MI) {
322  const LiveInterval &LI = LIs.getInterval(reg);
323  SlotIndex SI = LIs.getInstructionIndex(MI);
324  return LI.expiredAt(SI);
325}
326
327void A57ChainingConstraint::apply(PBQPRAGraph &G) {
328  const MachineFunction &MF = G.getMetadata().MF;
329  LiveIntervals &LIs = G.getMetadata().LIS;
330
331  TRI = MF.getSubtarget().getRegisterInfo();
332  DEBUG(MF.dump());
333
334  for (const auto &MBB: MF) {
335    Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
336
337    for (const auto &MI: MBB) {
338
339      // Forget Chains which have expired
340      for (auto r : Chains) {
341        SmallVector<unsigned, 8> toDel;
342        if(regJustKilledBefore(LIs, r, MI)) {
343          DEBUG(dbgs() << "Killing chain " << PrintReg(r, TRI) << " at ";
344                MI.print(dbgs()););
345          toDel.push_back(r);
346        }
347
348        while (!toDel.empty()) {
349          Chains.remove(toDel.back());
350          toDel.pop_back();
351        }
352      }
353
354      switch (MI.getOpcode()) {
355      case AArch64::FMSUBSrrr:
356      case AArch64::FMADDSrrr:
357      case AArch64::FNMSUBSrrr:
358      case AArch64::FNMADDSrrr:
359      case AArch64::FMSUBDrrr:
360      case AArch64::FMADDDrrr:
361      case AArch64::FNMSUBDrrr:
362      case AArch64::FNMADDDrrr: {
363        unsigned Rd = MI.getOperand(0).getReg();
364        unsigned Ra = MI.getOperand(3).getReg();
365
366        if (addIntraChainConstraint(G, Rd, Ra))
367          addInterChainConstraint(G, Rd, Ra);
368        break;
369      }
370
371      case AArch64::FMLAv2f32:
372      case AArch64::FMLSv2f32: {
373        unsigned Rd = MI.getOperand(0).getReg();
374        addInterChainConstraint(G, Rd, Rd);
375        break;
376      }
377
378      default:
379        break;
380      }
381    }
382  }
383}
384