1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "adb.h"
18
19#include "command.h"
20#include "print.h"
21#include "util.h"
22
23#include <errno.h>
24#include <string.h>
25#include <stdlib.h>
26#include <unistd.h>
27#include <sys/types.h>
28#include <sys/wait.h>
29#include <limits.h>
30
31#include <iostream>
32#include <istream>
33#include <streambuf>
34
35using namespace std;
36
37struct Buffer: public streambuf
38{
39    Buffer(char* begin, size_t size);
40};
41
42Buffer::Buffer(char* begin, size_t size)
43{
44    this->setg(begin, begin, begin + size);
45}
46
47int
48run_adb(const char* first, ...)
49{
50    Command cmd("adb");
51
52    if (first == NULL) {
53        return 0;
54    }
55
56    cmd.AddArg(first);
57
58    va_list args;
59    va_start(args, first);
60    while (true) {
61        const char* arg = va_arg(args, char*);
62        if (arg == NULL) {
63            break;
64        }
65        cmd.AddArg(arg);
66    }
67    va_end(args);
68
69    return run_command(cmd);
70}
71
72string
73get_system_property(const string& name, int* err)
74{
75    Command cmd("adb");
76    cmd.AddArg("shell");
77    cmd.AddArg("getprop");
78    cmd.AddArg(name);
79
80    return trim(get_command_output(cmd, err, false));
81}
82
83
84static uint64_t
85read_varint(int fd, int* err, bool* done)
86{
87    uint32_t bits = 0;
88    uint64_t result = 0;
89    while (true) {
90        uint8_t byte;
91        ssize_t amt = read(fd, &byte, 1);
92        if (amt == 0) {
93            *done = true;
94            return result;
95        } else if (amt < 0) {
96            return *err = errno;
97        }
98        result |= uint64_t(byte & 0x7F) << bits;
99        if ((byte & 0x80) == 0) {
100            return result;
101        }
102        bits += 7;
103        if (bits > 64) {
104            *err = -1;
105            return 0;
106        }
107    }
108}
109
110static char*
111read_sized_buffer(int fd, int* err, size_t* resultSize)
112{
113    bool done = false;
114    uint64_t size = read_varint(fd, err, &done);
115    if (*err != 0 || done) {
116        return NULL;
117    }
118    if (size == 0) {
119        *resultSize = 0;
120        return NULL;
121    }
122    // 10 MB seems like a reasonable limit.
123    if (size > 10*1024*1024) {
124        print_error("result buffer too large: %llu", size);
125        return NULL;
126    }
127    char* buf = (char*)malloc(size);
128    if (buf == NULL) {
129        print_error("Can't allocate a buffer of size for test results: %llu", size);
130        return NULL;
131    }
132    int pos = 0;
133    while (size - pos > 0) {
134        ssize_t amt = read(fd, buf+pos, size-pos);
135        if (amt == 0) {
136            // early end of pipe
137            print_error("Early end of pipe.");
138            *err = -1;
139            free(buf);
140            return NULL;
141        } else if (amt < 0) {
142            // error
143            *err = errno;
144            free(buf);
145            return NULL;
146        }
147        pos += amt;
148    }
149    *resultSize = (size_t)size;
150    return buf;
151}
152
153static int
154read_sized_proto(int fd, Message* message)
155{
156    int err = 0;
157    size_t size;
158    char* buf = read_sized_buffer(fd, &err, &size);
159    if (err != 0) {
160        if (buf != NULL) {
161            free(buf);
162        }
163        return err;
164    } else if (size == 0) {
165        if (buf != NULL) {
166            free(buf);
167        }
168        return 0;
169    } else if (buf == NULL) {
170        return -1;
171    }
172    Buffer buffer(buf, size);
173    istream in(&buffer);
174
175    err = message->ParseFromIstream(&in) ? 0 : -1;
176
177    free(buf);
178    return err;
179}
180
181static int
182skip_bytes(int fd, ssize_t size, char* scratch, int scratchSize)
183{
184    while (size > 0) {
185        ssize_t amt = size < scratchSize ? size : scratchSize;
186        fprintf(stderr, "skipping %lu/%ld bytes\n", size, amt);
187        amt = read(fd, scratch, amt);
188        if (amt == 0) {
189            // early end of pipe
190            print_error("Early end of pipe.");
191            return -1;
192        } else if (amt < 0) {
193            // error
194            return errno;
195        }
196        size -= amt;
197    }
198    return 0;
199}
200
201static int
202skip_unknown_field(int fd, uint64_t tag, char* scratch, int scratchSize) {
203    bool done;
204    int err;
205    uint64_t size;
206    switch (tag & 0x7) {
207        case 0: // varint
208            read_varint(fd, &err, &done);
209            if (err != 0) {
210                return err;
211            } else if (done) {
212                return -1;
213            } else {
214                return 0;
215            }
216        case 1:
217            return skip_bytes(fd, 8, scratch, scratchSize);
218        case 2:
219            size = read_varint(fd, &err, &done);
220            if (err != 0) {
221                return err;
222            } else if (done) {
223                return -1;
224            }
225            if (size > INT_MAX) {
226                // we'll be here a long time but this keeps it from overflowing
227                return -1;
228            }
229            return skip_bytes(fd, (ssize_t)size, scratch, scratchSize);
230        case 5:
231            return skip_bytes(fd, 4, scratch, scratchSize);
232        default:
233            print_error("bad wire type for tag 0x%lx\n", tag);
234            return -1;
235    }
236}
237
238static int
239read_instrumentation_results(int fd, char* scratch, int scratchSize,
240        InstrumentationCallbacks* callbacks)
241{
242    bool done = false;
243    int err = 0;
244    string result;
245    while (true) {
246        uint64_t tag = read_varint(fd, &err, &done);
247        if (done) {
248            // Done reading input (this is the only place that a stream end isn't an error).
249            return 0;
250        } else if (err != 0) {
251            return err;
252        } else if (tag == 0xa) { // test_status
253            TestStatus status;
254            err = read_sized_proto(fd, &status);
255            if (err != 0) {
256                return err;
257            }
258            callbacks->OnTestStatus(status);
259        } else if (tag == 0x12) { // session_status
260            SessionStatus status;
261            err = read_sized_proto(fd, &status);
262            if (err != 0) {
263                return err;
264            }
265            callbacks->OnSessionStatus(status);
266        } else {
267            err = skip_unknown_field(fd, tag, scratch, scratchSize);
268            if (err != 0) {
269                return err;
270            }
271        }
272    }
273    return 0;
274}
275
276int
277run_instrumentation_test(const string& packageName, const string& runner, const string& className,
278        InstrumentationCallbacks* callbacks)
279{
280    Command cmd("adb");
281    cmd.AddArg("shell");
282    cmd.AddArg("am");
283    cmd.AddArg("instrument");
284    cmd.AddArg("-w");
285    cmd.AddArg("-m");
286    if (className.length() > 0) {
287        cmd.AddArg("-e");
288        cmd.AddArg("class");
289        cmd.AddArg(className);
290    }
291    cmd.AddArg(packageName + "/" + runner);
292
293    print_command(cmd);
294
295    int fds[2];
296    pipe(fds);
297
298    pid_t pid = fork();
299
300    if (pid == -1) {
301        // fork error
302        return errno;
303    } else if (pid == 0) {
304        // child
305        while ((dup2(fds[1], STDOUT_FILENO) == -1) && (errno == EINTR)) {}
306        close(fds[1]);
307        close(fds[0]);
308        const char* prog = cmd.GetProg();
309        char* const* argv = cmd.GetArgv();
310        char* const* env = cmd.GetEnv();
311        exec_with_path_search(prog, argv, env);
312        print_error("Unable to run command: %s", prog);
313        exit(1);
314    } else {
315        // parent
316        close(fds[1]);
317        string result;
318        const int size = 16*1024;
319        char* buf = (char*)malloc(size);
320        int err = read_instrumentation_results(fds[0], buf, size, callbacks);
321        free(buf);
322        int status;
323        waitpid(pid, &status, 0);
324        if (err != 0) {
325            return err;
326        }
327        if (WIFEXITED(status)) {
328            return WEXITSTATUS(status);
329        } else {
330            return -1;
331        }
332    }
333}
334
335/**
336 * Get the second to last bundle in the args list. Stores the last name found
337 * in last. If the path is not found or if the args list is empty, returns NULL.
338 */
339static const ResultsBundleEntry *
340find_penultimate_entry(const ResultsBundle& bundle, va_list args)
341{
342    const ResultsBundle* b = &bundle;
343    const char* arg = va_arg(args, char*);
344    while (arg) {
345        string last = arg;
346        arg = va_arg(args, char*);
347        bool found = false;
348        for (int i=0; i<b->entries_size(); i++) {
349            const ResultsBundleEntry& e = b->entries(i);
350            if (e.key() == last) {
351                if (arg == NULL) {
352                    return &e;
353                } else if (e.has_value_bundle()) {
354                    b = &e.value_bundle();
355                    found = true;
356                }
357            }
358        }
359        if (!found) {
360            return NULL;
361        }
362        if (arg == NULL) {
363            return NULL;
364        }
365    }
366    return NULL;
367}
368
369string
370get_bundle_string(const ResultsBundle& bundle, bool* found, ...)
371{
372    va_list args;
373    va_start(args, found);
374    const ResultsBundleEntry* entry = find_penultimate_entry(bundle, args);
375    va_end(args);
376    if (entry == NULL) {
377        *found = false;
378        return string();
379    }
380    if (entry->has_value_string()) {
381        *found = true;
382        return entry->value_string();
383    }
384    *found = false;
385    return string();
386}
387
388int32_t
389get_bundle_int(const ResultsBundle& bundle, bool* found, ...)
390{
391    va_list args;
392    va_start(args, found);
393    const ResultsBundleEntry* entry = find_penultimate_entry(bundle, args);
394    va_end(args);
395    if (entry == NULL) {
396        *found = false;
397        return 0;
398    }
399    if (entry->has_value_int()) {
400        *found = true;
401        return entry->value_int();
402    }
403    *found = false;
404    return 0;
405}
406
407float
408get_bundle_float(const ResultsBundle& bundle, bool* found, ...)
409{
410    va_list args;
411    va_start(args, found);
412    const ResultsBundleEntry* entry = find_penultimate_entry(bundle, args);
413    va_end(args);
414    if (entry == NULL) {
415        *found = false;
416        return 0;
417    }
418    if (entry->has_value_float()) {
419        *found = true;
420        return entry->value_float();
421    }
422    *found = false;
423    return 0;
424}
425
426double
427get_bundle_double(const ResultsBundle& bundle, bool* found, ...)
428{
429    va_list args;
430    va_start(args, found);
431    const ResultsBundleEntry* entry = find_penultimate_entry(bundle, args);
432    va_end(args);
433    if (entry == NULL) {
434        *found = false;
435        return 0;
436    }
437    if (entry->has_value_double()) {
438        *found = true;
439        return entry->value_double();
440    }
441    *found = false;
442    return 0;
443}
444
445int64_t
446get_bundle_long(const ResultsBundle& bundle, bool* found, ...)
447{
448    va_list args;
449    va_start(args, found);
450    const ResultsBundleEntry* entry = find_penultimate_entry(bundle, args);
451    va_end(args);
452    if (entry == NULL) {
453        *found = false;
454        return 0;
455    }
456    if (entry->has_value_long()) {
457        *found = true;
458        return entry->value_long();
459    }
460    *found = false;
461    return 0;
462}
463
464