1//
2// Copyright (C) 2012 The Android Open Source Project
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16
17#include "shill/connectivity_trial.h"
18
19#include <string>
20
21#include <base/bind.h>
22#include <base/strings/pattern.h>
23#include <base/strings/string_number_conversions.h>
24#include <base/strings/string_util.h>
25#include <base/strings/stringprintf.h>
26#if defined(__ANDROID__)
27#include <dbus/service_constants.h>
28#else
29#include <chromeos/dbus/service_constants.h>
30#endif  // __ANDROID__
31
32#include "shill/async_connection.h"
33#include "shill/connection.h"
34#include "shill/dns_client.h"
35#include "shill/event_dispatcher.h"
36#include "shill/http_request.h"
37#include "shill/http_url.h"
38#include "shill/logging.h"
39#include "shill/net/ip_address.h"
40#include "shill/net/sockets.h"
41
42using base::Bind;
43using base::Callback;
44using base::StringPrintf;
45using std::string;
46
47namespace shill {
48
49namespace Logging {
50static auto kModuleLogScope = ScopeLogger::kPortal;
51static string ObjectID(Connection* c) { return c->interface_name(); }
52}
53
54const char ConnectivityTrial::kDefaultURL[] =
55    "http://www.gstatic.com/generate_204";
56const char ConnectivityTrial::kResponseExpected[] = "HTTP/?.? 204";
57
58ConnectivityTrial::ConnectivityTrial(
59    ConnectionRefPtr connection,
60    EventDispatcher* dispatcher,
61    int trial_timeout_seconds,
62    const Callback<void(Result)>& callback)
63    : connection_(connection),
64      dispatcher_(dispatcher),
65      trial_timeout_seconds_(trial_timeout_seconds),
66      trial_callback_(callback),
67      weak_ptr_factory_(this),
68      request_read_callback_(
69          Bind(&ConnectivityTrial::RequestReadCallback,
70               weak_ptr_factory_.GetWeakPtr())),
71      request_result_callback_(
72          Bind(&ConnectivityTrial::RequestResultCallback,
73               weak_ptr_factory_.GetWeakPtr())),
74      is_active_(false) { }
75
76ConnectivityTrial::~ConnectivityTrial() {
77  Stop();
78}
79
80bool ConnectivityTrial::Retry(int start_delay_milliseconds) {
81  SLOG(connection_.get(), 3) << "In " << __func__;
82  if (request_.get())
83    CleanupTrial(false);
84  else
85    return false;
86  StartTrialAfterDelay(start_delay_milliseconds);
87  return true;
88}
89
90bool ConnectivityTrial::Start(const string& url_string,
91                              int start_delay_milliseconds) {
92  SLOG(connection_.get(), 3) << "In " << __func__;
93
94  if (!url_.ParseFromString(url_string)) {
95    LOG(ERROR) << "Failed to parse URL string: " << url_string;
96    return false;
97  }
98  if (request_.get()) {
99    CleanupTrial(false);
100  } else {
101    request_.reset(new HTTPRequest(connection_, dispatcher_, &sockets_));
102  }
103  StartTrialAfterDelay(start_delay_milliseconds);
104  return true;
105}
106
107void ConnectivityTrial::Stop() {
108  SLOG(connection_.get(), 3) << "In " << __func__;
109
110  if (!request_.get()) {
111    return;
112  }
113
114  CleanupTrial(true);
115}
116
117void ConnectivityTrial::StartTrialAfterDelay(int start_delay_milliseconds) {
118  SLOG(connection_.get(), 4) << "In " << __func__
119                             << " delay = " << start_delay_milliseconds
120                             << "ms.";
121  trial_.Reset(Bind(&ConnectivityTrial::StartTrialTask,
122                    weak_ptr_factory_.GetWeakPtr()));
123  dispatcher_->PostDelayedTask(trial_.callback(), start_delay_milliseconds);
124}
125
126void ConnectivityTrial::StartTrialTask() {
127  HTTPRequest::Result result =
128      request_->Start(url_, request_read_callback_, request_result_callback_);
129  if (result != HTTPRequest::kResultInProgress) {
130    CompleteTrial(ConnectivityTrial::GetPortalResultForRequestResult(result));
131    return;
132  }
133  is_active_ = true;
134
135  trial_timeout_.Reset(Bind(&ConnectivityTrial::TimeoutTrialTask,
136                            weak_ptr_factory_.GetWeakPtr()));
137  dispatcher_->PostDelayedTask(trial_timeout_.callback(),
138                               trial_timeout_seconds_ * 1000);
139}
140
141bool ConnectivityTrial::IsActive() {
142  return is_active_;
143}
144
145void ConnectivityTrial::RequestReadCallback(const ByteString& response_data) {
146  const string response_expected(kResponseExpected);
147  bool expected_length_received = false;
148  int compare_length = 0;
149  if (response_data.GetLength() < response_expected.length()) {
150    // There isn't enough data yet for a final decision, but we can still
151    // test to see if the partial string matches so far.
152    expected_length_received = false;
153    compare_length = response_data.GetLength();
154  } else {
155    expected_length_received = true;
156    compare_length = response_expected.length();
157  }
158
159  if (base::MatchPattern(
160          string(reinterpret_cast<const char*>(response_data.GetConstData()),
161                 compare_length),
162          response_expected.substr(0, compare_length))) {
163    if (expected_length_received) {
164      CompleteTrial(Result(kPhaseContent, kStatusSuccess));
165    }
166    // Otherwise, we wait for more data from the server.
167  } else {
168    CompleteTrial(Result(kPhaseContent, kStatusFailure));
169  }
170}
171
172void ConnectivityTrial::RequestResultCallback(
173    HTTPRequest::Result result, const ByteString& /*response_data*/) {
174  CompleteTrial(GetPortalResultForRequestResult(result));
175}
176
177void ConnectivityTrial::CompleteTrial(Result result) {
178  SLOG(connection_.get(), 3)
179      << StringPrintf("Connectivity Trial completed with phase==%s, status==%s",
180                      PhaseToString(result.phase).c_str(),
181                      StatusToString(result.status).c_str());
182  CleanupTrial(false);
183  trial_callback_.Run(result);
184}
185
186void ConnectivityTrial::CleanupTrial(bool reset_request) {
187  trial_timeout_.Cancel();
188
189  if (request_.get())
190    request_->Stop();
191
192  is_active_ = false;
193
194  if (!reset_request || !request_.get())
195    return;
196
197  request_.reset();
198}
199
200void ConnectivityTrial::TimeoutTrialTask() {
201  LOG(ERROR) << "Connectivity Trial - Request timed out";
202  if (request_->response_data().GetLength()) {
203    CompleteTrial(ConnectivityTrial::Result(ConnectivityTrial::kPhaseContent,
204                                            ConnectivityTrial::kStatusTimeout));
205  } else {
206    CompleteTrial(ConnectivityTrial::Result(ConnectivityTrial::kPhaseUnknown,
207                                            ConnectivityTrial::kStatusTimeout));
208  }
209}
210
211// statiic
212const string ConnectivityTrial::PhaseToString(Phase phase) {
213  switch (phase) {
214    case kPhaseConnection:
215      return kPortalDetectionPhaseConnection;
216    case kPhaseDNS:
217      return kPortalDetectionPhaseDns;
218    case kPhaseHTTP:
219      return kPortalDetectionPhaseHttp;
220    case kPhaseContent:
221      return kPortalDetectionPhaseContent;
222    case kPhaseUnknown:
223    default:
224      return kPortalDetectionPhaseUnknown;
225  }
226}
227
228// static
229const string ConnectivityTrial::StatusToString(Status status) {
230  switch (status) {
231    case kStatusSuccess:
232      return kPortalDetectionStatusSuccess;
233    case kStatusTimeout:
234      return kPortalDetectionStatusTimeout;
235    case kStatusFailure:
236    default:
237      return kPortalDetectionStatusFailure;
238  }
239}
240
241ConnectivityTrial::Result ConnectivityTrial::GetPortalResultForRequestResult(
242    HTTPRequest::Result result) {
243  switch (result) {
244    case HTTPRequest::kResultSuccess:
245      // The request completed without receiving the expected payload.
246      return Result(kPhaseContent, kStatusFailure);
247    case HTTPRequest::kResultDNSFailure:
248      return Result(kPhaseDNS, kStatusFailure);
249    case HTTPRequest::kResultDNSTimeout:
250      return Result(kPhaseDNS, kStatusTimeout);
251    case HTTPRequest::kResultConnectionFailure:
252      return Result(kPhaseConnection, kStatusFailure);
253    case HTTPRequest::kResultConnectionTimeout:
254      return Result(kPhaseConnection, kStatusTimeout);
255    case HTTPRequest::kResultRequestFailure:
256    case HTTPRequest::kResultResponseFailure:
257      return Result(kPhaseHTTP, kStatusFailure);
258    case HTTPRequest::kResultRequestTimeout:
259    case HTTPRequest::kResultResponseTimeout:
260      return Result(kPhaseHTTP, kStatusTimeout);
261    case HTTPRequest::kResultUnknown:
262    default:
263      return Result(kPhaseUnknown, kStatusFailure);
264  }
265}
266
267}  // namespace shill
268