1# Copyright 2016 The Gemmlowp Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""32bit ARM/NEON assembly emitter.
15
16Used by code generators to produce ARM assembly with NEON simd code.
17Provides tools for easier register management: named register variable
18allocation/deallocation, and offers a more procedural/structured approach
19to generating assembly.
20
21TODO: right now neon emitter prints out assembly instructions immediately,
22it might be beneficial to keep the whole structure and emit the assembly after
23applying some optimizations like: instruction reordering or register reuse.
24
25TODO: NeonRegister object assigns explicit registers at allocation time.
26Similarily to emiting code, register mapping and reuse can be performed and
27optimized lazily.
28"""
29
30
31class Error(Exception):
32  """Module level error."""
33
34
35class RegisterAllocationError(Error):
36  """Cannot alocate registers."""
37
38
39class LaneError(Error):
40  """Wrong lane number."""
41
42
43class ArgumentError(Error):
44  """Wrong argument."""
45
46
47def _Low(register):
48  assert register[0] == 'q'
49  num = int(register[1:])
50  return 'd%d' % (num * 2)
51
52
53def _High(register):
54  assert register[0] == 'q'
55  num = int(register[1:])
56  return 'd%d' % (num * 2 + 1)
57
58
59def _ExpandQuads(registers):
60  doubles = []
61  for register in registers:
62    if register[0] == 'q':
63      doubles.append(_Low(register))
64      doubles.append(_High(register))
65    else:
66      doubles.append(register)
67  return doubles
68
69
70def _MakeCompatible(op1, op2, op3):
71  if op1[0] == 'd' or op2[0] == 'd' or op3[0] == 'd':
72    if op1[0] == 'q':
73      op1 = _Low(op1)
74    if op2[0] == 'q':
75      op2 = _Low(op2)
76    if op3[0] == 'q':
77      op3 = _Low(op3)
78  return (op1, op2, op3)
79
80
81class _NeonRegisters32Bit(object):
82  """Utility that keeps track of used 32bit ARM/NEON registers."""
83
84  def __init__(self):
85    self.double = set()
86    self.double_ever = set()
87    self.general = set()
88    self.general_ever = set()
89    self.parameters = dict()
90    self.output_parameters = dict()
91
92  def MapParameter(self, parameter, parameter_value=None):
93    if not parameter_value:
94      parameter_value = parameter
95    self.parameters[parameter] = (parameter_value, 'r')
96    return '%%[%s]' % parameter
97
98  def MapMemoryParameter(self, parameter, parameter_value=None):
99    if not parameter_value:
100      parameter_value = parameter
101    self.parameters[parameter] = (parameter_value, 'm')
102    return '%%[%s]' % parameter
103
104  def MapOutputParameter(self, parameter, parameter_value=None):
105    if not parameter_value:
106      parameter_value = parameter
107    self.output_parameters[parameter] = (parameter_value, '+r')
108    return '%%[%s]' % parameter
109
110  def DoubleRegister(self, min_val=0):
111    for i in range(min_val, 32):
112      if i not in self.double:
113        self.double.add(i)
114        self.double_ever.add(i)
115        return 'd%d' % i
116    raise RegisterAllocationError('Not enough double registers.')
117
118  def QuadRegister(self, min_val=0):
119    for i in range(min_val, 16):
120      if ((i * 2) not in self.double) and ((i * 2 + 1) not in self.double):
121        self.double.add(i * 2)
122        self.double.add(i * 2 + 1)
123        self.double_ever.add(i * 2)
124        self.double_ever.add(i * 2 + 1)
125        return 'q%d' % i
126    raise RegisterAllocationError('Not enough quad registers.')
127
128  def GeneralRegister(self):
129    for i in range(0, 16):
130      if i not in self.general:
131        self.general.add(i)
132        self.general_ever.add(i)
133        return 'r%d' % i
134    raise RegisterAllocationError('Not enough general registers.')
135
136  def MappedParameters(self):
137    return [(k, v) for (k, v) in self.parameters.items()]
138
139  def MappedOutputParameters(self):
140    return [(k, v) for (k, v) in self.output_parameters.items()]
141
142  def Clobbers(self):
143    return (['r%d' % i for i in self.general_ever] +
144            ['d%d' % i for i in self.DoubleClobbers()])
145
146  def DoubleClobbers(self):
147    return sorted(self.double_ever)
148
149  def FreeRegister(self, register):
150    assert len(register) > 1
151    if register[0] not in ['r', 'd', 'q']:
152      return
153
154    num = int(register[1:])
155
156    if register[0] == 'r':
157      assert num in self.general
158      self.general.remove(num)
159    elif register[0] == 'd':
160      assert num in self.double
161      self.double.remove(num)
162    elif register[0] == 'q':
163      assert num * 2 in self.double
164      assert num * 2 + 1 in self.double
165      self.double.remove(num * 2)
166      self.double.remove(num * 2 + 1)
167    else:
168      raise RegisterDeallocationError('Register not allocated: %s' % register)
169
170  def FreeRegisters(self, registers):
171    for register in registers:
172      self.FreeRegister(register)
173
174
175class NeonEmitter(object):
176  """Emits ARM/NEON assembly opcodes."""
177
178  def __init__(self, debug=False):
179    self.ops = {}
180    self.indent = ''
181    self.debug = debug
182
183  def PushIndent(self, delta='  '):
184    self.indent += delta
185
186  def PopIndent(self, delta=2):
187    self.indent = self.indent[:-delta]
188
189  def EmitIndented(self, what):
190    print self.indent + what
191
192  def PushOp(self, op):
193    if op in self.ops.keys():
194      self.ops[op] += 1
195    else:
196      self.ops[op] = 1
197
198  def ClearCounters(self):
199    self.ops.clear()
200
201  def EmitNewline(self):
202    print ''
203
204  def EmitPreprocessor1(self, op, param):
205    print '#%s %s' % (op, param)
206
207  def EmitPreprocessor(self, op):
208    print '#%s' % op
209
210  def EmitInclude(self, include):
211    self.EmitPreprocessor1('include', include)
212
213  def EmitCall1(self, function, param):
214    self.EmitIndented('%s(%s);' % (function, param))
215
216  def EmitAssert(self, assert_expression):
217    if self.debug:
218      self.EmitCall1('assert', assert_expression)
219
220  def EmitHeaderBegin(self, header_name, includes):
221    self.EmitPreprocessor1('ifndef', (header_name + '_H_').upper())
222    self.EmitPreprocessor1('define', (header_name + '_H_').upper())
223    self.EmitNewline()
224    if includes:
225      for include in includes:
226        self.EmitInclude(include)
227      self.EmitNewline()
228
229  def EmitHeaderEnd(self):
230    self.EmitPreprocessor('endif')
231
232  def EmitCode(self, code):
233    self.EmitIndented('%s;' % code)
234
235  def EmitFunctionBeginA(self, function_name, params, return_type):
236    self.EmitIndented('%s %s(%s) {' %
237                      (return_type, function_name,
238                       ', '.join(['%s %s' % (t, n) for (t, n) in params])))
239    self.PushIndent()
240
241  def EmitFunctionEnd(self):
242    self.PopIndent()
243    self.EmitIndented('}')
244
245  def EmitAsmBegin(self):
246    self.EmitIndented('asm volatile(')
247    self.PushIndent()
248
249  def EmitAsmMapping(self, elements):
250    if elements:
251      self.EmitIndented(': ' + ', '.join(
252          ['[%s] "%s"(%s)' % (d, v[1], v[0]) for (d, v) in elements]))
253    else:
254      self.EmitIndented(':')
255
256  def EmitClobbers(self, elements):
257    if elements:
258      self.EmitIndented(': ' + ', '.join(['"%s"' % c for c in elements]))
259    else:
260      self.EmitIndented(':')
261
262  def EmitAsmEnd(self, registers):
263    self.EmitAsmMapping(registers.MappedOutputParameters())
264    self.EmitAsmMapping(registers.MappedParameters())
265    self.EmitClobbers(registers.Clobbers() + ['cc', 'memory'])
266    self.PopIndent()
267    self.EmitIndented(');')
268
269  def EmitComment(self, comment):
270    self.EmitIndented('// ' + comment)
271
272  def EmitNumericalLabel(self, label):
273    self.EmitIndented('"%d:"' % label)
274
275  def EmitOp1(self, op, param1):
276    self.PushOp(op)
277    self.EmitIndented('"%s %s\\n"' % (op, param1))
278
279  def EmitOp2(self, op, param1, param2):
280    self.PushOp(op)
281    self.EmitIndented('"%s %s, %s\\n"' % (op, param1, param2))
282
283  def EmitOp3(self, op, param1, param2, param3):
284    self.PushOp(op)
285    self.EmitIndented('"%s %s, %s, %s\\n"' % (op, param1, param2, param3))
286
287  def EmitAdd(self, destination, source, param):
288    self.EmitOp3('add', destination, source, param)
289
290  def EmitSubs(self, destination, source, param):
291    self.EmitOp3('subs', destination, source, param)
292
293  def EmitSub(self, destination, source, param):
294    self.EmitOp3('sub', destination, source, param)
295
296  def EmitMul(self, destination, source, param):
297    self.EmitOp3('mul', destination, source, param)
298
299  def EmitMov(self, param1, param2):
300    self.EmitOp2('mov', param1, param2)
301
302  def EmitBeqBack(self, label):
303    self.EmitOp1('beq', '%db' % label)
304
305  def EmitBeqFront(self, label):
306    self.EmitOp1('beq', '%df' % label)
307
308  def EmitBgtBack(self, label):
309    self.EmitOp1('bgt', '%db' % label)
310
311  def EmitBgtFront(self, label):
312    self.EmitOp1('bgt', '%df' % label)
313
314  def EmitBleBack(self, label):
315    self.EmitOp1('ble', '%db' % label)
316
317  def EmitBleFront(self, label):
318    self.EmitOp1('ble', '%df' % label)
319
320  def EmitBneBack(self, label):
321    self.EmitOp1('bne', '%db' % label)
322
323  def EmitBneFront(self, label):
324    self.EmitOp1('bne', '%df' % label)
325
326  def EmitVAdd(self, add_type, destination, source_1, source_2):
327    destination, source_1, source_2 = _MakeCompatible(destination, source_1,
328                                                      source_2)
329    self.EmitOp3('vadd.%s' % add_type, destination, source_1, source_2)
330
331  def EmitVAddw(self, add_type, destination, source_1, source_2):
332    self.EmitOp3('vaddw.%s' % add_type, destination, source_1, source_2)
333
334  def EmitVSub(self, sub_type, destination, source_1, source_2):
335    destination, source_1, source_2 = _MakeCompatible(destination, source_1,
336                                                      source_2)
337    self.EmitOp3('vsub.%s' % sub_type, destination, source_1, source_2)
338
339  def EmitVCvt(self, cvt_to, cvt_from, destination, source):
340    self.EmitOp2('vcvt.%s.%s' % (cvt_to, cvt_from), destination, source)
341
342  def EmitVDup(self, dup_type, destination, source):
343    self.EmitOp2('vdup.%s' % dup_type, destination, source)
344
345  def EmitVMax(self, size, destination, source_1, source_2):
346    self.EmitOp3('vmax.%s' % size, destination, source_1, source_2)
347
348  def EmitVMin(self, size, destination, source_1, source_2):
349    self.EmitOp3('vmin.%s' % size, destination, source_1, source_2)
350
351  def EmitVMov(self, mov_type, destination, source):
352    self.EmitOp2('vmov.%s' % mov_type, destination, source)
353
354  def EmitVMovl(self, mov_type, destination, source):
355    if source[0] == 'q':
356      source = _Low(source)
357    self.EmitOp2('vmovl.%s' % mov_type, destination, source)
358
359  def EmitVMovl2(self, mov_type, destination_1, destination_2, source):
360    self.EmitVMovl(mov_type, destination_2, _High(source))
361    self.EmitVMovl(mov_type, destination_1, _Low(source))
362
363  def EmitVQmovn(self, mov_type, destination, source):
364    if destination[0] == 'q':
365      destination = _Low(destination)
366    self.EmitOp2('vqmovn.%s' % mov_type, destination, source)
367
368  def EmitVQmovn2(self, mov_type, destination, source_1, source_2):
369    self.EmitVQmovn(mov_type, _Low(destination), source_1)
370    self.EmitVQmovn(mov_type, _High(destination), source_2)
371
372  def EmitVQmovun(self, mov_type, destination, source):
373    if destination[0] == 'q':
374      destination = _Low(destination)
375    self.EmitOp2('vqmovun.%s' % mov_type, destination, source)
376
377  def EmitVQmovun2(self, mov_type, destination, source_1, source_2):
378    self.EmitVQmovun(mov_type, _Low(destination), source_1)
379    self.EmitVQmovun(mov_type, _High(destination), source_2)
380
381  def EmitVMul(self, mul_type, destination, source_1, source_2):
382    destination, source_1, source_2 = _MakeCompatible(destination, source_1,
383                                                      source_2)
384    self.EmitOp3('vmul.%s' % mul_type, destination, source_1, source_2)
385
386  def EmitVMulScalar(self, mul_type, destination, source_1, source_2):
387    self.EmitOp3('vmul.%s' % mul_type, destination, source_1, source_2)
388
389  def EmitVMull(self, mul_type, destination, source_1, source_2):
390    self.EmitOp3('vmull.%s' % mul_type, destination, source_1, source_2)
391
392  def EmitVPadd(self, add_type, destination, source_1, source_2):
393    self.EmitOp3('vpadd.%s' % add_type, destination, source_1, source_2)
394
395  def EmitVPaddl(self, add_type, destination, source):
396    self.EmitOp2('vpaddl.%s' % add_type, destination, source)
397
398  def EmitVPadal(self, add_type, destination, source):
399    self.EmitOp2('vpadal.%s' % add_type, destination, source)
400
401  def EmitLdr(self, register, value):
402    self.EmitOp2('ldr', register, value)
403
404  def EmitVLoad(self, load_no, load_type, destination, source):
405    self.EmitVLoadA(load_no, load_type, [destination], source)
406
407  def EmitVLoadA(self, load_no, load_type, destinations, source):
408    self.EmitOp2('vld%d.%d' % (load_no, load_type),
409                 '{%s}' % ', '.join(_ExpandQuads(destinations)), source)
410
411  def EmitVLoadAE(self,
412                  load_type,
413                  elem_count,
414                  destinations,
415                  source,
416                  alignment=None):
417    bits_to_load = load_type * elem_count
418    destinations = _ExpandQuads(destinations)
419    if len(destinations) * 64 < bits_to_load:
420      raise ArgumentError('To few destinations: %d to load %d bits.' %
421                          (len(destinations), bits_to_load))
422
423    while bits_to_load > 0:
424      if bits_to_load >= 256:
425        self.EmitVLoadA(1, 32, destinations[:4],
426                        self.DereferenceIncrement(source, alignment))
427        bits_to_load -= 256
428        destinations = destinations[4:]
429      elif bits_to_load >= 192:
430        self.EmitVLoadA(1, 32, destinations[:3],
431                        self.DereferenceIncrement(source, alignment))
432        bits_to_load -= 192
433        destinations = destinations[3:]
434      elif bits_to_load >= 128:
435        self.EmitVLoadA(1, 32, destinations[:2],
436                        self.DereferenceIncrement(source, alignment))
437        bits_to_load -= 128
438        destinations = destinations[2:]
439      elif bits_to_load >= 64:
440        self.EmitVLoad(1, 32, destinations[0],
441                       self.DereferenceIncrement(source, alignment))
442        bits_to_load -= 64
443        destinations = destinations[1:]
444      else:
445        destination = destinations[0]
446        if bits_to_load == 56:
447          self.EmitVLoad(1, 32,
448                         self.Lane(32, destination, 0),
449                         self.DereferenceIncrement(source))
450          self.EmitVLoad(1, 16,
451                         self.Lane(16, destination, 2),
452                         self.DereferenceIncrement(source))
453          self.EmitVLoad(1, 8,
454                         self.Lane(8, destination, 6),
455                         self.DereferenceIncrement(source))
456        elif bits_to_load == 48:
457          self.EmitVLoad(1, 32,
458                         self.Lane(32, destination, 0),
459                         self.DereferenceIncrement(source))
460          self.EmitVLoad(1, 16,
461                         self.Lane(16, destination, 2),
462                         self.DereferenceIncrement(source))
463        elif bits_to_load == 40:
464          self.EmitVLoad(1, 32,
465                         self.Lane(32, destination, 0),
466                         self.DereferenceIncrement(source))
467          self.EmitVLoad(1, 8,
468                         self.Lane(8, destination, 4),
469                         self.DereferenceIncrement(source))
470        elif bits_to_load == 32:
471          self.EmitVLoad(1, 32,
472                         self.Lane(32, destination, 0),
473                         self.DereferenceIncrement(source))
474        elif bits_to_load == 24:
475          self.EmitVLoad(1, 16,
476                         self.Lane(16, destination, 0),
477                         self.DereferenceIncrement(source))
478          self.EmitVLoad(1, 8,
479                         self.Lane(8, destination, 2),
480                         self.DereferenceIncrement(source))
481        elif bits_to_load == 16:
482          self.EmitVLoad(1, 16,
483                         self.Lane(16, destination, 0),
484                         self.DereferenceIncrement(source))
485        elif bits_to_load == 8:
486          self.EmitVLoad(1, 8,
487                         self.Lane(8, destination, 0),
488                         self.DereferenceIncrement(source))
489        else:
490          raise ArgumentError('Wrong leftover: %d' % bits_to_load)
491        return
492
493  def EmitVLoadE(self, load_type, count, destination, source, alignment=None):
494    self.EmitVLoadAE(load_type, count, [destination], source, alignment)
495
496  def EmitVLoadAllLanes(self, load_no, load_type, destination, source):
497    destinations = []
498    if destination[0] == 'q':
499      destinations.append(self.AllLanes(_Low(destination)))
500      destinations.append(self.AllLanes(_High(destination)))
501    else:
502      destinations.append(self.AllLanes(destination))
503    self.EmitVLoadA(load_no, load_type, destinations, source)
504
505  def EmitVLoadOffset(self, load_no, load_type, destination, source, offset):
506    self.EmitVLoadOffsetA(load_no, load_type, [destination], source, offset)
507
508  def EmitVLoadOffsetA(self, load_no, load_type, destinations, source, offset):
509    assert len(destinations) <= 4
510    self.EmitOp3('vld%d.%d' % (load_no, load_type),
511                 '{%s}' % ', '.join(_ExpandQuads(destinations)), source, offset)
512
513  def EmitPld(self, load_address_register):
514    self.EmitOp1('pld', '[%s]' % load_address_register)
515
516  def EmitPldw(self, store_address_register):
517    self.EmitOp1('pldw', '[%s]' % store_address_register)
518
519  def EmitPldOffset(self, load_address_register, offset):
520    self.EmitOp1('pld', '[%s, %s]' % (load_address_register, offset))
521
522  def EmitPldwOffset(self, store_address_register, offset):
523    self.EmitOp1('pldw', '[%s, %s]' % (store_address_register, offset))
524
525  def EmitVShl(self, shift_type, destination, source, shift):
526    self.EmitOp3('vshl.%s' % shift_type, destination, source, shift)
527
528  def EmitVStore(self, store_no, store_type, source, destination):
529    self.EmitVStoreA(store_no, store_type, [source], destination)
530
531  def EmitVStoreA(self, store_no, store_type, sources, destination):
532    self.EmitOp2('vst%d.%d' % (store_no, store_type),
533                 '{%s}' % ', '.join(_ExpandQuads(sources)), destination)
534
535  def EmitVStoreAE(self,
536                   store_type,
537                   elem_count,
538                   sources,
539                   destination,
540                   alignment=None):
541    bits_to_store = store_type * elem_count
542    sources = _ExpandQuads(sources)
543    if len(sources) * 64 < bits_to_store:
544      raise ArgumentError('To few sources: %d to store %d bits.' %
545                          (len(sources), bits_to_store))
546
547    while bits_to_store > 0:
548      if bits_to_store >= 256:
549        self.EmitVStoreA(1, 32, sources[:4],
550                         self.DereferenceIncrement(destination, alignment))
551        bits_to_store -= 256
552        sources = sources[4:]
553      elif bits_to_store >= 192:
554        self.EmitVStoreA(1, 32, sources[:3],
555                         self.DereferenceIncrement(destination, alignment))
556        bits_to_store -= 192
557        sources = sources[3:]
558      elif bits_to_store >= 128:
559        self.EmitVStoreA(1, 32, sources[:2],
560                         self.DereferenceIncrement(destination, alignment))
561        bits_to_store -= 128
562        sources = sources[2:]
563      elif bits_to_store >= 64:
564        self.EmitVStore(1, 32, sources[0],
565                        self.DereferenceIncrement(destination, alignment))
566        bits_to_store -= 64
567        sources = sources[1:]
568      else:
569        source = sources[0]
570        if bits_to_store == 56:
571          self.EmitVStore(1, 32,
572                          self.Lane(32, source, 0),
573                          self.DereferenceIncrement(destination))
574          self.EmitVStore(1, 16,
575                          self.Lane(16, source, 2),
576                          self.DereferenceIncrement(destination))
577          self.EmitVStore(1, 8,
578                          self.Lane(8, source, 6),
579                          self.DereferenceIncrement(destination))
580        elif bits_to_store == 48:
581          self.EmitVStore(1, 32,
582                          self.Lane(32, source, 0),
583                          self.DereferenceIncrement(destination))
584          self.EmitVStore(1, 16,
585                          self.Lane(16, source, 2),
586                          self.DereferenceIncrement(destination))
587        elif bits_to_store == 40:
588          self.EmitVStore(1, 32,
589                          self.Lane(32, source, 0),
590                          self.DereferenceIncrement(destination))
591          self.EmitVStore(1, 8,
592                          self.Lane(8, source, 4),
593                          self.DereferenceIncrement(destination))
594        elif bits_to_store == 32:
595          self.EmitVStore(1, 32,
596                          self.Lane(32, source, 0),
597                          self.DereferenceIncrement(destination))
598        elif bits_to_store == 24:
599          self.EmitVStore(1, 16,
600                          self.Lane(16, source, 0),
601                          self.DereferenceIncrement(destination))
602          self.EmitVStore(1, 8,
603                          self.Lane(8, source, 2),
604                          self.DereferenceIncrement(destination))
605        elif bits_to_store == 16:
606          self.EmitVStore(1, 16,
607                          self.Lane(16, source, 0),
608                          self.DereferenceIncrement(destination))
609        elif bits_to_store == 8:
610          self.EmitVStore(1, 8,
611                          self.Lane(8, source, 0),
612                          self.DereferenceIncrement(destination))
613        else:
614          raise ArgumentError('Wrong leftover: %d' % bits_to_store)
615        return
616
617  def EmitVStoreE(self, store_type, count, source, destination, alignment=None):
618    self.EmitVStoreAE(store_type, count, [source], destination, alignment)
619
620  def EmitVStoreOffset(self, store_no, store_type, source, destination, offset):
621    self.EmitVStoreOffsetA(store_no, store_type, [source], destination, offset)
622
623  def EmitVStoreOffsetA(self, store_no, store_type, sources, destination,
624                        offset):
625    self.EmitOp3('vst%d.%d' % (store_no, store_type),
626                 '{%s}' % ', '.join(_ExpandQuads(sources)), destination, offset)
627
628  def EmitVStoreOffsetE(self, store_type, count, source, destination, offset):
629    """Emit assembly to store a number elements from the source registers."""
630    if store_type is not 32:
631      raise ArgumentError('Unsupported store_type: %d' % store_type)
632
633    sources = []
634    if source[0] == 'q':
635      sources.append(_Low(source))
636      sources.append(_High(source))
637      if count * store_type > 128:
638        raise ArgumentError('To many %dbit elements in a q register: %d' %
639                            (store_type, count))
640    else:
641      sources.append(source)
642      if count * store_type > 64:
643        raise ArgumentError('To many %dbit elements in a d register: %d' %
644                            (store_type, count))
645
646    if count == 1:
647      self.EmitVStoreOffset(1, store_type,
648                            self.Lane(store_type, sources[0], 0),
649                            self.Dereference(destination, None), offset)
650    elif count == 2:
651      self.EmitVStoreOffset(1, store_type, sources[0],
652                            self.Dereference(destination, None), offset)
653    elif count == 3:
654      self.EmitVStore(1, store_type, sources[0],
655                      self.DereferenceIncrement(destination, None))
656      self.EmitVStoreOffset(1, store_type,
657                            self.Lane(store_type, sources[1], 0),
658                            self.Dereference(destination, None), offset)
659      self.EmitSub(destination, destination, self.ImmediateConstant(8))
660    elif count == 4:
661      self.EmitVStoreOffsetA(1, store_type, sources,
662                             self.Dereference(destination, None), offset)
663    else:
664      raise ArgumentError('To many elements: %d' % count)
665
666  def EmitVSumReduce(self, reduce_type, elem_count, reduce_count, destinations,
667                     sources):
668    """Emit assembly for n-fold horizontal sum reduction."""
669    if reduce_type is not 'u32':
670      raise ArgumentError('Unsupported reduce: %s' % reduce_type)
671
672    sources = _ExpandQuads(sources)
673
674    destinations = _ExpandQuads(destinations)
675
676    if len(destinations) * 2 < elem_count:
677      raise ArgumentError('Not enough space in destination: %d vs %d' %
678                          (len(destinations) * 2, elem_count))
679
680    if len(sources) * 2 != elem_count * reduce_count:
681      raise ArgumentError('Wrong number of sources: %d vs %d' %
682                          (len(sources) * 2, elem_count * reduce_count))
683
684    if reduce_count <= 1:
685      raise ArgumentError('Unsupported reduce_count: %d' % reduce_count)
686
687    while reduce_count > 1:
688      if len(sources) % 2 == 1:
689        sources.append(sources[-1])
690
691      if reduce_count == 2:
692        for i in range(len(sources) / 2):
693          self.EmitVPadd(reduce_type, destinations[i], sources[2 * i],
694                         sources[2 * i + 1])
695        return
696      else:
697        sources_2 = []
698        for i in range(len(sources) / 2):
699          self.EmitVPadd(reduce_type, sources[2 * i], sources[2 * i],
700                         sources[2 * i + 1])
701          sources_2.append(sources[2 * i])
702        reduce_count /= 2
703        sources = sources_2
704
705  def EmitVUzp(self, uzp_type, operand_1, operand_2):
706    self.EmitOp2('vuzp.%d' % uzp_type, operand_1, operand_2)
707
708  def EmitVTrn(self, trn_type, operand_1, operand_2):
709    self.EmitOp2('vtrn.%d' % trn_type, operand_1, operand_2)
710
711  def EmitColBlockStride(self, cols, stride, new_stride):
712    assert cols in [1, 2, 3, 4, 5, 6, 7, 8]
713    if cols in [5, 6, 7]:
714      self.EmitSub(new_stride, stride, self.ImmediateConstant(4))
715
716  def EmitLoadColBlock(self, unused_registers, load_type, cols, elements, block,
717                       input_address, stride):
718    """Load a block of column major data."""
719    assert cols is len(block)
720    assert load_type is 8
721
722    input_deref = self.Dereference(input_address, None)
723    input_deref_increment = self.DereferenceIncrement(input_address, None)
724
725    if cols is 1:
726      for i in range(elements):
727        self.EmitVLoadOffset(1, 8,
728                             self.Lane(8, block[0], i), input_deref, stride)
729      self.EmitPld(input_address)
730    elif cols is 2:
731      for i in range(elements):
732        self.EmitVLoadOffset(1, 16,
733                             self.Lane(16, block[i / 4], i % 4), input_deref,
734                             stride)
735      self.EmitPld(input_address)
736      self.EmitVUzp(8, block[0], block[1])
737    elif cols is 3:
738      for i in range(elements):
739        self.EmitVLoadOffsetA(3, 8, [self.Lane(8, row, i) for row in block],
740                              input_deref, stride)
741    elif cols is 4:
742      for i in range(elements):
743        self.EmitVLoadOffset(1, 32,
744                             self.Lane(32, block[i % 4], i / 4), input_deref,
745                             stride)
746      self.EmitPld(input_address)
747      self.EmitVTrn(16, block[0], block[2])
748      self.EmitVTrn(16, block[1], block[3])
749      self.EmitVTrn(8, block[0], block[1])
750      self.EmitVTrn(8, block[2], block[3])
751    elif cols is 5:
752      for i in range(elements):
753        self.EmitVLoad(1, 32,
754                       self.Lane(32, block[i % 4], i / 4),
755                       input_deref_increment)
756        self.EmitVLoadOffset(1, 8,
757                             self.Lane(8, block[4], i), input_deref, stride)
758      self.EmitPld(input_address)
759      self.EmitVTrn(16, block[0], block[2])
760      self.EmitVTrn(16, block[1], block[3])
761      self.EmitVTrn(8, block[0], block[1])
762      self.EmitVTrn(8, block[2], block[3])
763    elif cols is 6:
764      for i in range(elements):
765        self.EmitVLoad(1, 32,
766                       self.Lane(32, block[i % 4], i / 4),
767                       input_deref_increment)
768        self.EmitVLoadOffset(1, 16,
769                             self.Lane(16, block[4 + i / 4], i % 4),
770                             input_deref, stride)
771      self.EmitPld(input_address)
772      self.EmitVTrn(16, block[0], block[2])
773      self.EmitVTrn(16, block[1], block[3])
774      self.EmitVUzp(8, block[4], block[5])
775      self.EmitVTrn(8, block[0], block[1])
776      self.EmitVTrn(8, block[2], block[3])
777    elif cols is 7:
778      for i in range(elements):
779        self.EmitVLoad(1, 32,
780                       self.Lane(32, block[i % 4], i / 4),
781                       input_deref_increment)
782        self.EmitVLoadOffsetA(3, 8,
783                              [self.Lane(8, row, i) for row in block[4:]],
784                              input_deref, stride)
785      self.EmitPld(input_address)
786      self.EmitVTrn(16, block[0], block[2])
787      self.EmitVTrn(16, block[1], block[3])
788      self.EmitVTrn(8, block[0], block[1])
789      self.EmitVTrn(8, block[2], block[3])
790    elif cols is 8:
791      for i in range(elements):
792        self.EmitVLoadOffset(1, 32, block[i], input_deref, stride)
793      self.EmitPld(input_address)
794      self.EmitVTrn(8, block[0], block[1])
795      self.EmitVTrn(8, block[2], block[3])
796      self.EmitVTrn(8, block[4], block[5])
797      self.EmitVTrn(8, block[6], block[7])
798      self.EmitVTrn(16, block[0], block[2])
799      self.EmitVTrn(16, block[1], block[3])
800      self.EmitVTrn(16, block[4], block[6])
801      self.EmitVTrn(16, block[5], block[7])
802      self.EmitVTrn(32, block[0], block[4])
803      self.EmitVTrn(32, block[1], block[5])
804      self.EmitVTrn(32, block[2], block[6])
805      self.EmitVTrn(32, block[3], block[7])
806    else:
807      assert False
808    return block
809
810  def Dereference(self, value, alignment=None):
811    if alignment:
812      return '[%s:%d]' % (value, alignment)
813    else:
814      return '[%s]' % value
815
816  def DereferenceIncrement(self, value, alignment=None):
817    return '%s!' % self.Dereference(value, alignment)
818
819  def ImmediateConstant(self, value):
820    return '#%d' % value
821
822  def AllLanes(self, value):
823    return '%s[]' % value
824
825  def Lane(self, bits, value, lane):
826    """Get the proper n-bit lane from the given register."""
827    registers = []
828    if value[0] == 'q':
829      registers.append(_Low(value))
830      registers.append(_High(value))
831    else:
832      registers.append(value)
833
834    elems_per_register = 64 / bits
835    register = lane / elems_per_register
836    lane %= elems_per_register
837
838    return '%s[%d]' % (registers[register], lane)
839
840  def CreateRegisters(self):
841    return _NeonRegisters32Bit()
842