1# Copyright 2016 The Gemmlowp Authors. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""64bit ARM/NEON assembly emitter. 15 16Used by code generators to produce ARM assembly with NEON simd code. 17Provides tools for easier register management: named register variable 18allocation/deallocation, and offers a more procedural/structured approach 19to generating assembly. 20 21""" 22 23_WIDE_TYPES = { 24 8: 16, 25 16: 32, 26 32: 64, 27 '8': '16', 28 '16': '32', 29 '32': '64', 30 'i8': 'i16', 31 'i16': 'i32', 32 'i32': 'i64', 33 'u8': 'u16', 34 'u16': 'u32', 35 'u32': 'u64', 36 's8': 's16', 37 's16': 's32', 38 's32': 's64' 39} 40 41_NARROW_TYPES = { 42 64: 32, 43 32: 16, 44 16: 8, 45 '64': '32', 46 '32': '16', 47 '16': '8', 48 'i64': 'i32', 49 'i32': 'i16', 50 'i16': 'i8', 51 'u64': 'u32', 52 'u32': 'u16', 53 'u16': 'u8', 54 's64': 's32', 55 's32': 's16', 56 's16': 's8' 57} 58 59_TYPE_BITS = { 60 8: 8, 61 16: 16, 62 32: 32, 63 64: 64, 64 '8': 8, 65 '16': 16, 66 '32': 32, 67 '64': 64, 68 'i8': 8, 69 'i16': 16, 70 'i32': 32, 71 'i64': 64, 72 'u8': 8, 73 'u16': 16, 74 'u32': 32, 75 'u64': 64, 76 's8': 8, 77 's16': 16, 78 's32': 32, 79 's64': 64, 80 'f32': 32, 81 'f64': 64, 82 'b': 8, 83 'h': 16, 84 's': 32, 85 'd': 64 86} 87 88 89class Error(Exception): 90 """Module level error.""" 91 92 93class RegisterAllocationError(Error): 94 """Cannot alocate registers.""" 95 96 97class LaneError(Error): 98 """Wrong lane number.""" 99 100 101class RegisterSubtypeError(Error): 102 """The register needs to be lane-typed.""" 103 104 105class ArgumentError(Error): 106 """Wrong argument.""" 107 108 109def _AppendType(type_name, register): 110 """Calculates sizes and attaches the type information to the register.""" 111 if register.register_type is not 'v': 112 raise ArgumentError('Only vector registers can have type appended.') 113 114 if type_name in set([8, '8', 'i8', 's8', 'u8']): 115 subtype = 'b' 116 subtype_bits = 8 117 elif type_name in set([16, '16', 'i16', 's16', 'u16']): 118 subtype = 'h' 119 subtype_bits = 16 120 elif type_name in set([32, '32', 'i32', 's32', 'u32', 'f32']): 121 subtype = 's' 122 subtype_bits = 32 123 elif type_name in set([64, '64', 'i64', 's64', 'u64', 'f64']): 124 subtype = 'd' 125 subtype_bits = 64 126 else: 127 raise ArgumentError('Unknown type: %s' % type_name) 128 129 new_register = register.Copy() 130 new_register.register_subtype = subtype 131 new_register.register_subtype_count = register.register_bits / subtype_bits 132 return new_register 133 134 135def _UnsignedType(type_name): 136 return type_name in set(['u8', 'u16', 'u32', 'u64']) 137 138 139def _FloatType(type_name): 140 return type_name in set(['f32', 'f64']) 141 142 143def _WideType(type_name): 144 if type_name in _WIDE_TYPES.keys(): 145 return _WIDE_TYPES[type_name] 146 else: 147 raise ArgumentError('No wide type for: %s' % type_name) 148 149 150def _NarrowType(type_name): 151 if type_name in _NARROW_TYPES.keys(): 152 return _NARROW_TYPES[type_name] 153 else: 154 raise ArgumentError('No narrow type for: %s' % type_name) 155 156 157def _LoadStoreSize(register): 158 if register.lane is None: 159 return register.register_bits 160 else: 161 return register.lane_bits 162 163 164def _MakeCompatibleDown(reg_1, reg_2, reg_3): 165 bits = min([reg_1.register_bits, reg_2.register_bits, reg_3.register_bits]) 166 return (_Cast(bits, reg_1), _Cast(bits, reg_2), _Cast(bits, reg_3)) 167 168 169def _MakeCompatibleUp(reg_1, reg_2, reg_3): 170 bits = max([reg_1.register_bits, reg_2.register_bits, reg_3.register_bits]) 171 return (_Cast(bits, reg_1), _Cast(bits, reg_2), _Cast(bits, reg_3)) 172 173 174def _Cast(bits, reg): 175 if reg.register_bits is bits: 176 return reg 177 else: 178 new_reg = reg.Copy() 179 new_reg.register_bits = bits 180 return new_reg 181 182 183def _TypeBits(type_name): 184 if type_name in _TYPE_BITS.keys(): 185 return _TYPE_BITS[type_name] 186 else: 187 raise ArgumentError('Unknown type: %s' % type_name) 188 189 190def _RegisterList(list_type, registers): 191 lanes = list(set([register.lane for register in registers])) 192 if len(lanes) > 1: 193 raise ArgumentError('Cannot mix lanes on a register list.') 194 typed_registers = [_AppendType(list_type, register) for register in registers] 195 196 if lanes[0] is None: 197 return '{%s}' % ', '.join(map(str, typed_registers)) 198 elif lanes[0] is -1: 199 raise ArgumentError('Cannot construct a list with all lane indexing.') 200 else: 201 typed_registers_nolane = [register.Copy() for register in typed_registers] 202 for register in typed_registers_nolane: 203 register.lane = None 204 register.register_subtype_count = None 205 return '{%s}[%d]' % (', '.join(map(str, typed_registers_nolane)), lanes[0]) 206 207 208class _GeneralRegister(object): 209 """Arm v8 general register: (x|w)n.""" 210 211 def __init__(self, 212 register_bits, 213 number, 214 dereference=False, 215 dereference_increment=False): 216 self.register_type = 'r' 217 self.register_bits = register_bits 218 self.number = number 219 self.dereference = dereference 220 self.dereference_increment = dereference_increment 221 222 def Copy(self): 223 return _GeneralRegister(self.register_bits, self.number, self.dereference, 224 self.dereference_increment) 225 226 def __repr__(self): 227 if self.register_bits is 64: 228 text = 'x%d' % self.number 229 elif self.register_bits <= 32: 230 text = 'w%d' % self.number 231 else: 232 raise RegisterSubtypeError('Wrong bits (%d) for general register: %d' % 233 (self.register_bits, self.number)) 234 if self.dereference: 235 return '[%s]' % text 236 else: 237 return text 238 239 240class _MappedParameter(object): 241 """Object representing a C variable mapped to a register.""" 242 243 def __init__(self, 244 name, 245 register_bits=64, 246 dereference=False, 247 dereference_increment=False): 248 self.name = name 249 self.register_bits = register_bits 250 self.dereference = dereference 251 self.dereference_increment = dereference_increment 252 253 def Copy(self): 254 return _MappedParameter(self.name, self.register_bits, self.dereference, 255 self.dereference_increment) 256 257 def __repr__(self): 258 if self.register_bits is None: 259 text = '%%[%s]' % self.name 260 elif self.register_bits is 64: 261 text = '%%x[%s]' % self.name 262 elif self.register_bits <= 32: 263 text = '%%w[%s]' % self.name 264 else: 265 raise RegisterSubtypeError('Wrong bits (%d) for mapped parameter: %s' % 266 (self.register_bits, self.name)) 267 if self.dereference: 268 return '[%s]' % text 269 else: 270 return text 271 272 273class _VectorRegister(object): 274 """Arm v8 vector register Vn.TT.""" 275 276 def __init__(self, 277 register_bits, 278 number, 279 register_subtype=None, 280 register_subtype_count=None, 281 lane=None, 282 lane_bits=None): 283 self.register_type = 'v' 284 self.register_bits = register_bits 285 self.number = number 286 self.register_subtype = register_subtype 287 self.register_subtype_count = register_subtype_count 288 self.lane = lane 289 self.lane_bits = lane_bits 290 291 def Copy(self): 292 return _VectorRegister(self.register_bits, self.number, 293 self.register_subtype, self.register_subtype_count, 294 self.lane, self.lane_bits) 295 296 def __repr__(self): 297 if self.register_subtype is None: 298 raise RegisterSubtypeError('Register: %s%d has no lane types defined.' % 299 (self.register_type, self.number)) 300 if (self.register_subtype_count is None or (self.lane is not None and 301 self.lane is not -1)): 302 typed_name = '%s%d.%s' % (self.register_type, self.number, 303 self.register_subtype) 304 else: 305 typed_name = '%s%d.%d%s' % (self.register_type, self.number, 306 self.register_subtype_count, 307 self.register_subtype) 308 309 if self.lane is None or self.lane is -1: 310 return typed_name 311 elif self.lane >= 0 and self.lane < self.register_subtype_count: 312 return '%s[%d]' % (typed_name, self.lane) 313 else: 314 raise LaneError('Wrong lane: %d for: %s' % (self.lane, typed_name)) 315 316 317class _ImmediateConstant(object): 318 319 def __init__(self, value): 320 self.register_type = 'i' 321 self.value = value 322 323 def Copy(self): 324 return _ImmediateConstant(self.value) 325 326 def __repr__(self): 327 return '#%d' % self.value 328 329 330class _NeonRegisters64Bit(object): 331 """Utility that keeps track of used 32bit ARM/NEON registers.""" 332 333 def __init__(self): 334 self.vector = set() 335 self.vector_ever = set() 336 self.general = set() 337 self.general_ever = set() 338 self.parameters = dict() 339 self.output_parameters = dict() 340 341 def MapParameter(self, parameter, parameter_value=None): 342 if not parameter_value: 343 parameter_value = parameter 344 self.parameters[parameter] = (parameter_value, 'r') 345 return _MappedParameter(parameter) 346 347 def MapMemoryParameter(self, parameter, parameter_value=None): 348 if not parameter_value: 349 parameter_value = parameter 350 self.parameters[parameter] = (parameter_value, 'm') 351 return _MappedParameter(parameter) 352 353 def MapOutputParameter(self, parameter, parameter_value=None): 354 if not parameter_value: 355 parameter_value = parameter 356 self.output_parameters[parameter] = (parameter_value, '+r') 357 return _MappedParameter(parameter) 358 359 def _VectorRegisterNum(self, min_val=0): 360 for i in range(min_val, 32): 361 if i not in self.vector: 362 self.vector.add(i) 363 self.vector_ever.add(i) 364 return i 365 raise RegisterAllocationError('Not enough vector registers.') 366 367 def DoubleRegister(self, min_val=0): 368 return _VectorRegister(64, self._VectorRegisterNum(min_val)) 369 370 def QuadRegister(self, min_val=0): 371 return _VectorRegister(128, self._VectorRegisterNum(min_val)) 372 373 def GeneralRegister(self): 374 for i in range(0, 30): 375 if i not in self.general: 376 self.general.add(i) 377 self.general_ever.add(i) 378 return _GeneralRegister(64, i) 379 raise RegisterAllocationError('Not enough general registers.') 380 381 def MappedParameters(self): 382 return [x for x in self.parameters.items()] 383 384 def MappedOutputParameters(self): 385 return [x for x in self.output_parameters.items()] 386 387 def Clobbers(self): 388 return ( 389 ['x%d' % i 390 for i in self.general_ever] + ['v%d' % i for i in self.vector_ever]) 391 392 def FreeRegister(self, register): 393 if isinstance(register, _MappedParameter): 394 return 395 396 if register.register_type == 'v': 397 assert register.number in self.vector 398 self.vector.remove(register.number) 399 elif register.register_type == 'r': 400 assert register.number in self.general 401 self.general.remove(register.number) 402 else: 403 raise RegisterAllocationError('Register not allocated: %s%d' % 404 (register.register_type, register.number)) 405 406 def FreeRegisters(self, registers): 407 for register in registers: 408 self.FreeRegister(register) 409 410 411class NeonEmitter64(object): 412 """Emits ARM/NEON 64bit assembly opcodes.""" 413 414 def __init__(self, debug=False): 415 self.ops = {} 416 self.indent = '' 417 self.debug = debug 418 419 def PushIndent(self, delta_indent=' '): 420 self.indent += delta_indent 421 422 def PopIndent(self, delta=2): 423 self.indent = self.indent[:-delta] 424 425 def EmitIndented(self, what): 426 print self.indent + what 427 428 def PushOp(self, op): 429 if op in self.ops.keys(): 430 self.ops[op] += 1 431 else: 432 self.ops[op] = 1 433 434 def ClearCounters(self): 435 self.ops.clear() 436 437 def EmitNewline(self): 438 print '' 439 440 def EmitPreprocessor1(self, op, param): 441 print '#%s %s' % (op, param) 442 443 def EmitPreprocessor(self, op): 444 print '#%s' % op 445 446 def EmitInclude(self, include): 447 self.EmitPreprocessor1('include', include) 448 449 def EmitCall1(self, function, param): 450 self.EmitIndented('%s(%s);' % (function, param)) 451 452 def EmitAssert(self, assert_expression): 453 if self.debug: 454 self.EmitCall1('assert', assert_expression) 455 456 def EmitHeaderBegin(self, header_name, includes): 457 self.EmitPreprocessor1('ifndef', (header_name + '_H_').upper()) 458 self.EmitPreprocessor1('define', (header_name + '_H_').upper()) 459 self.EmitNewline() 460 if includes: 461 for include in includes: 462 self.EmitInclude(include) 463 self.EmitNewline() 464 465 def EmitHeaderEnd(self): 466 self.EmitPreprocessor('endif') 467 468 def EmitCode(self, code): 469 self.EmitIndented('%s;' % code) 470 471 def EmitFunctionBeginA(self, function_name, params, return_type): 472 self.EmitIndented('%s %s(%s) {' % 473 (return_type, function_name, 474 ', '.join(['%s %s' % (t, n) for (t, n) in params]))) 475 self.PushIndent() 476 477 def EmitFunctionEnd(self): 478 self.PopIndent() 479 self.EmitIndented('}') 480 481 def EmitAsmBegin(self): 482 self.EmitIndented('asm volatile(') 483 self.PushIndent() 484 485 def EmitAsmMapping(self, elements): 486 if elements: 487 self.EmitIndented(': ' + ', '.join( 488 ['[%s] "%s"(%s)' % (k, v[1], v[0]) for (k, v) in elements])) 489 else: 490 self.EmitIndented(':') 491 492 def EmitClobbers(self, elements): 493 if elements: 494 self.EmitIndented(': ' + ', '.join(['"%s"' % c for c in elements])) 495 else: 496 self.EmitIndented(':') 497 498 def EmitAsmEnd(self, registers): 499 self.EmitAsmMapping(registers.MappedOutputParameters()) 500 self.EmitAsmMapping(registers.MappedParameters()) 501 self.EmitClobbers(registers.Clobbers() + ['cc', 'memory']) 502 self.PopIndent() 503 self.EmitIndented(');') 504 505 def EmitComment(self, comment): 506 self.EmitIndented('// ' + comment) 507 508 def EmitNumericalLabel(self, label): 509 self.EmitIndented('"%d:"' % label) 510 511 def EmitOp1(self, op, param1): 512 self.PushOp(op) 513 self.EmitIndented('"%s %s\\n"' % (op, param1)) 514 515 def EmitOp2(self, op, param1, param2): 516 self.PushOp(op) 517 self.EmitIndented('"%s %s, %s\\n"' % (op, param1, param2)) 518 519 def EmitOp3(self, op, param1, param2, param3): 520 self.PushOp(op) 521 self.EmitIndented('"%s %s, %s, %s\\n"' % (op, param1, param2, param3)) 522 523 def EmitAdd(self, destination, source, param): 524 self.EmitOp3('add', destination, source, param) 525 526 def EmitSubs(self, destination, source, param): 527 self.EmitOp3('subs', destination, source, param) 528 529 def EmitSub(self, destination, source, param): 530 self.EmitOp3('sub', destination, source, param) 531 532 def EmitMul(self, destination, source, param): 533 self.EmitOp3('mul', destination, source, param) 534 535 def EmitMov(self, param1, param2): 536 self.EmitOp2('mov', param1, param2) 537 538 def EmitVMovl(self, mov_type, destination, source): 539 wide_type = _WideType(mov_type) 540 destination = _AppendType(wide_type, destination) 541 source = _AppendType(mov_type, _Cast(source.register_bits / 2, source)) 542 if _UnsignedType(mov_type): 543 self.EmitOp2('uxtl', destination, source) 544 else: 545 self.EmitOp2('sxtl', destination, source) 546 547 def EmitVMovl2(self, mov_type, destination_1, destination_2, source): 548 wide_type = _WideType(mov_type) 549 if (destination_1.register_bits != source.register_bits or 550 destination_2.register_bits != source.register_bits): 551 raise ArgumentError('Register sizes do not match.') 552 if _UnsignedType(mov_type): 553 self.EmitOp2('uxtl2', 554 _AppendType(wide_type, destination_2), 555 _AppendType(mov_type, source)) 556 self.EmitOp2('uxtl', 557 _AppendType(wide_type, destination_1), 558 _AppendType(mov_type, 559 _Cast(source.register_bits / 2, source))) 560 else: 561 self.EmitOp2('sxtl2', 562 _AppendType(wide_type, destination_2), 563 _AppendType(mov_type, source)) 564 self.EmitOp2('sxtl', 565 _AppendType(wide_type, destination_1), 566 _AppendType(mov_type, 567 _Cast(source.register_bits / 2, source))) 568 569 def EmitVMax(self, max_type, destination, source_1, source_2): 570 if _UnsignedType(max_type): 571 self.EmitOp3('umax', 572 _AppendType(max_type, destination), 573 _AppendType(max_type, source_1), 574 _AppendType(max_type, source_2)) 575 else: 576 self.EmitOp3('smax', 577 _AppendType(max_type, destination), 578 _AppendType(max_type, source_1), 579 _AppendType(max_type, source_2)) 580 581 def EmitVMin(self, min_type, destination, source_1, source_2): 582 if _UnsignedType(min_type): 583 self.EmitOp3('umin', 584 _AppendType(min_type, destination), 585 _AppendType(min_type, source_1), 586 _AppendType(min_type, source_2)) 587 else: 588 self.EmitOp3('smin', 589 _AppendType(min_type, destination), 590 _AppendType(min_type, source_1), 591 _AppendType(min_type, source_2)) 592 593 def EmitBeqBack(self, label): 594 self.EmitOp1('beq', '%db' % label) 595 596 def EmitBeqFront(self, label): 597 self.EmitOp1('beq', '%df' % label) 598 599 def EmitBgtBack(self, label): 600 self.EmitOp1('bgt', '%db' % label) 601 602 def EmitBgtFront(self, label): 603 self.EmitOp1('bgt', '%df' % label) 604 605 def EmitBleBack(self, label): 606 self.EmitOp1('ble', '%db' % label) 607 608 def EmitBleFront(self, label): 609 self.EmitOp1('ble', '%df' % label) 610 611 def EmitBneBack(self, label): 612 self.EmitOp1('bne', '%db' % label) 613 614 def EmitBneFront(self, label): 615 self.EmitOp1('bne', '%df' % label) 616 617 def EmitVAdd(self, add_type, destination, source_1, source_2): 618 destination, source_1, source_2 = _MakeCompatibleDown(destination, source_1, 619 source_2) 620 if _FloatType(add_type): 621 self.EmitOp3('fadd', 622 _AppendType(add_type, destination), 623 _AppendType(add_type, source_1), 624 _AppendType(add_type, source_2)) 625 else: 626 self.EmitOp3('add', 627 _AppendType(add_type, destination), 628 _AppendType(add_type, source_1), 629 _AppendType(add_type, source_2)) 630 631 def EmitVAddw(self, add_type, destination, source_1, source_2): 632 wide_type = _WideType(add_type) 633 destination = _AppendType(wide_type, destination) 634 source_1 = _AppendType(wide_type, source_1) 635 source_2 = _AppendType(add_type, source_2) 636 if _UnsignedType(add_type): 637 self.EmitOp3('uaddw', destination, source_1, source_2) 638 else: 639 self.EmitOp3('saddw', destination, source_1, source_2) 640 641 def EmitVSub(self, sub_type, destination, source_1, source_2): 642 destination, source_1, source_2 = _MakeCompatibleDown(destination, source_1, 643 source_2) 644 if _FloatType(sub_type): 645 self.EmitOp3('fsub', 646 _AppendType(sub_type, destination), 647 _AppendType(sub_type, source_1), 648 _AppendType(sub_type, source_2)) 649 else: 650 self.EmitOp3('sub', 651 _AppendType(sub_type, destination), 652 _AppendType(sub_type, source_1), 653 _AppendType(sub_type, source_2)) 654 655 def EmitVCvt(self, cvt_to, cvt_from, destination, source): 656 if cvt_to == 'f32' and cvt_from == 's32': 657 self.EmitOp2('scvtf', 658 _AppendType('f32', destination), _AppendType('s32', source)) 659 elif cvt_to == 'f32' and cvt_from == 'u32': 660 self.EmitOp2('ucvtf', 661 _AppendType('f32', destination), _AppendType('u32', source)) 662 elif cvt_to == 's32' and cvt_from == 'f32': 663 self.EmitOp2('fcvtzs', 664 _AppendType('s32', destination), _AppendType('f32', source)) 665 else: 666 raise ArgumentError('Convert not supported, to: %s from: %s' % (cvt_to, 667 cvt_from)) 668 669 def EmitVDup(self, dup_type, destination, source): 670 if (isinstance(source, _GeneralRegister) or 671 isinstance(source, _MappedParameter)): 672 self.EmitOp2('dup', 673 _AppendType(dup_type, destination), 674 _Cast(_TypeBits(dup_type), source)) 675 else: 676 self.EmitOp2('dup', 677 _AppendType(dup_type, destination), 678 _AppendType(dup_type, source)) 679 680 def EmitVMov(self, mov_type, destination, source): 681 if isinstance(source, _ImmediateConstant): 682 self.EmitOp2('movi', _AppendType(mov_type, destination), source) 683 elif (isinstance(source, _GeneralRegister) or 684 isinstance(source, _MappedParameter)): 685 self.EmitOp2('mov', 686 _AppendType(mov_type, destination), 687 _Cast(_TypeBits(mov_type), source)) 688 else: 689 self.EmitOp2('mov', _AppendType(8, destination), _AppendType(8, source)) 690 691 def EmitVQmovn(self, mov_type, destination, source): 692 narrow_type = _NarrowType(mov_type) 693 if destination.register_bits * 2 == source.register_bits: 694 self.EmitOp2('sqxtn', 695 _AppendType(narrow_type, destination), 696 _AppendType(mov_type, source)) 697 elif destination.register_bits == source.register_bits: 698 self.EmitOp2('sqxtn', 699 _AppendType(narrow_type, 700 _Cast(destination.register_bits / 2, 701 destination)), 702 _AppendType(mov_type, source)) 703 704 def EmitVQmovn2(self, mov_type, destination, source_1, source_2): 705 narrow_type = _NarrowType(mov_type) 706 if (destination.register_bits != source_1.register_bits or 707 destination.register_bits != source_2.register_bits): 708 raise ArgumentError('Register sizes do not match.') 709 self.EmitOp2('sqxtn', 710 _AppendType(narrow_type, 711 _Cast(destination.register_bits / 2, destination)), 712 _AppendType(mov_type, source_1)) 713 self.EmitOp2('sqxtn2', 714 _AppendType(narrow_type, destination), 715 _AppendType(mov_type, source_2)) 716 717 def EmitVQmovun(self, mov_type, destination, source): 718 narrow_type = _NarrowType(mov_type) 719 if destination.register_bits * 2 == source.register_bits: 720 self.EmitOp2('sqxtun', 721 _AppendType(narrow_type, destination), 722 _AppendType(mov_type, source)) 723 elif destination.register_bits == source.register_bits: 724 self.EmitOp2('sqxtun', 725 _AppendType(narrow_type, 726 _Cast(destination.register_bits / 2, 727 destination)), 728 _AppendType(mov_type, source)) 729 730 def EmitVQmovun2(self, mov_type, destination, source_1, source_2): 731 narrow_type = _NarrowType(mov_type) 732 if (destination.register_bits != source_1.register_bits or 733 destination.register_bits != source_2.register_bits): 734 raise ArgumentError('Register sizes do not match.') 735 self.EmitOp2('sqxtun', 736 _AppendType(narrow_type, 737 _Cast(destination.register_bits / 2, destination)), 738 _AppendType(mov_type, source_1)) 739 self.EmitOp2('sqxtun2', 740 _AppendType(narrow_type, destination), 741 _AppendType(mov_type, source_2)) 742 743 def EmitVMul(self, mul_type, destination, source_1, source_2): 744 destination, source_1, source_2 = _MakeCompatibleDown(destination, source_1, 745 source_2) 746 if _FloatType(mul_type): 747 self.EmitOp3('fmul', 748 _AppendType(mul_type, destination), 749 _AppendType(mul_type, source_1), 750 _AppendType(mul_type, source_2)) 751 else: 752 self.EmitOp3('mul', 753 _AppendType(mul_type, destination), 754 _AppendType(mul_type, source_1), 755 _AppendType(mul_type, source_2)) 756 757 def EmitVMulScalar(self, mul_type, destination, source_1, source_2): 758 self.EmitOp3('mul', 759 _AppendType(mul_type, destination), 760 _AppendType(mul_type, source_1), 761 _AppendType(mul_type, source_2)) 762 763 def EmitVMull(self, mul_type, destination, source_1, source_2): 764 wide_type = _WideType(mul_type) 765 if _UnsignedType(mul_type): 766 self.EmitOp3('umull', 767 _AppendType(wide_type, destination), 768 _AppendType(mul_type, source_1), 769 _AppendType(mul_type, source_2)) 770 else: 771 self.EmitOp3('smull', 772 _AppendType(wide_type, destination), 773 _AppendType(mul_type, source_1), 774 _AppendType(mul_type, source_2)) 775 776 def EmitVPadd(self, add_type, destination, source_1, source_2): 777 self.EmitOp3('addp', 778 _AppendType(add_type, destination), 779 _AppendType(add_type, source_1), 780 _AppendType(add_type, source_2)) 781 782 def EmitVPaddl(self, add_type, destination, source): 783 wide_type = _WideType(add_type) 784 if _UnsignedType(add_type): 785 self.EmitOp2('uaddlp', 786 _AppendType(wide_type, destination), 787 _AppendType(add_type, source)) 788 else: 789 self.EmitOp2('saddlp', 790 _AppendType(wide_type, destination), 791 _AppendType(add_type, source)) 792 793 def EmitVPadal(self, add_type, destination, source): 794 wide_type = _WideType(add_type) 795 if _UnsignedType(add_type): 796 self.EmitOp2('uadalp', 797 _AppendType(wide_type, destination), 798 _AppendType(add_type, source)) 799 else: 800 self.EmitOp2('sadalp', 801 _AppendType(wide_type, destination), 802 _AppendType(add_type, source)) 803 804 def EmitLdr(self, register, value): 805 self.EmitOp2('ldr', _Cast(32, register), _Cast(None, value)) 806 807 def EmitVLoad(self, load_no, load_type, destination, source): 808 self.EmitVLoadA(load_no, load_type, [destination], source) 809 810 def EmitVLoadA(self, load_no, load_type, destinations, source): 811 if source.dereference_increment: 812 increment = sum( 813 [_LoadStoreSize(destination) for destination in destinations]) / 8 814 self.EmitVLoadAPostIncrement(load_no, load_type, destinations, source, 815 self.ImmediateConstant(increment)) 816 else: 817 self.EmitVLoadAPostIncrement(load_no, load_type, destinations, source, 818 None) 819 820 def EmitVLoadAPostIncrement(self, load_no, load_type, destinations, source, 821 increment): 822 """Generate assembly to load memory to registers and increment source.""" 823 if len(destinations) == 1 and destinations[0].lane is -1: 824 destination = '{%s}' % _AppendType(load_type, destinations[0]) 825 if increment: 826 self.EmitOp3('ld%dr' % load_no, destination, source, increment) 827 else: 828 self.EmitOp2('ld%dr' % load_no, destination, source) 829 return 830 831 destination_list = _RegisterList(load_type, destinations) 832 if increment: 833 self.EmitOp3('ld%d' % load_no, destination_list, source, increment) 834 else: 835 self.EmitOp2('ld%d' % load_no, destination_list, source) 836 837 def EmitVLoadAE(self, 838 load_type, 839 elem_count, 840 destinations, 841 source, 842 alignment=None): 843 """Generate assembly to load an array of elements of given size.""" 844 bits_to_load = load_type * elem_count 845 min_bits = min([destination.register_bits for destination in destinations]) 846 max_bits = max([destination.register_bits for destination in destinations]) 847 848 if min_bits is not max_bits: 849 raise ArgumentError('Cannot mix double and quad loads.') 850 851 if len(destinations) * min_bits < bits_to_load: 852 raise ArgumentError('To few destinations: %d to load %d bits.' % 853 (len(destinations), bits_to_load)) 854 855 leftover_loaded = 0 856 while bits_to_load > 0: 857 if bits_to_load >= 4 * min_bits: 858 self.EmitVLoadA(1, 32, destinations[:4], 859 self.DereferenceIncrement(source, alignment)) 860 bits_to_load -= 4 * min_bits 861 destinations = destinations[4:] 862 elif bits_to_load >= 3 * min_bits: 863 self.EmitVLoadA(1, 32, destinations[:3], 864 self.DereferenceIncrement(source, alignment)) 865 bits_to_load -= 3 * min_bits 866 destinations = destinations[3:] 867 elif bits_to_load >= 2 * min_bits: 868 self.EmitVLoadA(1, 32, destinations[:2], 869 self.DereferenceIncrement(source, alignment)) 870 bits_to_load -= 2 * min_bits 871 destinations = destinations[2:] 872 elif bits_to_load >= min_bits: 873 self.EmitVLoad(1, 32, destinations[0], 874 self.DereferenceIncrement(source, alignment)) 875 bits_to_load -= min_bits 876 destinations = destinations[1:] 877 elif bits_to_load >= 64: 878 self.EmitVLoad(1, 32, 879 _Cast(64, destinations[0]), 880 self.DereferenceIncrement(source)) 881 bits_to_load -= 64 882 leftover_loaded += 64 883 elif bits_to_load >= 32: 884 self.EmitVLoad(1, 32, 885 self.Lane(32, destinations[0], leftover_loaded / 32), 886 self.DereferenceIncrement(source)) 887 bits_to_load -= 32 888 leftover_loaded += 32 889 elif bits_to_load >= 16: 890 self.EmitVLoad(1, 16, 891 self.Lane(16, destinations[0], leftover_loaded / 16), 892 self.DereferenceIncrement(source)) 893 bits_to_load -= 16 894 leftover_loaded += 16 895 elif bits_to_load is 8: 896 self.EmitVLoad(1, 8, 897 self.Lane(8, destinations[0], leftover_loaded / 8), 898 self.DereferenceIncrement(source)) 899 bits_to_load -= 8 900 leftover_loaded += 8 901 else: 902 raise ArgumentError('Wrong leftover: %d' % bits_to_load) 903 904 def EmitVLoadE(self, load_type, count, destination, source, alignment=None): 905 self.EmitVLoadAE(load_type, count, [destination], source, alignment) 906 907 def EmitVLoadAllLanes(self, load_no, load_type, destination, source): 908 new_destination = destination.Copy() 909 new_destination.lane = -1 910 new_destination.lane_bits = load_type 911 self.EmitVLoad(load_no, load_type, new_destination, source) 912 913 def EmitVLoadOffset(self, load_no, load_type, destination, source, offset): 914 self.EmitVLoadOffsetA(load_no, load_type, [destination], source, offset) 915 916 def EmitVLoadOffsetA(self, load_no, load_type, destinations, source, offset): 917 assert len(destinations) <= 4 918 self.EmitOp3('ld%d' % load_no, 919 _RegisterList(load_type, destinations), source, offset) 920 921 def EmitPld(self, load_address_register): 922 self.EmitOp2('prfm', 'pldl1keep', '[%s]' % load_address_register) 923 924 def EmitPldOffset(self, load_address_register, offset): 925 self.EmitOp2('prfm', 'pldl1keep', 926 '[%s, %s]' % (load_address_register, offset)) 927 928 def EmitVShl(self, shift_type, destination, source, shift): 929 self.EmitOp3('sshl', 930 _AppendType(shift_type, destination), 931 _AppendType(shift_type, source), _AppendType('i32', shift)) 932 933 def EmitVStore(self, store_no, store_type, source, destination): 934 self.EmitVStoreA(store_no, store_type, [source], destination) 935 936 def EmitVStoreA(self, store_no, store_type, sources, destination): 937 if destination.dereference_increment: 938 increment = sum([_LoadStoreSize(source) for source in sources]) / 8 939 self.EmitVStoreAPostIncrement(store_no, store_type, sources, destination, 940 self.ImmediateConstant(increment)) 941 else: 942 self.EmitVStoreAPostIncrement(store_no, store_type, sources, destination, 943 None) 944 945 def EmitVStoreAPostIncrement(self, store_no, store_type, sources, destination, 946 increment): 947 source_list = _RegisterList(store_type, sources) 948 if increment: 949 self.EmitOp3('st%d' % store_no, source_list, destination, increment) 950 else: 951 self.EmitOp2('st%d' % store_no, source_list, destination) 952 953 def EmitVStoreAE(self, 954 store_type, 955 elem_count, 956 sources, 957 destination, 958 alignment=None): 959 """Generate assembly to store an array of elements of given size.""" 960 bits_to_store = store_type * elem_count 961 min_bits = min([source.register_bits for source in sources]) 962 max_bits = max([source.register_bits for source in sources]) 963 964 if min_bits is not max_bits: 965 raise ArgumentError('Cannot mix double and quad stores.') 966 967 if len(sources) * min_bits < bits_to_store: 968 raise ArgumentError('To few destinations: %d to store %d bits.' % 969 (len(sources), bits_to_store)) 970 971 leftover_stored = 0 972 while bits_to_store > 0: 973 if bits_to_store >= 4 * min_bits: 974 self.EmitVStoreA(1, 32, sources[:4], 975 self.DereferenceIncrement(destination, alignment)) 976 bits_to_store -= 4 * min_bits 977 sources = sources[4:] 978 elif bits_to_store >= 3 * min_bits: 979 self.EmitVStoreA(1, 32, sources[:3], 980 self.DereferenceIncrement(destination, alignment)) 981 bits_to_store -= 3 * min_bits 982 sources = sources[3:] 983 elif bits_to_store >= 2 * min_bits: 984 self.EmitVStoreA(1, 32, sources[:2], 985 self.DereferenceIncrement(destination, alignment)) 986 bits_to_store -= 2 * min_bits 987 sources = sources[2:] 988 elif bits_to_store >= min_bits: 989 self.EmitVStore(1, 32, sources[0], 990 self.DereferenceIncrement(destination, alignment)) 991 bits_to_store -= min_bits 992 sources = sources[1:] 993 elif bits_to_store >= 64: 994 self.EmitVStore(1, 32, 995 _Cast(64, sources[0]), 996 self.DereferenceIncrement(destination, alignment)) 997 bits_to_store -= 64 998 leftover_stored += 64 999 elif bits_to_store >= 32: 1000 self.EmitVStore(1, 32, 1001 self.Lane(32, sources[0], leftover_stored / 32), 1002 self.DereferenceIncrement(destination)) 1003 bits_to_store -= 32 1004 leftover_stored += 32 1005 elif bits_to_store >= 16: 1006 self.EmitVStore(1, 16, 1007 self.Lane(16, sources[0], leftover_stored / 16), 1008 self.DereferenceIncrement(destination)) 1009 bits_to_store -= 16 1010 leftover_stored += 16 1011 elif bits_to_store >= 8: 1012 self.EmitVStore(1, 8, 1013 self.Lane(8, sources[0], leftover_stored / 8), 1014 self.DereferenceIncrement(destination)) 1015 bits_to_store -= 8 1016 leftover_stored += 8 1017 else: 1018 raise ArgumentError('Wrong leftover: %d' % bits_to_store) 1019 1020 def EmitVStoreE(self, store_type, count, source, destination, alignment=None): 1021 self.EmitVStoreAE(store_type, count, [source], destination, alignment) 1022 1023 def EmitVStoreOffset(self, store_no, store_type, source, destination, offset): 1024 self.EmitVStoreOffsetA(store_no, store_type, [source], destination, offset) 1025 1026 def EmitVStoreOffsetA(self, store_no, store_type, sources, destination, 1027 offset): 1028 self.EmitOp3('st%d' % store_no, 1029 _RegisterList(store_type, sources), destination, offset) 1030 1031 def EmitVStoreOffsetE(self, store_type, count, source, destination, offset): 1032 if store_type is not 32: 1033 raise ArgumentError('Unsupported store_type: %d' % store_type) 1034 1035 if count == 1: 1036 self.EmitVStoreOffset(1, 32, 1037 self.Lane(32, source, 0), 1038 self.Dereference(destination, None), offset) 1039 elif count == 2: 1040 self.EmitVStoreOffset(1, 32, 1041 _Cast(64, source), 1042 self.Dereference(destination, None), offset) 1043 elif count == 3: 1044 self.EmitVStore(1, 32, 1045 _Cast(64, source), 1046 self.DereferenceIncrement(destination, None)) 1047 self.EmitVStoreOffset(1, 32, 1048 self.Lane(32, source, 2), 1049 self.Dereference(destination, None), offset) 1050 self.EmitSub(destination, destination, self.ImmediateConstant(8)) 1051 elif count == 4: 1052 self.EmitVStoreOffset(1, 32, source, 1053 self.Dereference(destination, None), offset) 1054 else: 1055 raise ArgumentError('To many elements: %d' % count) 1056 1057 def EmitVSumReduce(self, reduce_type, elem_count, reduce_count, destinations, 1058 sources): 1059 """Generate assembly to perform n-fold horizontal sum reduction.""" 1060 if reduce_type is not 'u32': 1061 raise ArgumentError('Unsupported reduce: %s' % reduce_type) 1062 1063 if (elem_count + 3) / 4 > len(destinations): 1064 raise ArgumentError('To few destinations: %d (%d needed)' % 1065 (len(destinations), (elem_count + 3) / 4)) 1066 1067 if elem_count * reduce_count > len(sources) * 4: 1068 raise ArgumentError('To few sources: %d' % len(sources)) 1069 1070 if reduce_count <= 1: 1071 raise ArgumentError('Unsupported reduce_count: %d' % reduce_count) 1072 1073 sources = [_Cast(128, source) for source in sources] 1074 destinations = [_Cast(128, destination) for destination in destinations] 1075 1076 while reduce_count > 1: 1077 if len(sources) % 2 == 1: 1078 sources.append(sources[-1]) 1079 1080 if reduce_count == 2: 1081 for i in range(len(destinations)): 1082 self.EmitVPadd(reduce_type, destinations[i], sources[2 * i], 1083 sources[2 * i + 1]) 1084 return 1085 else: 1086 sources_2 = [] 1087 for i in range(len(sources) / 2): 1088 self.EmitVPadd(reduce_type, sources[2 * i], sources[2 * i], 1089 sources[2 * i + 1]) 1090 sources_2.append(sources[2 * i]) 1091 reduce_count /= 2 1092 sources = sources_2 1093 1094 def EmitVUzp1(self, uzp_type, destination, source_1, source_2): 1095 self.EmitOp3('uzp1', 1096 _AppendType(uzp_type, destination), 1097 _AppendType(uzp_type, source_1), 1098 _AppendType(uzp_type, source_2)) 1099 1100 def EmitVUzp2(self, uzp_type, destination, source_1, source_2): 1101 self.EmitOp3('uzp2', 1102 _AppendType(uzp_type, destination), 1103 _AppendType(uzp_type, source_1), 1104 _AppendType(uzp_type, source_2)) 1105 1106 def EmitVUzp(self, uzp_type, destination_1, destination_2, source_1, 1107 source_2): 1108 self.EmitVUzp1(uzp_type, destination_1, source_1, source_2) 1109 self.EmitVUzp2(uzp_type, destination_2, source_1, source_2) 1110 1111 def EmitVTrn1(self, trn_type, destination, source_1, source_2): 1112 self.EmitOp3('trn1', 1113 _AppendType(trn_type, destination), 1114 _AppendType(trn_type, source_1), 1115 _AppendType(trn_type, source_2)) 1116 1117 def EmitVTrn2(self, trn_type, destination, source_1, source_2): 1118 self.EmitOp3('trn2', 1119 _AppendType(trn_type, destination), 1120 _AppendType(trn_type, source_1), 1121 _AppendType(trn_type, source_2)) 1122 1123 def EmitVTrn(self, trn_type, destination_1, destination_2, source_1, 1124 source_2): 1125 self.EmitVTrn1(trn_type, destination_1, source_1, source_2) 1126 self.EmitVTrn2(trn_type, destination_2, source_1, source_2) 1127 1128 def EmitColBlockStride(self, cols, stride, new_stride): 1129 assert cols in [1, 2, 3, 4, 5, 6, 7, 8] 1130 if cols in [5, 6, 7]: 1131 self.EmitSub(new_stride, stride, self.ImmediateConstant(4)) 1132 1133 def EmitLoadColBlock(self, registers, load_type, cols, elements, block, 1134 input_address, stride): 1135 assert cols is len(block) 1136 assert load_type is 8 1137 1138 input_deref = self.Dereference(input_address, None) 1139 input_deref_increment = self.DereferenceIncrement(input_address, None) 1140 1141 if cols is 1: 1142 for i in range(elements): 1143 self.EmitVLoadOffset(1, 8, 1144 self.Lane(8, block[0], i), input_deref, stride) 1145 self.EmitPld(input_address) 1146 return block 1147 elif cols is 2: 1148 temp = [registers.DoubleRegister() for unused_i in range(2)] 1149 for i in range(elements): 1150 self.EmitVLoadOffset(1, 16, 1151 self.Lane(16, block[i / 4], i % 4), input_deref, 1152 stride) 1153 self.EmitPld(input_address) 1154 self.EmitVUzp(8, temp[0], temp[1], block[0], block[1]) 1155 registers.FreeRegisters(block) 1156 return temp 1157 elif cols is 3: 1158 for i in range(elements): 1159 self.EmitVLoadOffsetA(3, 8, [self.Lane(8, row, i) for row in block], 1160 input_deref, stride) 1161 self.EmitPld(input_address) 1162 return block 1163 elif cols is 4: 1164 temp = [registers.DoubleRegister() for unused_i in range(4)] 1165 for i in range(elements): 1166 self.EmitVLoadOffset(1, 32, 1167 self.Lane(32, block[i % 4], i / 4), input_deref, 1168 stride) 1169 self.EmitPld(input_address) 1170 self.EmitVTrn(16, temp[0], temp[2], block[0], block[2]) 1171 self.EmitVTrn(16, temp[1], temp[3], block[1], block[3]) 1172 self.EmitVTrn(8, block[0], block[1], temp[0], temp[1]) 1173 self.EmitVTrn(8, block[2], block[3], temp[2], temp[3]) 1174 registers.FreeRegisters(temp) 1175 return block 1176 elif cols is 5: 1177 temp = [registers.DoubleRegister() for unused_i in range(4)] 1178 for i in range(elements): 1179 self.EmitVLoad(1, 32, 1180 self.Lane(32, block[i % 4], i / 4), 1181 input_deref_increment) 1182 self.EmitVLoadOffset(1, 8, 1183 self.Lane(8, block[4], i), input_deref, stride) 1184 self.EmitPld(input_address) 1185 self.EmitVTrn(16, temp[0], temp[2], block[0], block[2]) 1186 self.EmitVTrn(16, temp[1], temp[3], block[1], block[3]) 1187 self.EmitVTrn(8, block[0], block[1], temp[0], temp[1]) 1188 self.EmitVTrn(8, block[2], block[3], temp[2], temp[3]) 1189 registers.FreeRegisters(temp) 1190 return block 1191 elif cols is 6: 1192 temp = [registers.DoubleRegister() for unused_i in range(6)] 1193 for i in range(elements): 1194 self.EmitVLoad(1, 32, 1195 self.Lane(32, block[i % 4], i / 4), 1196 input_deref_increment) 1197 self.EmitVLoadOffset(1, 16, 1198 self.Lane(16, block[4 + i / 4], i % 4), 1199 input_deref, stride) 1200 self.EmitPld(input_address) 1201 self.EmitVTrn(16, temp[0], temp[2], block[0], block[2]) 1202 self.EmitVTrn(16, temp[1], temp[3], block[1], block[3]) 1203 self.EmitVUzp(8, temp[4], temp[5], block[4], block[5]) 1204 self.EmitVTrn(8, block[0], block[1], temp[0], temp[1]) 1205 self.EmitVTrn(8, block[2], block[3], temp[2], temp[3]) 1206 registers.FreeRegisters( 1207 [block[4], block[5], temp[0], temp[1], temp[2], temp[3]]) 1208 return [block[0], block[1], block[2], block[3], temp[4], temp[5]] 1209 elif cols is 7: 1210 temp = [registers.DoubleRegister() for unused_i in range(4)] 1211 for i in range(elements): 1212 self.EmitVLoad(1, 32, 1213 self.Lane(32, block[i % 4], i / 4), 1214 input_deref_increment) 1215 self.EmitVLoadOffsetA(3, 8, 1216 [self.Lane(8, row, i) for row in block[4:]], 1217 input_deref, stride) 1218 self.EmitPld(input_address) 1219 self.EmitVTrn1(16, temp[0], block[0], block[2]) 1220 self.EmitVTrn2(16, temp[2], block[0], block[2]) 1221 self.EmitVTrn1(16, temp[1], block[1], block[3]) 1222 self.EmitVTrn2(16, temp[3], block[1], block[3]) 1223 self.EmitVTrn1(8, block[0], temp[0], temp[1]) 1224 self.EmitVTrn2(8, block[1], temp[0], temp[1]) 1225 self.EmitVTrn1(8, block[2], temp[2], temp[3]) 1226 self.EmitVTrn2(8, block[3], temp[2], temp[3]) 1227 registers.FreeRegisters(temp) 1228 return block 1229 elif cols is 8: 1230 temp = [registers.DoubleRegister() for unused_i in range(8)] 1231 for i in range(elements): 1232 self.EmitVLoadOffset(1, 32, block[i], input_deref, stride) 1233 self.EmitPld(input_address) 1234 self.EmitVTrn(8, temp[0], temp[1], block[0], block[1]) 1235 self.EmitVTrn(8, temp[2], temp[3], block[2], block[3]) 1236 self.EmitVTrn(8, temp[4], temp[5], block[4], block[5]) 1237 self.EmitVTrn(8, temp[6], temp[7], block[6], block[7]) 1238 self.EmitVTrn(16, block[0], block[2], temp[0], temp[2]) 1239 self.EmitVTrn(16, block[1], block[3], temp[1], temp[3]) 1240 self.EmitVTrn(16, block[4], block[6], temp[4], temp[6]) 1241 self.EmitVTrn(16, block[5], block[7], temp[5], temp[7]) 1242 self.EmitVTrn(32, temp[0], temp[4], block[0], block[4]) 1243 self.EmitVTrn(32, temp[1], temp[5], block[1], block[5]) 1244 self.EmitVTrn(32, temp[2], temp[6], block[2], block[6]) 1245 self.EmitVTrn(32, temp[3], temp[7], block[3], block[7]) 1246 registers.FreeRegisters(block) 1247 return temp 1248 else: 1249 assert False 1250 1251 def Dereference(self, value, unused_alignment=None): 1252 new_value = value.Copy() 1253 new_value.dereference = True 1254 return new_value 1255 1256 def DereferenceIncrement(self, value, alignment=None): 1257 new_value = self.Dereference(value, alignment).Copy() 1258 new_value.dereference_increment = True 1259 return new_value 1260 1261 def ImmediateConstant(self, value): 1262 return _ImmediateConstant(value) 1263 1264 def AllLanes(self, value): 1265 return '%s[]' % value 1266 1267 def Lane(self, bits, value, lane): 1268 new_value = value.Copy() 1269 if bits * (lane + 1) > new_value.register_bits: 1270 raise ArgumentError('Lane to big: (%d + 1) x %d > %d' % 1271 (lane, bits, new_value.register_bits)) 1272 new_value.lane = lane 1273 new_value.lane_bits = bits 1274 return new_value 1275 1276 def CreateRegisters(self): 1277 return _NeonRegisters64Bit() 1278