1/* 2 * Copyright 2004 The WebRTC Project Authors. All rights reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11#include "webrtc/base/asynctcpsocket.h" 12 13#include <string.h> 14 15#include "webrtc/base/byteorder.h" 16#include "webrtc/base/common.h" 17#include "webrtc/base/logging.h" 18 19#if defined(WEBRTC_POSIX) 20#include <errno.h> 21#endif // WEBRTC_POSIX 22 23namespace rtc { 24 25static const size_t kMaxPacketSize = 64 * 1024; 26 27typedef uint16 PacketLength; 28static const size_t kPacketLenSize = sizeof(PacketLength); 29 30static const size_t kBufSize = kMaxPacketSize + kPacketLenSize; 31 32static const int kListenBacklog = 5; 33 34// Binds and connects |socket| 35AsyncSocket* AsyncTCPSocketBase::ConnectSocket( 36 rtc::AsyncSocket* socket, 37 const rtc::SocketAddress& bind_address, 38 const rtc::SocketAddress& remote_address) { 39 rtc::scoped_ptr<rtc::AsyncSocket> owned_socket(socket); 40 if (socket->Bind(bind_address) < 0) { 41 LOG(LS_ERROR) << "Bind() failed with error " << socket->GetError(); 42 return NULL; 43 } 44 if (socket->Connect(remote_address) < 0) { 45 LOG(LS_ERROR) << "Connect() failed with error " << socket->GetError(); 46 return NULL; 47 } 48 return owned_socket.release(); 49} 50 51AsyncTCPSocketBase::AsyncTCPSocketBase(AsyncSocket* socket, bool listen, 52 size_t max_packet_size) 53 : socket_(socket), 54 listen_(listen), 55 insize_(max_packet_size), 56 inpos_(0), 57 outsize_(max_packet_size), 58 outpos_(0) { 59 inbuf_ = new char[insize_]; 60 outbuf_ = new char[outsize_]; 61 62 ASSERT(socket_.get() != NULL); 63 socket_->SignalConnectEvent.connect( 64 this, &AsyncTCPSocketBase::OnConnectEvent); 65 socket_->SignalReadEvent.connect(this, &AsyncTCPSocketBase::OnReadEvent); 66 socket_->SignalWriteEvent.connect(this, &AsyncTCPSocketBase::OnWriteEvent); 67 socket_->SignalCloseEvent.connect(this, &AsyncTCPSocketBase::OnCloseEvent); 68 69 if (listen_) { 70 if (socket_->Listen(kListenBacklog) < 0) { 71 LOG(LS_ERROR) << "Listen() failed with error " << socket_->GetError(); 72 } 73 } 74} 75 76AsyncTCPSocketBase::~AsyncTCPSocketBase() { 77 delete [] inbuf_; 78 delete [] outbuf_; 79} 80 81SocketAddress AsyncTCPSocketBase::GetLocalAddress() const { 82 return socket_->GetLocalAddress(); 83} 84 85SocketAddress AsyncTCPSocketBase::GetRemoteAddress() const { 86 return socket_->GetRemoteAddress(); 87} 88 89int AsyncTCPSocketBase::Close() { 90 return socket_->Close(); 91} 92 93AsyncTCPSocket::State AsyncTCPSocketBase::GetState() const { 94 switch (socket_->GetState()) { 95 case Socket::CS_CLOSED: 96 return STATE_CLOSED; 97 case Socket::CS_CONNECTING: 98 if (listen_) { 99 return STATE_BOUND; 100 } else { 101 return STATE_CONNECTING; 102 } 103 case Socket::CS_CONNECTED: 104 return STATE_CONNECTED; 105 default: 106 ASSERT(false); 107 return STATE_CLOSED; 108 } 109} 110 111int AsyncTCPSocketBase::GetOption(Socket::Option opt, int* value) { 112 return socket_->GetOption(opt, value); 113} 114 115int AsyncTCPSocketBase::SetOption(Socket::Option opt, int value) { 116 return socket_->SetOption(opt, value); 117} 118 119int AsyncTCPSocketBase::GetError() const { 120 return socket_->GetError(); 121} 122 123void AsyncTCPSocketBase::SetError(int error) { 124 return socket_->SetError(error); 125} 126 127int AsyncTCPSocketBase::SendTo(const void *pv, size_t cb, 128 const SocketAddress& addr, 129 const rtc::PacketOptions& options) { 130 if (addr == GetRemoteAddress()) 131 return Send(pv, cb, options); 132 133 ASSERT(false); 134 socket_->SetError(ENOTCONN); 135 return -1; 136} 137 138int AsyncTCPSocketBase::SendRaw(const void * pv, size_t cb) { 139 if (outpos_ + cb > outsize_) { 140 socket_->SetError(EMSGSIZE); 141 return -1; 142 } 143 144 memcpy(outbuf_ + outpos_, pv, cb); 145 outpos_ += cb; 146 147 return FlushOutBuffer(); 148} 149 150int AsyncTCPSocketBase::FlushOutBuffer() { 151 int res = socket_->Send(outbuf_, outpos_); 152 if (res <= 0) { 153 return res; 154 } 155 if (static_cast<size_t>(res) <= outpos_) { 156 outpos_ -= res; 157 } else { 158 ASSERT(false); 159 return -1; 160 } 161 if (outpos_ > 0) { 162 memmove(outbuf_, outbuf_ + res, outpos_); 163 } 164 return res; 165} 166 167void AsyncTCPSocketBase::AppendToOutBuffer(const void* pv, size_t cb) { 168 ASSERT(outpos_ + cb < outsize_); 169 memcpy(outbuf_ + outpos_, pv, cb); 170 outpos_ += cb; 171} 172 173void AsyncTCPSocketBase::OnConnectEvent(AsyncSocket* socket) { 174 SignalConnect(this); 175} 176 177void AsyncTCPSocketBase::OnReadEvent(AsyncSocket* socket) { 178 ASSERT(socket_.get() == socket); 179 180 if (listen_) { 181 rtc::SocketAddress address; 182 rtc::AsyncSocket* new_socket = socket->Accept(&address); 183 if (!new_socket) { 184 // TODO: Do something better like forwarding the error 185 // to the user. 186 LOG(LS_ERROR) << "TCP accept failed with error " << socket_->GetError(); 187 return; 188 } 189 190 HandleIncomingConnection(new_socket); 191 192 // Prime a read event in case data is waiting. 193 new_socket->SignalReadEvent(new_socket); 194 } else { 195 int len = socket_->Recv(inbuf_ + inpos_, insize_ - inpos_); 196 if (len < 0) { 197 // TODO: Do something better like forwarding the error to the user. 198 if (!socket_->IsBlocking()) { 199 LOG(LS_ERROR) << "Recv() returned error: " << socket_->GetError(); 200 } 201 return; 202 } 203 204 inpos_ += len; 205 206 ProcessInput(inbuf_, &inpos_); 207 208 if (inpos_ >= insize_) { 209 LOG(LS_ERROR) << "input buffer overflow"; 210 ASSERT(false); 211 inpos_ = 0; 212 } 213 } 214} 215 216void AsyncTCPSocketBase::OnWriteEvent(AsyncSocket* socket) { 217 ASSERT(socket_.get() == socket); 218 219 if (outpos_ > 0) { 220 FlushOutBuffer(); 221 } 222 223 if (outpos_ == 0) { 224 SignalReadyToSend(this); 225 } 226} 227 228void AsyncTCPSocketBase::OnCloseEvent(AsyncSocket* socket, int error) { 229 SignalClose(this, error); 230} 231 232// AsyncTCPSocket 233// Binds and connects |socket| and creates AsyncTCPSocket for 234// it. Takes ownership of |socket|. Returns NULL if bind() or 235// connect() fail (|socket| is destroyed in that case). 236AsyncTCPSocket* AsyncTCPSocket::Create( 237 AsyncSocket* socket, 238 const SocketAddress& bind_address, 239 const SocketAddress& remote_address) { 240 return new AsyncTCPSocket(AsyncTCPSocketBase::ConnectSocket( 241 socket, bind_address, remote_address), false); 242} 243 244AsyncTCPSocket::AsyncTCPSocket(AsyncSocket* socket, bool listen) 245 : AsyncTCPSocketBase(socket, listen, kBufSize) { 246} 247 248int AsyncTCPSocket::Send(const void *pv, size_t cb, 249 const rtc::PacketOptions& options) { 250 if (cb > kBufSize) { 251 SetError(EMSGSIZE); 252 return -1; 253 } 254 255 // If we are blocking on send, then silently drop this packet 256 if (!IsOutBufferEmpty()) 257 return static_cast<int>(cb); 258 259 PacketLength pkt_len = HostToNetwork16(static_cast<PacketLength>(cb)); 260 AppendToOutBuffer(&pkt_len, kPacketLenSize); 261 AppendToOutBuffer(pv, cb); 262 263 int res = FlushOutBuffer(); 264 if (res <= 0) { 265 // drop packet if we made no progress 266 ClearOutBuffer(); 267 return res; 268 } 269 270 // We claim to have sent the whole thing, even if we only sent partial 271 return static_cast<int>(cb); 272} 273 274void AsyncTCPSocket::ProcessInput(char * data, size_t* len) { 275 SocketAddress remote_addr(GetRemoteAddress()); 276 277 while (true) { 278 if (*len < kPacketLenSize) 279 return; 280 281 PacketLength pkt_len = rtc::GetBE16(data); 282 if (*len < kPacketLenSize + pkt_len) 283 return; 284 285 SignalReadPacket(this, data + kPacketLenSize, pkt_len, remote_addr, 286 CreatePacketTime(0)); 287 288 *len -= kPacketLenSize + pkt_len; 289 if (*len > 0) { 290 memmove(data, data + kPacketLenSize + pkt_len, *len); 291 } 292 } 293} 294 295void AsyncTCPSocket::HandleIncomingConnection(AsyncSocket* socket) { 296 SignalNewConnection(this, new AsyncTCPSocket(socket, false)); 297} 298 299} // namespace rtc 300