1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "cloud_print/gcp20/prototype/dns_sd_server.h"
6
7#include <string.h>
8
9#include "base/basictypes.h"
10#include "base/bind.h"
11#include "base/command_line.h"
12#include "base/message_loop/message_loop.h"
13#include "base/strings/stringprintf.h"
14#include "cloud_print/gcp20/prototype/dns_packet_parser.h"
15#include "cloud_print/gcp20/prototype/dns_response_builder.h"
16#include "net/base/big_endian.h"
17#include "net/base/dns_util.h"
18#include "net/base/net_errors.h"
19#include "net/base/net_util.h"
20#include "net/dns/dns_protocol.h"
21
22namespace {
23
24const char kDefaultIpAddressMulticast[] = "224.0.0.251";
25const uint16 kDefaultPortMulticast = 5353;
26
27const double kTimeToNextAnnouncement = 0.8;  // relatively to TTL
28const int kDnsBufSize = 65537;
29
30const uint16 kSrvPriority = 0;
31const uint16 kSrvWeight = 0;
32
33void DoNothingAfterSendToSocket(int /*val*/) {
34  NOTREACHED();
35  // TODO(maksymb): Delete this function once empty callback for SendTo() method
36  // will be allowed.
37}
38
39}  // namespace
40
41DnsSdServer::DnsSdServer()
42    : recv_buf_(new net::IOBufferWithSize(kDnsBufSize)),
43      full_ttl_(0) {
44}
45
46DnsSdServer::~DnsSdServer() {
47  Shutdown();
48}
49
50bool DnsSdServer::Start(const ServiceParameters& serv_params, uint32 full_ttl,
51                        const std::vector<std::string>& metadata) {
52  if (IsOnline())
53    return true;
54
55  if (!CreateSocket())
56    return false;
57
58  // Initializing server with parameters from arguments.
59  serv_params_ = serv_params;
60  full_ttl_ = full_ttl;
61  metadata_ = metadata;
62
63  LOG(INFO) << "DNS server started";
64  LOG(WARNING) << "DNS server does not support probing";
65
66  SendAnnouncement(full_ttl_);
67  base::MessageLoop::current()->PostTask(
68      FROM_HERE,
69      base::Bind(&DnsSdServer::OnDatagramReceived, AsWeakPtr()));
70
71  return true;
72}
73
74void DnsSdServer::Update() {
75  if (!IsOnline())
76    return;
77
78  SendAnnouncement(full_ttl_);
79}
80
81void DnsSdServer::Shutdown() {
82  if (!IsOnline())
83    return;
84
85  SendAnnouncement(0);  // TTL is 0
86  socket_->Close();
87  socket_.reset(NULL);
88  LOG(INFO) << "DNS server stopped";
89}
90
91void DnsSdServer::UpdateMetadata(const std::vector<std::string>& metadata) {
92  if (!IsOnline())
93    return;
94
95  metadata_ = metadata;
96
97  // TODO(maksymb): If less than 20% of full TTL left before next announcement
98  // then send it now.
99
100  uint32 current_ttl = GetCurrentTLL();
101  if (!CommandLine::ForCurrentProcess()->HasSwitch("no-announcement")) {
102    DnsResponseBuilder builder(current_ttl);
103
104    builder.AppendTxt(serv_params_.service_name_, current_ttl, metadata_);
105    scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
106
107    DCHECK(buffer.get() != NULL);
108
109    socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_,
110                    base::Bind(&DoNothingAfterSendToSocket));
111  }
112}
113
114bool DnsSdServer::CreateSocket() {
115  net::IPAddressNumber local_ip_any;
116  bool success = net::ParseIPLiteralToNumber("0.0.0.0", &local_ip_any);
117  DCHECK(success);
118
119  net::IPAddressNumber multicast_dns_ip_address;
120  success = net::ParseIPLiteralToNumber(kDefaultIpAddressMulticast,
121                                        &multicast_dns_ip_address);
122  DCHECK(success);
123
124
125  socket_.reset(new net::UDPSocket(net::DatagramSocket::DEFAULT_BIND,
126                                   net::RandIntCallback(), NULL,
127                                   net::NetLog::Source()));
128
129  net::IPEndPoint local_address = net::IPEndPoint(local_ip_any,
130                                                  kDefaultPortMulticast);
131  multicast_address_ = net::IPEndPoint(multicast_dns_ip_address,
132                                       kDefaultPortMulticast);
133
134  socket_->AllowAddressReuse();
135
136  int status = socket_->Bind(local_address);
137  if (status < 0)
138    return false;
139
140  socket_->SetMulticastLoopbackMode(false);
141  status = socket_->JoinGroup(multicast_dns_ip_address);
142
143  if (status < 0)
144    return false;
145
146  DCHECK(socket_->is_connected());
147
148  return true;
149}
150
151void DnsSdServer::ProcessMessage(int len, net::IOBufferWithSize* buf) {
152  VLOG(1) << "Received new message with length: " << len;
153
154  // Parse the message.
155  DnsPacketParser parser(buf->data(), len);
156
157  if (!parser.IsValid())  // Was unable to parse header.
158    return;
159
160  // TODO(maksymb): Handle truncated messages.
161  if (parser.header().flags & net::dns_protocol::kFlagResponse)  // Not a query.
162    return;
163
164  DnsResponseBuilder builder(parser.header().id);
165
166  uint32 current_ttl = GetCurrentTLL();
167
168  DnsQueryRecord query;
169  // TODO(maksymb): Check known answers.
170  for (int query_idx = 0; query_idx < parser.header().qdcount; ++query_idx) {
171    bool success = parser.ReadRecord(&query);
172    if (success) {
173      ProccessQuery(current_ttl, query, &builder);
174    } else {  // if (success)
175      LOG(INFO) << "Broken package";
176      break;
177    }
178  }
179
180  scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
181  if (buffer.get() == NULL)
182    return;  // No answers.
183
184  VLOG(1) << "Current TTL for respond: " << current_ttl;
185
186  bool unicast_respond =
187      CommandLine::ForCurrentProcess()->HasSwitch("unicast-respond");
188  socket_->SendTo(buffer.get(), buffer.get()->size(),
189                  unicast_respond ? recv_address_ : multicast_address_,
190                  base::Bind(&DoNothingAfterSendToSocket));
191  VLOG(1) << "Responded to "
192      << (unicast_respond ? recv_address_ : multicast_address_).ToString();
193}
194
195void DnsSdServer::ProccessQuery(uint32 current_ttl, const DnsQueryRecord& query,
196                                DnsResponseBuilder* builder) const {
197  std::string log;
198  bool responded = false;
199  switch (query.qtype) {
200    // TODO(maksymb): Add IPv6 support.
201    case net::dns_protocol::kTypePTR:
202      log = "Processing PTR query";
203      if (query.qname == serv_params_.service_type_) {
204        builder->AppendPtr(serv_params_.service_type_, current_ttl,
205                           serv_params_.service_name_);
206        responded = true;
207      }
208      break;
209    case net::dns_protocol::kTypeSRV:
210      log = "Processing SRV query";
211      if (query.qname == serv_params_.service_name_) {
212        builder->AppendSrv(serv_params_.service_name_, current_ttl,
213                           kSrvPriority, kSrvWeight, serv_params_.http_port_,
214                           serv_params_.service_domain_name_);
215        responded = true;
216      }
217      break;
218    case net::dns_protocol::kTypeA:
219      log = "Processing A query";
220      if (query.qname == serv_params_.service_domain_name_) {
221        builder->AppendA(serv_params_.service_domain_name_, current_ttl,
222                         serv_params_.http_ipv4_);
223        responded = true;
224      }
225      break;
226    case net::dns_protocol::kTypeTXT:
227      log = "Processing TXT query";
228      if (query.qname == serv_params_.service_name_) {
229        builder->AppendTxt(serv_params_.service_name_, current_ttl, metadata_);
230        responded = true;
231      }
232      break;
233    default:
234      base::SStringPrintf(&log, "Unknown query type (%d)", query.qtype);
235  }
236  log += responded ? ": responded" : ": ignored";
237  VLOG(1) << log;
238}
239
240void DnsSdServer::DoLoop(int rv) {
241  // TODO(maksymb): Check what happened if buffer will be overflowed
242  do {
243    if (rv > 0)
244      ProcessMessage(rv, recv_buf_.get());
245    rv = socket_->RecvFrom(recv_buf_.get(), recv_buf_->size(), &recv_address_,
246                           base::Bind(&DnsSdServer::DoLoop, AsWeakPtr()));
247  } while (rv > 0);
248
249  // TODO(maksymb): Add handler for errors
250  DCHECK(rv == net::ERR_IO_PENDING);
251}
252
253void DnsSdServer::OnDatagramReceived() {
254  DoLoop(0);
255}
256
257void DnsSdServer::SendAnnouncement(uint32 ttl) {
258  if (!CommandLine::ForCurrentProcess()->HasSwitch("no-announcement")) {
259    DnsResponseBuilder builder(ttl);
260
261    builder.AppendPtr(serv_params_.service_type_, ttl,
262                      serv_params_.service_name_);
263    builder.AppendSrv(serv_params_.service_name_, ttl, kSrvPriority, kSrvWeight,
264                      serv_params_.http_port_,
265                      serv_params_.service_domain_name_);
266    builder.AppendA(serv_params_.service_domain_name_, ttl,
267                    serv_params_.http_ipv4_);
268    builder.AppendTxt(serv_params_.service_name_, ttl, metadata_);
269    scoped_refptr<net::IOBufferWithSize> buffer(builder.Build());
270
271    DCHECK(buffer.get() != NULL);
272
273    socket_->SendTo(buffer.get(), buffer.get()->size(), multicast_address_,
274                    base::Bind(&DoNothingAfterSendToSocket));
275
276    VLOG(1) << "Announcement was sent with TTL: " << ttl;
277  }
278
279  time_until_live_ = base::Time::Now() +
280      base::TimeDelta::FromSeconds(full_ttl_);
281
282  // Schedule next announcement.
283  base::MessageLoop::current()->PostDelayedTask(
284      FROM_HERE,
285      base::Bind(&DnsSdServer::Update, AsWeakPtr()),
286      base::TimeDelta::FromSeconds(static_cast<int64>(
287          kTimeToNextAnnouncement*full_ttl_)));
288}
289
290uint32 DnsSdServer::GetCurrentTLL() const {
291  uint32 current_ttl = (time_until_live_ - base::Time::Now()).InSeconds();
292  if (time_until_live_ < base::Time::Now() || current_ttl == 0) {
293    // This should not be reachable. But still we don't need to fail.
294    current_ttl = 1;  // Service is still alive.
295    LOG(ERROR) << "|current_ttl| was equal to zero.";
296  }
297  return current_ttl;
298}
299
300