1/*
2 *  Copyright (c) 2012 The WebRTC project authors. All Rights Reserved.
3 *
4 *  Use of this source code is governed by a BSD-style license
5 *  that can be found in the LICENSE file in the root of the source
6 *  tree. An additional intellectual property rights grant can be found
7 *  in the file PATENTS.  All contributing project authors may
8 *  be found in the AUTHORS file in the root of the source tree.
9 */
10
11#include "webrtc/test/channel_transport/udp_socket2_manager_win.h"
12
13#include <assert.h>
14#include <stdio.h>
15
16#include "webrtc/system_wrappers/include/aligned_malloc.h"
17#include "webrtc/test/channel_transport/udp_socket2_win.h"
18
19namespace webrtc {
20namespace test {
21
22uint32_t UdpSocket2ManagerWindows::_numOfActiveManagers = 0;
23bool UdpSocket2ManagerWindows::_wsaInit = false;
24
25UdpSocket2ManagerWindows::UdpSocket2ManagerWindows()
26    : UdpSocketManager(),
27      _id(-1),
28      _stopped(false),
29      _init(false),
30      _pCrit(CriticalSectionWrapper::CreateCriticalSection()),
31      _ioCompletionHandle(NULL),
32      _numActiveSockets(0),
33      _event(EventWrapper::Create())
34{
35    _managerNumber = _numOfActiveManagers++;
36
37    if(_numOfActiveManagers == 1)
38    {
39        WORD wVersionRequested = MAKEWORD(2, 2);
40        WSADATA wsaData;
41        _wsaInit = WSAStartup(wVersionRequested, &wsaData) == 0;
42        // TODO (hellner): seems safer to use RAII for this. E.g. what happens
43        //                 if a UdpSocket2ManagerWindows() created and destroyed
44        //                 without being initialized.
45    }
46}
47
48UdpSocket2ManagerWindows::~UdpSocket2ManagerWindows()
49{
50    WEBRTC_TRACE(kTraceDebug, kTraceTransport, _id,
51                 "UdpSocket2ManagerWindows(%d)::~UdpSocket2ManagerWindows()",
52                 _managerNumber);
53
54    if(_init)
55    {
56        _pCrit->Enter();
57        if(_numActiveSockets)
58        {
59            _pCrit->Leave();
60            _event->Wait(INFINITE);
61        }
62        else
63        {
64            _pCrit->Leave();
65        }
66        StopWorkerThreads();
67
68        for (WorkerList::iterator iter = _workerThreadsList.begin();
69             iter != _workerThreadsList.end(); ++iter) {
70          delete *iter;
71        }
72        _workerThreadsList.clear();
73        _ioContextPool.Free();
74
75        _numOfActiveManagers--;
76        if(_ioCompletionHandle)
77        {
78            CloseHandle(_ioCompletionHandle);
79        }
80        if (_numOfActiveManagers == 0)
81        {
82            if(_wsaInit)
83            {
84                WSACleanup();
85            }
86        }
87    }
88    if(_pCrit)
89    {
90        delete _pCrit;
91    }
92    if(_event)
93    {
94        delete _event;
95    }
96}
97
98bool UdpSocket2ManagerWindows::Init(int32_t id,
99                                    uint8_t& numOfWorkThreads) {
100  CriticalSectionScoped cs(_pCrit);
101  if ((_id != -1) || (_numOfWorkThreads != 0)) {
102      assert(_id != -1);
103      assert(_numOfWorkThreads != 0);
104      return false;
105  }
106  _id = id;
107  _numOfWorkThreads = numOfWorkThreads;
108  return true;
109}
110
111bool UdpSocket2ManagerWindows::Start()
112{
113    WEBRTC_TRACE(kTraceDebug, kTraceTransport, _id,
114                 "UdpSocket2ManagerWindows(%d)::Start()",_managerNumber);
115    if(!_init)
116    {
117        StartWorkerThreads();
118    }
119
120    if(!_init)
121    {
122        return false;
123    }
124    _pCrit->Enter();
125    // Start worker threads.
126    _stopped = false;
127    int32_t error = 0;
128    for (WorkerList::iterator iter = _workerThreadsList.begin();
129         iter != _workerThreadsList.end() && !error; ++iter) {
130      if(!(*iter)->Start())
131        error = 1;
132    }
133    if(error)
134    {
135        WEBRTC_TRACE(
136            kTraceError,
137            kTraceTransport,
138            _id,
139            "UdpSocket2ManagerWindows(%d)::Start() error starting worker\
140 threads",
141            _managerNumber);
142        _pCrit->Leave();
143        return false;
144    }
145    _pCrit->Leave();
146    return true;
147}
148
149bool UdpSocket2ManagerWindows::StartWorkerThreads()
150{
151    if(!_init)
152    {
153        _pCrit->Enter();
154
155        _ioCompletionHandle = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL,
156                                                     0, 0);
157        if(_ioCompletionHandle == NULL)
158        {
159            int32_t error = GetLastError();
160            WEBRTC_TRACE(
161                kTraceError,
162                kTraceTransport,
163                _id,
164                "UdpSocket2ManagerWindows(%d)::StartWorkerThreads()"
165                "_ioCompletioHandle == NULL: error:%d",
166                _managerNumber,error);
167            _pCrit->Leave();
168            return false;
169        }
170
171        // Create worker threads.
172        uint32_t i = 0;
173        bool error = false;
174        while(i < _numOfWorkThreads && !error)
175        {
176            UdpSocket2WorkerWindows* pWorker =
177                new UdpSocket2WorkerWindows(_ioCompletionHandle);
178            if(pWorker->Init() != 0)
179            {
180                error = true;
181                delete pWorker;
182                break;
183            }
184            _workerThreadsList.push_front(pWorker);
185            i++;
186        }
187        if(error)
188        {
189            WEBRTC_TRACE(
190                kTraceError,
191                kTraceTransport,
192                _id,
193                "UdpSocket2ManagerWindows(%d)::StartWorkerThreads() error "
194                "creating work threads",
195                _managerNumber);
196            // Delete worker threads.
197            for (WorkerList::iterator iter = _workerThreadsList.begin();
198                 iter != _workerThreadsList.end(); ++iter) {
199              delete *iter;
200            }
201            _workerThreadsList.clear();
202            _pCrit->Leave();
203            return false;
204        }
205        if(_ioContextPool.Init())
206        {
207            WEBRTC_TRACE(
208                kTraceError,
209                kTraceTransport,
210                _id,
211                "UdpSocket2ManagerWindows(%d)::StartWorkerThreads() error "
212                "initiating _ioContextPool",
213                _managerNumber);
214            _pCrit->Leave();
215            return false;
216        }
217        _init = true;
218        WEBRTC_TRACE(
219            kTraceDebug,
220            kTraceTransport,
221            _id,
222            "UdpSocket2ManagerWindows::StartWorkerThreads %d number of work "
223            "threads created and initialized",
224            _numOfWorkThreads);
225        _pCrit->Leave();
226    }
227    return true;
228}
229
230bool UdpSocket2ManagerWindows::Stop()
231{
232    WEBRTC_TRACE(kTraceDebug, kTraceTransport, _id,
233                 "UdpSocket2ManagerWindows(%d)::Stop()",_managerNumber);
234
235    if(!_init)
236    {
237        return false;
238    }
239    _pCrit->Enter();
240    _stopped = true;
241    if(_numActiveSockets)
242    {
243        WEBRTC_TRACE(
244            kTraceError,
245            kTraceTransport,
246            _id,
247            "UdpSocket2ManagerWindows(%d)::Stop() there is still active\
248 sockets",
249            _managerNumber);
250        _pCrit->Leave();
251        return false;
252    }
253    // No active sockets. Stop all worker threads.
254    bool result = StopWorkerThreads();
255    _pCrit->Leave();
256    return result;
257}
258
259bool UdpSocket2ManagerWindows::StopWorkerThreads()
260{
261    int32_t error = 0;
262    WEBRTC_TRACE(
263        kTraceDebug,
264        kTraceTransport,
265        _id,
266        "UdpSocket2ManagerWindows(%d)::StopWorkerThreads() Worker\
267 threadsStoped, numActicve Sockets=%d",
268        _managerNumber,
269        _numActiveSockets);
270
271    // Release all threads waiting for GetQueuedCompletionStatus(..).
272    if(_ioCompletionHandle)
273    {
274        uint32_t i = 0;
275        for(i = 0; i < _workerThreadsList.size(); i++)
276        {
277            PostQueuedCompletionStatus(_ioCompletionHandle, 0 ,0 , NULL);
278        }
279    }
280    for (WorkerList::iterator iter = _workerThreadsList.begin();
281         iter != _workerThreadsList.end(); ++iter) {
282        if((*iter)->Stop() == false)
283        {
284            error = -1;
285            WEBRTC_TRACE(kTraceWarning,  kTraceTransport, -1,
286                         "failed to stop worker thread");
287        }
288    }
289
290    if(error)
291    {
292        WEBRTC_TRACE(
293            kTraceError,
294            kTraceTransport,
295            _id,
296            "UdpSocket2ManagerWindows(%d)::StopWorkerThreads() error stopping\
297 worker threads",
298            _managerNumber);
299        return false;
300    }
301    return true;
302}
303
304bool UdpSocket2ManagerWindows::AddSocketPrv(UdpSocket2Windows* s)
305{
306    WEBRTC_TRACE(kTraceDebug, kTraceTransport, _id,
307                 "UdpSocket2ManagerWindows(%d)::AddSocketPrv()",_managerNumber);
308    if(!_init)
309    {
310        WEBRTC_TRACE(
311            kTraceError,
312            kTraceTransport,
313            _id,
314            "UdpSocket2ManagerWindows(%d)::AddSocketPrv() manager not\
315 initialized",
316            _managerNumber);
317        return false;
318    }
319    _pCrit->Enter();
320    if(s == NULL)
321    {
322        WEBRTC_TRACE(
323            kTraceError,
324            kTraceTransport,
325            _id,
326            "UdpSocket2ManagerWindows(%d)::AddSocketPrv() socket == NULL",
327            _managerNumber);
328        _pCrit->Leave();
329        return false;
330    }
331    if(s->GetFd() == NULL || s->GetFd() == INVALID_SOCKET)
332    {
333        WEBRTC_TRACE(
334            kTraceError,
335            kTraceTransport,
336            _id,
337            "UdpSocket2ManagerWindows(%d)::AddSocketPrv() socket->GetFd() ==\
338 %d",
339            _managerNumber,
340            (int32_t)s->GetFd());
341        _pCrit->Leave();
342        return false;
343
344    }
345    _ioCompletionHandle = CreateIoCompletionPort((HANDLE)s->GetFd(),
346                                                 _ioCompletionHandle,
347                                                 (ULONG_PTR)(s), 0);
348    if(_ioCompletionHandle == NULL)
349    {
350        int32_t error = GetLastError();
351        WEBRTC_TRACE(
352            kTraceError,
353            kTraceTransport,
354            _id,
355            "UdpSocket2ManagerWindows(%d)::AddSocketPrv() Error adding to IO\
356 completion: %d",
357            _managerNumber,
358            error);
359        _pCrit->Leave();
360        return false;
361    }
362    _numActiveSockets++;
363    _pCrit->Leave();
364    return true;
365}
366bool UdpSocket2ManagerWindows::RemoveSocketPrv(UdpSocket2Windows* s)
367{
368    if(!_init)
369    {
370        return false;
371    }
372    _pCrit->Enter();
373    _numActiveSockets--;
374    if(_numActiveSockets == 0)
375    {
376        _event->Set();
377    }
378    _pCrit->Leave();
379    return true;
380}
381
382PerIoContext* UdpSocket2ManagerWindows::PopIoContext()
383{
384    if(!_init)
385    {
386        return NULL;
387    }
388
389    PerIoContext* pIoC = NULL;
390    if(!_stopped)
391    {
392        pIoC = _ioContextPool.PopIoContext();
393    }else
394    {
395        WEBRTC_TRACE(
396            kTraceError,
397            kTraceTransport,
398            _id,
399            "UdpSocket2ManagerWindows(%d)::PopIoContext() Manager Not started",
400            _managerNumber);
401    }
402    return pIoC;
403}
404
405int32_t UdpSocket2ManagerWindows::PushIoContext(PerIoContext* pIoContext)
406{
407    return _ioContextPool.PushIoContext(pIoContext);
408}
409
410IoContextPool::IoContextPool()
411    : _pListHead(NULL),
412      _init(false),
413      _size(0),
414      _inUse(0)
415{
416}
417
418IoContextPool::~IoContextPool()
419{
420    Free();
421    assert(_size.Value() == 0);
422    AlignedFree(_pListHead);
423}
424
425int32_t IoContextPool::Init(uint32_t /*increaseSize*/)
426{
427    if(_init)
428    {
429        return 0;
430    }
431
432    _pListHead = (PSLIST_HEADER)AlignedMalloc(sizeof(SLIST_HEADER),
433                                              MEMORY_ALLOCATION_ALIGNMENT);
434    if(_pListHead == NULL)
435    {
436        return -1;
437    }
438    InitializeSListHead(_pListHead);
439    _init = true;
440    return 0;
441}
442
443PerIoContext* IoContextPool::PopIoContext()
444{
445    if(!_init)
446    {
447        return NULL;
448    }
449
450    PSLIST_ENTRY pListEntry = InterlockedPopEntrySList(_pListHead);
451    if(pListEntry == NULL)
452    {
453        IoContextPoolItem* item = (IoContextPoolItem*)
454            AlignedMalloc(
455                sizeof(IoContextPoolItem),
456                MEMORY_ALLOCATION_ALIGNMENT);
457        if(item == NULL)
458        {
459            return NULL;
460        }
461        memset(&item->payload.ioContext,0,sizeof(PerIoContext));
462        item->payload.base = item;
463        pListEntry = &(item->itemEntry);
464        ++_size;
465    }
466    ++_inUse;
467    return &((IoContextPoolItem*)pListEntry)->payload.ioContext;
468}
469
470int32_t IoContextPool::PushIoContext(PerIoContext* pIoContext)
471{
472    // TODO (hellner): Overlapped IO should be completed at this point. Perhaps
473    //                 add an assert?
474    const bool overlappedIOCompleted = HasOverlappedIoCompleted(
475        (LPOVERLAPPED)pIoContext);
476
477    IoContextPoolItem* item = ((IoContextPoolItemPayload*)pIoContext)->base;
478
479    const int32_t usedItems = --_inUse;
480    const int32_t totalItems = _size.Value();
481    const int32_t freeItems = totalItems - usedItems;
482    if(freeItems < 0)
483    {
484        assert(false);
485        AlignedFree(item);
486        return -1;
487    }
488    if((freeItems >= totalItems>>1) &&
489        overlappedIOCompleted)
490    {
491        AlignedFree(item);
492        --_size;
493        return 0;
494    }
495    InterlockedPushEntrySList(_pListHead, &(item->itemEntry));
496    return 0;
497}
498
499int32_t IoContextPool::Free()
500{
501    if(!_init)
502    {
503        return 0;
504    }
505
506    int32_t itemsFreed = 0;
507    PSLIST_ENTRY pListEntry = InterlockedPopEntrySList(_pListHead);
508    while(pListEntry != NULL)
509    {
510        IoContextPoolItem* item = ((IoContextPoolItem*)pListEntry);
511        AlignedFree(item);
512        --_size;
513        itemsFreed++;
514        pListEntry = InterlockedPopEntrySList(_pListHead);
515    }
516    return itemsFreed;
517}
518
519int32_t UdpSocket2WorkerWindows::_numOfWorkers = 0;
520
521UdpSocket2WorkerWindows::UdpSocket2WorkerWindows(HANDLE ioCompletionHandle)
522    : _ioCompletionHandle(ioCompletionHandle),
523      _pThread(Run, this, "UdpSocket2ManagerWindows_thread"),
524      _init(false) {
525    _workerNumber = _numOfWorkers++;
526    WEBRTC_TRACE(kTraceMemory,  kTraceTransport, -1,
527                 "UdpSocket2WorkerWindows created");
528}
529
530UdpSocket2WorkerWindows::~UdpSocket2WorkerWindows()
531{
532    WEBRTC_TRACE(kTraceMemory,  kTraceTransport, -1,
533                 "UdpSocket2WorkerWindows deleted");
534}
535
536bool UdpSocket2WorkerWindows::Start()
537{
538    WEBRTC_TRACE(kTraceStateInfo,  kTraceTransport, -1,
539                 "Start UdpSocket2WorkerWindows");
540    _pThread.Start();
541
542    _pThread.SetPriority(rtc::kRealtimePriority);
543    return true;
544}
545
546bool UdpSocket2WorkerWindows::Stop()
547{
548    WEBRTC_TRACE(kTraceStateInfo,  kTraceTransport, -1,
549                 "Stop UdpSocket2WorkerWindows");
550    _pThread.Stop();
551    return true;
552}
553
554int32_t UdpSocket2WorkerWindows::Init()
555{
556  _init = true;
557  return 0;
558}
559
560bool UdpSocket2WorkerWindows::Run(void* obj)
561{
562    UdpSocket2WorkerWindows* pWorker =
563        static_cast<UdpSocket2WorkerWindows*>(obj);
564    return pWorker->Process();
565}
566
567// Process should always return true. Stopping the worker threads is done in
568// the UdpSocket2ManagerWindows::StopWorkerThreads() function.
569bool UdpSocket2WorkerWindows::Process()
570{
571    int32_t success = 0;
572    DWORD ioSize = 0;
573    UdpSocket2Windows* pSocket = NULL;
574    PerIoContext* pIOContext = 0;
575    OVERLAPPED* pOverlapped = 0;
576    success = GetQueuedCompletionStatus(_ioCompletionHandle,
577                                        &ioSize,
578                                       (ULONG_PTR*)&pSocket, &pOverlapped, 200);
579
580    uint32_t error = 0;
581    if(!success)
582    {
583        error = GetLastError();
584        if(error == WAIT_TIMEOUT)
585        {
586            return true;
587        }
588        // This may happen if e.g. PostQueuedCompletionStatus() has been called.
589        // The IO context still needs to be reclaimed or re-used which is done
590        // in UdpSocket2Windows::IOCompleted(..).
591    }
592    if(pSocket == NULL)
593    {
594        WEBRTC_TRACE(
595            kTraceDebug,
596            kTraceTransport,
597            -1,
598            "UdpSocket2WorkerWindows(%d)::Process(), pSocket == 0, end thread",
599            _workerNumber);
600        return true;
601    }
602    pIOContext = (PerIoContext*)pOverlapped;
603    pSocket->IOCompleted(pIOContext,ioSize,error);
604    return true;
605}
606
607}  // namespace test
608}  // namespace webrtc
609