1"""Generates the specialized gemv functions."""
2
3import mul_1x8_Mx8_neon
4import mul_Nx8_Mx8_neon
5import qnt_Nx8_neon
6import zip_Nx8_neon
7
8_QUANTIZED_8BIT = 'quantized_8bit'
9_FULL_32BIT = 'full_32bit'
10_FULL_FLOAT = 'full_float'
11
12
13class Error(Exception):
14  """Module level error."""
15
16
17class ConfigurationError(Error):
18  """Runtime configuration error."""
19
20
21def GenerateCommonTempsCountersAndConsts(emitter):
22  """Generates common gemv boilerplate variables."""
23  emitter.EmitDeclare('const std::int32_t', 'col_chunks', 'n / 8')
24  emitter.EmitDeclare('const std::int32_t', 'padded_k', '((k + 7) / 8) * 8')
25  emitter.EmitDeclare('const std::int32_t', 'chunk_size', 'k * 4')
26  emitter.EmitDeclare('const std::int32_t', 'zipped_chunk_size',
27                      '(padded_k + 16) * 4')
28  emitter.EmitDeclare('const std::uint8_t*', 'rhs_chunk', 'rhs')
29  emitter.EmitDeclare('std::uint8_t*', 'zipped_lhs', 'scratch')
30  emitter.EmitDeclare('std::int32_t*', 'zipped_lhs_offsets',
31                      'reinterpret_cast<std::int32_t*>(zipped_lhs + padded_k)')
32  emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_1',
33                      'scratch + padded_k + 16')
34  emitter.EmitDeclare('std::uint8_t*', 'zipped_rhs_2',
35                      'zipped_rhs_1 + zipped_chunk_size')
36  emitter.EmitNewline()
37
38
39def GenerateQuantized8BitTempsCountersAndConsts(emitter):
40  """Generates all the boilerplate variables for the q8 gemm function."""
41  GenerateCommonTempsCountersAndConsts(emitter)
42  emitter.EmitDeclare('const std::int32_t', 'const_offset',
43                      'lhs_offset * rhs_offset * k + result_offset')
44  emitter.EmitDeclare('const std::int32_t', 'rounding_offset',
45                      '(1 << (shift - 1))')
46  emitter.EmitDeclare('std::int32_t*', 'temp_result',
47                      'reinterpret_cast<std::int32_t*>('
48                      'zipped_rhs_2 + zipped_chunk_size)')
49  emitter.EmitDeclare('std::int32_t*', 'mul_result_chunk', 'temp_result')
50  emitter.EmitNewline()
51
52
53def GenerateFullTempsCountersAndConsts(emitter, result_type):
54  """Generates all the boilerplate variables for the int32 and float gemms."""
55  GenerateCommonTempsCountersAndConsts(emitter)
56  emitter.EmitDeclare('const std::int32_t', 'const_offset',
57                      'lhs_offset * rhs_offset * k')
58  emitter.EmitDeclare(result_type, 'mul_result_chunk', 'result')
59  emitter.EmitNewline()
60
61
62def GenerateZipVector(emitter, aligned, leftovers):
63  emitter.EmitCall(
64      zip_Nx8_neon.BuildName(1, leftovers, aligned),
65      ['lhs', 'k', 'k', 'zipped_lhs', 'rhs_offset', 0])
66
67
68def GetMul2Params(result_type):
69  params = ['zipped_lhs', 'zipped_rhs_1', 'zipped_rhs_2', 'padded_k',
70            'mul_result_chunk']
71  if result_type is 'float':
72    params.append('result_scale')
73  return params
74
75
76def GetMulParams(result_type):
77  params = ['zipped_lhs', 'zipped_rhs_1', 'padded_k', 'mul_result_chunk', 0]
78  if result_type is 'float':
79    params.append('result_scale')
80  return params
81
82
83def GenerateMulCols(emitter, result_type, lhs_add, rhs_add, aligned, cols,
84                    leftovers):
85  """Emits code responsible for multiplication of one horizontal lhs strip."""
86  emitter.EmitOpenBracket('for (int i = 0; i < col_chunks; ++i)')
87  emitter.EmitCall(
88      zip_Nx8_neon.BuildName(4, leftovers, aligned),
89      ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset'])
90  emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
91
92  emitter.EmitCall(
93      zip_Nx8_neon.BuildName(4, leftovers, aligned),
94      ['rhs_chunk', 'k', 'k', 'zipped_rhs_2', 'lhs_offset', 'const_offset'])
95  emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
96
97  emitter.EmitCall(
98      mul_1x8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, 8),
99      GetMul2Params(result_type))
100
101  emitter.EmitAssignIncrement('mul_result_chunk', 8)
102  emitter.EmitCloseBracket()
103
104  if cols > 4:
105    emitter.EmitCall(
106        zip_Nx8_neon.BuildName(4, leftovers, aligned),
107        ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset'])
108    emitter.EmitAssignIncrement('rhs_chunk', 'chunk_size')
109
110    emitter.EmitCall(
111        zip_Nx8_neon.BuildName(cols - 4, leftovers, aligned),
112        ['rhs_chunk', 'k', 'k', 'zipped_rhs_2', 'lhs_offset', 'const_offset'])
113
114    emitter.EmitCall(
115        mul_1x8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, cols),
116        GetMul2Params(result_type))
117  elif cols > 0:
118    emitter.EmitCall(
119        zip_Nx8_neon.BuildName(cols, leftovers, aligned),
120        ['rhs_chunk', 'k', 'k', 'zipped_rhs_1', 'lhs_offset', 'const_offset'])
121
122    emitter.EmitCall(
123        mul_Nx8_Mx8_neon.BuildName(result_type, lhs_add, rhs_add, 1, cols),
124        GetMulParams(result_type))
125
126
127def GenerateQuantized8BitMul(emitter, aligned, cols, leftovers):
128  """Emits code for all lhs strips & leftover rows. Quantize after mul code."""
129  GenerateMulCols(emitter, 'int32', False, True, aligned, cols, leftovers)
130  emitter.EmitCall(
131      qnt_Nx8_neon.BuildName(1, cols, aligned),
132      ['temp_result', 'n', 0, 'zipped_lhs_offsets', 'result', 0,
133       'multiplicative_offset', 'rounding_offset', '-shift'])
134
135
136def GenerateFullMul(emitter, result_type, aligned, cols, leftovers):
137  GenerateMulCols(emitter, result_type, True, True, aligned, cols, leftovers)
138
139
140def BuildName(output_type, aligned, cols, leftover):
141  name = BuildMainGemvName(output_type) + '_%d_%d' % (cols, leftover)
142  if aligned:
143    name += '_aligned'
144  return name
145
146
147def GetCommonGemvParameters():
148  return [['std::uint8_t*', 'scratch'], ['const std::uint8_t*', 'lhs'],
149          ['const std::uint8_t*', 'rhs'], ['std::int32_t', 'n'],
150          ['std::int32_t', 'k'], ['std::int32_t', 'lhs_offset'],
151          ['std::int32_t', 'rhs_offset']]
152
153
154def GetGemvParameters(output_type):
155  """Prepares a (type, parameter) array for the gemm functions."""
156  params = GetCommonGemvParameters()
157  if output_type is _QUANTIZED_8BIT:
158    params += [['std::int32_t', 'result_offset'],
159               ['std::int32_t', 'multiplicative_offset'],
160               ['std::int32_t', 'shift'], ['std::uint8_t*', 'result']]
161  elif output_type is _FULL_32BIT:
162    params += [['std::int32_t*', 'result']]
163  elif output_type is _FULL_FLOAT:
164    params += [['float', 'result_scale'], ['float*', 'result']]
165  else:
166    raise ConfigurationError('Unsupported output type: %s' % output_type)
167  return params
168
169
170def GenerateGemv(emitter, output_type, aligned, cols, leftovers):
171  """Build one gemm function for given col, and depth leftovers."""
172  emitter.EmitFunctionBeginA(
173      BuildName(output_type, aligned, cols, leftovers),
174      GetGemvParameters(output_type), 'void')
175
176  emitter.EmitAssert('n %% 8 == %d' % cols)
177  emitter.EmitAssert('k %% 8 == %d' % leftovers)
178
179  if output_type is _QUANTIZED_8BIT:
180    GenerateQuantized8BitTempsCountersAndConsts(emitter)
181    GenerateZipVector(emitter, aligned, leftovers)
182    GenerateQuantized8BitMul(emitter, aligned, cols, leftovers)
183  elif output_type is _FULL_32BIT:
184    GenerateFullTempsCountersAndConsts(emitter, 'std::int32_t*')
185    GenerateZipVector(emitter, aligned, leftovers)
186    GenerateFullMul(emitter, 'int32', aligned, cols, leftovers)
187  elif output_type is _FULL_FLOAT:
188    GenerateFullTempsCountersAndConsts(emitter, 'float*')
189    GenerateZipVector(emitter, aligned, leftovers)
190    GenerateFullMul(emitter, 'float', aligned, cols, leftovers)
191  else:
192    raise ConfigurationError('Unknown output type: %s' % output_type)
193
194  emitter.EmitFunctionEnd()
195
196
197def GenerateGemvCall(emitter, output_type, aligned, m_mod, leftovers):
198  emitter.EmitCall(
199      emitter.Scope('internal',
200                    BuildName(output_type, aligned, m_mod, leftovers)),
201      [p for (unused_t, p) in GetGemvParameters(output_type)])
202
203
204def GenerateGemvSwitch2(emitter, output_type, aligned, n_mod):
205  """Second level of main switch, choose optimized version on depth leftover."""
206  emitter.EmitSwitch('k % 8')
207
208  for leftovers in range(0, 8):
209    emitter.EmitCase(leftovers)
210    emitter.PushIndent()
211    GenerateGemvCall(emitter, output_type, aligned, n_mod, leftovers)
212    emitter.EmitBreak()
213    emitter.PopIndent()
214
215  emitter.EmitSwitchEnd()
216
217
218def GenerateGemvSwitch1(emitter, output_type, aligned):
219  """First level of main switch, choose optimized version on cols leftover."""
220  emitter.EmitSwitch('n % 8')
221
222  for n_mod in range(0, 8):
223    emitter.EmitCase(n_mod)
224    emitter.PushIndent()
225    GenerateGemvSwitch2(emitter, output_type, aligned, n_mod)
226    emitter.EmitBreak()
227    emitter.PopIndent()
228
229  emitter.EmitSwitchEnd()
230
231
232def BuildMainGemvName(output_type):
233  if output_type is _QUANTIZED_8BIT:
234    return 'gemv_q8'
235  elif output_type is _FULL_32BIT:
236    return 'gemv_i32'
237  elif output_type is _FULL_FLOAT:
238    return 'gemv_f'
239  else:
240    raise ConfigurationError('Unsupported output type: %s' % output_type)
241
242
243def GenerateMainGemvFunction(emitter, output_type):
244  """Emit high level gemv function that switches between optimized versions."""
245  emitter.EmitFunctionBeginA(
246      BuildMainGemvName(output_type), GetGemvParameters(output_type), 'void')
247
248  emitter.EmitDeclare('const bool', 'lhs_aligned',
249                      '((reinterpret_cast<std::uintptr_t>(lhs) % 8) == 0)')
250  emitter.EmitDeclare('const bool', 'rhs_aligned',
251                      '((reinterpret_cast<std::uintptr_t>(rhs) % 8) == 0)')
252  emitter.EmitDeclare('const bool', 'k_aligned', '((k % 8) == 0)')
253
254  if output_type is _QUANTIZED_8BIT:
255    emitter.EmitDeclare('const bool', 'result_aligned',
256                        '((reinterpret_cast<std::uintptr_t>(result) % 8) == 0)')
257    emitter.EmitDeclare('const bool', 'aligned',
258                        'lhs_aligned && rhs_aligned && result_aligned '
259                        '&& k_aligned')
260  else:
261    emitter.EmitDeclare('const bool', 'aligned',
262                        'lhs_aligned && rhs_aligned && k_aligned')
263
264  emitter.EmitIf('aligned')
265  GenerateGemvSwitch1(emitter, output_type, True)
266  emitter.EmitElse()
267  GenerateGemvSwitch1(emitter, output_type, False)
268  emitter.EmitEndif()
269  emitter.EmitFunctionEnd()
270
271
272def GenerateInternalFunctions(emitter):
273  """Generate all the functions hidden in the internal namespace."""
274  for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
275    for aligned in [True, False]:
276      for cols in range(0, 8):
277        for leftover in range(0, 8):
278          GenerateGemv(emitter, output_type, aligned, cols, leftover)
279          emitter.EmitNewline()
280
281
282def GeneratePublicFunctions(emitter):
283  for output_type in [_QUANTIZED_8BIT, _FULL_32BIT, _FULL_FLOAT]:
284    GenerateMainGemvFunction(emitter, output_type)
285    emitter.EmitNewline()
286