1/*
2 * Copyright 2011 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7#include <netdb.h>
8#include <unistd.h>
9#include <errno.h>
10#include <fcntl.h>
11#include "SkSockets.h"
12#include "SkData.h"
13
14SkSocket::SkSocket() {
15    fMaxfd = 0;
16    FD_ZERO(&fMasterSet);
17    fConnected = false;
18    fReady = false;
19    fReadSuspended = false;
20    fWriteSuspended = false;
21    fSockfd = this->createSocket();
22}
23
24SkSocket::~SkSocket() {
25    this->closeSocket(fSockfd);
26    shutdown(fSockfd, 2); //stop sending/receiving
27}
28
29int SkSocket::createSocket() {
30    int sockfd = socket(AF_INET, SOCK_STREAM, 0);
31    if (sockfd < 0) {
32        SkDebugf("ERROR opening socket\n");
33        return -1;
34    }
35    int reuse = 1;
36
37    if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &reuse, sizeof(int)) < 0) {
38        SkDebugf("error: %s\n", strerror(errno));
39        return -1;
40    }
41#ifdef NONBLOCKING_SOCKETS
42    this->setNonBlocking(sockfd);
43#endif
44    //SkDebugf("Opened fd:%d\n", sockfd);
45    fReady = true;
46    return sockfd;
47}
48
49void SkSocket::closeSocket(int sockfd) {
50    if (!fReady)
51        return;
52
53    close(sockfd);
54    //SkDebugf("Closed fd:%d\n", sockfd);
55
56    if (FD_ISSET(sockfd, &fMasterSet)) {
57        FD_CLR(sockfd, &fMasterSet);
58        if (sockfd >= fMaxfd) {
59            while (FD_ISSET(fMaxfd, &fMasterSet) == false && fMaxfd > 0)
60                fMaxfd -= 1;
61        }
62    }
63    if (0 == fMaxfd)
64        fConnected = false;
65}
66
67void SkSocket::onFailedConnection(int sockfd) {
68    this->closeSocket(sockfd);
69}
70
71void SkSocket::setNonBlocking(int sockfd) {
72    int flags = fcntl(sockfd, F_GETFL);
73    fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
74}
75
76void SkSocket::addToMasterSet(int sockfd) {
77    FD_SET(sockfd, &fMasterSet);
78    if (sockfd > fMaxfd)
79        fMaxfd = sockfd;
80}
81
82int SkSocket::readPacket(void (*onRead)(int, const void*, size_t, DataType,
83                                        void*), void* context) {
84    if (!fConnected || !fReady || NULL == onRead || NULL == context
85        || fReadSuspended)
86        return -1;
87
88    int totalBytesRead = 0;
89
90    char packet[PACKET_SIZE];
91    for (int i = 0; i <= fMaxfd; ++i) {
92        if (!FD_ISSET (i, &fMasterSet))
93            continue;
94
95        memset(packet, 0, PACKET_SIZE);
96        SkDynamicMemoryWStream stream;
97        int attempts = 0;
98        bool failure = false;
99        int bytesReadInTransfer = 0;
100        int bytesReadInPacket = 0;
101        header h;
102        h.done = false;
103        h.bytes = 0;
104        while (!h.done && fConnected && !failure) {
105            int retval = read(i, packet + bytesReadInPacket,
106                              PACKET_SIZE - bytesReadInPacket);
107
108            ++attempts;
109            if (retval < 0) {
110#ifdef NONBLOCKING_SOCKETS
111                if (errno == EWOULDBLOCK || errno == EAGAIN) {
112                    if (bytesReadInPacket > 0 || bytesReadInTransfer > 0)
113                        continue; //incomplete packet or frame, keep tring
114                    else
115                        break; //nothing to read
116                }
117#endif
118                //SkDebugf("Read() failed with error: %s\n", strerror(errno));
119                failure = true;
120                break;
121            }
122
123            if (retval == 0) {
124                //SkDebugf("Peer closed connection or connection failed\n");
125                failure = true;
126                break;
127            }
128
129            SkASSERT(retval > 0);
130            bytesReadInPacket += retval;
131            if (bytesReadInPacket < PACKET_SIZE) {
132                //SkDebugf("Read %d/%d\n", bytesReadInPacket, PACKET_SIZE);
133                continue; //incomplete packet, keep trying
134            }
135
136            SkASSERT((bytesReadInPacket == PACKET_SIZE) && !failure);
137            memcpy(&h.done, packet, sizeof(bool));
138            memcpy(&h.bytes, packet + sizeof(bool), sizeof(int));
139            memcpy(&h.type, packet + sizeof(bool) + sizeof(int), sizeof(DataType));
140            if (h.bytes > CONTENT_SIZE || h.bytes <= 0) {
141                //SkDebugf("bad packet\n");
142                failure = true;
143                break;
144            }
145            //SkDebugf("read packet(done:%d, bytes:%d) from fd:%d in %d tries\n",
146            //         h.done, h.bytes, fSockfd, attempts);
147            stream.write(packet + HEADER_SIZE, h.bytes);
148            bytesReadInPacket = 0;
149            attempts = 0;
150            bytesReadInTransfer += h.bytes;
151        }
152
153        if (failure) {
154            onRead(i, NULL, 0, h.type, context);
155            this->onFailedConnection(i);
156            continue;
157        }
158
159        if (bytesReadInTransfer > 0) {
160            SkData* data = stream.copyToData();
161            SkASSERT(data->size() == bytesReadInTransfer);
162            onRead(i, data->data(), data->size(), h.type, context);
163            data->unref();
164
165            totalBytesRead += bytesReadInTransfer;
166        }
167    }
168    return totalBytesRead;
169}
170
171int SkSocket::writePacket(void* data, size_t size, DataType type) {
172    if (size < 0|| NULL == data || !fConnected || !fReady || fWriteSuspended)
173        return -1;
174
175    int totalBytesWritten = 0;
176    header h;
177    char packet[PACKET_SIZE];
178    for (int i = 0; i <= fMaxfd; ++i) {
179        if (!FD_ISSET (i, &fMasterSet))
180            continue;
181
182        int bytesWrittenInTransfer = 0;
183        int bytesWrittenInPacket = 0;
184        int attempts = 0;
185        bool failure = false;
186        while (bytesWrittenInTransfer < size && fConnected && !failure) {
187            memset(packet, 0, PACKET_SIZE);
188            h.done = (size - bytesWrittenInTransfer <= CONTENT_SIZE);
189            h.bytes = (h.done) ? size - bytesWrittenInTransfer : CONTENT_SIZE;
190            h.type = type;
191            memcpy(packet, &h.done, sizeof(bool));
192            memcpy(packet + sizeof(bool), &h.bytes, sizeof(int));
193            memcpy(packet + sizeof(bool) + sizeof(int), &h.type, sizeof(DataType));
194            memcpy(packet + HEADER_SIZE, (char*)data + bytesWrittenInTransfer,
195                   h.bytes);
196
197            int retval = write(i, packet + bytesWrittenInPacket,
198                               PACKET_SIZE - bytesWrittenInPacket);
199            attempts++;
200
201            if (retval < 0) {
202                if (errno == EPIPE) {
203                    //SkDebugf("broken pipe, client closed connection");
204                    failure = true;
205                    break;
206                }
207#ifdef NONBLOCKING_SOCKETS
208                else if (errno == EWOULDBLOCK || errno == EAGAIN) {
209                    if (bytesWrittenInPacket > 0 || bytesWrittenInTransfer > 0)
210                        continue; //incomplete packet or frame, keep trying
211                    else
212                        break; //client not available, skip current transfer
213                }
214#endif
215                else {
216                    //SkDebugf("write(%d) failed with error:%s\n", i,
217                    //         strerror(errno));
218                    failure = true;
219                    break;
220                }
221            }
222
223            bytesWrittenInPacket += retval;
224            if (bytesWrittenInPacket < PACKET_SIZE)
225                continue; //incomplete packet, keep trying
226
227            SkASSERT(bytesWrittenInPacket == PACKET_SIZE);
228            //SkDebugf("wrote to packet(done:%d, bytes:%d) to fd:%d in %d tries\n",
229            //         h.done, h.bytes, i, attempts);
230            bytesWrittenInTransfer += h.bytes;
231            bytesWrittenInPacket = 0;
232            attempts = 0;
233        }
234
235        if (failure)
236            this->onFailedConnection(i);
237
238        totalBytesWritten += bytesWrittenInTransfer;
239    }
240    return totalBytesWritten;
241}
242
243////////////////////////////////////////////////////////////////////////////////
244SkTCPServer::SkTCPServer(int port) {
245    sockaddr_in serverAddr;
246    serverAddr.sin_family = AF_INET;
247    serverAddr.sin_addr.s_addr = INADDR_ANY;
248    serverAddr.sin_port = htons(port);
249
250    if (bind(fSockfd, (sockaddr*)&serverAddr, sizeof(serverAddr)) < 0) {
251        SkDebugf("ERROR on binding: %s\n", strerror(errno));
252        fReady = false;
253    }
254}
255
256SkTCPServer::~SkTCPServer() {
257    this->disconnectAll();
258}
259
260int SkTCPServer::acceptConnections() {
261    if (!fReady)
262        return -1;
263
264    listen(fSockfd, MAX_WAITING_CLIENTS);
265    int newfd;
266    for (int i = 0; i < MAX_WAITING_CLIENTS; ++i) {
267#ifdef NONBLOCKING_SOCKETS
268        fd_set workingSet;
269        FD_ZERO(&workingSet);
270        FD_SET(fSockfd, &workingSet);
271        timeval timeout;
272        timeout.tv_sec  = 0;
273        timeout.tv_usec = 0;
274        int sel = select(fSockfd + 1, &workingSet, NULL, NULL, &timeout);
275        if (sel < 0) {
276            SkDebugf("select() failed with error %s\n", strerror(errno));
277            continue;
278        }
279        if (sel == 0) //select() timed out
280            continue;
281#endif
282        sockaddr_in clientAddr;
283        socklen_t clientLen = sizeof(clientAddr);
284        newfd = accept(fSockfd, (struct sockaddr*)&clientAddr, &clientLen);
285        if (newfd< 0) {
286            SkDebugf("accept() failed with error %s\n", strerror(errno));
287            continue;
288        }
289        SkDebugf("New incoming connection - %d\n", newfd);
290        fConnected = true;
291#ifdef NONBLOCKING_SOCKETS
292        this->setNonBlocking(newfd);
293#endif
294        this->addToMasterSet(newfd);
295    }
296    return 0;
297}
298
299
300int SkTCPServer::disconnectAll() {
301    if (!fConnected || !fReady)
302        return -1;
303    for (int i = 0; i <= fMaxfd; ++i) {
304        if (FD_ISSET(i, &fMasterSet))
305            this->closeSocket(i);
306    }
307    fConnected = false;
308    return 0;
309}
310
311////////////////////////////////////////////////////////////////////////////////
312SkTCPClient::SkTCPClient(const char* hostname, int port) {
313    //Add fSockfd since the client will be using it to read/write
314    this->addToMasterSet(fSockfd);
315
316    hostent* server = gethostbyname(hostname);
317    if (server) {
318        fServerAddr.sin_family = AF_INET;
319        memcpy((char*)&fServerAddr.sin_addr.s_addr, (char*)server->h_addr,
320               server->h_length);
321        fServerAddr.sin_port = htons(port);
322    }
323    else {
324        //SkDebugf("ERROR, no such host\n");
325        fReady = false;
326    }
327}
328
329void SkTCPClient::onFailedConnection(int sockfd) { //cleanup and recreate socket
330    SkASSERT(sockfd == fSockfd);
331    this->closeSocket(fSockfd);
332    fSockfd = this->createSocket();
333    //Add fSockfd since the client will be using it to read/write
334    this->addToMasterSet(fSockfd);
335}
336
337int SkTCPClient::connectToServer() {
338    if (!fReady)
339        return -1;
340    if (fConnected)
341        return 0;
342
343    int conn = connect(fSockfd, (sockaddr*)&fServerAddr, sizeof(fServerAddr));
344    if (conn < 0) {
345#ifdef NONBLOCKING_SOCKETS
346        if (errno == EINPROGRESS || errno == EALREADY)
347            return conn;
348#endif
349        if (errno != EISCONN) {
350            //SkDebugf("error: %s\n", strerror(errno));
351            this->onFailedConnection(fSockfd);
352            return conn;
353        }
354    }
355    fConnected = true;
356    SkDebugf("Succesfully reached server\n");
357    return 0;
358}
359