1#!/usr/bin/python
2#
3# Copyright (C) 2018 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#      http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# Make sure that simpleperf's inferno is on the PYTHONPATH, e.g., run as
18# PYTHONPATH=$PYTHONPATH:$ANDROID_BUILD_TOP/system/extras/simpleperf/scripts/inferno python ..
19
20import argparse
21import itertools
22import sqlite3
23
24class Callsite(object):
25    def __init__(self, dso_id, sym_id):
26        self.dso_id = dso_id
27        self.sym_id = sym_id
28        self.count = 0
29        self.child_map = {}
30        self.id = self._get_next_callsite_id()
31
32    def add(self, dso_id, sym_id):
33        if (dso_id, sym_id) in self.child_map:
34            return self.child_map[(dso_id, sym_id)]
35        new_callsite = Callsite(dso_id, sym_id)
36        self.child_map[(dso_id, sym_id)] = new_callsite
37        return new_callsite
38
39    def child_count_to_self(self):
40        self.count = reduce(lambda x, y: x + y[1].count, self.child_map.iteritems(), 0)
41
42    def trim(self, local_threshold_in_percent, global_threshold):
43        local_threshold = local_threshold_in_percent * 0.01 * self.count
44        threshold = max(local_threshold, global_threshold)
45        for k, v in self.child_map.items():
46            if v.count < threshold:
47                del self.child_map[k]
48        for _, v in self.child_map.iteritems():
49            v.trim(local_threshold_in_percent, global_threshold)
50
51    def _get_str(self, id, m):
52        if id in m:
53            return m[id]
54        return str(id)
55
56    def print_callsite_ascii(self, depth, indent, dsos, syms):
57
58        print '  ' * indent + "%s (%s) [%d]" % (self._get_str(self.sym_id, syms),
59                                                self._get_str(self.dso_id, dsos),
60                                                self.count)
61        if depth == 0:
62            return
63        for v in sorted(self.child_map.itervalues, key=lambda x: x.count, reverse=True):
64            v.print_callsite_ascii(depth - 1, indent + 1, dsos, syms)
65
66    # Functions for flamegraph compatibility.
67
68    callsite_counter = 0
69    @classmethod
70    def _get_next_callsite_id(cls):
71        cls.callsite_counter += 1
72        return cls.callsite_counter
73
74    def create_children_list(self):
75        self.children = sorted(self.child_map.itervalues(), key=lambda x: x.count, reverse=True)
76
77    def generate_offset(self, start_offset):
78        self.offset = start_offset
79        child_offset = start_offset
80        for child in self.children:
81            child_offset = child.generate_offset(child_offset)
82        return self.offset + self.count
83
84    def svgrenderer_compat(self, dsos, syms):
85        self.create_children_list()
86        self.method = self._get_str(self.sym_id, syms)
87        self.dso = self._get_str(self.dso_id, dsos)
88        self.offset = 0
89        for c in self.children:
90            c.svgrenderer_compat(dsos, syms)
91
92    def weight(self):
93        return float(self.count)
94
95    def get_max_depth(self):
96        if self.child_map:
97            return max([c.get_max_depth() for c in self.child_map.itervalues()]) + 1
98        return 1
99
100class SqliteReader(object):
101    def __init__(self):
102        self.root = Callsite("root", "root")
103        self.dsos = {}
104        self.syms = {}
105
106    def open(self, f):
107        self._conn = sqlite3.connect(f)
108        self._c = self._conn.cursor()
109
110    def close(self):
111        self._conn.close()
112
113    def read(self, local_threshold_in_percent, global_threshold_in_percent, limit):
114        # Read aux tables first, as we need to find the kernel symbols.
115        def read_table(name, dest_table):
116            self._c.execute('select id, name from %s' % (name))
117            while True:
118                rows = self._c.fetchmany(100)
119                if not rows:
120                    break
121                for row in rows:
122                    dest_table[row[0]] = row[1]
123
124        print 'Reading DSOs'
125        read_table('dsos', self.dsos)
126
127        print 'Reading symbol strings'
128        read_table('syms', self.syms)
129
130        kernel_sym_id = None
131        for i, v in self.syms.iteritems():
132            if v == '[kernel]':
133                kernel_sym_id = i
134                break
135
136        print 'Reading samples'
137        self._c.execute('''select sample_id, depth, dso_id, sym_id from stacks
138                           order by sample_id asc, depth desc''')
139
140        last_sample_id = None
141        chain = None
142        count = 0
143        while True:
144            rows = self._c.fetchmany(100)
145
146            if not rows:
147                break
148            for row in rows:
149                if row[3] == kernel_sym_id and row[1] == 0:
150                    # Skip kernel.
151                    continue
152                if row[0] != last_sample_id:
153                    last_sample_id = row[0]
154                    chain = self.root
155                chain = chain.add(row[2], row[3])
156                chain.count = chain.count + 1
157
158            count = count + len(rows)
159            if limit is not None and count >= limit:
160                print 'Breaking as limit is reached'
161                break
162
163        self.root.child_count_to_self()
164        global_threshold = global_threshold_in_percent * 0.01 * self.root.count
165        self.root.trim(local_threshold_in_percent, global_threshold)
166
167    def print_data_ascii(self, depth):
168        self.root.print_callsite_ascii(depth, 0, self.dsos, self.syms)
169
170    def print_svg(self, filename, depth):
171        from svg_renderer import renderSVG
172        self.root.svgrenderer_compat(self.dsos, self.syms)
173        self.root.generate_offset(0)
174        f = open(filename, 'w')
175        f.write('''
176<html>
177<body>
178<div id='flamegraph_id' style='font-family: Monospace;'>
179<style type="text/css"> .s { stroke:black; stroke-width:0.5; cursor:pointer;} </style>
180<style type="text/css"> .t:hover { cursor:pointer; } </style>
181''')
182
183        class FakeProcess:
184            def __init__(self):
185                self.props = { 'trace_offcpu': False }
186        fake_process = FakeProcess()
187        renderSVG(fake_process, self.root, f, 'hot')
188
189        f.write('''
190</div>
191''')
192
193        # Emit script.js, if we can find it.
194        import os.path
195        import sys
196        script_js_rel = "../../simpleperf/scripts/inferno/script.js"
197        script_js = os.path.join(os.path.dirname(__file__), script_js_rel)
198        if os.path.exists(script_js):
199            f.write('<script>\n')
200            with open(script_js, 'r') as script_f:
201                f.write(script_f.read())
202            f.write('''
203</script>
204<br/><br/>
205<div>Navigate with WASD, zoom in with SPACE, zoom out with BACKSPACE.</div>
206<script>document.addEventListener('DOMContentLoaded', flamegraphInit);</script>
207</body>
208</html>
209''')
210        f.close()
211
212if __name__ == "__main__":
213    parser = argparse.ArgumentParser(description='''Translate a perfprofd database into a flame
214                                                    representation''')
215
216    parser.add_argument('file', help='the sqlite database to use', metavar='file', type=str)
217
218    parser.add_argument('--html-out', help='output file for HTML flame graph', type=str)
219    parser.add_argument('--threshold', help='child threshold in percent', type=float, default=5)
220    parser.add_argument('--global-threshold', help='global threshold in percent', type=float,
221                        default=.1)
222    parser.add_argument('--depth', help='depth to print to', type=int, default=10)
223    parser.add_argument('--limit', help='limit to given number of stack trace entries', type=int)
224
225    args = parser.parse_args()
226    if args is not None:
227        sql_out = SqliteReader()
228        sql_out.open(args.file)
229        sql_out.read(args.threshold, args.global_threshold, args.limit)
230        if args.html_out is None:
231            sql_out.print_data_ascii(args.depth)
232        else:
233            sql_out.print_svg(args.html_out, args.depth)
234        sql_out.close()
235