1# Copyright 2016 The Gemmlowp Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""64bit ARM/NEON assembly emitter.
15
16Used by code generators to produce ARM assembly with NEON simd code.
17Provides tools for easier register management: named register variable
18allocation/deallocation, and offers a more procedural/structured approach
19to generating assembly.
20
21"""
22
23_WIDE_TYPES = {
24    8: 16,
25    16: 32,
26    32: 64,
27    '8': '16',
28    '16': '32',
29    '32': '64',
30    'i8': 'i16',
31    'i16': 'i32',
32    'i32': 'i64',
33    'u8': 'u16',
34    'u16': 'u32',
35    'u32': 'u64',
36    's8': 's16',
37    's16': 's32',
38    's32': 's64'
39}
40
41_NARROW_TYPES = {
42    64: 32,
43    32: 16,
44    16: 8,
45    '64': '32',
46    '32': '16',
47    '16': '8',
48    'i64': 'i32',
49    'i32': 'i16',
50    'i16': 'i8',
51    'u64': 'u32',
52    'u32': 'u16',
53    'u16': 'u8',
54    's64': 's32',
55    's32': 's16',
56    's16': 's8'
57}
58
59_TYPE_BITS = {
60    8: 8,
61    16: 16,
62    32: 32,
63    64: 64,
64    '8': 8,
65    '16': 16,
66    '32': 32,
67    '64': 64,
68    'i8': 8,
69    'i16': 16,
70    'i32': 32,
71    'i64': 64,
72    'u8': 8,
73    'u16': 16,
74    'u32': 32,
75    'u64': 64,
76    's8': 8,
77    's16': 16,
78    's32': 32,
79    's64': 64,
80    'f32': 32,
81    'f64': 64,
82    'b': 8,
83    'h': 16,
84    's': 32,
85    'd': 64
86}
87
88
89class Error(Exception):
90  """Module level error."""
91
92
93class RegisterAllocationError(Error):
94  """Cannot alocate registers."""
95
96
97class LaneError(Error):
98  """Wrong lane number."""
99
100
101class RegisterSubtypeError(Error):
102  """The register needs to be lane-typed."""
103
104
105class ArgumentError(Error):
106  """Wrong argument."""
107
108
109def _AppendType(type_name, register):
110  """Calculates sizes and attaches the type information to the register."""
111  if register.register_type is not 'v':
112    raise ArgumentError('Only vector registers can have type appended.')
113
114  if type_name in set([8, '8', 'i8', 's8', 'u8']):
115    subtype = 'b'
116    subtype_bits = 8
117  elif type_name in set([16, '16', 'i16', 's16', 'u16']):
118    subtype = 'h'
119    subtype_bits = 16
120  elif type_name in set([32, '32', 'i32', 's32', 'u32', 'f32']):
121    subtype = 's'
122    subtype_bits = 32
123  elif type_name in set([64, '64', 'i64', 's64', 'u64', 'f64']):
124    subtype = 'd'
125    subtype_bits = 64
126  else:
127    raise ArgumentError('Unknown type: %s' % type_name)
128
129  new_register = register.Copy()
130  new_register.register_subtype = subtype
131  new_register.register_subtype_count = register.register_bits / subtype_bits
132  return new_register
133
134
135def _UnsignedType(type_name):
136  return type_name in set(['u8', 'u16', 'u32', 'u64'])
137
138
139def _FloatType(type_name):
140  return type_name in set(['f32', 'f64'])
141
142
143def _WideType(type_name):
144  if type_name in _WIDE_TYPES.keys():
145    return _WIDE_TYPES[type_name]
146  else:
147    raise ArgumentError('No wide type for: %s' % type_name)
148
149
150def _NarrowType(type_name):
151  if type_name in _NARROW_TYPES.keys():
152    return _NARROW_TYPES[type_name]
153  else:
154    raise ArgumentError('No narrow type for: %s' % type_name)
155
156
157def _LoadStoreSize(register):
158  if register.lane is None:
159    return register.register_bits
160  else:
161    return register.lane_bits
162
163
164def _MakeCompatibleDown(reg_1, reg_2, reg_3):
165  bits = min([reg_1.register_bits, reg_2.register_bits, reg_3.register_bits])
166  return (_Cast(bits, reg_1), _Cast(bits, reg_2), _Cast(bits, reg_3))
167
168
169def _MakeCompatibleUp(reg_1, reg_2, reg_3):
170  bits = max([reg_1.register_bits, reg_2.register_bits, reg_3.register_bits])
171  return (_Cast(bits, reg_1), _Cast(bits, reg_2), _Cast(bits, reg_3))
172
173
174def _Cast(bits, reg):
175  if reg.register_bits is bits:
176    return reg
177  else:
178    new_reg = reg.Copy()
179    new_reg.register_bits = bits
180    return new_reg
181
182
183def _TypeBits(type_name):
184  if type_name in _TYPE_BITS.keys():
185    return _TYPE_BITS[type_name]
186  else:
187    raise ArgumentError('Unknown type: %s' % type_name)
188
189
190def _RegisterList(list_type, registers):
191  lanes = list(set([register.lane for register in registers]))
192  if len(lanes) > 1:
193    raise ArgumentError('Cannot mix lanes on a register list.')
194  typed_registers = [_AppendType(list_type, register) for register in registers]
195
196  if lanes[0] is None:
197    return '{%s}' % ', '.join(map(str, typed_registers))
198  elif lanes[0] is -1:
199    raise ArgumentError('Cannot construct a list with all lane indexing.')
200  else:
201    typed_registers_nolane = [register.Copy() for register in typed_registers]
202    for register in typed_registers_nolane:
203      register.lane = None
204      register.register_subtype_count = None
205    return '{%s}[%d]' % (', '.join(map(str, typed_registers_nolane)), lanes[0])
206
207
208class _GeneralRegister(object):
209  """Arm v8 general register: (x|w)n."""
210
211  def __init__(self,
212               register_bits,
213               number,
214               dereference=False,
215               dereference_increment=False):
216    self.register_type = 'r'
217    self.register_bits = register_bits
218    self.number = number
219    self.dereference = dereference
220    self.dereference_increment = dereference_increment
221
222  def Copy(self):
223    return _GeneralRegister(self.register_bits, self.number, self.dereference,
224                            self.dereference_increment)
225
226  def __repr__(self):
227    if self.register_bits is 64:
228      text = 'x%d' % self.number
229    elif self.register_bits <= 32:
230      text = 'w%d' % self.number
231    else:
232      raise RegisterSubtypeError('Wrong bits (%d) for general register: %d' %
233                                 (self.register_bits, self.number))
234    if self.dereference:
235      return '[%s]' % text
236    else:
237      return text
238
239
240class _MappedParameter(object):
241  """Object representing a C variable mapped to a register."""
242
243  def __init__(self,
244               name,
245               register_bits=64,
246               dereference=False,
247               dereference_increment=False):
248    self.name = name
249    self.register_bits = register_bits
250    self.dereference = dereference
251    self.dereference_increment = dereference_increment
252
253  def Copy(self):
254    return _MappedParameter(self.name, self.register_bits, self.dereference,
255                            self.dereference_increment)
256
257  def __repr__(self):
258    if self.register_bits is None:
259      text = '%%[%s]' % self.name
260    elif self.register_bits is 64:
261      text = '%%x[%s]' % self.name
262    elif self.register_bits <= 32:
263      text = '%%w[%s]' % self.name
264    else:
265      raise RegisterSubtypeError('Wrong bits (%d) for mapped parameter: %s' %
266                                 (self.register_bits, self.name))
267    if self.dereference:
268      return '[%s]' % text
269    else:
270      return text
271
272
273class _VectorRegister(object):
274  """Arm v8 vector register Vn.TT."""
275
276  def __init__(self,
277               register_bits,
278               number,
279               register_subtype=None,
280               register_subtype_count=None,
281               lane=None,
282               lane_bits=None):
283    self.register_type = 'v'
284    self.register_bits = register_bits
285    self.number = number
286    self.register_subtype = register_subtype
287    self.register_subtype_count = register_subtype_count
288    self.lane = lane
289    self.lane_bits = lane_bits
290
291  def Copy(self):
292    return _VectorRegister(self.register_bits, self.number,
293                           self.register_subtype, self.register_subtype_count,
294                           self.lane, self.lane_bits)
295
296  def __repr__(self):
297    if self.register_subtype is None:
298      raise RegisterSubtypeError('Register: %s%d has no lane types defined.' %
299                                 (self.register_type, self.number))
300    if (self.register_subtype_count is None or (self.lane is not None and
301                                                self.lane is not -1)):
302      typed_name = '%s%d.%s' % (self.register_type, self.number,
303                                self.register_subtype)
304    else:
305      typed_name = '%s%d.%d%s' % (self.register_type, self.number,
306                                  self.register_subtype_count,
307                                  self.register_subtype)
308
309    if self.lane is None or self.lane is -1:
310      return typed_name
311    elif self.lane >= 0 and self.lane < self.register_subtype_count:
312      return '%s[%d]' % (typed_name, self.lane)
313    else:
314      raise LaneError('Wrong lane: %d for: %s' % (self.lane, typed_name))
315
316
317class _ImmediateConstant(object):
318
319  def __init__(self, value):
320    self.register_type = 'i'
321    self.value = value
322
323  def Copy(self):
324    return _ImmediateConstant(self.value)
325
326  def __repr__(self):
327    return '#%d' % self.value
328
329
330class _NeonRegisters64Bit(object):
331  """Utility that keeps track of used 32bit ARM/NEON registers."""
332
333  def __init__(self):
334    self.vector = set()
335    self.vector_ever = set()
336    self.general = set()
337    self.general_ever = set()
338    self.parameters = dict()
339    self.output_parameters = dict()
340
341  def MapParameter(self, parameter, parameter_value=None):
342    if not parameter_value:
343      parameter_value = parameter
344    self.parameters[parameter] = (parameter_value, 'r')
345    return _MappedParameter(parameter)
346
347  def MapMemoryParameter(self, parameter, parameter_value=None):
348    if not parameter_value:
349      parameter_value = parameter
350    self.parameters[parameter] = (parameter_value, 'm')
351    return _MappedParameter(parameter)
352
353  def MapOutputParameter(self, parameter, parameter_value=None):
354    if not parameter_value:
355      parameter_value = parameter
356    self.output_parameters[parameter] = (parameter_value, '+r')
357    return _MappedParameter(parameter)
358
359  def _VectorRegisterNum(self, min_val=0):
360    for i in range(min_val, 32):
361      if i not in self.vector:
362        self.vector.add(i)
363        self.vector_ever.add(i)
364        return i
365    raise RegisterAllocationError('Not enough vector registers.')
366
367  def DoubleRegister(self, min_val=0):
368    return _VectorRegister(64, self._VectorRegisterNum(min_val))
369
370  def QuadRegister(self, min_val=0):
371    return _VectorRegister(128, self._VectorRegisterNum(min_val))
372
373  def GeneralRegister(self):
374    for i in range(0, 30):
375      if i not in self.general:
376        self.general.add(i)
377        self.general_ever.add(i)
378        return _GeneralRegister(64, i)
379    raise RegisterAllocationError('Not enough general registers.')
380
381  def MappedParameters(self):
382    return [x for x in self.parameters.items()]
383
384  def MappedOutputParameters(self):
385    return [x for x in self.output_parameters.items()]
386
387  def Clobbers(self):
388    return (
389        ['x%d' % i
390         for i in self.general_ever] + ['v%d' % i for i in self.vector_ever])
391
392  def FreeRegister(self, register):
393    if isinstance(register, _MappedParameter):
394      return
395
396    if register.register_type == 'v':
397      assert register.number in self.vector
398      self.vector.remove(register.number)
399    elif register.register_type == 'r':
400      assert register.number in self.general
401      self.general.remove(register.number)
402    else:
403      raise RegisterAllocationError('Register not allocated: %s%d' %
404                                    (register.register_type, register.number))
405
406  def FreeRegisters(self, registers):
407    for register in registers:
408      self.FreeRegister(register)
409
410
411class NeonEmitter64(object):
412  """Emits ARM/NEON 64bit assembly opcodes."""
413
414  def __init__(self, debug=False):
415    self.ops = {}
416    self.indent = ''
417    self.debug = debug
418
419  def PushIndent(self, delta_indent='  '):
420    self.indent += delta_indent
421
422  def PopIndent(self, delta=2):
423    self.indent = self.indent[:-delta]
424
425  def EmitIndented(self, what):
426    print self.indent + what
427
428  def PushOp(self, op):
429    if op in self.ops.keys():
430      self.ops[op] += 1
431    else:
432      self.ops[op] = 1
433
434  def ClearCounters(self):
435    self.ops.clear()
436
437  def EmitNewline(self):
438    print ''
439
440  def EmitPreprocessor1(self, op, param):
441    print '#%s %s' % (op, param)
442
443  def EmitPreprocessor(self, op):
444    print '#%s' % op
445
446  def EmitInclude(self, include):
447    self.EmitPreprocessor1('include', include)
448
449  def EmitCall1(self, function, param):
450    self.EmitIndented('%s(%s);' % (function, param))
451
452  def EmitAssert(self, assert_expression):
453    if self.debug:
454      self.EmitCall1('assert', assert_expression)
455
456  def EmitHeaderBegin(self, header_name, includes):
457    self.EmitPreprocessor1('ifndef', (header_name + '_H_').upper())
458    self.EmitPreprocessor1('define', (header_name + '_H_').upper())
459    self.EmitNewline()
460    if includes:
461      for include in includes:
462        self.EmitInclude(include)
463      self.EmitNewline()
464
465  def EmitHeaderEnd(self):
466    self.EmitPreprocessor('endif')
467
468  def EmitCode(self, code):
469    self.EmitIndented('%s;' % code)
470
471  def EmitFunctionBeginA(self, function_name, params, return_type):
472    self.EmitIndented('%s %s(%s) {' %
473                      (return_type, function_name,
474                       ', '.join(['%s %s' % (t, n) for (t, n) in params])))
475    self.PushIndent()
476
477  def EmitFunctionEnd(self):
478    self.PopIndent()
479    self.EmitIndented('}')
480
481  def EmitAsmBegin(self):
482    self.EmitIndented('asm volatile(')
483    self.PushIndent()
484
485  def EmitAsmMapping(self, elements):
486    if elements:
487      self.EmitIndented(': ' + ', '.join(
488          ['[%s] "%s"(%s)' % (k, v[1], v[0]) for (k, v) in elements]))
489    else:
490      self.EmitIndented(':')
491
492  def EmitClobbers(self, elements):
493    if elements:
494      self.EmitIndented(': ' + ', '.join(['"%s"' % c for c in elements]))
495    else:
496      self.EmitIndented(':')
497
498  def EmitAsmEnd(self, registers):
499    self.EmitAsmMapping(registers.MappedOutputParameters())
500    self.EmitAsmMapping(registers.MappedParameters())
501    self.EmitClobbers(registers.Clobbers() + ['cc', 'memory'])
502    self.PopIndent()
503    self.EmitIndented(');')
504
505  def EmitComment(self, comment):
506    self.EmitIndented('// ' + comment)
507
508  def EmitNumericalLabel(self, label):
509    self.EmitIndented('"%d:"' % label)
510
511  def EmitOp1(self, op, param1):
512    self.PushOp(op)
513    self.EmitIndented('"%s %s\\n"' % (op, param1))
514
515  def EmitOp2(self, op, param1, param2):
516    self.PushOp(op)
517    self.EmitIndented('"%s %s, %s\\n"' % (op, param1, param2))
518
519  def EmitOp3(self, op, param1, param2, param3):
520    self.PushOp(op)
521    self.EmitIndented('"%s %s, %s, %s\\n"' % (op, param1, param2, param3))
522
523  def EmitAdd(self, destination, source, param):
524    self.EmitOp3('add', destination, source, param)
525
526  def EmitSubs(self, destination, source, param):
527    self.EmitOp3('subs', destination, source, param)
528
529  def EmitSub(self, destination, source, param):
530    self.EmitOp3('sub', destination, source, param)
531
532  def EmitMul(self, destination, source, param):
533    self.EmitOp3('mul', destination, source, param)
534
535  def EmitMov(self, param1, param2):
536    self.EmitOp2('mov', param1, param2)
537
538  def EmitVMovl(self, mov_type, destination, source):
539    wide_type = _WideType(mov_type)
540    destination = _AppendType(wide_type, destination)
541    source = _AppendType(mov_type, _Cast(source.register_bits / 2, source))
542    if _UnsignedType(mov_type):
543      self.EmitOp2('uxtl', destination, source)
544    else:
545      self.EmitOp2('sxtl', destination, source)
546
547  def EmitVMovl2(self, mov_type, destination_1, destination_2, source):
548    wide_type = _WideType(mov_type)
549    if (destination_1.register_bits != source.register_bits or
550        destination_2.register_bits != source.register_bits):
551      raise ArgumentError('Register sizes do not match.')
552    if _UnsignedType(mov_type):
553      self.EmitOp2('uxtl2',
554                   _AppendType(wide_type, destination_2),
555                   _AppendType(mov_type, source))
556      self.EmitOp2('uxtl',
557                   _AppendType(wide_type, destination_1),
558                   _AppendType(mov_type,
559                               _Cast(source.register_bits / 2, source)))
560    else:
561      self.EmitOp2('sxtl2',
562                   _AppendType(wide_type, destination_2),
563                   _AppendType(mov_type, source))
564      self.EmitOp2('sxtl',
565                   _AppendType(wide_type, destination_1),
566                   _AppendType(mov_type,
567                               _Cast(source.register_bits / 2, source)))
568
569  def EmitVMax(self, max_type, destination, source_1, source_2):
570    if _UnsignedType(max_type):
571      self.EmitOp3('umax',
572                   _AppendType(max_type, destination),
573                   _AppendType(max_type, source_1),
574                   _AppendType(max_type, source_2))
575    else:
576      self.EmitOp3('smax',
577                   _AppendType(max_type, destination),
578                   _AppendType(max_type, source_1),
579                   _AppendType(max_type, source_2))
580
581  def EmitVMin(self, min_type, destination, source_1, source_2):
582    if _UnsignedType(min_type):
583      self.EmitOp3('umin',
584                   _AppendType(min_type, destination),
585                   _AppendType(min_type, source_1),
586                   _AppendType(min_type, source_2))
587    else:
588      self.EmitOp3('smin',
589                   _AppendType(min_type, destination),
590                   _AppendType(min_type, source_1),
591                   _AppendType(min_type, source_2))
592
593  def EmitBeqBack(self, label):
594    self.EmitOp1('beq', '%db' % label)
595
596  def EmitBeqFront(self, label):
597    self.EmitOp1('beq', '%df' % label)
598
599  def EmitBgtBack(self, label):
600    self.EmitOp1('bgt', '%db' % label)
601
602  def EmitBgtFront(self, label):
603    self.EmitOp1('bgt', '%df' % label)
604
605  def EmitBleBack(self, label):
606    self.EmitOp1('ble', '%db' % label)
607
608  def EmitBleFront(self, label):
609    self.EmitOp1('ble', '%df' % label)
610
611  def EmitBneBack(self, label):
612    self.EmitOp1('bne', '%db' % label)
613
614  def EmitBneFront(self, label):
615    self.EmitOp1('bne', '%df' % label)
616
617  def EmitVAdd(self, add_type, destination, source_1, source_2):
618    destination, source_1, source_2 = _MakeCompatibleDown(destination, source_1,
619                                                          source_2)
620    if _FloatType(add_type):
621      self.EmitOp3('fadd',
622                   _AppendType(add_type, destination),
623                   _AppendType(add_type, source_1),
624                   _AppendType(add_type, source_2))
625    else:
626      self.EmitOp3('add',
627                   _AppendType(add_type, destination),
628                   _AppendType(add_type, source_1),
629                   _AppendType(add_type, source_2))
630
631  def EmitVAddw(self, add_type, destination, source_1, source_2):
632    wide_type = _WideType(add_type)
633    destination = _AppendType(wide_type, destination)
634    source_1 = _AppendType(wide_type, source_1)
635    source_2 = _AppendType(add_type, source_2)
636    if _UnsignedType(add_type):
637      self.EmitOp3('uaddw', destination, source_1, source_2)
638    else:
639      self.EmitOp3('saddw', destination, source_1, source_2)
640
641  def EmitVSub(self, sub_type, destination, source_1, source_2):
642    destination, source_1, source_2 = _MakeCompatibleDown(destination, source_1,
643                                                          source_2)
644    if _FloatType(sub_type):
645      self.EmitOp3('fsub',
646                   _AppendType(sub_type, destination),
647                   _AppendType(sub_type, source_1),
648                   _AppendType(sub_type, source_2))
649    else:
650      self.EmitOp3('sub',
651                   _AppendType(sub_type, destination),
652                   _AppendType(sub_type, source_1),
653                   _AppendType(sub_type, source_2))
654
655  def EmitVCvt(self, cvt_to, cvt_from, destination, source):
656    if cvt_to == 'f32' and cvt_from == 's32':
657      self.EmitOp2('scvtf',
658                   _AppendType('f32', destination), _AppendType('s32', source))
659    elif cvt_to == 'f32' and cvt_from == 'u32':
660      self.EmitOp2('ucvtf',
661                   _AppendType('f32', destination), _AppendType('u32', source))
662    elif cvt_to == 's32' and cvt_from == 'f32':
663      self.EmitOp2('fcvtzs',
664                   _AppendType('s32', destination), _AppendType('f32', source))
665    else:
666      raise ArgumentError('Convert not supported, to: %s from: %s' % (cvt_to,
667                                                                      cvt_from))
668
669  def EmitVDup(self, dup_type, destination, source):
670    if (isinstance(source, _GeneralRegister) or
671        isinstance(source, _MappedParameter)):
672      self.EmitOp2('dup',
673                   _AppendType(dup_type, destination),
674                   _Cast(_TypeBits(dup_type), source))
675    else:
676      self.EmitOp2('dup',
677                   _AppendType(dup_type, destination),
678                   _AppendType(dup_type, source))
679
680  def EmitVMov(self, mov_type, destination, source):
681    if isinstance(source, _ImmediateConstant):
682      self.EmitOp2('movi', _AppendType(mov_type, destination), source)
683    elif (isinstance(source, _GeneralRegister) or
684          isinstance(source, _MappedParameter)):
685      self.EmitOp2('mov',
686                   _AppendType(mov_type, destination),
687                   _Cast(_TypeBits(mov_type), source))
688    else:
689      self.EmitOp2('mov', _AppendType(8, destination), _AppendType(8, source))
690
691  def EmitVQmovn(self, mov_type, destination, source):
692    narrow_type = _NarrowType(mov_type)
693    if destination.register_bits * 2 == source.register_bits:
694      self.EmitOp2('sqxtn',
695                   _AppendType(narrow_type, destination),
696                   _AppendType(mov_type, source))
697    elif destination.register_bits == source.register_bits:
698      self.EmitOp2('sqxtn',
699                   _AppendType(narrow_type,
700                               _Cast(destination.register_bits / 2,
701                                     destination)),
702                   _AppendType(mov_type, source))
703
704  def EmitVQmovn2(self, mov_type, destination, source_1, source_2):
705    narrow_type = _NarrowType(mov_type)
706    if (destination.register_bits != source_1.register_bits or
707        destination.register_bits != source_2.register_bits):
708      raise ArgumentError('Register sizes do not match.')
709    self.EmitOp2('sqxtn',
710                 _AppendType(narrow_type,
711                             _Cast(destination.register_bits / 2, destination)),
712                 _AppendType(mov_type, source_1))
713    self.EmitOp2('sqxtn2',
714                 _AppendType(narrow_type, destination),
715                 _AppendType(mov_type, source_2))
716
717  def EmitVQmovun(self, mov_type, destination, source):
718    narrow_type = _NarrowType(mov_type)
719    if destination.register_bits * 2 == source.register_bits:
720      self.EmitOp2('sqxtun',
721                   _AppendType(narrow_type, destination),
722                   _AppendType(mov_type, source))
723    elif destination.register_bits == source.register_bits:
724      self.EmitOp2('sqxtun',
725                   _AppendType(narrow_type,
726                               _Cast(destination.register_bits / 2,
727                                     destination)),
728                   _AppendType(mov_type, source))
729
730  def EmitVQmovun2(self, mov_type, destination, source_1, source_2):
731    narrow_type = _NarrowType(mov_type)
732    if (destination.register_bits != source_1.register_bits or
733        destination.register_bits != source_2.register_bits):
734      raise ArgumentError('Register sizes do not match.')
735    self.EmitOp2('sqxtun',
736                 _AppendType(narrow_type,
737                             _Cast(destination.register_bits / 2, destination)),
738                 _AppendType(mov_type, source_1))
739    self.EmitOp2('sqxtun2',
740                 _AppendType(narrow_type, destination),
741                 _AppendType(mov_type, source_2))
742
743  def EmitVMul(self, mul_type, destination, source_1, source_2):
744    destination, source_1, source_2 = _MakeCompatibleDown(destination, source_1,
745                                                          source_2)
746    if _FloatType(mul_type):
747      self.EmitOp3('fmul',
748                   _AppendType(mul_type, destination),
749                   _AppendType(mul_type, source_1),
750                   _AppendType(mul_type, source_2))
751    else:
752      self.EmitOp3('mul',
753                   _AppendType(mul_type, destination),
754                   _AppendType(mul_type, source_1),
755                   _AppendType(mul_type, source_2))
756
757  def EmitVMulScalar(self, mul_type, destination, source_1, source_2):
758    self.EmitOp3('mul',
759                 _AppendType(mul_type, destination),
760                 _AppendType(mul_type, source_1),
761                 _AppendType(mul_type, source_2))
762
763  def EmitVMull(self, mul_type, destination, source_1, source_2):
764    wide_type = _WideType(mul_type)
765    if _UnsignedType(mul_type):
766      self.EmitOp3('umull',
767                   _AppendType(wide_type, destination),
768                   _AppendType(mul_type, source_1),
769                   _AppendType(mul_type, source_2))
770    else:
771      self.EmitOp3('smull',
772                   _AppendType(wide_type, destination),
773                   _AppendType(mul_type, source_1),
774                   _AppendType(mul_type, source_2))
775
776  def EmitVPadd(self, add_type, destination, source_1, source_2):
777    self.EmitOp3('addp',
778                 _AppendType(add_type, destination),
779                 _AppendType(add_type, source_1),
780                 _AppendType(add_type, source_2))
781
782  def EmitVPaddl(self, add_type, destination, source):
783    wide_type = _WideType(add_type)
784    if _UnsignedType(add_type):
785      self.EmitOp2('uaddlp',
786                   _AppendType(wide_type, destination),
787                   _AppendType(add_type, source))
788    else:
789      self.EmitOp2('saddlp',
790                   _AppendType(wide_type, destination),
791                   _AppendType(add_type, source))
792
793  def EmitVPadal(self, add_type, destination, source):
794    wide_type = _WideType(add_type)
795    if _UnsignedType(add_type):
796      self.EmitOp2('uadalp',
797                   _AppendType(wide_type, destination),
798                   _AppendType(add_type, source))
799    else:
800      self.EmitOp2('sadalp',
801                   _AppendType(wide_type, destination),
802                   _AppendType(add_type, source))
803
804  def EmitLdr(self, register, value):
805    self.EmitOp2('ldr', _Cast(32, register), _Cast(None, value))
806
807  def EmitVLoad(self, load_no, load_type, destination, source):
808    self.EmitVLoadA(load_no, load_type, [destination], source)
809
810  def EmitVLoadA(self, load_no, load_type, destinations, source):
811    if source.dereference_increment:
812      increment = sum(
813          [_LoadStoreSize(destination) for destination in destinations]) / 8
814      self.EmitVLoadAPostIncrement(load_no, load_type, destinations, source,
815                                   self.ImmediateConstant(increment))
816    else:
817      self.EmitVLoadAPostIncrement(load_no, load_type, destinations, source,
818                                   None)
819
820  def EmitVLoadAPostIncrement(self, load_no, load_type, destinations, source,
821                              increment):
822    """Generate assembly to load memory to registers and increment source."""
823    if len(destinations) == 1 and destinations[0].lane is -1:
824      destination = '{%s}' % _AppendType(load_type, destinations[0])
825      if increment:
826        self.EmitOp3('ld%dr' % load_no, destination, source, increment)
827      else:
828        self.EmitOp2('ld%dr' % load_no, destination, source)
829      return
830
831    destination_list = _RegisterList(load_type, destinations)
832    if increment:
833      self.EmitOp3('ld%d' % load_no, destination_list, source, increment)
834    else:
835      self.EmitOp2('ld%d' % load_no, destination_list, source)
836
837  def EmitVLoadAE(self,
838                  load_type,
839                  elem_count,
840                  destinations,
841                  source,
842                  alignment=None):
843    """Generate assembly to load an array of elements of given size."""
844    bits_to_load = load_type * elem_count
845    min_bits = min([destination.register_bits for destination in destinations])
846    max_bits = max([destination.register_bits for destination in destinations])
847
848    if min_bits is not max_bits:
849      raise ArgumentError('Cannot mix double and quad loads.')
850
851    if len(destinations) * min_bits < bits_to_load:
852      raise ArgumentError('To few destinations: %d to load %d bits.' %
853                          (len(destinations), bits_to_load))
854
855    leftover_loaded = 0
856    while bits_to_load > 0:
857      if bits_to_load >= 4 * min_bits:
858        self.EmitVLoadA(1, 32, destinations[:4],
859                        self.DereferenceIncrement(source, alignment))
860        bits_to_load -= 4 * min_bits
861        destinations = destinations[4:]
862      elif bits_to_load >= 3 * min_bits:
863        self.EmitVLoadA(1, 32, destinations[:3],
864                        self.DereferenceIncrement(source, alignment))
865        bits_to_load -= 3 * min_bits
866        destinations = destinations[3:]
867      elif bits_to_load >= 2 * min_bits:
868        self.EmitVLoadA(1, 32, destinations[:2],
869                        self.DereferenceIncrement(source, alignment))
870        bits_to_load -= 2 * min_bits
871        destinations = destinations[2:]
872      elif bits_to_load >= min_bits:
873        self.EmitVLoad(1, 32, destinations[0],
874                       self.DereferenceIncrement(source, alignment))
875        bits_to_load -= min_bits
876        destinations = destinations[1:]
877      elif bits_to_load >= 64:
878        self.EmitVLoad(1, 32,
879                       _Cast(64, destinations[0]),
880                       self.DereferenceIncrement(source))
881        bits_to_load -= 64
882        leftover_loaded += 64
883      elif bits_to_load >= 32:
884        self.EmitVLoad(1, 32,
885                       self.Lane(32, destinations[0], leftover_loaded / 32),
886                       self.DereferenceIncrement(source))
887        bits_to_load -= 32
888        leftover_loaded += 32
889      elif bits_to_load >= 16:
890        self.EmitVLoad(1, 16,
891                       self.Lane(16, destinations[0], leftover_loaded / 16),
892                       self.DereferenceIncrement(source))
893        bits_to_load -= 16
894        leftover_loaded += 16
895      elif bits_to_load is 8:
896        self.EmitVLoad(1, 8,
897                       self.Lane(8, destinations[0], leftover_loaded / 8),
898                       self.DereferenceIncrement(source))
899        bits_to_load -= 8
900        leftover_loaded += 8
901      else:
902        raise ArgumentError('Wrong leftover: %d' % bits_to_load)
903
904  def EmitVLoadE(self, load_type, count, destination, source, alignment=None):
905    self.EmitVLoadAE(load_type, count, [destination], source, alignment)
906
907  def EmitVLoadAllLanes(self, load_no, load_type, destination, source):
908    new_destination = destination.Copy()
909    new_destination.lane = -1
910    new_destination.lane_bits = load_type
911    self.EmitVLoad(load_no, load_type, new_destination, source)
912
913  def EmitVLoadOffset(self, load_no, load_type, destination, source, offset):
914    self.EmitVLoadOffsetA(load_no, load_type, [destination], source, offset)
915
916  def EmitVLoadOffsetA(self, load_no, load_type, destinations, source, offset):
917    assert len(destinations) <= 4
918    self.EmitOp3('ld%d' % load_no,
919                 _RegisterList(load_type, destinations), source, offset)
920
921  def EmitPld(self, load_address_register):
922    self.EmitOp2('prfm', 'pldl1keep', '[%s]' % load_address_register)
923
924  def EmitPldOffset(self, load_address_register, offset):
925    self.EmitOp2('prfm', 'pldl1keep',
926                 '[%s, %s]' % (load_address_register, offset))
927
928  def EmitVShl(self, shift_type, destination, source, shift):
929    self.EmitOp3('sshl',
930                 _AppendType(shift_type, destination),
931                 _AppendType(shift_type, source), _AppendType('i32', shift))
932
933  def EmitVStore(self, store_no, store_type, source, destination):
934    self.EmitVStoreA(store_no, store_type, [source], destination)
935
936  def EmitVStoreA(self, store_no, store_type, sources, destination):
937    if destination.dereference_increment:
938      increment = sum([_LoadStoreSize(source) for source in sources]) / 8
939      self.EmitVStoreAPostIncrement(store_no, store_type, sources, destination,
940                                    self.ImmediateConstant(increment))
941    else:
942      self.EmitVStoreAPostIncrement(store_no, store_type, sources, destination,
943                                    None)
944
945  def EmitVStoreAPostIncrement(self, store_no, store_type, sources, destination,
946                               increment):
947    source_list = _RegisterList(store_type, sources)
948    if increment:
949      self.EmitOp3('st%d' % store_no, source_list, destination, increment)
950    else:
951      self.EmitOp2('st%d' % store_no, source_list, destination)
952
953  def EmitVStoreAE(self,
954                   store_type,
955                   elem_count,
956                   sources,
957                   destination,
958                   alignment=None):
959    """Generate assembly to store an array of elements of given size."""
960    bits_to_store = store_type * elem_count
961    min_bits = min([source.register_bits for source in sources])
962    max_bits = max([source.register_bits for source in sources])
963
964    if min_bits is not max_bits:
965      raise ArgumentError('Cannot mix double and quad stores.')
966
967    if len(sources) * min_bits < bits_to_store:
968      raise ArgumentError('To few destinations: %d to store %d bits.' %
969                          (len(sources), bits_to_store))
970
971    leftover_stored = 0
972    while bits_to_store > 0:
973      if bits_to_store >= 4 * min_bits:
974        self.EmitVStoreA(1, 32, sources[:4],
975                         self.DereferenceIncrement(destination, alignment))
976        bits_to_store -= 4 * min_bits
977        sources = sources[4:]
978      elif bits_to_store >= 3 * min_bits:
979        self.EmitVStoreA(1, 32, sources[:3],
980                         self.DereferenceIncrement(destination, alignment))
981        bits_to_store -= 3 * min_bits
982        sources = sources[3:]
983      elif bits_to_store >= 2 * min_bits:
984        self.EmitVStoreA(1, 32, sources[:2],
985                         self.DereferenceIncrement(destination, alignment))
986        bits_to_store -= 2 * min_bits
987        sources = sources[2:]
988      elif bits_to_store >= min_bits:
989        self.EmitVStore(1, 32, sources[0],
990                        self.DereferenceIncrement(destination, alignment))
991        bits_to_store -= min_bits
992        sources = sources[1:]
993      elif bits_to_store >= 64:
994        self.EmitVStore(1, 32,
995                        _Cast(64, sources[0]),
996                        self.DereferenceIncrement(destination, alignment))
997        bits_to_store -= 64
998        leftover_stored += 64
999      elif bits_to_store >= 32:
1000        self.EmitVStore(1, 32,
1001                        self.Lane(32, sources[0], leftover_stored / 32),
1002                        self.DereferenceIncrement(destination))
1003        bits_to_store -= 32
1004        leftover_stored += 32
1005      elif bits_to_store >= 16:
1006        self.EmitVStore(1, 16,
1007                        self.Lane(16, sources[0], leftover_stored / 16),
1008                        self.DereferenceIncrement(destination))
1009        bits_to_store -= 16
1010        leftover_stored += 16
1011      elif bits_to_store >= 8:
1012        self.EmitVStore(1, 8,
1013                        self.Lane(8, sources[0], leftover_stored / 8),
1014                        self.DereferenceIncrement(destination))
1015        bits_to_store -= 8
1016        leftover_stored += 8
1017      else:
1018        raise ArgumentError('Wrong leftover: %d' % bits_to_store)
1019
1020  def EmitVStoreE(self, store_type, count, source, destination, alignment=None):
1021    self.EmitVStoreAE(store_type, count, [source], destination, alignment)
1022
1023  def EmitVStoreOffset(self, store_no, store_type, source, destination, offset):
1024    self.EmitVStoreOffsetA(store_no, store_type, [source], destination, offset)
1025
1026  def EmitVStoreOffsetA(self, store_no, store_type, sources, destination,
1027                        offset):
1028    self.EmitOp3('st%d' % store_no,
1029                 _RegisterList(store_type, sources), destination, offset)
1030
1031  def EmitVStoreOffsetE(self, store_type, count, source, destination, offset):
1032    if store_type is not 32:
1033      raise ArgumentError('Unsupported store_type: %d' % store_type)
1034
1035    if count == 1:
1036      self.EmitVStoreOffset(1, 32,
1037                            self.Lane(32, source, 0),
1038                            self.Dereference(destination, None), offset)
1039    elif count == 2:
1040      self.EmitVStoreOffset(1, 32,
1041                            _Cast(64, source),
1042                            self.Dereference(destination, None), offset)
1043    elif count == 3:
1044      self.EmitVStore(1, 32,
1045                      _Cast(64, source),
1046                      self.DereferenceIncrement(destination, None))
1047      self.EmitVStoreOffset(1, 32,
1048                            self.Lane(32, source, 2),
1049                            self.Dereference(destination, None), offset)
1050      self.EmitSub(destination, destination, self.ImmediateConstant(8))
1051    elif count == 4:
1052      self.EmitVStoreOffset(1, 32, source,
1053                            self.Dereference(destination, None), offset)
1054    else:
1055      raise ArgumentError('To many elements: %d' % count)
1056
1057  def EmitVSumReduce(self, reduce_type, elem_count, reduce_count, destinations,
1058                     sources):
1059    """Generate assembly to perform n-fold horizontal sum reduction."""
1060    if reduce_type is not 'u32':
1061      raise ArgumentError('Unsupported reduce: %s' % reduce_type)
1062
1063    if (elem_count + 3) / 4 > len(destinations):
1064      raise ArgumentError('To few destinations: %d (%d needed)' %
1065                          (len(destinations), (elem_count + 3) / 4))
1066
1067    if elem_count * reduce_count > len(sources) * 4:
1068      raise ArgumentError('To few sources: %d' % len(sources))
1069
1070    if reduce_count <= 1:
1071      raise ArgumentError('Unsupported reduce_count: %d' % reduce_count)
1072
1073    sources = [_Cast(128, source) for source in sources]
1074    destinations = [_Cast(128, destination) for destination in destinations]
1075
1076    while reduce_count > 1:
1077      if len(sources) % 2 == 1:
1078        sources.append(sources[-1])
1079
1080      if reduce_count == 2:
1081        for i in range(len(destinations)):
1082          self.EmitVPadd(reduce_type, destinations[i], sources[2 * i],
1083                         sources[2 * i + 1])
1084        return
1085      else:
1086        sources_2 = []
1087        for i in range(len(sources) / 2):
1088          self.EmitVPadd(reduce_type, sources[2 * i], sources[2 * i],
1089                         sources[2 * i + 1])
1090          sources_2.append(sources[2 * i])
1091        reduce_count /= 2
1092        sources = sources_2
1093
1094  def EmitVUzp1(self, uzp_type, destination, source_1, source_2):
1095    self.EmitOp3('uzp1',
1096                 _AppendType(uzp_type, destination),
1097                 _AppendType(uzp_type, source_1),
1098                 _AppendType(uzp_type, source_2))
1099
1100  def EmitVUzp2(self, uzp_type, destination, source_1, source_2):
1101    self.EmitOp3('uzp2',
1102                 _AppendType(uzp_type, destination),
1103                 _AppendType(uzp_type, source_1),
1104                 _AppendType(uzp_type, source_2))
1105
1106  def EmitVUzp(self, uzp_type, destination_1, destination_2, source_1,
1107               source_2):
1108    self.EmitVUzp1(uzp_type, destination_1, source_1, source_2)
1109    self.EmitVUzp2(uzp_type, destination_2, source_1, source_2)
1110
1111  def EmitVTrn1(self, trn_type, destination, source_1, source_2):
1112    self.EmitOp3('trn1',
1113                 _AppendType(trn_type, destination),
1114                 _AppendType(trn_type, source_1),
1115                 _AppendType(trn_type, source_2))
1116
1117  def EmitVTrn2(self, trn_type, destination, source_1, source_2):
1118    self.EmitOp3('trn2',
1119                 _AppendType(trn_type, destination),
1120                 _AppendType(trn_type, source_1),
1121                 _AppendType(trn_type, source_2))
1122
1123  def EmitVTrn(self, trn_type, destination_1, destination_2, source_1,
1124               source_2):
1125    self.EmitVTrn1(trn_type, destination_1, source_1, source_2)
1126    self.EmitVTrn2(trn_type, destination_2, source_1, source_2)
1127
1128  def EmitColBlockStride(self, cols, stride, new_stride):
1129    assert cols in [1, 2, 3, 4, 5, 6, 7, 8]
1130    if cols in [5, 6, 7]:
1131      self.EmitSub(new_stride, stride, self.ImmediateConstant(4))
1132
1133  def EmitLoadColBlock(self, registers, load_type, cols, elements, block,
1134                       input_address, stride):
1135    assert cols is len(block)
1136    assert load_type is 8
1137
1138    input_deref = self.Dereference(input_address, None)
1139    input_deref_increment = self.DereferenceIncrement(input_address, None)
1140
1141    if cols is 1:
1142      for i in range(elements):
1143        self.EmitVLoadOffset(1, 8,
1144                             self.Lane(8, block[0], i), input_deref, stride)
1145      self.EmitPld(input_address)
1146      return block
1147    elif cols is 2:
1148      temp = [registers.DoubleRegister() for unused_i in range(2)]
1149      for i in range(elements):
1150        self.EmitVLoadOffset(1, 16,
1151                             self.Lane(16, block[i / 4], i % 4), input_deref,
1152                             stride)
1153      self.EmitPld(input_address)
1154      self.EmitVUzp(8, temp[0], temp[1], block[0], block[1])
1155      registers.FreeRegisters(block)
1156      return temp
1157    elif cols is 3:
1158      for i in range(elements):
1159        self.EmitVLoadOffsetA(3, 8, [self.Lane(8, row, i) for row in block],
1160                              input_deref, stride)
1161      self.EmitPld(input_address)
1162      return block
1163    elif cols is 4:
1164      temp = [registers.DoubleRegister() for unused_i in range(4)]
1165      for i in range(elements):
1166        self.EmitVLoadOffset(1, 32,
1167                             self.Lane(32, block[i % 4], i / 4), input_deref,
1168                             stride)
1169      self.EmitPld(input_address)
1170      self.EmitVTrn(16, temp[0], temp[2], block[0], block[2])
1171      self.EmitVTrn(16, temp[1], temp[3], block[1], block[3])
1172      self.EmitVTrn(8, block[0], block[1], temp[0], temp[1])
1173      self.EmitVTrn(8, block[2], block[3], temp[2], temp[3])
1174      registers.FreeRegisters(temp)
1175      return block
1176    elif cols is 5:
1177      temp = [registers.DoubleRegister() for unused_i in range(4)]
1178      for i in range(elements):
1179        self.EmitVLoad(1, 32,
1180                       self.Lane(32, block[i % 4], i / 4),
1181                       input_deref_increment)
1182        self.EmitVLoadOffset(1, 8,
1183                             self.Lane(8, block[4], i), input_deref, stride)
1184      self.EmitPld(input_address)
1185      self.EmitVTrn(16, temp[0], temp[2], block[0], block[2])
1186      self.EmitVTrn(16, temp[1], temp[3], block[1], block[3])
1187      self.EmitVTrn(8, block[0], block[1], temp[0], temp[1])
1188      self.EmitVTrn(8, block[2], block[3], temp[2], temp[3])
1189      registers.FreeRegisters(temp)
1190      return block
1191    elif cols is 6:
1192      temp = [registers.DoubleRegister() for unused_i in range(6)]
1193      for i in range(elements):
1194        self.EmitVLoad(1, 32,
1195                       self.Lane(32, block[i % 4], i / 4),
1196                       input_deref_increment)
1197        self.EmitVLoadOffset(1, 16,
1198                             self.Lane(16, block[4 + i / 4], i % 4),
1199                             input_deref, stride)
1200      self.EmitPld(input_address)
1201      self.EmitVTrn(16, temp[0], temp[2], block[0], block[2])
1202      self.EmitVTrn(16, temp[1], temp[3], block[1], block[3])
1203      self.EmitVUzp(8, temp[4], temp[5], block[4], block[5])
1204      self.EmitVTrn(8, block[0], block[1], temp[0], temp[1])
1205      self.EmitVTrn(8, block[2], block[3], temp[2], temp[3])
1206      registers.FreeRegisters(
1207          [block[4], block[5], temp[0], temp[1], temp[2], temp[3]])
1208      return [block[0], block[1], block[2], block[3], temp[4], temp[5]]
1209    elif cols is 7:
1210      temp = [registers.DoubleRegister() for unused_i in range(4)]
1211      for i in range(elements):
1212        self.EmitVLoad(1, 32,
1213                       self.Lane(32, block[i % 4], i / 4),
1214                       input_deref_increment)
1215        self.EmitVLoadOffsetA(3, 8,
1216                              [self.Lane(8, row, i) for row in block[4:]],
1217                              input_deref, stride)
1218      self.EmitPld(input_address)
1219      self.EmitVTrn1(16, temp[0], block[0], block[2])
1220      self.EmitVTrn2(16, temp[2], block[0], block[2])
1221      self.EmitVTrn1(16, temp[1], block[1], block[3])
1222      self.EmitVTrn2(16, temp[3], block[1], block[3])
1223      self.EmitVTrn1(8, block[0], temp[0], temp[1])
1224      self.EmitVTrn2(8, block[1], temp[0], temp[1])
1225      self.EmitVTrn1(8, block[2], temp[2], temp[3])
1226      self.EmitVTrn2(8, block[3], temp[2], temp[3])
1227      registers.FreeRegisters(temp)
1228      return block
1229    elif cols is 8:
1230      temp = [registers.DoubleRegister() for unused_i in range(8)]
1231      for i in range(elements):
1232        self.EmitVLoadOffset(1, 32, block[i], input_deref, stride)
1233      self.EmitPld(input_address)
1234      self.EmitVTrn(8, temp[0], temp[1], block[0], block[1])
1235      self.EmitVTrn(8, temp[2], temp[3], block[2], block[3])
1236      self.EmitVTrn(8, temp[4], temp[5], block[4], block[5])
1237      self.EmitVTrn(8, temp[6], temp[7], block[6], block[7])
1238      self.EmitVTrn(16, block[0], block[2], temp[0], temp[2])
1239      self.EmitVTrn(16, block[1], block[3], temp[1], temp[3])
1240      self.EmitVTrn(16, block[4], block[6], temp[4], temp[6])
1241      self.EmitVTrn(16, block[5], block[7], temp[5], temp[7])
1242      self.EmitVTrn(32, temp[0], temp[4], block[0], block[4])
1243      self.EmitVTrn(32, temp[1], temp[5], block[1], block[5])
1244      self.EmitVTrn(32, temp[2], temp[6], block[2], block[6])
1245      self.EmitVTrn(32, temp[3], temp[7], block[3], block[7])
1246      registers.FreeRegisters(block)
1247      return temp
1248    else:
1249      assert False
1250
1251  def Dereference(self, value, unused_alignment=None):
1252    new_value = value.Copy()
1253    new_value.dereference = True
1254    return new_value
1255
1256  def DereferenceIncrement(self, value, alignment=None):
1257    new_value = self.Dereference(value, alignment).Copy()
1258    new_value.dereference_increment = True
1259    return new_value
1260
1261  def ImmediateConstant(self, value):
1262    return _ImmediateConstant(value)
1263
1264  def AllLanes(self, value):
1265    return '%s[]' % value
1266
1267  def Lane(self, bits, value, lane):
1268    new_value = value.Copy()
1269    if bits * (lane + 1) > new_value.register_bits:
1270      raise ArgumentError('Lane to big: (%d + 1) x %d > %d' %
1271                          (lane, bits, new_value.register_bits))
1272    new_value.lane = lane
1273    new_value.lane_bits = bits
1274    return new_value
1275
1276  def CreateRegisters(self):
1277    return _NeonRegisters64Bit()
1278