1/* 2 * Copyright 2004 The WebRTC Project Authors. All rights reserved. 3 * 4 * Use of this source code is governed by a BSD-style license 5 * that can be found in the LICENSE file in the root of the source 6 * tree. An additional intellectual property rights grant can be found 7 * in the file PATENTS. All contributing project authors may 8 * be found in the AUTHORS file in the root of the source tree. 9 */ 10 11#include "webrtc/base/asynctcpsocket.h" 12 13#include <string.h> 14 15#include "webrtc/base/byteorder.h" 16#include "webrtc/base/common.h" 17#include "webrtc/base/logging.h" 18 19#if defined(WEBRTC_POSIX) 20#include <errno.h> 21#endif // WEBRTC_POSIX 22 23namespace rtc { 24 25static const size_t kMaxPacketSize = 64 * 1024; 26 27typedef uint16_t PacketLength; 28static const size_t kPacketLenSize = sizeof(PacketLength); 29 30static const size_t kBufSize = kMaxPacketSize + kPacketLenSize; 31 32static const int kListenBacklog = 5; 33 34// Binds and connects |socket| 35AsyncSocket* AsyncTCPSocketBase::ConnectSocket( 36 rtc::AsyncSocket* socket, 37 const rtc::SocketAddress& bind_address, 38 const rtc::SocketAddress& remote_address) { 39 rtc::scoped_ptr<rtc::AsyncSocket> owned_socket(socket); 40 if (socket->Bind(bind_address) < 0) { 41 LOG(LS_ERROR) << "Bind() failed with error " << socket->GetError(); 42 return NULL; 43 } 44 if (socket->Connect(remote_address) < 0) { 45 LOG(LS_ERROR) << "Connect() failed with error " << socket->GetError(); 46 return NULL; 47 } 48 return owned_socket.release(); 49} 50 51AsyncTCPSocketBase::AsyncTCPSocketBase(AsyncSocket* socket, bool listen, 52 size_t max_packet_size) 53 : socket_(socket), 54 listen_(listen), 55 insize_(max_packet_size), 56 inpos_(0), 57 outsize_(max_packet_size), 58 outpos_(0) { 59 inbuf_ = new char[insize_]; 60 outbuf_ = new char[outsize_]; 61 62 ASSERT(socket_.get() != NULL); 63 socket_->SignalConnectEvent.connect( 64 this, &AsyncTCPSocketBase::OnConnectEvent); 65 socket_->SignalReadEvent.connect(this, &AsyncTCPSocketBase::OnReadEvent); 66 socket_->SignalWriteEvent.connect(this, &AsyncTCPSocketBase::OnWriteEvent); 67 socket_->SignalCloseEvent.connect(this, &AsyncTCPSocketBase::OnCloseEvent); 68 69 if (listen_) { 70 if (socket_->Listen(kListenBacklog) < 0) { 71 LOG(LS_ERROR) << "Listen() failed with error " << socket_->GetError(); 72 } 73 } 74} 75 76AsyncTCPSocketBase::~AsyncTCPSocketBase() { 77 delete [] inbuf_; 78 delete [] outbuf_; 79} 80 81SocketAddress AsyncTCPSocketBase::GetLocalAddress() const { 82 return socket_->GetLocalAddress(); 83} 84 85SocketAddress AsyncTCPSocketBase::GetRemoteAddress() const { 86 return socket_->GetRemoteAddress(); 87} 88 89int AsyncTCPSocketBase::Close() { 90 return socket_->Close(); 91} 92 93AsyncTCPSocket::State AsyncTCPSocketBase::GetState() const { 94 switch (socket_->GetState()) { 95 case Socket::CS_CLOSED: 96 return STATE_CLOSED; 97 case Socket::CS_CONNECTING: 98 if (listen_) { 99 return STATE_BOUND; 100 } else { 101 return STATE_CONNECTING; 102 } 103 case Socket::CS_CONNECTED: 104 return STATE_CONNECTED; 105 default: 106 ASSERT(false); 107 return STATE_CLOSED; 108 } 109} 110 111int AsyncTCPSocketBase::GetOption(Socket::Option opt, int* value) { 112 return socket_->GetOption(opt, value); 113} 114 115int AsyncTCPSocketBase::SetOption(Socket::Option opt, int value) { 116 return socket_->SetOption(opt, value); 117} 118 119int AsyncTCPSocketBase::GetError() const { 120 return socket_->GetError(); 121} 122 123void AsyncTCPSocketBase::SetError(int error) { 124 return socket_->SetError(error); 125} 126 127int AsyncTCPSocketBase::SendTo(const void *pv, size_t cb, 128 const SocketAddress& addr, 129 const rtc::PacketOptions& options) { 130 const SocketAddress& remote_address = GetRemoteAddress(); 131 if (addr == remote_address) 132 return Send(pv, cb, options); 133 // Remote address may be empty if there is a sudden network change. 134 ASSERT(remote_address.IsNil()); 135 socket_->SetError(ENOTCONN); 136 return -1; 137} 138 139int AsyncTCPSocketBase::SendRaw(const void * pv, size_t cb) { 140 if (outpos_ + cb > outsize_) { 141 socket_->SetError(EMSGSIZE); 142 return -1; 143 } 144 145 memcpy(outbuf_ + outpos_, pv, cb); 146 outpos_ += cb; 147 148 return FlushOutBuffer(); 149} 150 151int AsyncTCPSocketBase::FlushOutBuffer() { 152 int res = socket_->Send(outbuf_, outpos_); 153 if (res <= 0) { 154 return res; 155 } 156 if (static_cast<size_t>(res) <= outpos_) { 157 outpos_ -= res; 158 } else { 159 ASSERT(false); 160 return -1; 161 } 162 if (outpos_ > 0) { 163 memmove(outbuf_, outbuf_ + res, outpos_); 164 } 165 return res; 166} 167 168void AsyncTCPSocketBase::AppendToOutBuffer(const void* pv, size_t cb) { 169 ASSERT(outpos_ + cb < outsize_); 170 memcpy(outbuf_ + outpos_, pv, cb); 171 outpos_ += cb; 172} 173 174void AsyncTCPSocketBase::OnConnectEvent(AsyncSocket* socket) { 175 SignalConnect(this); 176} 177 178void AsyncTCPSocketBase::OnReadEvent(AsyncSocket* socket) { 179 ASSERT(socket_.get() == socket); 180 181 if (listen_) { 182 rtc::SocketAddress address; 183 rtc::AsyncSocket* new_socket = socket->Accept(&address); 184 if (!new_socket) { 185 // TODO: Do something better like forwarding the error 186 // to the user. 187 LOG(LS_ERROR) << "TCP accept failed with error " << socket_->GetError(); 188 return; 189 } 190 191 HandleIncomingConnection(new_socket); 192 193 // Prime a read event in case data is waiting. 194 new_socket->SignalReadEvent(new_socket); 195 } else { 196 int len = socket_->Recv(inbuf_ + inpos_, insize_ - inpos_); 197 if (len < 0) { 198 // TODO: Do something better like forwarding the error to the user. 199 if (!socket_->IsBlocking()) { 200 LOG(LS_ERROR) << "Recv() returned error: " << socket_->GetError(); 201 } 202 return; 203 } 204 205 inpos_ += len; 206 207 ProcessInput(inbuf_, &inpos_); 208 209 if (inpos_ >= insize_) { 210 LOG(LS_ERROR) << "input buffer overflow"; 211 ASSERT(false); 212 inpos_ = 0; 213 } 214 } 215} 216 217void AsyncTCPSocketBase::OnWriteEvent(AsyncSocket* socket) { 218 ASSERT(socket_.get() == socket); 219 220 if (outpos_ > 0) { 221 FlushOutBuffer(); 222 } 223 224 if (outpos_ == 0) { 225 SignalReadyToSend(this); 226 } 227} 228 229void AsyncTCPSocketBase::OnCloseEvent(AsyncSocket* socket, int error) { 230 SignalClose(this, error); 231} 232 233// AsyncTCPSocket 234// Binds and connects |socket| and creates AsyncTCPSocket for 235// it. Takes ownership of |socket|. Returns NULL if bind() or 236// connect() fail (|socket| is destroyed in that case). 237AsyncTCPSocket* AsyncTCPSocket::Create( 238 AsyncSocket* socket, 239 const SocketAddress& bind_address, 240 const SocketAddress& remote_address) { 241 return new AsyncTCPSocket(AsyncTCPSocketBase::ConnectSocket( 242 socket, bind_address, remote_address), false); 243} 244 245AsyncTCPSocket::AsyncTCPSocket(AsyncSocket* socket, bool listen) 246 : AsyncTCPSocketBase(socket, listen, kBufSize) { 247} 248 249int AsyncTCPSocket::Send(const void *pv, size_t cb, 250 const rtc::PacketOptions& options) { 251 if (cb > kBufSize) { 252 SetError(EMSGSIZE); 253 return -1; 254 } 255 256 // If we are blocking on send, then silently drop this packet 257 if (!IsOutBufferEmpty()) 258 return static_cast<int>(cb); 259 260 PacketLength pkt_len = HostToNetwork16(static_cast<PacketLength>(cb)); 261 AppendToOutBuffer(&pkt_len, kPacketLenSize); 262 AppendToOutBuffer(pv, cb); 263 264 int res = FlushOutBuffer(); 265 if (res <= 0) { 266 // drop packet if we made no progress 267 ClearOutBuffer(); 268 return res; 269 } 270 271 rtc::SentPacket sent_packet(options.packet_id, rtc::Time()); 272 SignalSentPacket(this, sent_packet); 273 274 // We claim to have sent the whole thing, even if we only sent partial 275 return static_cast<int>(cb); 276} 277 278void AsyncTCPSocket::ProcessInput(char * data, size_t* len) { 279 SocketAddress remote_addr(GetRemoteAddress()); 280 281 while (true) { 282 if (*len < kPacketLenSize) 283 return; 284 285 PacketLength pkt_len = rtc::GetBE16(data); 286 if (*len < kPacketLenSize + pkt_len) 287 return; 288 289 SignalReadPacket(this, data + kPacketLenSize, pkt_len, remote_addr, 290 CreatePacketTime(0)); 291 292 *len -= kPacketLenSize + pkt_len; 293 if (*len > 0) { 294 memmove(data, data + kPacketLenSize + pkt_len, *len); 295 } 296 } 297} 298 299void AsyncTCPSocket::HandleIncomingConnection(AsyncSocket* socket) { 300 SignalNewConnection(this, new AsyncTCPSocket(socket, false)); 301} 302 303} // namespace rtc 304