1#!/usr/bin/python3
2
3# Copyright 2017, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""NN model compiler
18
19Compile models and examples into NDK-based CTS unit tests
20"""
21
22from __future__ import absolute_import
23from __future__ import division
24from __future__ import print_function
25import argparse
26from functools import reduce
27import math
28import os
29import struct
30import sys
31import contextlib
32import pprint
33
34@contextlib.contextmanager
35def smart_open(filename=None):
36  if filename and filename != '-':
37    fh = open(filename, 'w')
38  else:
39    fh = sys.stdout
40
41  try:
42    yield fh
43  finally:
44    if fh is not sys.stdout:
45      fh.close()
46
47class Phase(object):
48  def __init__(self):
49    self.__objects = []
50    self.__contents = []
51    self.__dict_of_objects = {}
52
53  def append(self, obj, x):
54    self.__objects.append(obj)
55    self.__contents.append(x)
56    self.__dict_of_objects[obj.ID()] = obj
57
58  def dump(self, filename):
59    for x in self.__contents:
60      print ("  " + x + ";", file=filename)
61
62  def objects(self):
63    return self.__objects
64
65  def search(self, i):
66    return self.__dict_of_objects[i]
67
68# Tracking objects inside a model with a not necessarily unique name and
69# an unique number
70class NamedObject(object):
71  __serial = 0
72
73  def __init__(self, name = "NamedObject"):
74    self.__name = name
75    self.__id = NamedObject.serial()
76    NamedObject.__serial += 1
77
78  def ID(self):
79    return self.__id
80
81  def serial():
82    return NamedObject.__serial
83
84  def get_name(self):
85    return self.__name
86
87  def __str__(self):
88    return self.get_name()
89
90  def __hash__(self):
91    return self.__id
92
93# Object that can be traversed during topological sorting phase
94class Traversable(object):
95  def traversable(self):
96    return True
97
98class Nontraversable(object):
99  def traversable(self):
100    return False
101
102# Object that can take input from other objects
103class Uses(object):
104  all_uses = set()
105  def __init__(self, ins = []):
106    self.ins = ins.copy()
107    Uses.all_uses.add(self)
108    for i in ins:
109      i.outs.append(self)
110
111# Object that other objects takes its definition from
112class Definitions(object):
113  def __init__(self, outs = []):
114    self.outs = outs.copy()
115    for o in outs:
116      o.ins.append(self)
117
118class TypeLookup:
119  __type_lookup = {
120      "INT32": "int32_t",
121      "UINT32": "uint32_t",
122      "FLOAT32": "float",
123      "TENSOR_INT32": "int32_t",
124      "TENSOR_FLOAT32": "float",
125      "TENSOR_QUANT8_ASYMM": "uint8_t",
126#     "OEM_SCALAR": this is service-defined.
127      "TENSOR_OEM_BYTE": "uint8_t",
128    }
129
130  def get_cpptype(nnapi_type):
131    return TypeLookup.__type_lookup[nnapi_type]
132
133  def is_float(nnapi_type):
134    return TypeLookup.get_cpptype(nnapi_type) == "float"
135
136  def get_size(nnapi_type):
137    return 1 if TypeLookup.get_cpptype(nnapi_type) == "uint8_t" else 4
138
139
140class Type(object):
141  __types =  {}
142  __type_serial = 0 # types have their own numbering
143  def __init__(self, vt = None, shape = None):
144    self.__vt = vt
145    self.__shape = shape
146    if vt is None or shape is None:
147      self.__name = None
148      return
149
150    key = str(self)
151    if key not in Type.__types:
152      self.__id = Type.__type_serial
153      Type.__types[str(self)] = self
154      Type.__type_serial += 1
155    else:
156      self.__id = Type.__types[key].__id
157    self.__name = "type" + str(self.__id)
158
159  def get_shape(self):
160    return self.__shape
161
162  def get_element_type(self):
163    return self.__vt
164
165  def get_name(self):
166    return self.__name
167
168  def __str__(self):
169    return (", ".join([self.__vt, self.__shape]))
170
171  def __hash__(self):
172    return self.__id
173
174  def dump(filename):
175    for key, value in sorted(Type.__types.items()):
176      print ("  OperandType " + str(value.__name) + "(Type::" + str(key) + ");", file=filename)
177
178  def get_raw_shape(self):
179    return self.__shape
180
181  def get_parsed_shape(self):
182    # Parse shape
183    if (self.__shape != "" and self.__shape != "{}"):
184      left, sep, right = self.__shape.partition('{')
185      real_shape, sep, right = right.partition('}')
186      shape = [int(x) for x in real_shape.split(",")]
187      # left now looks like "0.0f, 127.5f, "
188      scale, sep, zero_point = right.rpartition(',')
189      if scale == "":
190        if zero_point == "":
191          return real_shape, "0", "0"
192        return real_shape, zero_point, "0"
193      left, sep, scale = scale.partition(',')
194      return real_shape, scale.replace("f", ""), zero_point
195    else:
196      return "", "0", "0"
197
198  def get_nr_elements(self):
199    # Parse shape
200    nr_elements = 1
201    real_shape, scale, zero_point = self.get_parsed_shape()
202
203    if (real_shape != "" and real_shape != "{}"):
204      shape = [int(x) for x in real_shape.split(",")]
205      nr_elements = reduce((lambda x, y: x*y), shape)
206    return nr_elements
207
208  def get_size(self):
209    element_size = TypeLookup.get_size(self.__vt)
210    return self.get_nr_elements() * element_size
211
212# A value is a typed, named object
213class Value(NamedObject):
214  def __init__(self, name, vt):
215    NamedObject.__init__(self, name)
216    self.type = vt
217
218# An operand that can be fed into operations. Also, an operand is always
219# declared before operations.
220class Operand(Value):
221  # All operand declarations in string
222  operands = Phase()
223
224  def __init__(self, name, vt):
225    Value.__init__(self, name, vt)
226    def_string = (
227        "auto " + self.get_name() + " = "\
228            "model->addOperand(&" + vt.get_name() + ")")
229    Operand.operands.append(self, def_string)
230
231  # By default, produce nothing (when asked by the Topological Sort phase)
232  def Definition(self):
233    pass
234
235  def Reference(self):
236    return NamedObject.__str__(self)
237
238  # Print a set of operands in curly braces
239  def print_operands(operands):
240    return [ x.Reference() for x in operands ]
241
242  # Defined with the model or not
243  def is_weight(self):
244    return False
245
246# A user-declared input operand
247class Input(Operand, Definitions, Traversable):
248  # for enumerating inputs
249  __next_number = 0
250  # Holds reference to all Inputs; used by Topoligcal sort as starting nodes.
251  __inputs = set()
252
253  def __init__(self, name, vt, shape, increase_next_number=True):
254    Operand.__init__(self, name, Type(vt, shape))
255    Definitions.__init__(self)
256    Input.__inputs.add(self)
257    self.number = Input.__next_number
258    if increase_next_number is True:
259      Input.__next_number += 1
260
261  def lifetime(self):
262    return "MODEL_INPUT"
263
264  def is_internal(self):
265    return False
266
267  def get_inputs(exclude_internal = None):
268    if exclude_internal is not None:
269      external = { x for x in Input.__inputs if not x.is_internal() }
270      return external
271    else:
272      return Input.__inputs
273
274# A user-declared output operand
275class Output(Operand, Uses, Nontraversable):
276  # for enumerating outputs
277  __next_number = 0
278  __outputs = []
279
280  def __init__(self, name, vt, shape):
281    Operand.__init__(self, name, Type(vt, shape))
282    Uses.__init__(self)
283    Output.__outputs.append(self)
284    self.number = Output.__next_number
285    Output.__next_number += 1
286
287  def lifetime(self):
288    return "MODEL_OUTPUT"
289
290  # return all unique outputs in the original order
291  def get_outputs():
292    saw = set()
293    unique = [x for x in Output.__outputs if x not in saw and (saw.add(x) or True)]
294    return unique
295
296# An output that we don't want to compare the results
297class IgnoredOutput(Output):
298  __ignored = set()
299  def __init__(self, name, vt, shape):
300    Output.__init__(self, name, vt, shape)
301    IgnoredOutput.__ignored.add(self)
302  def gen_ignored():
303    ignored_func = """
304bool is_ignored(int i) {
305  static std::set<int> ignore = {%s};
306  return ignore.find(i) != ignore.end();
307}""" % ", ".join([str(x.number) for x in IgnoredOutput.__ignored])
308    return ignored_func
309
310class ModelArgument:
311  __arguments = []
312
313  def __init__(self, arg_type, arg_name):
314    self.__arg_type = arg_type
315    self.__arg_name = arg_name
316    ModelArgument.__arguments.append(" ".join([arg_type, arg_name]))
317
318  def get_arg_type(self):
319    return self.__arg_type
320
321  def get_arg_name(self):
322    return self.__arg_name
323
324  def get_arguments():
325    return ModelArgument.__arguments
326
327  def lifetime(self):
328    return "CONSTANT_COPY"
329
330# Print in C float literal format
331def pretty_print_as_float(x):
332  s = str(float(x))
333  if s.find(".") >= 0 or s.find("e") >= 0:
334    return s + "f"
335  else:
336    return s + ".0f"
337
338class Parameter(Input):
339  # TODO seems wrong that's an Input.
340  def __init__(self, name, vt, shape, initializer):
341    Input.__init__(self, name, vt, shape, False)
342    self.initializer = initializer
343    self.cpptype = TypeLookup.get_cpptype(vt)
344  def is_internal(self):
345    return True
346  def Definition(self):
347    init_name = self.get_name() + "_init"
348    initializer = [str(x) for x in self.initializer]
349    if self.cpptype == "float":
350      initializer = [ pretty_print_as_float(x) for x in initializer]
351    init = self.cpptype + " " + init_name + "[]"
352    init = "static " + init + " = {" + ", ".join(initializer) + "};"
353    args = [ self.get_name(), init_name,
354            "sizeof(" + self.cpptype + ") * " + str(len(self.initializer)) ]
355    stmt = "\n  ".join([init,
356                      "model->setOperandValue(" + ", ".join(args)+");"])
357    return stmt
358  def is_weight(self):
359    return True
360  def lifetime(self):
361    if Configuration.useSHM():
362      return "CONSTANT_REFERENCE"
363    else:
364      return "CONSTANT_COPY"
365
366class Int32Scalar(Parameter):
367  def __init__(self, name, value):
368    Parameter.__init__(self, name, "INT32", "{}", [value])
369
370class Float32Scalar(Parameter):
371  def __init__(self, name, value):
372    Parameter.__init__(self, name, "FLOAT32", "{}", [value])
373
374# A compiler-generated intermediate result from an operation
375class IntermediateResult(Operand, Definitions, Uses, Traversable):
376  def __init__(self, src: Value):
377    tmp_name = "tmp" + str(NamedObject.serial())
378    Operand.__init__(self, tmp_name, src.type)
379    Definitions.__init__(self)
380    Uses.__init__(self, [src])
381
382  def lifetime(self):
383    return "TEMPORARY_VARIABLE"
384
385# An explicitly declared intermediate result
386class Internal(Operand, Definitions, Uses, Traversable):
387  def __init__(self, name, vt, shape):
388    Operand.__init__(self, name, Type(vt, shape))
389    Definitions.__init__(self)
390    Uses.__init__(self)
391
392  def lifetime(self):
393    return "TEMPORARY_VARIABLE"
394
395# An operation in a model
396class Operation(Definitions, Uses, Traversable):
397  def __init__(self, optype, ins, outs):
398    self.type = ins[0].type
399    Definitions.__init__(self, outs)
400    Uses.__init__(self, ins)
401    self.optype = optype
402
403  def __str__(self):
404    inputs = [ str(x) for x in self.ins ]
405    return "Operation:" + self.optype + " " + ", ".join(inputs)
406
407  def Reference(self):
408    return "operation" + str(self.ID());
409
410  def Definition(self):
411    inputs = Operand.print_operands(self.ins);
412    outputs = Operand.print_operands(self.outs);
413    return "model->addOperation(ANEURALNETWORKS_"+self.optype+", " + \
414        "{"+", ".join(inputs)+"}, {" + ", ".join(outputs) + "});"
415
416  # Get Python-ish dump for the op
417  def PyDefinition(self):
418    py_op_string = """Operation("{optype}", {inputs}).To({outputs})"""
419    inputs = [str(x) for x in Operand.print_operands(self.ins)]
420    inputs = ", ".join(inputs)
421    assert len(self.outs) <= 1
422    outputs = str(Operand.print_operands(self.outs)[0])
423    ops = {"optype": self.optype, "inputs": inputs, "outputs": outputs}
424    return py_op_string.format(**ops)
425
426# Main interface
427class Model(object):
428  __isRelaxed = False
429
430  def __init__(self):
431    self.__currentOp = None
432
433  # TODO turn this into generic binary operations
434  def Add(self, i1: Value, i2 = None) -> Operation:
435    ins = [i1]
436    if i2 is not None:
437      ins.append(i2)
438    if self.__currentOp is not None:
439      ir = IntermediateResult(self.__currentOp)
440      self.__currentOp = ir
441      ins.append(self.__currentOp)
442
443    op = Operation("ADD", ins, [])
444
445    self.__currentOp = op
446    return self
447
448  def Operation(self, op_name, *args):
449    ins = [i for i in args]
450    outs = []
451    op = Operation(op_name, ins, outs)
452    self.__currentOp = op
453    return self
454
455  def RawAdd(self, i1: Value, i2: Value, o = None) -> Operation:
456    ins = [i1, i2]
457    outs = []
458    if o is not None:
459      outs = [o]
460    op = Operation("ADD", ins, outs)
461
462    self.__currentOp = op
463    return self
464
465  # See CpuExecutor::executeOperation() for the arguments of each op
466  def AveragePool(self, input, padding, stride_width, stride_height, filter_width, filter_height, activation):
467    ins = [input, padding, stride_width,
468           stride_height, filter_width, filter_height, activation]
469    outs = []
470    op = Operation("AVERAGE_POOL_2D", ins, outs)
471    self.__currentOp = op
472    return self
473
474  def Concatenation(self, *args):
475    ins = [i for i in args]
476    outs = []
477    op = Operation("CONCATENATION", ins, outs)
478    self.__currentOp = op
479    return self
480
481  def Conv(self, filter, bias, input, padding, stride_width, stride_height, activation):
482    ins = [filter, bias, input, padding, stride_width,
483           stride_height, activation]
484    outs = []
485    op = Operation("CONV_2D", ins, outs)
486    self.__currentOp = op
487    return self
488
489  def DepthWiseConv(self, filter, bias, input, padding, stride_width, stride_height, depth_multiplier, activation):
490    ins = [filter, bias, input, padding, stride_width,
491           stride_height, depth_multiplier, activation]
492    outs = []
493    op = Operation("DEPTHWISE_CONV_2D", ins, outs)
494    self.__currentOp = op
495    return self
496
497  def FullyConnected(self, input, weights, bias, activation):
498    ins = [input, weights, bias, activation]
499    outs = []
500    op = Operation("FULLY_CONNECTED", ins, outs)
501    self.__currentOp = op
502    return self
503
504  def Logistic(self, input):
505    ins = [input]
506    outs = []
507    op = Operation("LOGISTIC", ins, outs)
508    self.__currentOp = op
509    return self
510
511  def L2Pool(self, input, padding, stride_width, stride_height, filter_width, filter_height, activation):
512    ins = [input, padding, stride_width,
513           stride_height, filter_width, filter_height, activation]
514    outs = []
515    op = Operation("L2_POOL_2D", ins, outs)
516    self.__currentOp = op
517    return self
518
519  def MaxPool(self, input, padding, stride_width, stride_height, filter_width, filter_height, activation):
520    ins = [input, padding, stride_width,
521           stride_height, filter_width, filter_height, activation]
522    outs = []
523    op = Operation("MAX_POOL_2D", ins, outs)
524    self.__currentOp = op
525    return self
526
527  def SoftMax(self, input, beta):
528    ins = [input, beta]
529    outs = []
530    op = Operation("SOFTMAX", ins, outs)
531    self.__currentOp = op
532    return self
533
534  def Reshape(self, input, shape):
535    ins = [input, shape]
536    outs = []
537    op = Operation("RESHAPE", ins, outs)
538    self.__currentOp = op
539    return self
540
541  def Out(self, o):
542    if (type(o) is list or type(o) is tuple):
543      for i in o:
544        self.__currentOp.outs.append(i)
545        i.ins.append(self.__currentOp)
546    else:
547      self.__currentOp.outs.append(o)
548      o.ins.append(self.__currentOp)
549    return self
550
551  def To(self, o:Value):
552    ret = Model.Out(self, o)
553    self.__currentOp = None
554    return self
555
556  def RelaxedExecution(self, isRelaxed):
557    Model.__isRelaxed = isRelaxed
558    return self
559
560  def isRelaxed():
561    return Model.__isRelaxed
562
563
564class FileNames:
565  SpecFile = ""
566
567class Example():
568  __examples = []
569  def __init__(self, list_of_examples):
570    Example.__examples.append(list_of_examples)
571
572  def dump_dict(d):
573    ret = []
574    for k, v in d.items():
575      key = str(k)
576      suffix = "f"
577      if type(k) is not int:
578        key = str(k.number)
579        if not TypeLookup.is_float(k.type.get_element_type()):
580          suffix = ""
581      init = ", ".join(
582          [str(i) + (suffix if str(i).find(".") != -1 else "") for i in v])
583      ret.append("{%s, {%s}}" % (key, init))
584    return ", ".join(ret)
585
586  def dump_mixed_types(d):
587    ret = []
588
589    float32_dict = {}
590    int32_dict = {}
591    uint8_dict = {}
592
593    for k, v in d.items():
594      key_id = k.ID() if type(k) is not int else k
595      ty = Operand.operands.search(key_id).type.get_element_type()
596      # find out type of the operand addressed by the key
597      if (ty == "TENSOR_FLOAT32"):
598        float32_dict[k] = v
599      elif (ty == "TENSOR_INT32"):
600        int32_dict[k] = v
601      elif (ty == "TENSOR_OEM_BYTE"):
602        uint8_dict[k] = v
603      elif (ty == "TENSOR_QUANT8_ASYMM"):
604        uint8_dict[k] = v
605      else:
606        print ("Unhandled type %s"%ty,  file = sys.stderr)
607        assert 0 and "unsupported example type"
608
609    tuple_init = """\
610{{ // See tools/test_generator/include/TestHarness.h:MixedTyped
611  // int -> FLOAT32 map
612  {{{float32_dict}}},
613  // int -> INT32 map
614  {{{int32_dict}}},
615  // int -> QUANT8_ASYMM map
616  {{{uint8_dict}}}
617}}"""
618    tuple_contents = {
619        'float32_dict': Example.dump_dict(float32_dict),
620        'int32_dict': Example.dump_dict(int32_dict),
621        'uint8_dict': Example.dump_dict(uint8_dict)
622    }
623    return tuple_init.format(**tuple_contents)
624
625
626  def dump(example_file):
627    if len(Example.__examples) > 0:
628      spec_file = " (from: %s)" % (FileNames.SpecFile)
629      print ('// Generated file%s. Do not edit' % (spec_file),
630             file = example_file)
631    for i, o in Example.__examples:
632      print ('// Begin of an example', file = example_file)
633      print ('{', file = example_file)
634      inputs = Example.dump_mixed_types(i)
635      outputs = Example.dump_mixed_types(o)
636      print ('//Input(s)\n%s,' % inputs , file = example_file)
637      print ('//Output(s)\n%s' % outputs, file = example_file)
638      print ('}, // End of an example', file = example_file)
639
640  # Similar to dump_dict, but in python. Used by the slicing tool
641  # if referenced is not None, only print operands that are present there
642  def py_dump_dict(d, referenced):
643    ret = []
644    for k, v in d.items():
645      if referenced != None and k not in referenced:
646        continue
647      key = str(k)
648      init = pprint.pformat(v)
649      ret.append("%s: %s" % (key, init))
650    return ", ".join(ret)
651
652  # similar to dump, but in python. Used by the slicing tool
653  # if referenced is not None, only print operands that are present there
654  def py_dump(example_file, override, referenced):
655    if len(Example.__examples) > 0:
656      example_no = 0
657      example_template = """\
658input{no} = {{{inputs}}}
659# Only executed during data collection phase
660if collecting_data is True:
661  Example((input{no}, {{{outputs}}}))
662"""
663      for i, o in Example.__examples:
664        print ('# Begin of an example', file = example_file)
665        inputs = Example.py_dump_dict(i, referenced)
666        output_list = []
667        for k, v in override.items():
668          output_list.append("%s: [0] * %d" % (k, v))
669        outputs = ",".join(output_list)
670
671        # TODO: handle >1 outputs
672        for k, v in o.items():
673          assert k.number == 0
674        example_contents = {
675            'no': example_no,
676            'inputs': inputs,
677            'outputs': outputs
678        }
679        print (example_template.format(**example_contents), file = example_file)
680
681
682def TopologicalSort(format_op):
683  start = Input.get_inputs().copy()
684  deps = { x: set(x.ins) for x in Uses.all_uses }
685
686  while len(start) > 0:
687    cur = start.pop()
688    if format_op(cur) is False:
689      return
690    distinct_outs = set(cur.outs)
691    for o in distinct_outs:
692      deps[o].remove(cur)
693      if len(deps[o]) == 0 and o.traversable():
694        start.add(o)
695
696class Configuration:
697  use_shm_for_weights = False
698  def useSHM():
699    return Configuration.use_shm_for_weights
700
701# Take a model from command line
702def import_source():
703  parser = argparse.ArgumentParser()
704  parser.add_argument("spec", help="the spec file")
705  parser.add_argument(
706      "-m", "--model", help="the output model file", default="-")
707  parser.add_argument(
708      "-e", "--example", help="the output example file", default="-")
709  args = parser.parse_args()
710
711  if os.path.exists(args.spec):
712    FileNames.SpecFile = os.path.basename(args.spec)
713    exec (open(args.spec).read())
714
715  return (args.model, args.example)
716
717
718def print_cts_op(model_file, op):
719  fmt = op.Definition()
720  if fmt is not None:
721    print ("  %s" % fmt, file = model_file)
722  return True
723
724if __name__ == '__main__':
725  (model, example) = import_source()
726  # Boilerplate
727  args = ""
728  if len(ModelArgument.get_arguments()) > 0:
729    args = ", " + ", ".join(ModelArgument.get_arguments())
730
731  print("Output CTS model: %s" % model, file=sys.stderr)
732  print("Output example:" + example, file=sys.stderr)
733
734  with smart_open(model) as model_file:
735    spec_file = " (from: %s)" % (FileNames.SpecFile)
736
737    print ('// Generated file%s. Do not edit'%(spec_file), file = model_file)
738    print ("void CreateModel(Model *model" + args + ") {", file=model_file)
739
740    # Phase 0: types
741    Type.dump(model_file)
742    # Phase 1: add operands
743    print ("  // Phase 1, operands", file=model_file)
744    Operand.operands.dump(model_file)
745
746    # Phase 2: operations
747    print ("  // Phase 2, operations", file=model_file)
748    TopologicalSort(lambda x: print_cts_op(model_file, x))
749
750    # Phase 3: add inputs and outputs
751    print ("  // Phase 3, inputs and outputs", file=model_file)
752    inputs = Operand.print_operands(Input.get_inputs(True));
753    outputs = Operand.print_operands(Output.get_outputs());
754    print ("  model->identifyInputsAndOutputs(\n" +
755           "    {"+", ".join(inputs)+"},\n    {" + ", ".join(outputs) + "});",
756           file=model_file)
757
758    # Phase 4: set relaxed execution if needed
759    if (Model.isRelaxed()):
760      print ("  // Phase 4: set relaxed execution", file=model_file)
761      print ("  model->relaxComputationFloat32toFloat16(true);", file=model_file)
762
763    # Boilerplate
764    print ("  assert(model->isValid());", file=model_file);
765    print ("}", file=model_file)
766    print (IgnoredOutput.gen_ignored(), file=model_file)
767
768  with smart_open(example) as example_file:
769    Example.dump(example_file)
770