1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "GraphDump.h"
18
19#include "HalInterfaces.h"
20
21#include <set>
22#include <iostream>
23
24namespace android {
25namespace nn {
26
27// Provide short name for OperandType value.
28static std::string translate(OperandType type) {
29    switch (type) {
30        case OperandType::FLOAT32:             return "F32";
31        case OperandType::INT32:               return "I32";
32        case OperandType::UINT32:              return "U32";
33        case OperandType::TENSOR_FLOAT32:      return "TF32";
34        case OperandType::TENSOR_INT32:        return "TI32";
35        case OperandType::TENSOR_QUANT8_ASYMM: return "TQ8A";
36        case OperandType::OEM:                 return "OEM";
37        case OperandType::TENSOR_OEM_BYTE:     return "TOEMB";
38        default:                               return toString(type);
39    }
40}
41
42void graphDump(const char* name, const Model& model, std::ostream& outStream) {
43    // Operand nodes are named "d" (operanD) followed by operand index.
44    // Operation nodes are named "n" (operatioN) followed by operation index.
45    // (These names are not the names that are actually displayed -- those
46    //  names are given by the "label" attribute.)
47
48    outStream << "// " << name << std::endl;
49    outStream << "digraph {" << std::endl;
50
51    // model inputs and outputs
52    std::set<uint32_t> modelIO;
53    for (unsigned i = 0, e = model.inputIndexes.size(); i < e; i++) {
54        modelIO.insert(model.inputIndexes[i]);
55    }
56    for (unsigned i = 0, e = model.outputIndexes.size(); i < e; i++) {
57        modelIO.insert(model.outputIndexes[i]);
58    }
59
60    // model operands
61    for (unsigned i = 0, e = model.operands.size(); i < e; i++) {
62        outStream << "    d" << i << " [";
63        if (modelIO.count(i)) {
64            outStream << "style=filled fillcolor=black fontcolor=white ";
65        }
66        outStream << "label=\"" << i;
67        const Operand& opnd = model.operands[i];
68        const char* kind = nullptr;
69        switch (opnd.lifetime) {
70            case OperandLifeTime::CONSTANT_COPY:
71                kind = "COPY";
72                break;
73            case OperandLifeTime::CONSTANT_REFERENCE:
74                kind = "REF";
75                break;
76            case OperandLifeTime::NO_VALUE:
77                kind = "NO";
78                break;
79            default:
80                // nothing interesting
81                break;
82        }
83        if (kind) {
84            outStream << ": " << kind;
85        }
86        outStream << "\\n" << translate(opnd.type);
87        if (opnd.dimensions.size()) {
88            outStream << "(";
89            for (unsigned i = 0, e = opnd.dimensions.size(); i < e; i++) {
90                if (i > 0) {
91                    outStream << "x";
92                }
93                outStream << opnd.dimensions[i];
94            }
95            outStream << ")";
96        }
97        outStream << "\"]" << std::endl;
98    }
99
100    // model operations
101    for (unsigned i = 0, e = model.operations.size(); i < e; i++) {
102        const Operation& operation = model.operations[i];
103        outStream << "    n" << i << " [shape=box";
104        const uint32_t maxArity = std::max(operation.inputs.size(), operation.outputs.size());
105        if (maxArity > 1) {
106            if (maxArity == operation.inputs.size()) {
107                outStream << " ordering=in";
108            } else {
109                outStream << " ordering=out";
110            }
111        }
112        outStream << " label=\"" << i << ": "
113                  << toString(operation.type) << "\"]" << std::endl;
114        {
115            // operation inputs
116            for (unsigned in = 0, inE = operation.inputs.size(); in < inE; in++) {
117                outStream << "    d" << operation.inputs[in] << " -> n" << i;
118                if (inE > 1) {
119                    outStream << " [label=" << in << "]";
120                }
121                outStream << std::endl;
122            }
123        }
124
125        {
126            // operation outputs
127            for (unsigned out = 0, outE = operation.outputs.size(); out < outE; out++) {
128                outStream << "    n" << i << " -> d" << operation.outputs[out];
129                if (outE > 1) {
130                    outStream << " [label=" << out << "]";
131                }
132                outStream << std::endl;
133            }
134        }
135    }
136    outStream << "}" << std::endl;
137}
138
139}  // namespace nn
140}  // namespace android
141