1/*
2 * Copyright 2012, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17//#define LOG_NEBUG 0
18#define LOG_TAG "udptest"
19#include <utils/Log.h>
20
21#include "ANetworkSession.h"
22
23#include <binder/ProcessState.h>
24#include <media/stagefright/foundation/ABuffer.h>
25#include <media/stagefright/foundation/ADebug.h>
26#include <media/stagefright/foundation/AHandler.h>
27#include <media/stagefright/foundation/ALooper.h>
28#include <media/stagefright/foundation/AMessage.h>
29#include <media/stagefright/Utils.h>
30
31namespace android {
32
33struct TestHandler : public AHandler {
34    TestHandler(const sp<ANetworkSession> &netSession);
35
36    void startServer(unsigned localPort);
37    void startClient(const char *remoteHost, unsigned remotePort);
38
39protected:
40    virtual ~TestHandler();
41
42    virtual void onMessageReceived(const sp<AMessage> &msg);
43
44private:
45    enum {
46        kWhatStartServer,
47        kWhatStartClient,
48        kWhatUDPNotify,
49        kWhatSendPacket,
50    };
51
52    sp<ANetworkSession> mNetSession;
53
54    bool mIsServer;
55    bool mConnected;
56    int32_t mUDPSession;
57    uint32_t mSeqNo;
58    double mTotalTimeUs;
59    int32_t mCount;
60
61    void postSendPacket(int64_t delayUs = 0ll);
62
63    DISALLOW_EVIL_CONSTRUCTORS(TestHandler);
64};
65
66TestHandler::TestHandler(const sp<ANetworkSession> &netSession)
67    : mNetSession(netSession),
68      mIsServer(false),
69      mConnected(false),
70      mUDPSession(0),
71      mSeqNo(0),
72      mTotalTimeUs(0.0),
73      mCount(0) {
74}
75
76TestHandler::~TestHandler() {
77}
78
79void TestHandler::startServer(unsigned localPort) {
80    sp<AMessage> msg = new AMessage(kWhatStartServer, id());
81    msg->setInt32("localPort", localPort);
82    msg->post();
83}
84
85void TestHandler::startClient(const char *remoteHost, unsigned remotePort) {
86    sp<AMessage> msg = new AMessage(kWhatStartClient, id());
87    msg->setString("remoteHost", remoteHost);
88    msg->setInt32("remotePort", remotePort);
89    msg->post();
90}
91
92void TestHandler::onMessageReceived(const sp<AMessage> &msg) {
93    switch (msg->what()) {
94        case kWhatStartClient:
95        {
96            AString remoteHost;
97            CHECK(msg->findString("remoteHost", &remoteHost));
98
99            int32_t remotePort;
100            CHECK(msg->findInt32("remotePort", &remotePort));
101
102            sp<AMessage> notify = new AMessage(kWhatUDPNotify, id());
103
104            CHECK_EQ((status_t)OK,
105                     mNetSession->createUDPSession(
106                         0 /* localPort */,
107                         remoteHost.c_str(),
108                         remotePort,
109                         notify,
110                         &mUDPSession));
111
112            postSendPacket();
113            break;
114        }
115
116        case kWhatStartServer:
117        {
118            mIsServer = true;
119
120            int32_t localPort;
121            CHECK(msg->findInt32("localPort", &localPort));
122
123            sp<AMessage> notify = new AMessage(kWhatUDPNotify, id());
124
125            CHECK_EQ((status_t)OK,
126                     mNetSession->createUDPSession(
127                         localPort, notify, &mUDPSession));
128
129            break;
130        }
131
132        case kWhatSendPacket:
133        {
134            char buffer[12];
135            memset(buffer, 0, sizeof(buffer));
136
137            buffer[0] = mSeqNo >> 24;
138            buffer[1] = (mSeqNo >> 16) & 0xff;
139            buffer[2] = (mSeqNo >> 8) & 0xff;
140            buffer[3] = mSeqNo & 0xff;
141            ++mSeqNo;
142
143            int64_t nowUs = ALooper::GetNowUs();
144            buffer[4] = nowUs >> 56;
145            buffer[5] = (nowUs >> 48) & 0xff;
146            buffer[6] = (nowUs >> 40) & 0xff;
147            buffer[7] = (nowUs >> 32) & 0xff;
148            buffer[8] = (nowUs >> 24) & 0xff;
149            buffer[9] = (nowUs >> 16) & 0xff;
150            buffer[10] = (nowUs >> 8) & 0xff;
151            buffer[11] = nowUs & 0xff;
152
153            CHECK_EQ((status_t)OK,
154                     mNetSession->sendRequest(
155                         mUDPSession, buffer, sizeof(buffer)));
156
157            postSendPacket(20000ll);
158            break;
159        }
160
161        case kWhatUDPNotify:
162        {
163            int32_t reason;
164            CHECK(msg->findInt32("reason", &reason));
165
166            switch (reason) {
167                case ANetworkSession::kWhatError:
168                {
169                    int32_t sessionID;
170                    CHECK(msg->findInt32("sessionID", &sessionID));
171
172                    int32_t err;
173                    CHECK(msg->findInt32("err", &err));
174
175                    AString detail;
176                    CHECK(msg->findString("detail", &detail));
177
178                    ALOGE("An error occurred in session %d (%d, '%s/%s').",
179                          sessionID,
180                          err,
181                          detail.c_str(),
182                          strerror(-err));
183
184                    mNetSession->destroySession(sessionID);
185                    break;
186                }
187
188                case ANetworkSession::kWhatDatagram:
189                {
190                    int32_t sessionID;
191                    CHECK(msg->findInt32("sessionID", &sessionID));
192
193                    sp<ABuffer> data;
194                    CHECK(msg->findBuffer("data", &data));
195
196                    if (mIsServer) {
197                        if (!mConnected) {
198                            AString fromAddr;
199                            CHECK(msg->findString("fromAddr", &fromAddr));
200
201                            int32_t fromPort;
202                            CHECK(msg->findInt32("fromPort", &fromPort));
203
204                            CHECK_EQ((status_t)OK,
205                                     mNetSession->connectUDPSession(
206                                         mUDPSession, fromAddr.c_str(), fromPort));
207
208                            mConnected = true;
209                        }
210
211                        int64_t nowUs = ALooper::GetNowUs();
212
213                        sp<ABuffer> buffer = new ABuffer(data->size() + 8);
214                        memcpy(buffer->data(), data->data(), data->size());
215
216                        uint8_t *ptr = buffer->data() + data->size();
217
218                        *ptr++ = nowUs >> 56;
219                        *ptr++ = (nowUs >> 48) & 0xff;
220                        *ptr++ = (nowUs >> 40) & 0xff;
221                        *ptr++ = (nowUs >> 32) & 0xff;
222                        *ptr++ = (nowUs >> 24) & 0xff;
223                        *ptr++ = (nowUs >> 16) & 0xff;
224                        *ptr++ = (nowUs >> 8) & 0xff;
225                        *ptr++ = nowUs & 0xff;
226
227                        CHECK_EQ((status_t)OK,
228                                 mNetSession->sendRequest(
229                                     mUDPSession, buffer->data(), buffer->size()));
230                    } else {
231                        CHECK_EQ(data->size(), 20u);
232
233                        uint32_t seqNo = U32_AT(data->data());
234                        int64_t t1 = U64_AT(data->data() + 4);
235                        int64_t t2 = U64_AT(data->data() + 12);
236
237                        int64_t t3;
238                        CHECK(data->meta()->findInt64("arrivalTimeUs", &t3));
239
240#if 0
241                        printf("roundtrip seqNo %u, time = %lld us\n",
242                               seqNo, t3 - t1);
243#else
244                        mTotalTimeUs += t3 - t1;
245                        ++mCount;
246                        printf("avg. roundtrip time %.2f us\n", mTotalTimeUs / mCount);
247#endif
248                    }
249                    break;
250                }
251
252                default:
253                    TRESPASS();
254            }
255
256            break;
257        }
258
259        default:
260            TRESPASS();
261    }
262}
263
264void TestHandler::postSendPacket(int64_t delayUs) {
265    (new AMessage(kWhatSendPacket, id()))->post(delayUs);
266}
267
268}  // namespace android
269
270static void usage(const char *me) {
271    fprintf(stderr,
272            "usage: %s -c host[:port]\tconnect to test server\n"
273            "           -l            \tcreate a test server\n",
274            me);
275}
276
277int main(int argc, char **argv) {
278    using namespace android;
279
280    ProcessState::self()->startThreadPool();
281
282    int32_t localPort = -1;
283    int32_t connectToPort = -1;
284    AString connectToHost;
285
286    int res;
287    while ((res = getopt(argc, argv, "hc:l:")) >= 0) {
288        switch (res) {
289            case 'c':
290            {
291                const char *colonPos = strrchr(optarg, ':');
292
293                if (colonPos == NULL) {
294                    connectToHost = optarg;
295                    connectToPort = 49152;
296                } else {
297                    connectToHost.setTo(optarg, colonPos - optarg);
298
299                    char *end;
300                    connectToPort = strtol(colonPos + 1, &end, 10);
301
302                    if (*end != '\0' || end == colonPos + 1
303                            || connectToPort < 1 || connectToPort > 65535) {
304                        fprintf(stderr, "Illegal port specified.\n");
305                        exit(1);
306                    }
307                }
308                break;
309            }
310
311            case 'l':
312            {
313                char *end;
314                localPort = strtol(optarg, &end, 10);
315
316                if (*end != '\0' || end == optarg
317                        || localPort < 1 || localPort > 65535) {
318                    fprintf(stderr, "Illegal port specified.\n");
319                    exit(1);
320                }
321                break;
322            }
323
324            case '?':
325            case 'h':
326                usage(argv[0]);
327                exit(1);
328        }
329    }
330
331    if (localPort < 0 && connectToPort < 0) {
332        fprintf(stderr,
333                "You need to select either client or server mode.\n");
334        exit(1);
335    }
336
337    sp<ANetworkSession> netSession = new ANetworkSession;
338    netSession->start();
339
340    sp<ALooper> looper = new ALooper;
341
342    sp<TestHandler> handler = new TestHandler(netSession);
343    looper->registerHandler(handler);
344
345    if (localPort >= 0) {
346        handler->startServer(localPort);
347    } else {
348        handler->startClient(connectToHost.c_str(), connectToPort);
349    }
350
351    looper->start(true /* runOnCallingThread */);
352
353    return 0;
354}
355
356