1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "cloud_print/gcp20/prototype/dns_sd_server.h"
6
7#include <string.h>
8
9#include "base/basictypes.h"
10#include "base/bind.h"
11#include "base/command_line.h"
12#include "base/message_loop/message_loop.h"
13#include "base/strings/stringprintf.h"
14#include "cloud_print/gcp20/prototype/dns_packet_parser.h"
15#include "cloud_print/gcp20/prototype/dns_response_builder.h"
16#include "cloud_print/gcp20/prototype/gcp20_switches.h"
17#include "net/base/dns_util.h"
18#include "net/base/net_errors.h"
19#include "net/base/net_util.h"
20#include "net/dns/dns_protocol.h"
21
22namespace {
23
24const char kDefaultIpAddressMulticast[] = "224.0.0.251";
25const uint16 kDefaultPortMulticast = 5353;
26
27const double kTimeToNextAnnouncement = 0.8;  // relatively to TTL
28const int kDnsBufSize = 65537;
29
30const uint16 kSrvPriority = 0;
31const uint16 kSrvWeight = 0;
32
33void DoNothingAfterSendToSocket(int /*val*/) {
34  NOTREACHED();
35  // TODO(maksymb): Delete this function once empty callback for SendTo() method
36  // will be allowed.
37}
38
39}  // namespace
40
41DnsSdServer::DnsSdServer()
42    : recv_buf_(new net::IOBufferWithSize(kDnsBufSize)),
43      full_ttl_(0) {
44}
45
46DnsSdServer::~DnsSdServer() {
47  Shutdown();
48}
49
50bool DnsSdServer::Start(const ServiceParameters& serv_params, uint32 full_ttl,
51                        const std::vector<std::string>& metadata) {
52  if (IsOnline())
53    return true;
54
55  if (!CreateSocket())
56    return false;
57
58  // Initializing server with parameters from arguments.
59  serv_params_ = serv_params;
60  full_ttl_ = full_ttl;
61  metadata_ = metadata;
62
63  VLOG(0) << "DNS server started";
64  LOG(WARNING) << "DNS server does not support probing";
65
66  SendAnnouncement(full_ttl_);
67  base::MessageLoop::current()->PostTask(
68      FROM_HERE,
69      base::Bind(&DnsSdServer::OnDatagramReceived, AsWeakPtr()));
70
71  return true;
72}
73
74void DnsSdServer::Update() {
75  if (!IsOnline())
76    return;
77
78  SendAnnouncement(full_ttl_);
79}
80
81void DnsSdServer::Shutdown() {
82  if (!IsOnline())
83    return;
84
85  SendAnnouncement(0);  // TTL is 0
86  socket_->Close();
87  socket_.reset(NULL);
88  VLOG(0) << "DNS server stopped";
89}
90
91void DnsSdServer::UpdateMetadata(const std::vector<std::string>& metadata) {
92  if (!IsOnline())
93    return;
94
95  metadata_ = metadata;
96
97  // TODO(maksymb): If less than 20% of full TTL left before next announcement
98  // then send it now.
99
100  uint32 current_ttl = GetCurrentTLL();
101  if (!CommandLine::ForCurrentProcess()->HasSwitch(switches::kNoAnnouncement)) {
102    DnsResponseBuilder builder(current_ttl);
103
104    builder.AppendTxt(serv_params_.service_name_, current_ttl, metadata_, true);
105    scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
106
107    DCHECK(buffer.get() != NULL);
108
109    socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_,
110                    base::Bind(&DoNothingAfterSendToSocket));
111  }
112}
113
114bool DnsSdServer::CreateSocket() {
115  net::IPAddressNumber local_ip_any;
116  bool success = net::ParseIPLiteralToNumber("0.0.0.0", &local_ip_any);
117  DCHECK(success);
118
119  net::IPAddressNumber multicast_dns_ip_address;
120  success = net::ParseIPLiteralToNumber(kDefaultIpAddressMulticast,
121                                        &multicast_dns_ip_address);
122  DCHECK(success);
123
124
125  socket_.reset(new net::UDPSocket(net::DatagramSocket::DEFAULT_BIND,
126                                   net::RandIntCallback(), NULL,
127                                   net::NetLog::Source()));
128
129  net::IPEndPoint local_address = net::IPEndPoint(local_ip_any,
130                                                  kDefaultPortMulticast);
131  multicast_address_ = net::IPEndPoint(multicast_dns_ip_address,
132                                       kDefaultPortMulticast);
133
134  socket_->AllowAddressReuse();
135
136  int status = socket_->Bind(local_address);
137  if (status < 0)
138    return false;
139
140  socket_->SetMulticastLoopbackMode(false);
141  status = socket_->JoinGroup(multicast_dns_ip_address);
142
143  if (status < 0)
144    return false;
145
146  DCHECK(socket_->is_connected());
147
148  return true;
149}
150
151void DnsSdServer::ProcessMessage(int len, net::IOBufferWithSize* buf) {
152  VLOG(1) << "Received new message with length: " << len;
153
154  // Parse the message.
155  DnsPacketParser parser(buf->data(), len);
156
157  if (!parser.IsValid())  // Was unable to parse header.
158    return;
159
160  // TODO(maksymb): Handle truncated messages.
161  if (parser.header().flags & net::dns_protocol::kFlagResponse)  // Not a query.
162    return;
163
164  DnsResponseBuilder builder(parser.header().id);
165
166  uint32 current_ttl = GetCurrentTLL();
167
168  DnsQueryRecord query;
169  // TODO(maksymb): Check known answers.
170  for (int query_idx = 0; query_idx < parser.header().qdcount; ++query_idx) {
171    bool success = parser.ReadRecord(&query);
172    if (success) {
173      ProccessQuery(current_ttl, query, &builder);
174    } else {  // if (success)
175      VLOG(0) << "Broken package";
176      break;
177    }
178  }
179
180  scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
181  if (buffer.get() == NULL)
182    return;  // No answers.
183
184  VLOG(1) << "Current TTL for respond: " << current_ttl;
185
186  bool unicast_respond =
187      CommandLine::ForCurrentProcess()->HasSwitch(switches::kUnicastRespond);
188  socket_->SendTo(buffer.get(), buffer.get()->size(),
189                  unicast_respond ? recv_address_ : multicast_address_,
190                  base::Bind(&DoNothingAfterSendToSocket));
191  VLOG(1) << "Responded to "
192      << (unicast_respond ? recv_address_ : multicast_address_).ToString();
193}
194
195void DnsSdServer::ProccessQuery(uint32 current_ttl, const DnsQueryRecord& query,
196                                DnsResponseBuilder* builder) const {
197  std::string log;
198  bool responded = false;
199  switch (query.qtype) {
200    // TODO(maksymb): Add IPv6 support.
201    case net::dns_protocol::kTypePTR:
202      log = "Processing PTR query";
203      if (query.qname == serv_params_.service_type_ ||
204          query.qname == serv_params_.secondary_service_type_) {
205        builder->AppendPtr(query.qname, current_ttl,
206                           serv_params_.service_name_, true);
207
208        if (CommandLine::ForCurrentProcess()->HasSwitch(
209                switches::kExtendedResponce)) {
210          builder->AppendSrv(serv_params_.service_name_, current_ttl,
211                             kSrvPriority, kSrvWeight, serv_params_.http_port_,
212                             serv_params_.service_domain_name_, false);
213          builder->AppendA(serv_params_.service_domain_name_, current_ttl,
214                           serv_params_.http_ipv4_, false);
215          builder->AppendAAAA(serv_params_.service_domain_name_, current_ttl,
216                              serv_params_.http_ipv6_, false);
217          builder->AppendTxt(serv_params_.service_name_, current_ttl, metadata_,
218                             false);
219        }
220
221        responded = true;
222      }
223
224      break;
225    case net::dns_protocol::kTypeSRV:
226      log = "Processing SRV query";
227      if (query.qname == serv_params_.service_name_) {
228        builder->AppendSrv(serv_params_.service_name_, current_ttl,
229                           kSrvPriority, kSrvWeight, serv_params_.http_port_,
230                           serv_params_.service_domain_name_, true);
231        responded = true;
232      }
233      break;
234    case net::dns_protocol::kTypeA:
235      log = "Processing A query";
236      if (query.qname == serv_params_.service_domain_name_) {
237        builder->AppendA(serv_params_.service_domain_name_, current_ttl,
238                         serv_params_.http_ipv4_, true);
239        responded = true;
240      }
241      break;
242    case net::dns_protocol::kTypeAAAA:
243      log = "Processing AAAA query";
244      if (query.qname == serv_params_.service_domain_name_) {
245        builder->AppendAAAA(serv_params_.service_domain_name_, current_ttl,
246                            serv_params_.http_ipv6_, true);
247        responded = true;
248      }
249      break;
250    case net::dns_protocol::kTypeTXT:
251      log = "Processing TXT query";
252      if (query.qname == serv_params_.service_name_) {
253        builder->AppendTxt(serv_params_.service_name_, current_ttl, metadata_,
254                           true);
255        responded = true;
256      }
257      break;
258    default:
259      base::SStringPrintf(&log, "Unknown query type (%d)", query.qtype);
260  }
261  log += responded ? ": responded" : ": ignored";
262  VLOG(1) << log;
263}
264
265void DnsSdServer::DoLoop(int rv) {
266  // TODO(maksymb): Check what happened if buffer will be overflowed
267  do {
268    if (rv > 0)
269      ProcessMessage(rv, recv_buf_.get());
270    rv = socket_->RecvFrom(recv_buf_.get(), recv_buf_->size(), &recv_address_,
271                           base::Bind(&DnsSdServer::DoLoop, AsWeakPtr()));
272  } while (rv > 0);
273
274  // TODO(maksymb): Add handler for errors
275  DCHECK(rv == net::ERR_IO_PENDING);
276}
277
278void DnsSdServer::OnDatagramReceived() {
279  DoLoop(0);
280}
281
282void DnsSdServer::SendAnnouncement(uint32 ttl) {
283  if (!CommandLine::ForCurrentProcess()->HasSwitch(switches::kNoAnnouncement)) {
284    DnsResponseBuilder builder(ttl);
285
286    builder.AppendPtr(serv_params_.service_type_, ttl,
287                     serv_params_.service_name_, true);
288    builder.AppendPtr(serv_params_.secondary_service_type_, ttl,
289                      serv_params_.service_name_, true);
290    builder.AppendSrv(serv_params_.service_name_, ttl, kSrvPriority,
291                      kSrvWeight, serv_params_.http_port_,
292                      serv_params_.service_domain_name_, true);
293    builder.AppendA(serv_params_.service_domain_name_, ttl,
294                    serv_params_.http_ipv4_, true);
295    builder.AppendAAAA(serv_params_.service_domain_name_, ttl,
296                       serv_params_.http_ipv6_, true);
297    builder.AppendTxt(serv_params_.service_name_, ttl, metadata_, true);
298
299    scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
300
301    DCHECK(buffer.get() != NULL);
302
303    socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_,
304                    base::Bind(&DoNothingAfterSendToSocket));
305
306    VLOG(1) << "Announcement was sent with TTL: " << ttl;
307  }
308
309  time_until_live_ = base::Time::Now() +
310      base::TimeDelta::FromSeconds(full_ttl_);
311
312  // Schedule next announcement.
313  base::MessageLoop::current()->PostDelayedTask(
314      FROM_HERE,
315      base::Bind(&DnsSdServer::Update, AsWeakPtr()),
316      base::TimeDelta::FromSeconds(static_cast<int64>(
317          kTimeToNextAnnouncement*full_ttl_)));
318}
319
320uint32 DnsSdServer::GetCurrentTLL() const {
321  uint32 current_ttl = (time_until_live_ - base::Time::Now()).InSeconds();
322  if (time_until_live_ < base::Time::Now() || current_ttl == 0) {
323    // This should not be reachable. But still we don't need to fail.
324    current_ttl = 1;  // Service is still alive.
325    LOG(ERROR) << "|current_ttl| was equal to zero.";
326  }
327  return current_ttl;
328}
329