1"""Zip primitive used by the GEMM function.
2
3Takes 1 to 3 rows of data and interleaves them in 8 byte chunks. Pads to
4multiply of 8 length with zeros. Calculates row sums and appends those at the
5end.
6"""
7
8
9import neon_emitter
10
11
12class Error(Exception):
13  """Module level error."""
14
15
16class ConfigurationError(Error):
17  """Unsupported configuration."""
18
19
20class ZipLane(object):
21
22  def __init__(self, input_address, load, aggregator):
23    self.input_address = input_address
24    self.load = load
25    self.aggregator = aggregator
26
27
28def GenerateZipLanes(emitter, registers, zip_lanes, input_address, stride):
29  """Prepares read lanes for the zip operation.
30
31  Args:
32    emitter: ARM/NEON emitter.
33    registers: ARM/NEON registers state.
34    zip_lanes: number of lanes to prepare.
35    input_address: register that contains the input address for the first lane.
36    stride: memory stride for lane inputs.
37
38  Returns:
39    Array of ZipLane objects.
40  """
41  lanes = []
42  last_address_register = input_address
43  for i in range(0, zip_lanes):
44    if not i:
45      lanes.append(ZipLane(input_address,
46                           registers.DoubleRegister(),
47                           registers.QuadRegister(2)))
48    else:
49      address_register = registers.GeneralRegister()
50      lanes.append(ZipLane(address_register,
51                           registers.DoubleRegister(),
52                           registers.QuadRegister(2)))
53      emitter.EmitAdd(address_register, last_address_register, stride)
54      last_address_register = address_register
55  return lanes
56
57
58def BuildName(zip_lanes, leftovers, aligned):
59  name = 'zip_%dx8' % zip_lanes
60  if leftovers:
61    name += '_%d' % leftovers
62  if aligned:
63    name += '_aligned'
64  return name
65
66
67def GenerateClearAggregators(emitter, lanes):
68  for lane in lanes:
69    emitter.EmitVMov('i16', lane.aggregator, emitter.ImmediateConstant(0))
70
71
72def GenerateLoadAggregateStore(emitter, lanes, output_address, alignment):
73  """Emit inner loop code for reading N lanes and interweaving them."""
74  emitter.EmitNewline()
75  emitter.EmitComment('Load Aggregate Store.')
76
77  for lane in lanes:
78    emitter.EmitVLoad(
79        '1.8', lane.load,
80        emitter.DereferenceIncrement(lane.input_address, alignment))
81
82  store_registers = []
83  for lane in lanes:
84    emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
85    store_registers.append(lane.load)
86
87  emitter.EmitVStoreA('1.8', store_registers,
88                      emitter.DereferenceIncrement(output_address, 64))
89
90
91def GenerateLeftoverLoadAggregateStore(
92    emitter, leftovers, lanes, output_address):
93  """Handle leftovers when count is not a multiply of 8."""
94  emitter.EmitNewline()
95  emitter.EmitComment('Leftover Load Aggregate Store.')
96
97  # Clear load registers.
98  for lane in lanes:
99    emitter.EmitVMov('i8', lane.load, emitter.ImmediateConstant(0))
100
101  if leftovers == 1:
102    # Load 8 bits.
103    for lane in lanes:
104      emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 0),
105                        emitter.Dereference(lane.input_address, None))
106  elif leftovers == 2:
107    # Load 16 bits.
108    for lane in lanes:
109      emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0),
110                        emitter.Dereference(lane.input_address, None))
111  elif leftovers == 3:
112    # Load 16 bits.
113    for lane in lanes:
114      emitter.EmitVLoad(
115          '1.16', emitter.Lane(lane.load, 0),
116          emitter.DereferenceIncrement(lane.input_address, None))
117    # Load 8 bits.
118    for lane in lanes:
119      emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 2),
120                        emitter.Dereference(lane.input_address, None))
121  elif leftovers == 4:
122    # Load 32 bits.
123    for lane in lanes:
124      emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0),
125                        emitter.Dereference(lane.input_address, None))
126  elif leftovers == 5:
127    # Load 32 bits..
128    for lane in lanes:
129      emitter.EmitVLoad(
130          '1.32', emitter.Lane(lane.load, 0),
131          emitter.DereferenceIncrement(lane.input_address, None))
132    # Load 8 bits.
133    for lane in lanes:
134      emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 4),
135                        emitter.Dereference(lane.input_address, None))
136  elif leftovers == 6:
137    # Load 32 bits..
138    for lane in lanes:
139      emitter.EmitVLoad(
140          '1.32', emitter.Lane(lane.load, 0),
141          emitter.DereferenceIncrement(lane.input_address, None))
142    # Load 16 bits.
143    for lane in lanes:
144      emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2),
145                        emitter.Dereference(lane.input_address, None))
146  elif leftovers == 7:
147    # Load 32 bits..
148    for lane in lanes:
149      emitter.EmitVLoad(
150          '1.32', emitter.Lane(lane.load, 0),
151          emitter.DereferenceIncrement(lane.input_address, None))
152    # Load 16 bits.
153    for lane in lanes:
154      emitter.EmitVLoad(
155          '1.16', emitter.Lane(lane.load, 2),
156          emitter.DereferenceIncrement(lane.input_address, None))
157    # Load 8 bits.
158    for lane in lanes:
159      emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 6),
160                        emitter.Dereference(lane.input_address, None))
161  else:
162    raise ConfigurationError('Unsupported leftover num: %d' % leftovers)
163
164  # Aggregate.
165  store_registers = []
166  for lane in lanes:
167    emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load)
168    store_registers.append(lane.load)
169
170  # Store.
171  emitter.EmitVStoreA('1.8', store_registers,
172                      emitter.DereferenceIncrement(output_address, 64))
173
174
175def GenerateAggregatorReduction(emitter,
176                                registers,
177                                lanes,
178                                output_address,
179                                multiplicative_offset,
180                                additive_offset):
181  """Reduce 4 lane sum aggregators to 1 value and store the sums."""
182  emitter.EmitNewline()
183  emitter.EmitComment('Aggregator Reduction.')
184
185  multiplier = registers.DoubleRegister()
186  emitter.EmitVMov('32', emitter.Lane(multiplier, 0), multiplicative_offset)
187  offset = registers.QuadRegister()
188  emitter.EmitVDup('32', offset, additive_offset)
189
190  lane_temps = []
191  for lane in lanes:
192    emitter.EmitVPaddl('u16', lane.aggregator, lane.aggregator)
193
194  for lane in lanes:
195    lane_temp = registers.DoubleRegister()
196    lane_temps.append(lane_temp)
197    emitter.EmitVPadd('u32',
198                      lane_temp,
199                      registers.Low(lane.aggregator),
200                      registers.High(lane.aggregator))
201
202  temp = registers.QuadRegister()
203  low = registers.Low(temp)
204  high = registers.High(temp)
205
206  if len(lanes) == 1:
207    emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[0])
208  elif len(lanes) == 2:
209    emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
210  elif len(lanes) == 3:
211    emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
212    emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[2])
213  elif len(lanes) == 4:
214    emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1])
215    emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[3])
216  else:
217    raise ConfigurationError(
218        'Unexpected number of aggregators to reduce: %d' % len(lanes))
219
220  emitter.EmitVMul('i32', temp, temp, emitter.Lane(multiplier, 0))
221  emitter.EmitVAdd('i32', temp, temp, offset)
222
223  if len(lanes) == 1:
224    emitter.EmitVStore(
225        '1.32', emitter.Lane(low, 0), emitter.Dereference(output_address, None))
226  elif len(lanes) == 2:
227    emitter.EmitVStore('1.32', low, emitter.Dereference(output_address, 64))
228  elif len(lanes) == 3:
229    emitter.EmitVStore(
230        '1.32', low, emitter.DereferenceIncrement(output_address, 64))
231    emitter.EmitVStore(
232        '1.32', emitter.Lane(high, 0),
233        emitter.Dereference(output_address, None))
234  elif len(lanes) == 4:
235    emitter.EmitVStore(
236        '1.32', low, emitter.DereferenceIncrement(output_address, 64))
237    emitter.EmitVStore('1.32', high, emitter.Dereference(output_address, 64))
238
239
240def GenerateZipNx8(emitter, zip_lanes, leftovers, aligned):
241  """Emit the zip function for a given number of rows and row size leftovers."""
242  if leftovers < 0 or leftovers > 7:
243    raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.')
244  if zip_lanes < 1 or zip_lanes > 3:
245    raise ConfigurationError('Zip_lanes should should be 1, 2 or 3.')
246
247  name = BuildName(zip_lanes, leftovers, aligned)
248
249  emitter.EmitFunctionBeginA(name,
250                             [['const std::uint8_t*', 'source'],
251                              ['std::int32_t', 'count'],
252                              ['std::int32_t', 'stride'],
253                              ['std::uint8_t*', 'destination'],
254                              ['std::int32_t', 'multiplicative_offset'],
255                              ['std::int32_t', 'additive_offset']],
256                             'void')
257  emitter.EmitAssert('count %% 8 == %d' % leftovers)
258  emitter.EmitAssert('count <= 2048')
259  emitter.EmitAssert('count >= 8')
260  emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0')
261  if aligned:
262    emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0')
263    if zip_lanes > 1:
264      emitter.EmitAssert('stride % 8 == 0')
265  emitter.EmitAsmBegin()
266
267  registers = neon_emitter.NeonRegisters()
268
269  count = registers.MapParameter('count')
270  output_address = registers.MapParameter('destination')
271
272  lanes = GenerateZipLanes(emitter,
273                           registers,
274                           zip_lanes,
275                           registers.MapParameter('source'),
276                           registers.MapParameter('stride'))
277
278  if leftovers:
279    emitter.EmitSub(count, count, emitter.ImmediateConstant(leftovers))
280
281  GenerateClearAggregators(emitter, lanes)
282
283  emitter.EmitNewline()
284  emitter.EmitNumericalLabel(1)
285  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
286
287  GenerateLoadAggregateStore(
288      emitter, lanes, output_address, 64 if aligned else None)
289
290  emitter.EmitNewline()
291  emitter.EmitBneBack(1)
292
293  if leftovers:
294    GenerateLeftoverLoadAggregateStore(
295        emitter, leftovers, lanes, output_address)
296
297  GenerateAggregatorReduction(emitter,
298                              registers,
299                              lanes,
300                              output_address,
301                              registers.MapParameter('multiplicative_offset'),
302                              registers.MapParameter('additive_offset'))
303
304  emitter.EmitAsmEnd(registers.MappedParameters(),
305                     [],
306                     registers.Clobbers() + ['cc', 'memory'])
307  emitter.EmitFunctionEnd()
308
309
310def GenerateFunctions(emitter):
311  for aligned in [True, False]:
312    for lanes in range(1, 4):
313      for leftovers in range(0, 8):
314        GenerateZipNx8(emitter, lanes, leftovers, aligned)
315        emitter.EmitNewline()
316