1"""Qnt primitive used by the GEMM function. 2 3""" 4 5import neon_emitter 6 7 8class Error(Exception): 9 """Module level error.""" 10 11 12class ConfigurationError(Error): 13 """Unsupported configuration.""" 14 15 16class QntLane(object): 17 18 def __init__(self, source, output, offset, load_1, load_2): 19 self.source = source 20 self.output = output 21 self.offset = offset 22 self.load_1 = load_1 23 self.load_2 = load_2 24 25 26def BuildName(lanes, leftovers, aligned): 27 name = 'qnt_%dx8' % lanes 28 if leftovers: 29 name += '_%d' % leftovers 30 if aligned: 31 name += '_aligned' 32 return name 33 34 35def LoadAndDuplicateOffsets(emitter, registers, lanes, offsets): 36 if lanes == 1 or lanes == 2 or lanes == 3: 37 offset_registers = [] 38 for unused_i in range(0, lanes): 39 register = registers.QuadRegister() 40 emitter.EmitVLoadA('1.32', 41 [emitter.AllLanes(registers.Low(register)), 42 emitter.AllLanes(registers.High(register))], 43 emitter.DereferenceIncrement(offsets, 32)) 44 offset_registers.append(register) 45 return offset_registers 46 else: 47 raise ConfigurationError('Unsupported number of lanes: %d' % lanes) 48 49 50def GenerateQntLanes(emitter, 51 registers, 52 qnt_lanes, 53 source, 54 stride, 55 destination, 56 destination_stride, 57 offsets): 58 """Prepare lanes for reading unquantized multiplication results.""" 59 offset_registers = LoadAndDuplicateOffsets( 60 emitter, registers, qnt_lanes, offsets) 61 62 lanes = [] 63 last_input_register = source 64 last_output_register = destination 65 for i in range(0, qnt_lanes): 66 if not i: 67 lanes.append(QntLane(source, 68 destination, 69 offset_registers[i], 70 registers.QuadRegister(), # load 1 71 registers.QuadRegister())) # load 2 72 else: 73 input_register = registers.GeneralRegister() 74 output_register = registers.GeneralRegister() 75 lanes.append(QntLane(input_register, 76 output_register, 77 offset_registers[i], 78 registers.QuadRegister(), # load 1 79 registers.QuadRegister())) # load 2 80 emitter.EmitAdd(input_register, last_input_register, stride) 81 emitter.EmitAdd(output_register, last_output_register, destination_stride) 82 last_input_register = input_register 83 last_output_register = output_register 84 return lanes 85 86 87def DuplicateRegister(emitter, registers, value): 88 register = registers.QuadRegister() 89 emitter.EmitVDup('32', register, value) 90 return register 91 92 93def GenerateQuantize(emitter, 94 registers, 95 lanes, 96 lane_temps, 97 multiplicative_offset, 98 rounding_offset, 99 shift): 100 """Inner loop for quantization: add offsets, multiply, round, shift.""" 101 for lane in lanes: 102 emitter.EmitVAdd('i32', lane[0], lane[0], lane[1]) 103 104 for lane in lanes: 105 emitter.EmitVMul('i32', lane[0], lane[0], multiplicative_offset) 106 107 for lane in lanes: 108 emitter.EmitVAdd('i32', lane[0], lane[0], rounding_offset) 109 110 for lane in lanes: 111 emitter.EmitVShl('s32', lane[0], lane[0], shift) 112 113 for lane in lanes: 114 emitter.EmitVQmovn('s32', lane[2], lane[0]) 115 116 for lane_temp in lane_temps: 117 emitter.EmitVQmovun('s16', registers.Low(lane_temp), lane_temp) 118 119 120def GenerateLoadQuantizeStore(emitter, 121 registers, 122 lanes, 123 multiplicative_offset, 124 rounding_offset, 125 shift, 126 alignment): 127 """Load unquantized data from lanes, quantize, store final result.""" 128 lane_temps = [] 129 for lane in lanes: 130 lane_temps.append(registers.QuadRegister()) 131 132 for lane in lanes: 133 emitter.EmitVLoadA('1.32', 134 [registers.Low(lane.load_1), 135 registers.High(lane.load_1), 136 registers.Low(lane.load_2), 137 registers.High(lane.load_2)], 138 emitter.DereferenceIncrement(lane.source, 64)) 139 140 for lane in lanes: 141 emitter.EmitPld(lane.source) 142 143 quantize_setup = [] 144 for (lane_temp, lane) in zip(lane_temps, lanes): 145 quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)]) 146 quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)]) 147 148 GenerateQuantize(emitter, 149 registers, 150 quantize_setup, 151 lane_temps, 152 multiplicative_offset, 153 rounding_offset, 154 shift) 155 156 for (lane_temp, lane) in zip(lane_temps, lanes): 157 emitter.EmitVStore('1.8', 158 registers.Low(lane_temp), 159 emitter.DereferenceIncrement(lane.output, alignment)) 160 161 for lane_temp in lane_temps: 162 registers.FreeRegister(lane_temp) 163 164 165def GenerateLoadLeftovers(emitter, registers, leftovers, lanes): 166 """Handle non multiply of 8 leftover loading.""" 167 if leftovers == 1: 168 for lane in lanes: 169 emitter.EmitVLoad('1.32', 170 emitter.Lane(registers.Low(lane.load_1), 0), 171 emitter.Dereference(lane.source, None)) 172 elif leftovers == 2: 173 for lane in lanes: 174 emitter.EmitVLoad('1.32', 175 registers.Low(lane.load_1), 176 emitter.Dereference(lane.source, 64)) 177 elif leftovers == 3: 178 for lane in lanes: 179 emitter.EmitVLoad('1.32', 180 registers.Low(lane.load_1), 181 emitter.DereferenceIncrement(lane.source, 64)) 182 for lane in lanes: 183 emitter.EmitVLoad('1.32', 184 emitter.Lane(registers.High(lane.load_1), 0), 185 emitter.Dereference(lane.source, None)) 186 elif leftovers == 4: 187 for lane in lanes: 188 emitter.EmitVLoadA('1.32', 189 [registers.Low(lane.load_1), 190 registers.High(lane.load_1)], 191 emitter.Dereference(lane.source, 64)) 192 elif leftovers == 5: 193 for lane in lanes: 194 emitter.EmitVLoadA('1.32', 195 [registers.Low(lane.load_1), 196 registers.High(lane.load_1)], 197 emitter.DereferenceIncrement(lane.source, 64)) 198 for lane in lanes: 199 emitter.EmitVLoad('1.32', 200 emitter.Lane(registers.Low(lane.load_2), 0), 201 emitter.Dereference(lane.source, None)) 202 elif leftovers == 6: 203 for lane in lanes: 204 emitter.EmitVLoadA('1.32', 205 [registers.Low(lane.load_1), 206 registers.High(lane.load_1), 207 registers.Low(lane.load_2)], 208 emitter.Dereference(lane.source, 64)) 209 elif leftovers == 7: 210 for lane in lanes: 211 emitter.EmitVLoadA('1.32', 212 [registers.Low(lane.load_1), 213 registers.High(lane.load_1), 214 registers.Low(lane.load_2)], 215 emitter.DereferenceIncrement(lane.source, 64)) 216 for lane in lanes: 217 emitter.EmitVLoad('1.32', 218 emitter.Lane(registers.High(lane.load_2), 0), 219 emitter.Dereference(lane.source, None)) 220 else: 221 raise ConfigurationError('Unsuported leftover count: %d' % leftovers) 222 223 224def GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes): 225 """Handle non multiply of 8 leftover storing.""" 226 setup = [] 227 for (temp, lane) in zip(lane_temps, lanes): 228 setup.append([registers.Low(temp), lane.output]) 229 230 if leftovers == 1: 231 for lane in setup: 232 emitter.EmitVStore('1.8', emitter.Lane(lane[0], 0), 233 emitter.Dereference(lane[1], None)) 234 elif leftovers == 2: 235 for lane in setup: 236 emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0), 237 emitter.Dereference(lane[1], None)) 238 elif leftovers == 3: 239 for lane in setup: 240 emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0), 241 emitter.DereferenceIncrement(lane[1], None)) 242 for lane in setup: 243 emitter.EmitVStore('1.8', emitter.Lane(lane[0], 2), 244 emitter.Dereference(lane[1], None)) 245 elif leftovers == 4: 246 for lane in setup: 247 emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), 248 emitter.Dereference(lane[1], None)) 249 elif leftovers == 5: 250 for lane in setup: 251 emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), 252 emitter.DereferenceIncrement(lane[1], None)) 253 for lane in setup: 254 emitter.EmitVStore('1.8', emitter.Lane(lane[0], 4), 255 emitter.Dereference(lane[1], None)) 256 elif leftovers == 6: 257 for lane in setup: 258 emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), 259 emitter.DereferenceIncrement(lane[1], None)) 260 for lane in setup: 261 emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2), 262 emitter.Dereference(lane[1], None)) 263 elif leftovers == 7: 264 for lane in setup: 265 emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), 266 emitter.DereferenceIncrement(lane[1], None)) 267 for lane in setup: 268 emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2), 269 emitter.DereferenceIncrement(lane[1], None)) 270 for lane in setup: 271 emitter.EmitVStore('1.8', emitter.Lane(lane[0], 6), 272 emitter.DereferenceIncrement(lane[1], None)) 273 else: 274 raise ConfigurationError('Unsupported leftovers count: %d' % leftovers) 275 276 277def GenerateLeftoverLoadQuantizeStore(emitter, 278 registers, 279 leftovers, 280 lanes, 281 multiplicative_offset, 282 rounding_offset, 283 shift): 284 """Handle leftovers if row size not a multiply of 8.""" 285 lane_temps = [] 286 for lane in lanes: 287 lane_temps.append(registers.QuadRegister()) 288 289 GenerateLoadLeftovers(emitter, registers, leftovers, lanes) 290 291 quantize_setup = [] 292 for (lane_temp, lane) in zip(lane_temps, lanes): 293 quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)]) 294 if leftovers > 4: 295 quantize_setup.append( 296 [lane.load_2, lane.offset, registers.High(lane_temp)]) 297 298 GenerateQuantize(emitter, 299 registers, 300 quantize_setup, 301 lane_temps, 302 multiplicative_offset, 303 rounding_offset, 304 shift) 305 306 GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes) 307 308 309def GenerateQntNx8(emitter, qnt_lanes, leftovers, aligned): 310 """Emits optimized quantization code for given lanes and row size.""" 311 if leftovers < 0 or leftovers > 7: 312 raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.') 313 if qnt_lanes < 1 or qnt_lanes > 3: 314 raise ConfigurationError('Qnt_lanes should should be 1, 2 or 3.') 315 316 name = BuildName(qnt_lanes, leftovers, aligned) 317 318 emitter.EmitFunctionBeginA(name, 319 [['const std::int32_t*', 'source'], 320 ['std::int32_t', 'count'], 321 ['std::int32_t', 'stride'], 322 ['const std::int32_t*', 'offsets'], 323 ['std::uint8_t*', 'destination'], 324 ['std::int32_t', 'destination_stride'], 325 ['std::int32_t', 'multiplicative_offset'], 326 ['std::int32_t', 'rounding_offset'], 327 ['std::int32_t', 'shift']], 328 'void') 329 emitter.EmitAssert('count %% 8 == %d' % leftovers) 330 emitter.EmitAssert('count >= 8') 331 emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0') 332 if aligned: 333 emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0') 334 if qnt_lanes > 1: 335 emitter.EmitAssert( 336 'destination_stride % 8 == 0') 337 emitter.EmitAsmBegin() 338 339 registers = neon_emitter.NeonRegisters() 340 341 count = registers.MapParameter('count') 342 343 multiplicative_offset = DuplicateRegister( 344 emitter, registers, registers.MapParameter('multiplicative_offset')) 345 rounding_offset = DuplicateRegister( 346 emitter, registers, registers.MapParameter('rounding_offset')) 347 shift = DuplicateRegister(emitter, registers, registers.MapParameter('shift')) 348 349 lanes = GenerateQntLanes( 350 emitter, registers, qnt_lanes, 351 registers.MapParameter('source'), 352 registers.MapParameter('stride'), 353 registers.MapParameter('destination'), 354 registers.MapParameter('destination_stride'), 355 registers.MapParameter('offsets')) 356 357 if leftovers: 358 emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers)) 359 emitter.EmitBeqFront(2) 360 361 emitter.EmitNewline() 362 emitter.EmitNumericalLabel(1) 363 emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) 364 365 GenerateLoadQuantizeStore(emitter, 366 registers, 367 lanes, 368 multiplicative_offset, 369 rounding_offset, 370 shift, 371 64 if aligned else None) 372 373 emitter.EmitNewline() 374 emitter.EmitBneBack(1) 375 376 if leftovers: 377 emitter.EmitNumericalLabel(2) 378 GenerateLeftoverLoadQuantizeStore(emitter, 379 registers, 380 leftovers, 381 lanes, 382 multiplicative_offset, 383 rounding_offset, 384 shift) 385 386 emitter.EmitAsmEnd(registers.MappedParameters(), 387 [], 388 registers.Clobbers() + ['cc', 'memory']) 389 emitter.EmitFunctionEnd() 390 391 392def GenerateFunctions(emitter): 393 for aligned in [True, False]: 394 for lanes in range(1, 4): 395 for leftovers in range(0, 8): 396 GenerateQntNx8(emitter, lanes, leftovers, aligned) 397 emitter.EmitNewline() 398