1"""Mul primitive used by the GEMM function.
2
3The Mul primitive takes 1-3 zipped rows and 1-3 zipped columns and performs
4matrix multiplication on those resulting in a small 1x1 to 3x3 block of results.
5"""
6
7import neon_emitter
8
9
10class Error(Exception):
11  """Module level error."""
12
13
14class ConfigurationError(Error):
15  """Unsupported configuration."""
16
17
18class MulLanes(object):
19
20  def __init__(self, input_address):
21    self.input_address = input_address
22    self.lanes = []
23
24  def AddLane(self, lane):
25    self.lanes.append(lane)
26
27  def FreeRegisters(self, registers):
28    for i in range(0, len(self.lanes)):
29      registers.FreeRegister(self.lanes[i])
30      self.lanes[i] = None
31
32
33def GenerateMulLanes(registers, lane_count, address):
34  lanes = MulLanes(address)
35  for unused_i in range(0, lane_count):
36    lanes.AddLane(registers.DoubleRegister())
37  return lanes
38
39
40def Generate3MulLanes(quad_register, registers, address):
41  lanes = MulLanes(address)
42  lanes.AddLane(registers.Low(quad_register))
43  lanes.AddLane(registers.High(quad_register))
44  lanes.AddLane(registers.DoubleRegister())
45  return lanes
46
47
48def GenerateAndClearAggregators(emitter, registers, aggregator_count):
49  """Prepare aggregators and emit aggregator clear code."""
50  emitter.EmitComment('Clear aggregators.')
51  aggregators = []
52  for i in range(0, aggregator_count):
53    aggregator = registers.QuadRegister()
54    aggregators.append(aggregator)
55    if i < 3:
56      emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0))
57    else:
58      emitter.EmitVMov('i32', aggregator, aggregators[i - 3])
59  emitter.EmitNewline()
60  return aggregators
61
62
63def GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes,
64                                     right_lanes, aggregators, count):
65  """Emit inner loop for N rows x M cols multiplication."""
66  emitter.EmitComment('General NxM lanes loop.')
67  emitter.EmitNumericalLabel(1)
68  emitter.EmitNewline()
69  emitter.EmitComment('Subtract counter.')
70  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
71  emitter.EmitNewline()
72
73  emitter.EmitVLoadA('1.8', left_lanes.lanes,
74                     emitter.DereferenceIncrement(left_lanes.input_address, 64))
75  emitter.EmitVLoadA(
76      '1.8', right_lanes.lanes,
77      emitter.DereferenceIncrement(right_lanes.input_address, 64))
78
79  emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64))
80  emitter.EmitPldOffset(right_lanes.input_address,
81                        emitter.ImmediateConstant(64))
82
83  rows = len(left_lanes.lanes)
84  cols = len(right_lanes.lanes)
85
86  multiply_results = []
87  for i in range(0, rows * cols):
88    multiply_results.append(registers.QuadRegister())
89
90  for row in range(0, rows):
91    for col in range(0, cols):
92      index = row * cols + col
93      emitter.EmitVMull('u8', multiply_results[index], right_lanes.lanes[col],
94                        left_lanes.lanes[row])
95
96  for i in range(0, rows * cols):
97    emitter.EmitVPadal('u16', aggregators[i], multiply_results[i])
98
99  emitter.EmitNewline()
100  emitter.EmitComment('Loop break.')
101  emitter.EmitBneBack(1)
102  emitter.EmitNewline()
103
104  for register in multiply_results:
105    registers.FreeRegister(register)
106
107
108def Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes,
109                                     right_lanes, aggregators, count,
110                                     backup_register):
111  """Emit inner loop for 3 rows x 3 cols multiplication (register trick)."""
112  emitter.EmitComment('3x3 lanes loop.')
113  emitter.EmitNumericalLabel(1)
114  emitter.EmitNewline()
115  emitter.EmitComment('Subtract counter.')
116  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
117  emitter.EmitNewline()
118
119  emitter.EmitVLoadA('1.8', left_lanes.lanes,
120                     emitter.DereferenceIncrement(left_lanes.input_address, 64))
121  emitter.EmitVLoadA(
122      '1.8', right_lanes.lanes,
123      emitter.DereferenceIncrement(right_lanes.input_address, 64))
124
125  emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64))
126  emitter.EmitPldOffset(right_lanes.input_address,
127                        emitter.ImmediateConstant(64))
128
129  temp = []
130  for unused_i in range(0, 4):
131    temp.append(registers.QuadRegister())
132
133  emitter.EmitVMull('u8', temp[0], left_lanes.lanes[0], right_lanes.lanes[0])
134  emitter.EmitVMull('u8', temp[1], left_lanes.lanes[0], right_lanes.lanes[1])
135  emitter.EmitVMull('u8', temp[2], left_lanes.lanes[0], right_lanes.lanes[2])
136  emitter.EmitVMull('u8', temp[3], left_lanes.lanes[1], right_lanes.lanes[0])
137
138  emitter.EmitVPadal('u16', aggregators[0], temp[0])
139  emitter.EmitVPadal('u16', aggregators[1], temp[1])
140  emitter.EmitVPadal('u16', aggregators[2], temp[2])
141  emitter.EmitVPadal('u16', aggregators[3], temp[3])
142
143  emitter.EmitVMull('u8', temp[0], left_lanes.lanes[1], right_lanes.lanes[1])
144  emitter.EmitVMull('u8', temp[1], left_lanes.lanes[1], right_lanes.lanes[2])
145  emitter.EmitVMull('u8', temp[2], left_lanes.lanes[2], right_lanes.lanes[0])
146  emitter.EmitVMull('u8', temp[3], left_lanes.lanes[2], right_lanes.lanes[1])
147  emitter.EmitVMull('u8', backup_register, left_lanes.lanes[2],
148                    right_lanes.lanes[2])
149
150  emitter.EmitVPadal('u16', aggregators[4], temp[0])
151  emitter.EmitVPadal('u16', aggregators[5], temp[1])
152  emitter.EmitVPadal('u16', aggregators[6], temp[2])
153  emitter.EmitVPadal('u16', aggregators[7], temp[3])
154  emitter.EmitVPadal('u16', aggregators[8], backup_register)
155
156  emitter.EmitNewline()
157  emitter.EmitComment('Loop break.')
158  emitter.EmitBneBack(1)
159  emitter.EmitNewline()
160
161  for register in temp:
162    registers.FreeRegister(register)
163
164
165def ReadParams(emitter, registers, input_address, elements, min_reg):
166  if elements == 1 or elements == 2:
167    register = registers.DoubleRegister(min_reg * 2)
168    emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64))
169    return register
170  elif elements == 3:
171    register = registers.QuadRegister(min_reg)
172    emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64))
173    return register
174  else:
175    raise ConfigurationError('Unsupported elements no: %d' % elements)
176
177
178def Duplicate(emitter, registers, rows, cols, min_register, values):
179  """Populate a grid of registers duplicating provided values."""
180  duplicated = []
181  if cols == 1 or cols == 2:
182    for unused_i in range(0, rows):
183      duplicated.append(registers.DoubleRegister(min_register))
184  elif cols == 3:
185    for unused_i in range(0, rows):
186      duplicated.append(registers.QuadRegister(min_register))
187  else:
188    raise ConfigurationError('Unsupported duplicate amount: %d' % cols)
189
190  if rows == 1:
191    emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0))
192  elif rows == 2:
193    emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0))
194    emitter.EmitVDup('32', duplicated[1], emitter.Lane(values, 1))
195  elif rows == 3:
196    emitter.EmitVDup('32', duplicated[0], emitter.Lane(
197        registers.Low(values), 0))
198    emitter.EmitVDup('32', duplicated[1], emitter.Lane(
199        registers.Low(values), 1))
200    emitter.EmitVDup('32', duplicated[2], emitter.Lane(
201        registers.High(values), 0))
202
203  return duplicated
204
205
206def DuplicateGeneralRegister(emitter, registers, cols, general_register,
207                             min_register):
208  if cols == 1 or cols == 2:
209    duplicated = registers.DoubleRegister(min_register)
210  elif cols == 3:
211    duplicated = registers.QuadRegister(min_register)
212  else:
213    raise ConfigurationError('Unsupported duplicate amount: %d' % cols)
214
215  emitter.EmitVDup('32', duplicated, general_register)
216  return duplicated
217
218
219def ReduceAggregator(emitter, registers, aggregators, row, cols):
220  if cols == 1:
221    register = registers.Low(aggregators[row])
222    emitter.EmitVPadd('u32', register, register, register)
223    return register
224  elif cols == 2:
225    register = registers.Low(aggregators[row * 2])
226    emitter.EmitVPadd('u32', register, register,
227                      registers.Low(aggregators[row * 2 + 1]))
228    return register
229  elif cols == 3:
230    register = aggregators[row * 3]
231    emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register),
232                      registers.Low(aggregators[row * 3 + 1]))
233    emitter.EmitVPadd('u32', registers.High(register),
234                      registers.Low(aggregators[row * 3 + 2]),
235                      registers.Low(aggregators[row * 3 + 2]))
236    return register
237  else:
238    raise ConfigurationError('Unsupported columns no: %d' % cols)
239
240
241def StoreAggregator(emitter, registers, aggregator, cols, result_address,
242                    result_stride):
243  if cols == 1:
244    emitter.EmitVStoreOffset('1.32', emitter.Lane(aggregator, 0),
245                             emitter.Dereference(result_address, None),
246                             result_stride)
247  elif cols == 2:
248    emitter.EmitVStoreOffset('1.32', aggregator,
249                             emitter.Dereference(result_address, None),
250                             result_stride)
251  elif cols == 3:
252    emitter.EmitVStore('1.32', registers.Low(aggregator),
253                       emitter.DereferenceIncrement(result_address, None))
254    emitter.EmitVStoreOffset('1.32', emitter.Lane(
255        registers.High(aggregator),
256        0), emitter.Dereference(result_address, None), result_stride)
257    emitter.EmitNewline()
258  else:
259    raise ConfigurationError('Unsupported columns no: %d' % cols)
260
261
262def GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type,
263                                  lhs_add, rhs_add, left_lanes, right_lanes,
264                                  results, results_stride):
265  """Emit code that reduces 4 lane aggregators to 1 value, and stores them."""
266  rows = len(left_lanes.lanes)
267  cols = len(right_lanes.lanes)
268
269  if lhs_add:
270    left_offset = ReadParams(emitter, registers, left_lanes.input_address, rows,
271                             4)
272    left_offsets = Duplicate(emitter, registers, rows, cols, 4, left_offset)
273  else:
274    left_offsets = None
275
276  if rhs_add:
277    right_offset = ReadParams(emitter, registers, right_lanes.input_address,
278                              cols, 4)
279  else:
280    right_offset = None
281
282  if result_type is 'float':
283    result_scale = DuplicateGeneralRegister(
284        emitter, registers, cols, registers.MapParameter('result_scale'), 4)
285  else:
286    result_scale = None
287
288  if cols == 3:
289    emitter.EmitNewline()
290    emitter.EmitComment('Change stride because storing in two ops.')
291    emitter.EmitSub(results_stride, results_stride,
292                    emitter.ImmediateConstant(8))
293
294  emitter.EmitNewline()
295  emitter.EmitComment('Horizontal reduce aggregators.')
296  for aggregator in aggregators:
297    emitter.EmitVPadd('u32', registers.Low(aggregator),
298                      registers.Low(aggregator), registers.High(aggregator))
299
300  emitter.EmitNewline()
301  emitter.EmitComment('Reduce rows.')
302  row_temps = []
303  for i in range(0, rows):
304    row_temps.append(ReduceAggregator(emitter, registers, aggregators, i, cols))
305
306  if lhs_add:
307    emitter.EmitNewline()
308    emitter.EmitComment('Add lhs offsets to aggregated rows.')
309    for (row_temp, left_offset) in zip(row_temps, left_offsets):
310      emitter.EmitVAdd('s32', row_temp, row_temp, left_offset)
311
312  if rhs_add:
313    emitter.EmitNewline()
314    emitter.EmitComment('Add rhs offset to aggregated rows.')
315    for row_temp in row_temps:
316      emitter.EmitVAdd('s32', row_temp, row_temp, right_offset)
317
318  if result_type is 'float':
319    emitter.EmitNewline()
320    emitter.EmitComment('Convert to float. Multiply by result scale.')
321    for row_temp in row_temps:
322      emitter.EmitVCvt('f32', 's32', row_temp, row_temp)
323    for row_temp in row_temps:
324      emitter.EmitVMul('f32', row_temp, row_temp, result_scale)
325
326  emitter.EmitNewline()
327  emitter.EmitComment('Store reduced rows.')
328  for row_temp in row_temps:
329    StoreAggregator(emitter, registers, row_temp, cols, results, results_stride)
330
331
332def BuildName(result_type, lhs_add, rhs_add, left, right):
333  name = 'mul_%dx8_%dx8_%s' % (left, right, result_type)
334  if lhs_add:
335    name += '_lhsadd'
336  if rhs_add:
337    name += '_rhsadd'
338  return name
339
340
341def CppResultType(result_type):
342  if result_type is 'int32':
343    return 'std::int32_t*'
344  elif result_type is 'float':
345    return 'float*'
346  else:
347    raise ConfigurationError('Unsupported result type: %s' % result_type)
348
349
350def GetParameters(result_type):
351  params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs'],
352            ['std::int32_t', 'count'], [CppResultType(result_type), 'result'],
353            ['std::int32_t', 'result_stride']]
354  if result_type is 'float':
355    params.append(['float', 'result_scale'])
356  return params
357
358
359def GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes_count,
360                      right_lanes_count):
361  """Emit the multiply code for given rows and cols counts."""
362  if left_lanes_count < 1 or left_lanes_count > 3:
363    raise ConfigurationError('Left_lanes should be: 1, 2 or 3.')
364  if right_lanes_count < 1 or right_lanes_count > 3:
365    raise ConfigurationError('Right_lanes should be: 1, 2 or 3.')
366
367  emitter.EmitFunctionBeginA(
368      BuildName(result_type, lhs_add, rhs_add, left_lanes_count,
369                right_lanes_count), GetParameters(result_type), 'inline void')
370
371  emitter.EmitAssert('count % 8 == 0')
372  emitter.EmitAssert('count >= 8')
373  emitter.EmitAsmBegin()
374
375  registers = neon_emitter.NeonRegisters()
376
377  count = registers.MapParameter('count')
378
379  size = left_lanes_count * right_lanes_count
380
381  if size < 9:
382    aggregators = GenerateAndClearAggregators(emitter, registers, size)
383
384    left_lanes = GenerateMulLanes(registers, left_lanes_count,
385                                  registers.MapParameter('lhs'))
386    right_lanes = GenerateMulLanes(registers, right_lanes_count,
387                                   registers.MapParameter('rhs'))
388
389    emitter.EmitPld(left_lanes.input_address)
390    emitter.EmitPld(right_lanes.input_address)
391
392    GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes,
393                                     right_lanes, aggregators, count)
394
395  else:  # left == 3 and right == 3
396    aggregators = GenerateAndClearAggregators(emitter, registers, size)
397    backup_register = registers.QuadRegister()
398    left_lanes = Generate3MulLanes(backup_register, registers,
399                                   registers.MapParameter('lhs'))
400    right_lanes = GenerateMulLanes(registers, right_lanes_count,
401                                   registers.MapParameter('rhs'))
402
403    emitter.EmitPld(left_lanes.input_address)
404    emitter.EmitPld(right_lanes.input_address)
405
406    Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes,
407                                     right_lanes, aggregators, count,
408                                     backup_register)
409
410  left_lanes.FreeRegisters(registers)
411  right_lanes.FreeRegisters(registers)
412
413  GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type,
414                                lhs_add, rhs_add, left_lanes, right_lanes,
415                                registers.MapParameter('result'),
416                                registers.MapParameter('result_stride'))
417
418  emitter.EmitAsmEnd(registers.MappedParameters(), [],
419                     registers.Clobbers() + ['cc', 'memory'])
420  emitter.EmitFunctionEnd()
421
422
423def GenerateFunctions(emitter, result_type, lhs_add, rhs_add):
424  for left_lanes in range(1, 4):
425    for right_lanes in range(1, 4):
426      GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes,
427                        right_lanes)
428      emitter.EmitNewline()
429