routing_table.cc revision 6fbf64f493a9aae7d743888039c61a57386203db
1// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "shill/routing_table.h"
6
7#include <arpa/inet.h>
8#include <fcntl.h>
9#include <linux/netlink.h>
10#include <linux/rtnetlink.h>
11#include <netinet/ether.h>
12#include <net/if.h>
13#include <net/if_arp.h>
14#include <string.h>
15#include <sys/socket.h>
16#include <time.h>
17#include <unistd.h>
18
19#include <string>
20
21#include <base/bind.h>
22#include <base/containers/hash_tables.h>
23#include <base/files/file_path.h>
24#include <base/file_util.h>
25#include <base/memory/scoped_ptr.h>
26#include <base/stl_util.h>
27#include <base/strings/stringprintf.h>
28
29#include "shill/byte_string.h"
30#include "shill/logging.h"
31#include "shill/routing_table_entry.h"
32#include "shill/rtnl_handler.h"
33#include "shill/rtnl_listener.h"
34#include "shill/rtnl_message.h"
35
36using base::Bind;
37using base::FilePath;
38using base::Unretained;
39using std::deque;
40using std::string;
41using std::vector;
42
43namespace shill {
44
45namespace {
46base::LazyInstance<RoutingTable> g_routing_table = LAZY_INSTANCE_INITIALIZER;
47}  // namespace
48
49// static
50const char RoutingTable::kRouteFlushPath4[] = "/proc/sys/net/ipv4/route/flush";
51// static
52const char RoutingTable::kRouteFlushPath6[] = "/proc/sys/net/ipv6/route/flush";
53
54RoutingTable::RoutingTable()
55    : route_callback_(Bind(&RoutingTable::RouteMsgHandler, Unretained(this))),
56      rtnl_handler_(RTNLHandler::GetInstance()) {
57  SLOG(Route, 2) << __func__;
58}
59
60RoutingTable::~RoutingTable() {}
61
62RoutingTable* RoutingTable::GetInstance() {
63  return g_routing_table.Pointer();
64}
65
66void RoutingTable::Start() {
67  SLOG(Route, 2) << __func__;
68
69  route_listener_.reset(
70      new RTNLListener(RTNLHandler::kRequestRoute, route_callback_));
71  rtnl_handler_->RequestDump(RTNLHandler::kRequestRoute);
72}
73
74void RoutingTable::Stop() {
75  SLOG(Route, 2) << __func__;
76
77  route_listener_.reset();
78}
79
80bool RoutingTable::AddRoute(int interface_index,
81                            const RoutingTableEntry &entry) {
82  SLOG(Route, 2) << __func__ << ": "
83                 << "destination " << entry.dst.ToString()
84                 << " index " << interface_index
85                 << " gateway " << entry.gateway.ToString()
86                 << " metric " << entry.metric;
87
88  CHECK(!entry.from_rtnl);
89  if (!ApplyRoute(interface_index,
90                  entry,
91                  RTNLMessage::kModeAdd,
92                  NLM_F_CREATE | NLM_F_EXCL)) {
93    return false;
94  }
95  tables_[interface_index].push_back(entry);
96  return true;
97}
98
99bool RoutingTable::GetDefaultRoute(int interface_index,
100                                   IPAddress::Family family,
101                                   RoutingTableEntry *entry) {
102  RoutingTableEntry *found_entry;
103  bool ret = GetDefaultRouteInternal(interface_index, family, &found_entry);
104  if (ret) {
105    *entry = *found_entry;
106  }
107  return ret;
108}
109
110bool RoutingTable::GetDefaultRouteInternal(int interface_index,
111                                           IPAddress::Family family,
112                                           RoutingTableEntry **entry) {
113  SLOG(Route, 2) << __func__ << " index " << interface_index
114                 << " family " << IPAddress::GetAddressFamilyName(family);
115
116  Tables::iterator table = tables_.find(interface_index);
117  if (table == tables_.end()) {
118    SLOG(Route, 2) << __func__ << " no table";
119    return false;
120  }
121
122  for (auto &nent : table->second) {
123    if (nent.dst.IsDefault() && nent.dst.family() == family) {
124      *entry = &nent;
125      SLOG(Route, 2) << __func__ << ": found"
126                     << " gateway " << nent.gateway.ToString()
127                     << " metric " << nent.metric;
128      return true;
129    }
130  }
131
132  SLOG(Route, 2) << __func__ << " no route";
133  return false;
134}
135
136bool RoutingTable::SetDefaultRoute(int interface_index,
137                                   const IPAddress &gateway_address,
138                                   uint32 metric) {
139  SLOG(Route, 2) << __func__ << " index " << interface_index
140                 << " metric " << metric;
141
142  RoutingTableEntry *old_entry;
143
144  if (GetDefaultRouteInternal(interface_index,
145                              gateway_address.family(),
146                              &old_entry)) {
147    if (old_entry->gateway.Equals(gateway_address)) {
148      if (old_entry->metric != metric) {
149        ReplaceMetric(interface_index, old_entry, metric);
150      }
151      return true;
152    } else {
153      // TODO(quiche): Update internal state as well?
154      ApplyRoute(interface_index,
155                 *old_entry,
156                 RTNLMessage::kModeDelete,
157                 0);
158    }
159  }
160
161  IPAddress default_address(gateway_address.family());
162  default_address.SetAddressToDefault();
163
164  return AddRoute(interface_index,
165                  RoutingTableEntry(default_address,
166                                    default_address,
167                                    gateway_address,
168                                    metric,
169                                    RT_SCOPE_UNIVERSE,
170                                    false));
171}
172
173bool RoutingTable::ConfigureRoutes(int interface_index,
174                                   const IPConfigRefPtr &ipconfig,
175                                   uint32 metric) {
176  bool ret = true;
177
178  IPAddress::Family address_family = ipconfig->properties().address_family;
179  const vector<IPConfig::Route> &routes = ipconfig->properties().routes;
180
181  for (const auto &route : routes) {
182    SLOG(Route, 3) << "Installing route:"
183                   << " Destination: " << route.host
184                   << " Netmask: " << route.netmask
185                   << " Gateway: " << route.gateway;
186    IPAddress destination_address(address_family);
187    IPAddress source_address(address_family);  // Left as default.
188    IPAddress gateway_address(address_family);
189    if (!destination_address.SetAddressFromString(route.host)) {
190      LOG(ERROR) << "Failed to parse host "
191                 << route.host;
192      ret = false;
193      continue;
194    }
195    if (!gateway_address.SetAddressFromString(route.gateway)) {
196      LOG(ERROR) << "Failed to parse gateway "
197                 << route.gateway;
198      ret = false;
199      continue;
200    }
201    destination_address.set_prefix(
202        IPAddress::GetPrefixLengthFromMask(address_family, route.netmask));
203    if (!AddRoute(interface_index,
204                  RoutingTableEntry(destination_address,
205                                    source_address,
206                                    gateway_address,
207                                    metric,
208                                    RT_SCOPE_UNIVERSE,
209                                    false))) {
210      ret = false;
211    }
212  }
213  return ret;
214}
215
216void RoutingTable::FlushRoutes(int interface_index) {
217  SLOG(Route, 2) << __func__;
218
219  auto table = tables_.find(interface_index);
220  if (table == tables_.end()) {
221    return;
222  }
223
224  for (const auto &nent : table->second) {
225    ApplyRoute(interface_index, nent, RTNLMessage::kModeDelete, 0);
226  }
227  table->second.clear();
228}
229
230void RoutingTable::FlushRoutesWithTag(int tag) {
231  SLOG(Route, 2) << __func__;
232
233  for (auto &table : tables_) {
234    for (auto nent = table.second.begin(); nent != table.second.end();) {
235      if (nent->tag == tag) {
236        ApplyRoute(table.first, *nent, RTNLMessage::kModeDelete, 0);
237        nent = table.second.erase(nent);
238      } else {
239        ++nent;
240      }
241    }
242  }
243}
244
245void RoutingTable::ResetTable(int interface_index) {
246  tables_.erase(interface_index);
247}
248
249void RoutingTable::SetDefaultMetric(int interface_index, uint32 metric) {
250  SLOG(Route, 2) << __func__ << " index " << interface_index
251                 << " metric " << metric;
252
253  RoutingTableEntry *entry;
254  if (GetDefaultRouteInternal(
255          interface_index, IPAddress::kFamilyIPv4, &entry) &&
256      entry->metric != metric) {
257    ReplaceMetric(interface_index, entry, metric);
258  }
259
260  if (GetDefaultRouteInternal(
261          interface_index, IPAddress::kFamilyIPv6, &entry) &&
262      entry->metric != metric) {
263    ReplaceMetric(interface_index, entry, metric);
264  }
265}
266
267// static
268bool RoutingTable::ParseRoutingTableMessage(const RTNLMessage &message,
269                                            int *interface_index,
270                                            RoutingTableEntry *entry) {
271  if (message.type() != RTNLMessage::kTypeRoute ||
272      message.family() == IPAddress::kFamilyUnknown ||
273      !message.HasAttribute(RTA_OIF)) {
274    return false;
275  }
276
277  const RTNLMessage::RouteStatus &route_status = message.route_status();
278
279  if (route_status.type != RTN_UNICAST ||
280      route_status.table != RT_TABLE_MAIN) {
281    return false;
282  }
283
284  uint32 interface_index_u32 = 0;
285  if (!message.GetAttribute(RTA_OIF).ConvertToCPUUInt32(&interface_index_u32)) {
286    return false;
287  }
288  *interface_index = interface_index_u32;
289
290  uint32 metric = 0;
291  if (message.HasAttribute(RTA_PRIORITY)) {
292    message.GetAttribute(RTA_PRIORITY).ConvertToCPUUInt32(&metric);
293  }
294
295  IPAddress default_addr(message.family());
296  default_addr.SetAddressToDefault();
297
298  ByteString dst_bytes(default_addr.address());
299  if (message.HasAttribute(RTA_DST)) {
300    dst_bytes = message.GetAttribute(RTA_DST);
301  }
302  ByteString src_bytes(default_addr.address());
303  if (message.HasAttribute(RTA_SRC)) {
304    src_bytes = message.GetAttribute(RTA_SRC);
305  }
306  ByteString gateway_bytes(default_addr.address());
307  if (message.HasAttribute(RTA_GATEWAY)) {
308    gateway_bytes = message.GetAttribute(RTA_GATEWAY);
309  }
310
311  entry->dst = IPAddress(message.family(), dst_bytes, route_status.dst_prefix);
312  entry->src = IPAddress(message.family(), src_bytes, route_status.src_prefix);
313  entry->gateway = IPAddress(message.family(), gateway_bytes);
314  entry->metric = metric;
315  entry->scope = route_status.scope;
316  entry->from_rtnl = true;
317
318  return true;
319}
320
321void RoutingTable::RouteMsgHandler(const RTNLMessage &message) {
322  int interface_index;
323  RoutingTableEntry entry;
324
325  if (!ParseRoutingTableMessage(message, &interface_index, &entry)) {
326    return;
327  }
328
329  if (!route_queries_.empty() &&
330      message.route_status().protocol == RTPROT_UNSPEC) {
331    SLOG(Route, 3) << __func__ << ": Message seq: " << message.seq()
332                   << " mode " << message.mode()
333                   << ", next query seq: " << route_queries_.front().sequence;
334
335    // Purge queries that have expired (sequence number of this message is
336    // greater than that of the head of the route query sequence).  Do the
337    // math in a way that's roll-over independent.
338    while (route_queries_.front().sequence - message.seq() > kuint32max / 2) {
339      LOG(ERROR) << __func__ << ": Purging un-replied route request sequence "
340                 << route_queries_.front().sequence
341                 << " (< " << message.seq() << ")";
342      route_queries_.pop_front();
343      if (route_queries_.empty())
344        return;
345    }
346
347    const Query &query = route_queries_.front();
348    if (query.sequence == message.seq()) {
349      RoutingTableEntry add_entry(entry);
350      add_entry.from_rtnl = false;
351      add_entry.tag = query.tag;
352      bool added = true;
353      if (add_entry.gateway.IsDefault()) {
354        SLOG(Route, 2) << __func__ << ": Ignoring route result with no gateway "
355                       << "since we don't need to plumb these.";
356      } else {
357        SLOG(Route, 2) << __func__ << ": Adding host route to "
358                       << add_entry.dst.ToString();
359        added = AddRoute(interface_index, add_entry);
360      }
361      if (added && !query.callback.is_null()) {
362        SLOG(Route, 2) << "Running query callback.";
363        query.callback.Run(interface_index, add_entry);
364      }
365      route_queries_.pop_front();
366    }
367    return;
368  } else if (message.route_status().protocol != RTPROT_BOOT) {
369    // Responses to route queries come back with a protocol of
370    // RTPROT_UNSPEC.  Otherwise, normal route updates that we are
371    // interested in come with a protocol of RTPROT_BOOT.
372    return;
373  }
374
375  TableEntryVector &table = tables_[interface_index];
376  for (auto nent = table.begin(); nent != table.end(); ++nent)  {
377    if (nent->dst.Equals(entry.dst) &&
378        nent->src.Equals(entry.src) &&
379        nent->gateway.Equals(entry.gateway) &&
380        nent->scope == entry.scope) {
381      if (message.mode() == RTNLMessage::kModeDelete &&
382          nent->metric == entry.metric) {
383        table.erase(nent);
384      } else if (message.mode() == RTNLMessage::kModeAdd) {
385        nent->from_rtnl = true;
386        nent->metric = entry.metric;
387      }
388      return;
389    }
390  }
391
392  if (message.mode() == RTNLMessage::kModeAdd) {
393    SLOG(Route, 2) << __func__ << " adding"
394                   << " destination " << entry.dst.ToString()
395                   << " index " << interface_index
396                   << " gateway " << entry.gateway.ToString()
397                   << " metric " << entry.metric;
398    table.push_back(entry);
399  }
400}
401
402bool RoutingTable::ApplyRoute(uint32 interface_index,
403                              const RoutingTableEntry &entry,
404                              RTNLMessage::Mode mode,
405                              unsigned int flags) {
406  SLOG(Route, 2) << base::StringPrintf(
407      "%s: dst %s/%d src %s/%d index %d mode %d flags 0x%x",
408      __func__, entry.dst.ToString().c_str(), entry.dst.prefix(),
409      entry.src.ToString().c_str(), entry.src.prefix(),
410      interface_index, mode, flags);
411
412  RTNLMessage message(
413      RTNLMessage::kTypeRoute,
414      mode,
415      NLM_F_REQUEST | flags,
416      0,
417      0,
418      0,
419      entry.dst.family());
420
421  message.set_route_status(RTNLMessage::RouteStatus(
422      entry.dst.prefix(),
423      entry.src.prefix(),
424      RT_TABLE_MAIN,
425      RTPROT_BOOT,
426      entry.scope,
427      RTN_UNICAST,
428      0));
429
430  message.SetAttribute(RTA_DST, entry.dst.address());
431  if (!entry.src.IsDefault()) {
432    message.SetAttribute(RTA_SRC, entry.src.address());
433  }
434  if (!entry.gateway.IsDefault()) {
435    message.SetAttribute(RTA_GATEWAY, entry.gateway.address());
436  }
437  message.SetAttribute(RTA_PRIORITY,
438                       ByteString::CreateFromCPUUInt32(entry.metric));
439  message.SetAttribute(RTA_OIF,
440                       ByteString::CreateFromCPUUInt32(interface_index));
441
442  return rtnl_handler_->SendMessage(&message);
443}
444
445// Somewhat surprisingly, the kernel allows you to create multiple routes
446// to the same destination through the same interface with different metrics.
447// Therefore, to change the metric on a route, we can't just use the
448// NLM_F_REPLACE flag by itself.  We have to explicitly remove the old route.
449// We do so after creating the route at a new metric so there is no traffic
450// disruption to existing network streams.
451void RoutingTable::ReplaceMetric(uint32 interface_index,
452                                 RoutingTableEntry *entry,
453                                 uint32 metric) {
454  SLOG(Route, 2) << __func__ << " index " << interface_index
455                 << " metric " << metric;
456  RoutingTableEntry new_entry = *entry;
457  new_entry.metric = metric;
458  // First create the route at the new metric.
459  ApplyRoute(interface_index, new_entry, RTNLMessage::kModeAdd,
460             NLM_F_CREATE | NLM_F_REPLACE);
461  // Then delete the route at the old metric.
462  ApplyRoute(interface_index, *entry, RTNLMessage::kModeDelete, 0);
463  // Now, update our routing table (via |*entry|) from |new_entry|.
464  *entry = new_entry;
465}
466
467bool RoutingTable::FlushCache() {
468  static const char *kPaths[2] = { kRouteFlushPath4, kRouteFlushPath6 };
469  bool ret = true;
470
471  SLOG(Route, 2) << __func__;
472
473  for (size_t i = 0; i < arraysize(kPaths); ++i) {
474    if (base::WriteFile(FilePath(kPaths[i]), "-1", 2) != 2) {
475      LOG(ERROR) << base::StringPrintf("Cannot write to route flush file %s",
476                                       kPaths[i]);
477      ret = false;
478    }
479  }
480
481  return ret;
482}
483
484bool RoutingTable::RequestRouteToHost(const IPAddress &address,
485                                      int interface_index,
486                                      int tag,
487                                      const Query::Callback &callback) {
488  // Make sure we don't get a cached response that is no longer valid.
489  FlushCache();
490
491  RTNLMessage message(
492      RTNLMessage::kTypeRoute,
493      RTNLMessage::kModeQuery,
494      NLM_F_REQUEST,
495      0,
496      0,
497      interface_index,
498      address.family());
499
500  RTNLMessage::RouteStatus status;
501  status.dst_prefix = address.prefix();
502  message.set_route_status(status);
503  message.SetAttribute(RTA_DST, address.address());
504
505  if (interface_index != -1) {
506    message.SetAttribute(RTA_OIF,
507                         ByteString::CreateFromCPUUInt32(interface_index));
508  }
509
510  if (!rtnl_handler_->SendMessage(&message)) {
511    return false;
512  }
513
514  // Save the sequence number of the request so we can create a route for
515  // this host when we get a reply.
516  route_queries_.push_back(Query(message.seq(), tag, callback));
517
518  return true;
519}
520
521bool RoutingTable::CreateBlackholeRoute(int interface_index,
522                                        IPAddress::Family family,
523                                        uint32 metric) {
524  SLOG(Route, 2) << base::StringPrintf(
525      "%s: index %d family %s metric %d",
526      __func__, interface_index,
527      IPAddress::GetAddressFamilyName(family).c_str(), metric);
528
529  RTNLMessage message(
530      RTNLMessage::kTypeRoute,
531      RTNLMessage::kModeAdd,
532      NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL,
533      0,
534      0,
535      0,
536      family);
537
538  message.set_route_status(RTNLMessage::RouteStatus(
539      0,
540      0,
541      RT_TABLE_MAIN,
542      RTPROT_BOOT,
543      RT_SCOPE_UNIVERSE,
544      RTN_BLACKHOLE,
545      0));
546
547  message.SetAttribute(RTA_PRIORITY,
548                       ByteString::CreateFromCPUUInt32(metric));
549  message.SetAttribute(RTA_OIF,
550                       ByteString::CreateFromCPUUInt32(interface_index));
551
552  return rtnl_handler_->SendMessage(&message);
553}
554
555bool RoutingTable::CreateLinkRoute(int interface_index,
556                                   const IPAddress &local_address,
557                                   const IPAddress &remote_address) {
558  if (!local_address.CanReachAddress(remote_address)) {
559    LOG(ERROR) << __func__ << " failed: "
560               << remote_address.ToString() << " is not reachable from "
561               << local_address.ToString();
562    return false;
563  }
564
565  IPAddress default_address(local_address.family());
566  default_address.SetAddressToDefault();
567  IPAddress destination_address(remote_address);
568  destination_address.set_prefix(
569      IPAddress::GetMaxPrefixLength(remote_address.family()));
570  SLOG(Route, 2) << "Creating link route to " << destination_address.ToString()
571                 << " from " << local_address.ToString()
572                 << " on interface index " << interface_index;
573  return AddRoute(interface_index,
574                  RoutingTableEntry(destination_address,
575                                    local_address,
576                                    default_address,
577                                    0,
578                                    RT_SCOPE_LINK,
579                                    false));
580}
581
582}  // namespace shill
583