1"""Qnt primitive used by the GEMM function.
2
3"""
4
5import neon_emitter
6
7
8class Error(Exception):
9  """Module level error."""
10
11
12class ConfigurationError(Error):
13  """Unsupported configuration."""
14
15
16class QntLane(object):
17
18  def __init__(self, source, output, offset, load_1, load_2):
19    self.source = source
20    self.output = output
21    self.offset = offset
22    self.load_1 = load_1
23    self.load_2 = load_2
24
25
26def BuildName(lanes, leftovers, aligned):
27  name = 'qnt_%dx8' % lanes
28  if leftovers:
29    name += '_%d' % leftovers
30  if aligned:
31    name += '_aligned'
32  return name
33
34
35def LoadAndDuplicateOffsets(emitter, registers, lanes, offsets):
36  if lanes == 1 or lanes == 2 or lanes == 3:
37    offset_registers = []
38    for unused_i in range(0, lanes):
39      register = registers.QuadRegister()
40      emitter.EmitVLoadA('1.32',
41                         [emitter.AllLanes(registers.Low(register)),
42                          emitter.AllLanes(registers.High(register))],
43                         emitter.DereferenceIncrement(offsets, 32))
44      offset_registers.append(register)
45    return offset_registers
46  else:
47    raise ConfigurationError('Unsupported number of lanes: %d' % lanes)
48
49
50def GenerateQntLanes(emitter,
51                     registers,
52                     qnt_lanes,
53                     source,
54                     stride,
55                     destination,
56                     destination_stride,
57                     offsets):
58  """Prepare lanes for reading unquantized multiplication results."""
59  offset_registers = LoadAndDuplicateOffsets(
60      emitter, registers, qnt_lanes, offsets)
61
62  lanes = []
63  last_input_register = source
64  last_output_register = destination
65  for i in range(0, qnt_lanes):
66    if not i:
67      lanes.append(QntLane(source,
68                           destination,
69                           offset_registers[i],
70                           registers.QuadRegister(),  # load 1
71                           registers.QuadRegister()))  # load 2
72    else:
73      input_register = registers.GeneralRegister()
74      output_register = registers.GeneralRegister()
75      lanes.append(QntLane(input_register,
76                           output_register,
77                           offset_registers[i],
78                           registers.QuadRegister(),  # load 1
79                           registers.QuadRegister()))  # load 2
80      emitter.EmitAdd(input_register, last_input_register, stride)
81      emitter.EmitAdd(output_register, last_output_register, destination_stride)
82      last_input_register = input_register
83      last_output_register = output_register
84  return lanes
85
86
87def DuplicateRegister(emitter, registers, value):
88  register = registers.QuadRegister()
89  emitter.EmitVDup('32', register, value)
90  return register
91
92
93def GenerateQuantize(emitter,
94                     registers,
95                     lanes,
96                     lane_temps,
97                     multiplicative_offset,
98                     rounding_offset,
99                     shift):
100  """Inner loop for quantization: add offsets, multiply, round, shift."""
101  for lane in lanes:
102    emitter.EmitVAdd('i32', lane[0], lane[0], lane[1])
103
104  for lane in lanes:
105    emitter.EmitVMul('i32', lane[0], lane[0], multiplicative_offset)
106
107  for lane in lanes:
108    emitter.EmitVAdd('i32', lane[0], lane[0], rounding_offset)
109
110  for lane in lanes:
111    emitter.EmitVShl('s32', lane[0], lane[0], shift)
112
113  for lane in lanes:
114    emitter.EmitVQmovn('s32', lane[2], lane[0])
115
116  for lane_temp in lane_temps:
117    emitter.EmitVQmovun('s16', registers.Low(lane_temp), lane_temp)
118
119
120def GenerateLoadQuantizeStore(emitter,
121                              registers,
122                              lanes,
123                              multiplicative_offset,
124                              rounding_offset,
125                              shift,
126                              alignment):
127  """Load unquantized data from lanes, quantize, store final result."""
128  lane_temps = []
129  for lane in lanes:
130    lane_temps.append(registers.QuadRegister())
131
132  for lane in lanes:
133    emitter.EmitVLoadA('1.32',
134                       [registers.Low(lane.load_1),
135                        registers.High(lane.load_1),
136                        registers.Low(lane.load_2),
137                        registers.High(lane.load_2)],
138                       emitter.DereferenceIncrement(lane.source, 64))
139
140  for lane in lanes:
141    emitter.EmitPld(lane.source)
142
143  quantize_setup = []
144  for (lane_temp, lane) in zip(lane_temps, lanes):
145    quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)])
146    quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)])
147
148  GenerateQuantize(emitter,
149                   registers,
150                   quantize_setup,
151                   lane_temps,
152                   multiplicative_offset,
153                   rounding_offset,
154                   shift)
155
156  for (lane_temp, lane) in zip(lane_temps, lanes):
157    emitter.EmitVStore('1.8',
158                       registers.Low(lane_temp),
159                       emitter.DereferenceIncrement(lane.output, alignment))
160
161  for lane_temp in lane_temps:
162    registers.FreeRegister(lane_temp)
163
164
165def GenerateLoadLeftovers(emitter, registers, leftovers, lanes):
166  """Handle non multiply of 8 leftover loading."""
167  if leftovers == 1:
168    for lane in lanes:
169      emitter.EmitVLoad('1.32',
170                        emitter.Lane(registers.Low(lane.load_1), 0),
171                        emitter.Dereference(lane.source, None))
172  elif leftovers == 2:
173    for lane in lanes:
174      emitter.EmitVLoad('1.32',
175                        registers.Low(lane.load_1),
176                        emitter.Dereference(lane.source, 64))
177  elif leftovers == 3:
178    for lane in lanes:
179      emitter.EmitVLoad('1.32',
180                        registers.Low(lane.load_1),
181                        emitter.DereferenceIncrement(lane.source, 64))
182    for lane in lanes:
183      emitter.EmitVLoad('1.32',
184                        emitter.Lane(registers.High(lane.load_1), 0),
185                        emitter.Dereference(lane.source, None))
186  elif leftovers == 4:
187    for lane in lanes:
188      emitter.EmitVLoadA('1.32',
189                         [registers.Low(lane.load_1),
190                          registers.High(lane.load_1)],
191                         emitter.Dereference(lane.source, 64))
192  elif leftovers == 5:
193    for lane in lanes:
194      emitter.EmitVLoadA('1.32',
195                         [registers.Low(lane.load_1),
196                          registers.High(lane.load_1)],
197                         emitter.DereferenceIncrement(lane.source, 64))
198    for lane in lanes:
199      emitter.EmitVLoad('1.32',
200                        emitter.Lane(registers.Low(lane.load_2), 0),
201                        emitter.Dereference(lane.source, None))
202  elif leftovers == 6:
203    for lane in lanes:
204      emitter.EmitVLoadA('1.32',
205                         [registers.Low(lane.load_1),
206                          registers.High(lane.load_1),
207                          registers.Low(lane.load_2)],
208                         emitter.Dereference(lane.source, 64))
209  elif leftovers == 7:
210    for lane in lanes:
211      emitter.EmitVLoadA('1.32',
212                         [registers.Low(lane.load_1),
213                          registers.High(lane.load_1),
214                          registers.Low(lane.load_2)],
215                         emitter.DereferenceIncrement(lane.source, 64))
216    for lane in lanes:
217      emitter.EmitVLoad('1.32',
218                        emitter.Lane(registers.High(lane.load_2), 0),
219                        emitter.Dereference(lane.source, None))
220  else:
221    raise ConfigurationError('Unsuported leftover count: %d' % leftovers)
222
223
224def GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes):
225  """Handle non multiply of 8 leftover storing."""
226  setup = []
227  for (temp, lane) in zip(lane_temps, lanes):
228    setup.append([registers.Low(temp), lane.output])
229
230  if leftovers == 1:
231    for lane in setup:
232      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 0),
233                         emitter.Dereference(lane[1], None))
234  elif leftovers == 2:
235    for lane in setup:
236      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0),
237                         emitter.Dereference(lane[1], None))
238  elif leftovers == 3:
239    for lane in setup:
240      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0),
241                         emitter.DereferenceIncrement(lane[1], None))
242    for lane in setup:
243      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 2),
244                         emitter.Dereference(lane[1], None))
245  elif leftovers == 4:
246    for lane in setup:
247      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
248                         emitter.Dereference(lane[1], None))
249  elif leftovers == 5:
250    for lane in setup:
251      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
252                         emitter.DereferenceIncrement(lane[1], None))
253    for lane in setup:
254      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 4),
255                         emitter.Dereference(lane[1], None))
256  elif leftovers == 6:
257    for lane in setup:
258      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
259                         emitter.DereferenceIncrement(lane[1], None))
260    for lane in setup:
261      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2),
262                         emitter.Dereference(lane[1], None))
263  elif leftovers == 7:
264    for lane in setup:
265      emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0),
266                         emitter.DereferenceIncrement(lane[1], None))
267    for lane in setup:
268      emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2),
269                         emitter.DereferenceIncrement(lane[1], None))
270    for lane in setup:
271      emitter.EmitVStore('1.8', emitter.Lane(lane[0], 6),
272                         emitter.DereferenceIncrement(lane[1], None))
273  else:
274    raise ConfigurationError('Unsupported leftovers count: %d' % leftovers)
275
276
277def GenerateLeftoverLoadQuantizeStore(emitter,
278                                      registers,
279                                      leftovers,
280                                      lanes,
281                                      multiplicative_offset,
282                                      rounding_offset,
283                                      shift):
284  """Handle leftovers if row size not a multiply of 8."""
285  lane_temps = []
286  for lane in lanes:
287    lane_temps.append(registers.QuadRegister())
288
289  GenerateLoadLeftovers(emitter, registers, leftovers, lanes)
290
291  quantize_setup = []
292  for (lane_temp, lane) in zip(lane_temps, lanes):
293    quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)])
294    if leftovers > 4:
295      quantize_setup.append(
296          [lane.load_2, lane.offset, registers.High(lane_temp)])
297
298  GenerateQuantize(emitter,
299                   registers,
300                   quantize_setup,
301                   lane_temps,
302                   multiplicative_offset,
303                   rounding_offset,
304                   shift)
305
306  GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes)
307
308
309def GenerateQntNx8(emitter, qnt_lanes, leftovers, aligned):
310  """Emits optimized quantization code for given lanes and row size."""
311  if leftovers < 0 or leftovers > 7:
312    raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.')
313  if qnt_lanes < 1 or qnt_lanes > 3:
314    raise ConfigurationError('Qnt_lanes should should be 1, 2 or 3.')
315
316  name = BuildName(qnt_lanes, leftovers, aligned)
317
318  emitter.EmitFunctionBeginA(name,
319                             [['const std::int32_t*', 'source'],
320                              ['std::int32_t', 'count'],
321                              ['std::int32_t', 'stride'],
322                              ['const std::int32_t*', 'offsets'],
323                              ['std::uint8_t*', 'destination'],
324                              ['std::int32_t', 'destination_stride'],
325                              ['std::int32_t', 'multiplicative_offset'],
326                              ['std::int32_t', 'rounding_offset'],
327                              ['std::int32_t', 'shift']],
328                             'void')
329  emitter.EmitAssert('count %% 8 == %d' % leftovers)
330  emitter.EmitAssert('count >= 8')
331  emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0')
332  if aligned:
333    emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0')
334    if qnt_lanes > 1:
335      emitter.EmitAssert(
336          'destination_stride % 8 == 0')
337  emitter.EmitAsmBegin()
338
339  registers = neon_emitter.NeonRegisters()
340
341  count = registers.MapParameter('count')
342
343  multiplicative_offset = DuplicateRegister(
344      emitter, registers, registers.MapParameter('multiplicative_offset'))
345  rounding_offset = DuplicateRegister(
346      emitter, registers, registers.MapParameter('rounding_offset'))
347  shift = DuplicateRegister(emitter, registers, registers.MapParameter('shift'))
348
349  lanes = GenerateQntLanes(
350      emitter, registers, qnt_lanes,
351      registers.MapParameter('source'),
352      registers.MapParameter('stride'),
353      registers.MapParameter('destination'),
354      registers.MapParameter('destination_stride'),
355      registers.MapParameter('offsets'))
356
357  if leftovers:
358    emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers))
359    emitter.EmitBeqFront(2)
360
361  emitter.EmitNewline()
362  emitter.EmitNumericalLabel(1)
363  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
364
365  GenerateLoadQuantizeStore(emitter,
366                            registers,
367                            lanes,
368                            multiplicative_offset,
369                            rounding_offset,
370                            shift,
371                            64 if aligned else None)
372
373  emitter.EmitNewline()
374  emitter.EmitBneBack(1)
375
376  if leftovers:
377    emitter.EmitNumericalLabel(2)
378    GenerateLeftoverLoadQuantizeStore(emitter,
379                                      registers,
380                                      leftovers,
381                                      lanes,
382                                      multiplicative_offset,
383                                      rounding_offset,
384                                      shift)
385
386  emitter.EmitAsmEnd(registers.MappedParameters(),
387                     [],
388                     registers.Clobbers() + ['cc', 'memory'])
389  emitter.EmitFunctionEnd()
390
391
392def GenerateFunctions(emitter):
393  for aligned in [True, False]:
394    for lanes in range(1, 4):
395      for leftovers in range(0, 8):
396        GenerateQntNx8(emitter, lanes, leftovers, aligned)
397        emitter.EmitNewline()
398