1//
2// Copyright (C) 2012 The Android Open Source Project
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16
17#include "shill/dns_client.h"
18
19#include <netdb.h>
20
21#include <memory>
22#include <string>
23#include <vector>
24
25#include <base/bind.h>
26
27#include "shill/error.h"
28#include "shill/event_dispatcher.h"
29#include "shill/mock_ares.h"
30#include "shill/mock_control.h"
31#include "shill/mock_event_dispatcher.h"
32#include "shill/net/io_handler.h"
33#include "shill/net/mock_time.h"
34#include "shill/testing.h"
35
36using base::Bind;
37using base::Unretained;
38using std::string;
39using std::vector;
40using testing::_;
41using testing::DoAll;
42using testing::Not;
43using testing::Return;
44using testing::ReturnArg;
45using testing::ReturnNew;
46using testing::Test;
47using testing::SetArgumentPointee;
48using testing::StrEq;
49using testing::StrictMock;
50
51namespace shill {
52
53namespace {
54const char kGoodName[] = "all-systems.mcast.net";
55const char kResult[] = "224.0.0.1";
56const char kGoodServer[] = "8.8.8.8";
57const char kBadServer[] = "10.9xx8.7";
58const char kNetworkInterface[] = "eth0";
59char kReturnAddressList0[] = { static_cast<char>(224), 0, 0, 1 };
60char* kReturnAddressList[] = { kReturnAddressList0, nullptr };
61char kFakeAresChannelData = 0;
62const ares_channel kAresChannel =
63    reinterpret_cast<ares_channel>(&kFakeAresChannelData);
64const int kAresFd = 10203;
65const int kAresTimeoutMS = 2000;  // ARES transaction timeout
66const int kAresWaitMS = 1000;     // Time period ARES asks caller to wait
67}  // namespace
68
69class DNSClientTest : public Test {
70 public:
71  DNSClientTest()
72      : ares_result_(ARES_SUCCESS), address_result_(IPAddress::kFamilyUnknown) {
73    time_val_.tv_sec = 0;
74    time_val_.tv_usec = 0;
75    ares_timeout_.tv_sec = kAresWaitMS / 1000;
76    ares_timeout_.tv_usec = (kAresWaitMS % 1000) * 1000;
77    hostent_.h_addrtype = IPAddress::kFamilyIPv4;
78    hostent_.h_length = sizeof(kReturnAddressList0);
79    hostent_.h_addr_list = kReturnAddressList;
80  }
81
82  virtual void SetUp() {
83    EXPECT_CALL(time_, GetTimeMonotonic(_))
84        .WillRepeatedly(DoAll(SetArgumentPointee<0>(time_val_), Return(0)));
85    SetInActive();
86  }
87
88  virtual void TearDown() {
89    // We need to make sure the dns_client instance releases ares_
90    // before the destructor for DNSClientTest deletes ares_.
91    if (dns_client_.get()) {
92      dns_client_->Stop();
93    }
94  }
95
96  void AdvanceTime(int time_ms) {
97    struct timeval adv_time = { time_ms/1000, (time_ms % 1000) * 1000 };
98    timeradd(&time_val_, &adv_time, &time_val_);
99    EXPECT_CALL(time_, GetTimeMonotonic(_))
100        .WillRepeatedly(DoAll(SetArgumentPointee<0>(time_val_), Return(0)));
101  }
102
103  void CallReplyCB() {
104    dns_client_->ReceiveDNSReplyCB(dns_client_.get(), ares_result_, 0,
105                                   &hostent_);
106  }
107
108  void CallDNSRead() {
109    dns_client_->HandleDNSRead(kAresFd);
110  }
111
112  void CallDNSWrite() {
113    dns_client_->HandleDNSWrite(kAresFd);
114  }
115
116  void CallTimeout() {
117    dns_client_->HandleTimeout();
118  }
119
120  void CallCompletion() {
121    dns_client_->HandleCompletion();
122  }
123
124  void CreateClient(const vector<string>& dns_servers, int timeout_ms) {
125    dns_client_.reset(new DNSClient(IPAddress::kFamilyIPv4,
126                                    kNetworkInterface,
127                                    dns_servers,
128                                    timeout_ms,
129                                    &dispatcher_,
130                                    callback_target_.callback()));
131    dns_client_->ares_ = &ares_;
132    dns_client_->time_ = &time_;
133  }
134
135  void SetActive() {
136    // Returns that socket kAresFd is readable.
137    EXPECT_CALL(ares_, GetSock(_, _, _))
138        .WillRepeatedly(DoAll(SetArgumentPointee<1>(kAresFd), Return(1)));
139    EXPECT_CALL(ares_, Timeout(_, _, _))
140        .WillRepeatedly(
141            DoAll(SetArgumentPointee<2>(ares_timeout_), ReturnArg<2>()));
142  }
143
144  void SetInActive() {
145    EXPECT_CALL(ares_, GetSock(_, _, _))
146        .WillRepeatedly(Return(0));
147    EXPECT_CALL(ares_, Timeout(_, _, _))
148        .WillRepeatedly(ReturnArg<1>());
149  }
150
151  void SetupRequest(const string& name, const string& server) {
152    vector<string> dns_servers;
153    dns_servers.push_back(server);
154    CreateClient(dns_servers, kAresTimeoutMS);
155    // These expectations are fulfilled when dns_client_->Start() is called.
156    EXPECT_CALL(ares_, InitOptions(_, _, _))
157        .WillOnce(DoAll(SetArgumentPointee<0>(kAresChannel),
158                        Return(ARES_SUCCESS)));
159    EXPECT_CALL(ares_, SetServersCsv(_, _))
160        .WillOnce(Return(ARES_SUCCESS));
161    EXPECT_CALL(ares_, SetLocalDev(kAresChannel, StrEq(kNetworkInterface)))
162        .Times(1);
163    EXPECT_CALL(ares_, GetHostByName(kAresChannel, StrEq(name), _, _, _));
164  }
165
166  void StartValidRequest() {
167    SetupRequest(kGoodName, kGoodServer);
168    EXPECT_CALL(dispatcher_,
169                CreateReadyHandler(kAresFd, IOHandler::kModeInput, _))
170        .WillOnce(ReturnNew<IOHandler>());
171    SetActive();
172    EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
173    Error error;
174    ASSERT_TRUE(dns_client_->Start(kGoodName, &error));
175    EXPECT_TRUE(error.IsSuccess());
176    EXPECT_CALL(ares_, Destroy(kAresChannel));
177  }
178
179  void TestValidCompletion() {
180    EXPECT_CALL(ares_, ProcessFd(kAresChannel, kAresFd, ARES_SOCKET_BAD))
181        .WillOnce(InvokeWithoutArgs(this, &DNSClientTest::CallReplyCB));
182    ExpectPostCompletionTask();
183    CallDNSRead();
184
185    // Make sure that the address value is correct as held in the DNSClient.
186    ASSERT_TRUE(dns_client_->address_.IsValid());
187    IPAddress ipaddr(dns_client_->address_.family());
188    ASSERT_TRUE(ipaddr.SetAddressFromString(kResult));
189    EXPECT_TRUE(ipaddr.Equals(dns_client_->address_));
190
191    // Make sure the callback gets called with a success result, and save
192    // the callback address argument in |address_result_|.
193    EXPECT_CALL(callback_target_, CallTarget(IsSuccess(), _))
194        .WillOnce(Invoke(this, &DNSClientTest::SaveCallbackArgs));
195    CallCompletion();
196
197    // Make sure the address was successfully passed to the callback.
198    EXPECT_TRUE(ipaddr.Equals(address_result_));
199    EXPECT_TRUE(dns_client_->address_.IsDefault());
200  }
201
202  void SaveCallbackArgs(const Error& error, const IPAddress& address)  {
203    error_result_.CopyFrom(error);
204    address_result_ = address;
205  }
206
207  void ExpectPostCompletionTask() {
208    EXPECT_CALL(dispatcher_, PostTask(_));
209  }
210
211  void ExpectReset() {
212    EXPECT_TRUE(dns_client_->address_.family() == IPAddress::kFamilyIPv4);
213    EXPECT_TRUE(dns_client_->address_.IsDefault());
214    EXPECT_FALSE(dns_client_->resolver_state_.get());
215  }
216
217 protected:
218  class DNSCallbackTarget {
219   public:
220    DNSCallbackTarget()
221        : callback_(Bind(&DNSCallbackTarget::CallTarget, Unretained(this))) {}
222
223    MOCK_METHOD2(CallTarget, void(const Error& error,
224                                  const IPAddress& address));
225    const DNSClient::ClientCallback& callback() { return callback_; }
226
227   private:
228    DNSClient::ClientCallback callback_;
229  };
230
231  std::unique_ptr<DNSClient> dns_client_;
232  StrictMock<MockEventDispatcher> dispatcher_;
233  string queued_request_;
234  StrictMock<DNSCallbackTarget> callback_target_;
235  StrictMock<MockAres> ares_;
236  StrictMock<MockTime> time_;
237  struct timeval time_val_;
238  struct timeval ares_timeout_;
239  struct hostent hostent_;
240  int ares_result_;
241  Error error_result_;
242  IPAddress address_result_;
243};
244
245class SentinelIOHandler : public IOHandler {
246 public:
247  MOCK_METHOD0(Die, void());
248  virtual ~SentinelIOHandler() { Die(); }
249};
250
251TEST_F(DNSClientTest, Constructor) {
252  vector<string> dns_servers;
253  dns_servers.push_back(kGoodServer);
254  CreateClient(dns_servers, kAresTimeoutMS);
255  ExpectReset();
256}
257
258// Receive error because no DNS servers were specified.
259TEST_F(DNSClientTest, NoServers) {
260  CreateClient(vector<string>(), kAresTimeoutMS);
261  Error error;
262  EXPECT_FALSE(dns_client_->Start(kGoodName, &error));
263  EXPECT_EQ(Error::kInvalidArguments, error.type());
264}
265
266// Setup error because SetServersCsv failed due to invalid DNS servers.
267TEST_F(DNSClientTest, SetServersCsvInvalidServer) {
268  vector<string> dns_servers;
269  dns_servers.push_back(kBadServer);
270  CreateClient(dns_servers, kAresTimeoutMS);
271  EXPECT_CALL(ares_, InitOptions(_, _, _))
272      .WillOnce(Return(ARES_SUCCESS));
273  EXPECT_CALL(ares_, SetServersCsv(_, _))
274      .WillOnce(Return(ARES_EBADSTR));
275  Error error;
276  EXPECT_FALSE(dns_client_->Start(kGoodName, &error));
277  EXPECT_EQ(Error::kOperationFailed, error.type());
278}
279
280// Setup error because InitOptions failed.
281TEST_F(DNSClientTest, InitOptionsFailure) {
282  vector<string> dns_servers;
283  dns_servers.push_back(kGoodServer);
284  CreateClient(dns_servers, kAresTimeoutMS);
285  EXPECT_CALL(ares_, InitOptions(_, _, _))
286      .WillOnce(Return(ARES_EBADFLAGS));
287  Error error;
288  EXPECT_FALSE(dns_client_->Start(kGoodName, &error));
289  EXPECT_EQ(Error::kOperationFailed, error.type());
290}
291
292// Fail a second request because one is already in progress.
293TEST_F(DNSClientTest, MultipleRequest) {
294  StartValidRequest();
295  Error error;
296  ASSERT_FALSE(dns_client_->Start(kGoodName, &error));
297  EXPECT_EQ(Error::kInProgress, error.type());
298}
299
300TEST_F(DNSClientTest, GoodRequest) {
301  StartValidRequest();
302  TestValidCompletion();
303}
304
305TEST_F(DNSClientTest, GoodRequestWithTimeout) {
306  StartValidRequest();
307  // Insert an intermediate HandleTimeout callback.
308  AdvanceTime(kAresWaitMS);
309  EXPECT_CALL(ares_, ProcessFd(kAresChannel, ARES_SOCKET_BAD, ARES_SOCKET_BAD));
310  EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
311  CallTimeout();
312  AdvanceTime(kAresWaitMS);
313  TestValidCompletion();
314}
315
316TEST_F(DNSClientTest, GoodRequestWithDNSRead) {
317  StartValidRequest();
318  // Insert an intermediate HandleDNSRead callback.
319  AdvanceTime(kAresWaitMS);
320  EXPECT_CALL(ares_, ProcessFd(kAresChannel, kAresFd, ARES_SOCKET_BAD));
321  EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
322  CallDNSRead();
323  AdvanceTime(kAresWaitMS);
324  TestValidCompletion();
325}
326
327TEST_F(DNSClientTest, GoodRequestWithDNSWrite) {
328  StartValidRequest();
329  // Insert an intermediate HandleDNSWrite callback.
330  AdvanceTime(kAresWaitMS);
331  EXPECT_CALL(ares_, ProcessFd(kAresChannel, ARES_SOCKET_BAD, kAresFd));
332  EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
333  CallDNSWrite();
334  AdvanceTime(kAresWaitMS);
335  TestValidCompletion();
336}
337
338// Failure due to the timeout occurring during first call to RefreshHandles.
339TEST_F(DNSClientTest, TimeoutFirstRefresh) {
340  SetupRequest(kGoodName, kGoodServer);
341  struct timeval init_time_val = time_val_;
342  AdvanceTime(kAresTimeoutMS);
343  EXPECT_CALL(time_, GetTimeMonotonic(_))
344      .WillOnce(DoAll(SetArgumentPointee<0>(init_time_val), Return(0)))
345      .WillRepeatedly(DoAll(SetArgumentPointee<0>(time_val_), Return(0)));
346  EXPECT_CALL(callback_target_, CallTarget(Not(IsSuccess()), _))
347      .Times(0);
348  EXPECT_CALL(ares_, Destroy(kAresChannel));
349  Error error;
350  // Expect the DNSClient to post a completion task.  However this task will
351  // never run since the Stop() gets called before returning.  We confirm
352  // that the task indeed gets canceled below in ExpectReset().
353  ExpectPostCompletionTask();
354  ASSERT_FALSE(dns_client_->Start(kGoodName, &error));
355  EXPECT_EQ(Error::kOperationTimeout, error.type());
356  EXPECT_EQ(string(DNSClient::kErrorTimedOut), error.message());
357  ExpectReset();
358}
359
360// Failed request due to timeout within the dns_client.
361TEST_F(DNSClientTest, TimeoutDispatcherEvent) {
362  StartValidRequest();
363  EXPECT_CALL(ares_, ProcessFd(kAresChannel,
364                               ARES_SOCKET_BAD, ARES_SOCKET_BAD));
365  AdvanceTime(kAresTimeoutMS);
366  ExpectPostCompletionTask();
367  CallTimeout();
368  EXPECT_CALL(callback_target_, CallTarget(
369      ErrorIs(Error::kOperationTimeout, DNSClient::kErrorTimedOut), _));
370  CallCompletion();
371}
372
373// Failed request due to timeout reported by ARES.
374TEST_F(DNSClientTest, TimeoutFromARES) {
375  StartValidRequest();
376  AdvanceTime(kAresWaitMS);
377  ares_result_ = ARES_ETIMEOUT;
378  EXPECT_CALL(ares_, ProcessFd(kAresChannel, ARES_SOCKET_BAD, ARES_SOCKET_BAD))
379        .WillOnce(InvokeWithoutArgs(this, &DNSClientTest::CallReplyCB));
380  ExpectPostCompletionTask();
381  CallTimeout();
382  EXPECT_CALL(callback_target_, CallTarget(
383      ErrorIs(Error::kOperationTimeout, DNSClient::kErrorTimedOut), _));
384  CallCompletion();
385}
386
387// Failed request due to "host not found" reported by ARES.
388TEST_F(DNSClientTest, HostNotFound) {
389  StartValidRequest();
390  AdvanceTime(kAresWaitMS);
391  ares_result_ = ARES_ENOTFOUND;
392  EXPECT_CALL(ares_, ProcessFd(kAresChannel, kAresFd, ARES_SOCKET_BAD))
393      .WillOnce(InvokeWithoutArgs(this, &DNSClientTest::CallReplyCB));
394  ExpectPostCompletionTask();
395  CallDNSRead();
396  EXPECT_CALL(callback_target_, CallTarget(
397      ErrorIs(Error::kOperationFailed, DNSClient::kErrorNotFound), _));
398  CallCompletion();
399}
400
401// Make sure IOHandles are deallocated when GetSock() reports them gone.
402TEST_F(DNSClientTest, IOHandleDeallocGetSock) {
403  SetupRequest(kGoodName, kGoodServer);
404  // This isn't any kind of scoped/ref pointer because we are tracking dealloc.
405  SentinelIOHandler* io_handler = new SentinelIOHandler();
406  EXPECT_CALL(dispatcher_,
407              CreateReadyHandler(kAresFd, IOHandler::kModeInput, _))
408      .WillOnce(Return(io_handler));
409  EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
410  SetActive();
411  Error error;
412  ASSERT_TRUE(dns_client_->Start(kGoodName, &error));
413  AdvanceTime(kAresWaitMS);
414  SetInActive();
415  EXPECT_CALL(*io_handler, Die());
416  EXPECT_CALL(ares_, ProcessFd(kAresChannel, kAresFd, ARES_SOCKET_BAD));
417  EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
418  CallDNSRead();
419  EXPECT_CALL(ares_, Destroy(kAresChannel));
420}
421
422// Make sure IOHandles are deallocated when Stop() is called.
423TEST_F(DNSClientTest, IOHandleDeallocStop) {
424  SetupRequest(kGoodName, kGoodServer);
425  // This isn't any kind of scoped/ref pointer because we are tracking dealloc.
426  SentinelIOHandler* io_handler = new SentinelIOHandler();
427  EXPECT_CALL(dispatcher_,
428              CreateReadyHandler(kAresFd, IOHandler::kModeInput, _))
429      .WillOnce(Return(io_handler));
430  EXPECT_CALL(dispatcher_, PostDelayedTask(_, kAresWaitMS));
431  SetActive();
432  Error error;
433  ASSERT_TRUE(dns_client_->Start(kGoodName, &error));
434  EXPECT_CALL(*io_handler, Die());
435  EXPECT_CALL(ares_, Destroy(kAresChannel));
436  dns_client_->Stop();
437}
438
439}  // namespace shill
440