1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "cloud_print/gcp20/prototype/dns_sd_server.h" 6 7#include <string.h> 8 9#include "base/basictypes.h" 10#include "base/bind.h" 11#include "base/command_line.h" 12#include "base/message_loop/message_loop.h" 13#include "base/strings/stringprintf.h" 14#include "cloud_print/gcp20/prototype/dns_packet_parser.h" 15#include "cloud_print/gcp20/prototype/dns_response_builder.h" 16#include "net/base/big_endian.h" 17#include "net/base/dns_util.h" 18#include "net/base/net_errors.h" 19#include "net/base/net_util.h" 20#include "net/dns/dns_protocol.h" 21 22namespace { 23 24const char kDefaultIpAddressMulticast[] = "224.0.0.251"; 25const uint16 kDefaultPortMulticast = 5353; 26 27const double kTimeToNextAnnouncement = 0.8; // relatively to TTL 28const int kDnsBufSize = 65537; 29 30const uint16 kSrvPriority = 0; 31const uint16 kSrvWeight = 0; 32 33void DoNothingAfterSendToSocket(int /*val*/) { 34 NOTREACHED(); 35 // TODO(maksymb): Delete this function once empty callback for SendTo() method 36 // will be allowed. 37} 38 39} // namespace 40 41DnsSdServer::DnsSdServer() 42 : recv_buf_(new net::IOBufferWithSize(kDnsBufSize)), 43 full_ttl_(0) { 44} 45 46DnsSdServer::~DnsSdServer() { 47 Shutdown(); 48} 49 50bool DnsSdServer::Start(const ServiceParameters& serv_params, uint32 full_ttl, 51 const std::vector<std::string>& metadata) { 52 if (IsOnline()) 53 return true; 54 55 if (!CreateSocket()) 56 return false; 57 58 // Initializing server with parameters from arguments. 59 serv_params_ = serv_params; 60 full_ttl_ = full_ttl; 61 metadata_ = metadata; 62 63 LOG(INFO) << "DNS server started"; 64 LOG(WARNING) << "DNS server does not support probing"; 65 66 SendAnnouncement(full_ttl_); 67 base::MessageLoop::current()->PostTask( 68 FROM_HERE, 69 base::Bind(&DnsSdServer::OnDatagramReceived, AsWeakPtr())); 70 71 return true; 72} 73 74void DnsSdServer::Update() { 75 if (!IsOnline()) 76 return; 77 78 SendAnnouncement(full_ttl_); 79} 80 81void DnsSdServer::Shutdown() { 82 if (!IsOnline()) 83 return; 84 85 SendAnnouncement(0); // TTL is 0 86 socket_->Close(); 87 socket_.reset(NULL); 88 LOG(INFO) << "DNS server stopped"; 89} 90 91void DnsSdServer::UpdateMetadata(const std::vector<std::string>& metadata) { 92 if (!IsOnline()) 93 return; 94 95 metadata_ = metadata; 96 97 // TODO(maksymb): If less than 20% of full TTL left before next announcement 98 // then send it now. 99 100 uint32 current_ttl = GetCurrentTLL(); 101 if (!CommandLine::ForCurrentProcess()->HasSwitch("no-announcement")) { 102 DnsResponseBuilder builder(current_ttl); 103 104 builder.AppendTxt(serv_params_.service_name_, current_ttl, metadata_); 105 scoped_refptr<net::IOBufferWithSize> buffer(builder.Build()); 106 107 DCHECK(buffer.get() != NULL); 108 109 socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_, 110 base::Bind(&DoNothingAfterSendToSocket)); 111 } 112} 113 114bool DnsSdServer::CreateSocket() { 115 net::IPAddressNumber local_ip_any; 116 bool success = net::ParseIPLiteralToNumber("0.0.0.0", &local_ip_any); 117 DCHECK(success); 118 119 net::IPAddressNumber multicast_dns_ip_address; 120 success = net::ParseIPLiteralToNumber(kDefaultIpAddressMulticast, 121 &multicast_dns_ip_address); 122 DCHECK(success); 123 124 125 socket_.reset(new net::UDPSocket(net::DatagramSocket::DEFAULT_BIND, 126 net::RandIntCallback(), NULL, 127 net::NetLog::Source())); 128 129 net::IPEndPoint local_address = net::IPEndPoint(local_ip_any, 130 kDefaultPortMulticast); 131 multicast_address_ = net::IPEndPoint(multicast_dns_ip_address, 132 kDefaultPortMulticast); 133 134 socket_->AllowAddressReuse(); 135 136 int status = socket_->Bind(local_address); 137 if (status < 0) 138 return false; 139 140 socket_->SetMulticastLoopbackMode(false); 141 status = socket_->JoinGroup(multicast_dns_ip_address); 142 143 if (status < 0) 144 return false; 145 146 DCHECK(socket_->is_connected()); 147 148 return true; 149} 150 151void DnsSdServer::ProcessMessage(int len, net::IOBufferWithSize* buf) { 152 VLOG(1) << "Received new message with length: " << len; 153 154 // Parse the message. 155 DnsPacketParser parser(buf->data(), len); 156 157 if (!parser.IsValid()) // Was unable to parse header. 158 return; 159 160 // TODO(maksymb): Handle truncated messages. 161 if (parser.header().flags & net::dns_protocol::kFlagResponse) // Not a query. 162 return; 163 164 DnsResponseBuilder builder(parser.header().id); 165 166 uint32 current_ttl = GetCurrentTLL(); 167 168 DnsQueryRecord query; 169 // TODO(maksymb): Check known answers. 170 for (int query_idx = 0; query_idx < parser.header().qdcount; ++query_idx) { 171 bool success = parser.ReadRecord(&query); 172 if (success) { 173 ProccessQuery(current_ttl, query, &builder); 174 } else { // if (success) 175 LOG(INFO) << "Broken package"; 176 break; 177 } 178 } 179 180 scoped_refptr<net::IOBufferWithSize> buffer(builder.Build()); 181 if (buffer.get() == NULL) 182 return; // No answers. 183 184 VLOG(1) << "Current TTL for respond: " << current_ttl; 185 186 bool unicast_respond = 187 CommandLine::ForCurrentProcess()->HasSwitch("unicast-respond"); 188 socket_->SendTo(buffer.get(), buffer.get()->size(), 189 unicast_respond ? recv_address_ : multicast_address_, 190 base::Bind(&DoNothingAfterSendToSocket)); 191 VLOG(1) << "Responded to " 192 << (unicast_respond ? recv_address_ : multicast_address_).ToString(); 193} 194 195void DnsSdServer::ProccessQuery(uint32 current_ttl, const DnsQueryRecord& query, 196 DnsResponseBuilder* builder) const { 197 std::string log; 198 bool responded = false; 199 switch (query.qtype) { 200 // TODO(maksymb): Add IPv6 support. 201 case net::dns_protocol::kTypePTR: 202 log = "Processing PTR query"; 203 if (query.qname == serv_params_.service_type_) { 204 builder->AppendPtr(serv_params_.service_type_, current_ttl, 205 serv_params_.service_name_); 206 responded = true; 207 } 208 break; 209 case net::dns_protocol::kTypeSRV: 210 log = "Processing SRV query"; 211 if (query.qname == serv_params_.service_name_) { 212 builder->AppendSrv(serv_params_.service_name_, current_ttl, 213 kSrvPriority, kSrvWeight, serv_params_.http_port_, 214 serv_params_.service_domain_name_); 215 responded = true; 216 } 217 break; 218 case net::dns_protocol::kTypeA: 219 log = "Processing A query"; 220 if (query.qname == serv_params_.service_domain_name_) { 221 builder->AppendA(serv_params_.service_domain_name_, current_ttl, 222 serv_params_.http_ipv4_); 223 responded = true; 224 } 225 break; 226 case net::dns_protocol::kTypeTXT: 227 log = "Processing TXT query"; 228 if (query.qname == serv_params_.service_name_) { 229 builder->AppendTxt(serv_params_.service_name_, current_ttl, metadata_); 230 responded = true; 231 } 232 break; 233 default: 234 base::SStringPrintf(&log, "Unknown query type (%d)", query.qtype); 235 } 236 log += responded ? ": responded" : ": ignored"; 237 VLOG(1) << log; 238} 239 240void DnsSdServer::DoLoop(int rv) { 241 // TODO(maksymb): Check what happened if buffer will be overflowed 242 do { 243 if (rv > 0) 244 ProcessMessage(rv, recv_buf_.get()); 245 rv = socket_->RecvFrom(recv_buf_.get(), recv_buf_->size(), &recv_address_, 246 base::Bind(&DnsSdServer::DoLoop, AsWeakPtr())); 247 } while (rv > 0); 248 249 // TODO(maksymb): Add handler for errors 250 DCHECK(rv == net::ERR_IO_PENDING); 251} 252 253void DnsSdServer::OnDatagramReceived() { 254 DoLoop(0); 255} 256 257void DnsSdServer::SendAnnouncement(uint32 ttl) { 258 if (!CommandLine::ForCurrentProcess()->HasSwitch("no-announcement")) { 259 DnsResponseBuilder builder(ttl); 260 261 builder.AppendPtr(serv_params_.service_type_, ttl, 262 serv_params_.service_name_); 263 builder.AppendSrv(serv_params_.service_name_, ttl, kSrvPriority, kSrvWeight, 264 serv_params_.http_port_, 265 serv_params_.service_domain_name_); 266 builder.AppendA(serv_params_.service_domain_name_, ttl, 267 serv_params_.http_ipv4_); 268 builder.AppendTxt(serv_params_.service_name_, ttl, metadata_); 269 scoped_refptr<net::IOBufferWithSize> buffer(builder.Build()); 270 271 DCHECK(buffer.get() != NULL); 272 273 socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_, 274 base::Bind(&DoNothingAfterSendToSocket)); 275 276 VLOG(1) << "Announcement was sent with TTL: " << ttl; 277 } 278 279 time_until_live_ = base::Time::Now() + 280 base::TimeDelta::FromSeconds(full_ttl_); 281 282 // Schedule next announcement. 283 base::MessageLoop::current()->PostDelayedTask( 284 FROM_HERE, 285 base::Bind(&DnsSdServer::Update, AsWeakPtr()), 286 base::TimeDelta::FromSeconds(static_cast<int64>( 287 kTimeToNextAnnouncement*full_ttl_))); 288} 289 290uint32 DnsSdServer::GetCurrentTLL() const { 291 uint32 current_ttl = (time_until_live_ - base::Time::Now()).InSeconds(); 292 if (time_until_live_ < base::Time::Now() || current_ttl == 0) { 293 // This should not be reachable. But still we don't need to fail. 294 current_ttl = 1; // Service is still alive. 295 LOG(ERROR) << "|current_ttl| was equal to zero."; 296 } 297 return current_ttl; 298} 299 300