routing_table.cc revision a6bfe87a2c0bcb68d789473ca10988243229667b
1// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "shill/routing_table.h"
6
7#include <arpa/inet.h>
8#include <fcntl.h>
9#include <linux/netlink.h>
10#include <linux/rtnetlink.h>
11#include <netinet/ether.h>
12#include <net/if.h>
13#include <net/if_arp.h>
14#include <string.h>
15#include <sys/socket.h>
16#include <time.h>
17#include <unistd.h>
18
19#include <string>
20
21#include <base/bind.h>
22#include <base/file_path.h>
23#include <base/file_util.h>
24#include <base/hash_tables.h>
25#include <base/memory/scoped_ptr.h>
26#include <base/stl_util.h>
27#include <base/stringprintf.h>
28
29#include "shill/byte_string.h"
30#include "shill/logging.h"
31#include "shill/routing_table_entry.h"
32#include "shill/rtnl_handler.h"
33#include "shill/rtnl_listener.h"
34#include "shill/rtnl_message.h"
35
36using base::Bind;
37using base::Unretained;
38using std::deque;
39using std::string;
40using std::vector;
41
42namespace shill {
43
44namespace {
45base::LazyInstance<RoutingTable> g_routing_table = LAZY_INSTANCE_INITIALIZER;
46}  // namespace
47
48// static
49const char RoutingTable::kRouteFlushPath4[] = "/proc/sys/net/ipv4/route/flush";
50// static
51const char RoutingTable::kRouteFlushPath6[] = "/proc/sys/net/ipv6/route/flush";
52
53RoutingTable::RoutingTable()
54    : route_callback_(Bind(&RoutingTable::RouteMsgHandler, Unretained(this))),
55      route_listener_(NULL),
56      rtnl_handler_(RTNLHandler::GetInstance()) {
57  SLOG(Route, 2) << __func__;
58}
59
60RoutingTable::~RoutingTable() {}
61
62RoutingTable* RoutingTable::GetInstance() {
63  return g_routing_table.Pointer();
64}
65
66void RoutingTable::Start() {
67  SLOG(Route, 2) << __func__;
68
69  route_listener_.reset(
70      new RTNLListener(RTNLHandler::kRequestRoute, route_callback_));
71  rtnl_handler_->RequestDump(RTNLHandler::kRequestRoute);
72}
73
74void RoutingTable::Stop() {
75  SLOG(Route, 2) << __func__;
76
77  route_listener_.reset();
78}
79
80bool RoutingTable::AddRoute(int interface_index,
81                            const RoutingTableEntry &entry) {
82  SLOG(Route, 2) << __func__ << ": "
83                 << "destination " << entry.dst.ToString()
84                 << " index " << interface_index
85                 << " gateway " << entry.gateway.ToString()
86                 << " metric " << entry.metric;
87
88  CHECK(!entry.from_rtnl);
89  if (!ApplyRoute(interface_index,
90                  entry,
91                  RTNLMessage::kModeAdd,
92                  NLM_F_CREATE | NLM_F_EXCL)) {
93    return false;
94  }
95  tables_[interface_index].push_back(entry);
96  return true;
97}
98
99bool RoutingTable::GetDefaultRoute(int interface_index,
100                                   IPAddress::Family family,
101                                   RoutingTableEntry *entry) {
102  RoutingTableEntry *found_entry;
103  bool ret = GetDefaultRouteInternal(interface_index, family, &found_entry);
104  if (ret) {
105    *entry = *found_entry;
106  }
107  return ret;
108}
109
110bool RoutingTable::GetDefaultRouteInternal(int interface_index,
111                                           IPAddress::Family family,
112                                           RoutingTableEntry **entry) {
113  SLOG(Route, 2) << __func__ << " index " << interface_index
114                 << " family " << IPAddress::GetAddressFamilyName(family);
115
116  base::hash_map<int, vector<RoutingTableEntry> >::iterator table =
117    tables_.find(interface_index);
118
119  if (table == tables_.end()) {
120    SLOG(Route, 2) << __func__ << " no table";
121    return false;
122  }
123
124  vector<RoutingTableEntry>::iterator nent;
125
126  for (nent = table->second.begin(); nent != table->second.end(); ++nent) {
127    if (nent->dst.IsDefault() && nent->dst.family() == family) {
128      *entry = &(*nent);
129      SLOG(Route, 2) << __func__ << ": found"
130                     << " gateway " << nent->gateway.ToString()
131                     << " metric " << nent->metric;
132      return true;
133    }
134  }
135
136  SLOG(Route, 2) << __func__ << " no route";
137  return false;
138}
139
140bool RoutingTable::SetDefaultRoute(int interface_index,
141                                   const IPAddress &gateway_address,
142                                   uint32 metric) {
143  SLOG(Route, 2) << __func__ << " index " << interface_index
144                 << " metric " << metric;
145
146  RoutingTableEntry *old_entry;
147
148  if (GetDefaultRouteInternal(interface_index,
149                              gateway_address.family(),
150                              &old_entry)) {
151    if (old_entry->gateway.Equals(gateway_address)) {
152      if (old_entry->metric != metric) {
153        ReplaceMetric(interface_index, old_entry, metric);
154      }
155      return true;
156    } else {
157      // TODO(quiche): Update internal state as well?
158      ApplyRoute(interface_index,
159                 *old_entry,
160                 RTNLMessage::kModeDelete,
161                 0);
162    }
163  }
164
165  IPAddress default_address(gateway_address.family());
166  default_address.SetAddressToDefault();
167
168  return AddRoute(interface_index,
169                  RoutingTableEntry(default_address,
170                                    default_address,
171                                    gateway_address,
172                                    metric,
173                                    RT_SCOPE_UNIVERSE,
174                                    false));
175}
176
177bool RoutingTable::ConfigureRoutes(int interface_index,
178                                   const IPConfigRefPtr &ipconfig,
179                                   uint32 metric) {
180  bool ret = true;
181
182  IPAddress::Family address_family = ipconfig->properties().address_family;
183  const vector<IPConfig::Route> &routes = ipconfig->properties().routes;
184
185  for (vector<IPConfig::Route>::const_iterator it = routes.begin();
186       it != routes.end();
187       ++it) {
188    SLOG(Route, 3) << "Installing route:"
189                   << " Destination: " << it->host
190                   << " Netmask: " << it->netmask
191                   << " Gateway: " << it->gateway;
192    IPAddress destination_address(address_family);
193    IPAddress source_address(address_family);  // Left as default.
194    IPAddress gateway_address(address_family);
195    if (!destination_address.SetAddressFromString(it->host)) {
196      LOG(ERROR) << "Failed to parse host "
197                 << it->host;
198      ret = false;
199      continue;
200    }
201    if (!gateway_address.SetAddressFromString(it->gateway)) {
202      LOG(ERROR) << "Failed to parse gateway "
203                 << it->gateway;
204      ret = false;
205      continue;
206    }
207    destination_address.set_prefix(
208        IPAddress::GetPrefixLengthFromMask(address_family, it->netmask));
209    if (!AddRoute(interface_index,
210                  RoutingTableEntry(destination_address,
211                                    source_address,
212                                    gateway_address,
213                                    metric,
214                                    RT_SCOPE_UNIVERSE,
215                                    false))) {
216      ret = false;
217    }
218  }
219  return ret;
220}
221
222void RoutingTable::FlushRoutes(int interface_index) {
223  SLOG(Route, 2) << __func__;
224
225  base::hash_map<int, vector<RoutingTableEntry> >::iterator table =
226    tables_.find(interface_index);
227
228  if (table == tables_.end()) {
229    return;
230  }
231
232  vector<RoutingTableEntry>::iterator nent;
233
234  for (nent = table->second.begin(); nent != table->second.end(); ++nent) {
235    ApplyRoute(interface_index, *nent, RTNLMessage::kModeDelete, 0);
236  }
237  table->second.clear();
238}
239
240void RoutingTable::FlushRoutesWithTag(int tag) {
241  SLOG(Route, 2) << __func__;
242
243  base::hash_map<int, vector<RoutingTableEntry> >::iterator table;
244  for (table = tables_.begin(); table != tables_.end(); ++table) {
245    vector<RoutingTableEntry>::iterator nent;
246
247    for (nent = table->second.begin(); nent != table->second.end();) {
248      if (nent->tag == tag) {
249        ApplyRoute(table->first, *nent, RTNLMessage::kModeDelete, 0);
250        nent = table->second.erase(nent);
251      } else {
252        ++nent;
253      }
254    }
255  }
256}
257
258void RoutingTable::ResetTable(int interface_index) {
259  tables_.erase(interface_index);
260}
261
262void RoutingTable::SetDefaultMetric(int interface_index, uint32 metric) {
263  SLOG(Route, 2) << __func__ << " index " << interface_index
264                 << " metric " << metric;
265
266  RoutingTableEntry *entry;
267  if (GetDefaultRouteInternal(
268          interface_index, IPAddress::kFamilyIPv4, &entry) &&
269      entry->metric != metric) {
270    ReplaceMetric(interface_index, entry, metric);
271  }
272
273  if (GetDefaultRouteInternal(
274          interface_index, IPAddress::kFamilyIPv6, &entry) &&
275      entry->metric != metric) {
276    ReplaceMetric(interface_index, entry, metric);
277  }
278}
279
280// static
281bool RoutingTable::ParseRoutingTableMessage(const RTNLMessage &message,
282                                            int *interface_index,
283                                            RoutingTableEntry *entry) {
284  if (message.type() != RTNLMessage::kTypeRoute ||
285      message.family() == IPAddress::kFamilyUnknown ||
286      !message.HasAttribute(RTA_OIF)) {
287    return false;
288  }
289
290  const RTNLMessage::RouteStatus &route_status = message.route_status();
291
292  if (route_status.type != RTN_UNICAST ||
293      route_status.table != RT_TABLE_MAIN) {
294    return false;
295  }
296
297  uint32 interface_index_u32 = 0;
298  if (!message.GetAttribute(RTA_OIF).ConvertToCPUUInt32(&interface_index_u32)) {
299    return false;
300  }
301  *interface_index = interface_index_u32;
302
303  uint32 metric = 0;
304  if (message.HasAttribute(RTA_PRIORITY)) {
305    message.GetAttribute(RTA_PRIORITY).ConvertToCPUUInt32(&metric);
306  }
307
308  IPAddress default_addr(message.family());
309  default_addr.SetAddressToDefault();
310
311  ByteString dst_bytes(default_addr.address());
312  if (message.HasAttribute(RTA_DST)) {
313    dst_bytes = message.GetAttribute(RTA_DST);
314  }
315  ByteString src_bytes(default_addr.address());
316  if (message.HasAttribute(RTA_SRC)) {
317    src_bytes = message.GetAttribute(RTA_SRC);
318  }
319  ByteString gateway_bytes(default_addr.address());
320  if (message.HasAttribute(RTA_GATEWAY)) {
321    gateway_bytes = message.GetAttribute(RTA_GATEWAY);
322  }
323
324  entry->dst = IPAddress(message.family(), dst_bytes, route_status.dst_prefix);
325  entry->src = IPAddress(message.family(), src_bytes, route_status.src_prefix);
326  entry->gateway = IPAddress(message.family(), gateway_bytes);
327  entry->metric = metric;
328  entry->scope = route_status.scope;
329  entry->from_rtnl = true;
330
331  return true;
332}
333
334void RoutingTable::RouteMsgHandler(const RTNLMessage &message) {
335  int interface_index;
336  RoutingTableEntry entry;
337
338  if (!ParseRoutingTableMessage(message, &interface_index, &entry)) {
339    return;
340  }
341
342  if (!route_queries_.empty() &&
343      message.route_status().protocol == RTPROT_UNSPEC) {
344    SLOG(Route, 3) << __func__ << ": Message seq: " << message.seq()
345                   << " mode " << message.mode()
346                   << ", next query seq: " << route_queries_.front().sequence;
347
348    // Purge queries that have expired (sequence number of this message is
349    // greater than that of the head of the route query sequence).  Do the
350    // math in a way that's roll-over independent.
351    while (route_queries_.front().sequence - message.seq() > kuint32max / 2) {
352      LOG(ERROR) << __func__ << ": Purging un-replied route request sequence "
353                 << route_queries_.front().sequence
354                 << " (< " << message.seq() << ")";
355      route_queries_.pop_front();
356      if (route_queries_.empty())
357        return;
358    }
359
360    const Query &query = route_queries_.front();
361    if (query.sequence == message.seq()) {
362      RoutingTableEntry add_entry(entry);
363      add_entry.from_rtnl = false;
364      add_entry.tag = query.tag;
365      bool added = true;
366      if (add_entry.gateway.IsDefault()) {
367        SLOG(Route, 2) << __func__ << ": Ignoring route result with no gateway "
368                       << "since we don't need to plumb these.";
369      } else {
370        SLOG(Route, 2) << __func__ << ": Adding host route to "
371                       << add_entry.dst.ToString();
372        added = AddRoute(interface_index, add_entry);
373      }
374      if (added && !query.callback.is_null()) {
375        SLOG(Route, 2) << "Running query callback.";
376        query.callback.Run(interface_index, add_entry);
377      }
378      route_queries_.pop_front();
379    }
380    return;
381  } else if (message.route_status().protocol != RTPROT_BOOT) {
382    // Responses to route queries come back with a protocol of
383    // RTPROT_UNSPEC.  Otherwise, normal route updates that we are
384    // interested in come with a protocol of RTPROT_BOOT.
385    return;
386  }
387
388  vector<RoutingTableEntry> &table = tables_[interface_index];
389  vector<RoutingTableEntry>::iterator nent;
390  for (nent = table.begin(); nent != table.end(); ++nent) {
391    if (nent->dst.Equals(entry.dst) &&
392        nent->src.Equals(entry.src) &&
393        nent->gateway.Equals(entry.gateway) &&
394        nent->scope == entry.scope) {
395      if (message.mode() == RTNLMessage::kModeDelete &&
396          nent->metric == entry.metric) {
397        table.erase(nent);
398      } else if (message.mode() == RTNLMessage::kModeAdd) {
399        nent->from_rtnl = true;
400        nent->metric = entry.metric;
401      }
402      return;
403    }
404  }
405
406  if (message.mode() == RTNLMessage::kModeAdd) {
407    SLOG(Route, 2) << __func__ << " adding"
408                   << " destination " << entry.dst.ToString()
409                   << " index " << interface_index
410                   << " gateway " << entry.gateway.ToString()
411                   << " metric " << entry.metric;
412    table.push_back(entry);
413  }
414}
415
416bool RoutingTable::ApplyRoute(uint32 interface_index,
417                              const RoutingTableEntry &entry,
418                              RTNLMessage::Mode mode,
419                              unsigned int flags) {
420  SLOG(Route, 2) << base::StringPrintf(
421      "%s: dst %s/%d src %s/%d index %d mode %d flags 0x%x",
422      __func__, entry.dst.ToString().c_str(), entry.dst.prefix(),
423      entry.src.ToString().c_str(), entry.src.prefix(),
424      interface_index, mode, flags);
425
426  RTNLMessage message(
427      RTNLMessage::kTypeRoute,
428      mode,
429      NLM_F_REQUEST | flags,
430      0,
431      0,
432      0,
433      entry.dst.family());
434
435  message.set_route_status(RTNLMessage::RouteStatus(
436      entry.dst.prefix(),
437      entry.src.prefix(),
438      RT_TABLE_MAIN,
439      RTPROT_BOOT,
440      entry.scope,
441      RTN_UNICAST,
442      0));
443
444  message.SetAttribute(RTA_DST, entry.dst.address());
445  if (!entry.src.IsDefault()) {
446    message.SetAttribute(RTA_SRC, entry.src.address());
447  }
448  if (!entry.gateway.IsDefault()) {
449    message.SetAttribute(RTA_GATEWAY, entry.gateway.address());
450  }
451  message.SetAttribute(RTA_PRIORITY,
452                       ByteString::CreateFromCPUUInt32(entry.metric));
453  message.SetAttribute(RTA_OIF,
454                       ByteString::CreateFromCPUUInt32(interface_index));
455
456  return rtnl_handler_->SendMessage(&message);
457}
458
459// Somewhat surprisingly, the kernel allows you to create multiple routes
460// to the same destination through the same interface with different metrics.
461// Therefore, to change the metric on a route, we can't just use the
462// NLM_F_REPLACE flag by itself.  We have to explicitly remove the old route.
463// We do so after creating the route at a new metric so there is no traffic
464// disruption to existing network streams.
465void RoutingTable::ReplaceMetric(uint32 interface_index,
466                                 RoutingTableEntry *entry,
467                                 uint32 metric) {
468  SLOG(Route, 2) << __func__ << " index " << interface_index
469                 << " metric " << metric;
470  RoutingTableEntry new_entry = *entry;
471  new_entry.metric = metric;
472  // First create the route at the new metric.
473  ApplyRoute(interface_index, new_entry, RTNLMessage::kModeAdd,
474             NLM_F_CREATE | NLM_F_REPLACE);
475  // Then delete the route at the old metric.
476  ApplyRoute(interface_index, *entry, RTNLMessage::kModeDelete, 0);
477  // Now, update our routing table (via |*entry|) from |new_entry|.
478  *entry = new_entry;
479}
480
481bool RoutingTable::FlushCache() {
482  static const char *kPaths[2] = { kRouteFlushPath4, kRouteFlushPath6 };
483  bool ret = true;
484
485  SLOG(Route, 2) << __func__;
486
487  for (size_t i = 0; i < arraysize(kPaths); ++i) {
488    if (file_util::WriteFile(FilePath(kPaths[i]), "-1", 2) != 2) {
489      LOG(ERROR) << base::StringPrintf("Cannot write to route flush file %s",
490                                       kPaths[i]);
491      ret = false;
492    }
493  }
494
495  return ret;
496}
497
498bool RoutingTable::RequestRouteToHost(const IPAddress &address,
499                                      int interface_index,
500                                      int tag,
501                                      const Query::Callback &callback) {
502  // Make sure we don't get a cached response that is no longer valid.
503  FlushCache();
504
505  RTNLMessage message(
506      RTNLMessage::kTypeRoute,
507      RTNLMessage::kModeQuery,
508      NLM_F_REQUEST,
509      0,
510      0,
511      interface_index,
512      address.family());
513
514  RTNLMessage::RouteStatus status;
515  status.dst_prefix = address.prefix();
516  message.set_route_status(status);
517  message.SetAttribute(RTA_DST, address.address());
518
519  if (interface_index != -1) {
520    message.SetAttribute(RTA_OIF,
521                         ByteString::CreateFromCPUUInt32(interface_index));
522  }
523
524  if (!rtnl_handler_->SendMessage(&message)) {
525    return false;
526  }
527
528  // Save the sequence number of the request so we can create a route for
529  // this host when we get a reply.
530  route_queries_.push_back(Query(message.seq(), tag, callback));
531
532  return true;
533}
534
535bool RoutingTable::CreateBlackholeRoute(int interface_index,
536                                        IPAddress::Family family,
537                                        uint32 metric) {
538  SLOG(Route, 2) << base::StringPrintf(
539      "%s: index %d family %s metric %d",
540      __func__, interface_index,
541      IPAddress::GetAddressFamilyName(family).c_str(), metric);
542
543  RTNLMessage message(
544      RTNLMessage::kTypeRoute,
545      RTNLMessage::kModeAdd,
546      NLM_F_REQUEST | NLM_F_CREATE | NLM_F_EXCL,
547      0,
548      0,
549      0,
550      family);
551
552  message.set_route_status(RTNLMessage::RouteStatus(
553      0,
554      0,
555      RT_TABLE_MAIN,
556      RTPROT_BOOT,
557      RT_SCOPE_UNIVERSE,
558      RTN_BLACKHOLE,
559      0));
560
561  message.SetAttribute(RTA_PRIORITY,
562                       ByteString::CreateFromCPUUInt32(metric));
563  message.SetAttribute(RTA_OIF,
564                       ByteString::CreateFromCPUUInt32(interface_index));
565
566  return rtnl_handler_->SendMessage(&message);
567}
568
569bool RoutingTable::CreateLinkRoute(int interface_index,
570                                   const IPAddress &local_address,
571                                   const IPAddress &remote_address) {
572  if (!local_address.CanReachAddress(remote_address)) {
573    LOG(ERROR) << __func__ << " failed: "
574               << remote_address.ToString() << " is not reachable from "
575               << local_address.ToString();
576    return false;
577  }
578
579  IPAddress default_address(local_address.family());
580  default_address.SetAddressToDefault();
581  IPAddress destination_address(remote_address);
582  destination_address.set_prefix(
583      IPAddress::GetMaxPrefixLength(remote_address.family()));
584  SLOG(Route, 2) << "Creating link route to " << destination_address.ToString()
585                 << " from " << local_address.ToString()
586                 << " on interface index " << interface_index;
587  return AddRoute(interface_index,
588                  RoutingTableEntry(destination_address,
589                                    local_address,
590                                    default_address,
591                                    0,
592                                    RT_SCOPE_LINK,
593                                    false));
594}
595
596}  // namespace shill
597