routing_table.cc revision abf6d289b2d29487f0a51b6138a127707a38507a
1// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "shill/routing_table.h"
6
7#include <arpa/inet.h>
8#include <fcntl.h>
9#include <linux/netlink.h>
10#include <linux/rtnetlink.h>
11#include <netinet/ether.h>
12#include <net/if.h>
13#include <net/if_arp.h>
14#include <string.h>
15#include <sys/socket.h>
16#include <time.h>
17#include <unistd.h>
18
19#include <string>
20
21#include <base/bind.h>
22#include <base/file_path.h>
23#include <base/file_util.h>
24#include <base/hash_tables.h>
25#include <base/logging.h>
26#include <base/memory/scoped_ptr.h>
27#include <base/stl_util.h>
28#include <base/stringprintf.h>
29
30#include "shill/byte_string.h"
31#include "shill/routing_table_entry.h"
32#include "shill/rtnl_handler.h"
33#include "shill/rtnl_listener.h"
34#include "shill/rtnl_message.h"
35#include "shill/scope_logger.h"
36
37using base::Bind;
38using base::Unretained;
39using std::deque;
40using std::string;
41using std::vector;
42
43namespace shill {
44
45namespace {
46base::LazyInstance<RoutingTable> g_routing_table = LAZY_INSTANCE_INITIALIZER;
47}  // namespace
48
49// static
50const char RoutingTable::kRouteFlushPath4[] = "/proc/sys/net/ipv4/route/flush";
51// static
52const char RoutingTable::kRouteFlushPath6[] = "/proc/sys/net/ipv6/route/flush";
53
54RoutingTable::RoutingTable()
55    : route_callback_(Bind(&RoutingTable::RouteMsgHandler, Unretained(this))),
56      route_listener_(NULL),
57      rtnl_handler_(RTNLHandler::GetInstance()) {
58  SLOG(Route, 2) << __func__;
59}
60
61RoutingTable::~RoutingTable() {}
62
63RoutingTable* RoutingTable::GetInstance() {
64  return g_routing_table.Pointer();
65}
66
67void RoutingTable::Start() {
68  SLOG(Route, 2) << __func__;
69
70  route_listener_.reset(
71      new RTNLListener(RTNLHandler::kRequestRoute, route_callback_));
72  rtnl_handler_->RequestDump(RTNLHandler::kRequestRoute);
73}
74
75void RoutingTable::Stop() {
76  SLOG(Route, 2) << __func__;
77
78  route_listener_.reset();
79}
80
81bool RoutingTable::AddRoute(int interface_index,
82                            const RoutingTableEntry &entry) {
83  SLOG(Route, 2) << __func__ << ": "
84                 << "destination " << entry.dst.ToString()
85                 << " index " << interface_index
86                 << " gateway " << entry.gateway.ToString()
87                 << " metric " << entry.metric;
88
89  CHECK(!entry.from_rtnl);
90  if (!ApplyRoute(interface_index,
91                  entry,
92                  RTNLMessage::kModeAdd,
93                  NLM_F_CREATE | NLM_F_EXCL)) {
94    return false;
95  }
96  tables_[interface_index].push_back(entry);
97  return true;
98}
99
100bool RoutingTable::GetDefaultRoute(int interface_index,
101                                   IPAddress::Family family,
102                                   RoutingTableEntry *entry) {
103  RoutingTableEntry *found_entry;
104  bool ret = GetDefaultRouteInternal(interface_index, family, &found_entry);
105  if (ret) {
106    *entry = *found_entry;
107  }
108  return ret;
109}
110
111bool RoutingTable::GetDefaultRouteInternal(int interface_index,
112                                           IPAddress::Family family,
113                                           RoutingTableEntry **entry) {
114  SLOG(Route, 2) << __func__ << " index " << interface_index
115                 << " family " << IPAddress::GetAddressFamilyName(family);
116
117  base::hash_map<int, vector<RoutingTableEntry> >::iterator table =
118    tables_.find(interface_index);
119
120  if (table == tables_.end()) {
121    SLOG(Route, 2) << __func__ << " no table";
122    return false;
123  }
124
125  vector<RoutingTableEntry>::iterator nent;
126
127  for (nent = table->second.begin(); nent != table->second.end(); ++nent) {
128    if (nent->dst.IsDefault() && nent->dst.family() == family) {
129      *entry = &(*nent);
130      SLOG(Route, 2) << __func__ << ": found"
131                     << " gateway " << nent->gateway.ToString()
132                     << " metric " << nent->metric;
133      return true;
134    }
135  }
136
137  SLOG(Route, 2) << __func__ << " no route";
138  return false;
139}
140
141bool RoutingTable::SetDefaultRoute(int interface_index,
142                                   const IPAddress &gateway_address,
143                                   uint32 metric) {
144  SLOG(Route, 2) << __func__ << " index " << interface_index
145                 << " metric " << metric;
146
147  RoutingTableEntry *old_entry;
148
149  if (GetDefaultRouteInternal(interface_index,
150                              gateway_address.family(),
151                              &old_entry)) {
152    if (old_entry->gateway.Equals(gateway_address)) {
153      if (old_entry->metric != metric) {
154        ReplaceMetric(interface_index, old_entry, metric);
155      }
156      return true;
157    } else {
158      // TODO(quiche): Update internal state as well?
159      ApplyRoute(interface_index,
160                 *old_entry,
161                 RTNLMessage::kModeDelete,
162                 0);
163    }
164  }
165
166  IPAddress default_address(gateway_address.family());
167  default_address.SetAddressToDefault();
168
169  return AddRoute(interface_index,
170                  RoutingTableEntry(default_address,
171                                    default_address,
172                                    gateway_address,
173                                    metric,
174                                    RT_SCOPE_UNIVERSE,
175                                    false));
176}
177
178bool RoutingTable::ConfigureRoutes(int interface_index,
179                                   const IPConfigRefPtr &ipconfig,
180                                   uint32 metric) {
181  bool ret = true;
182
183  IPAddress::Family address_family = ipconfig->properties().address_family;
184  const vector<IPConfig::Route> &routes = ipconfig->properties().routes;
185
186  for (vector<IPConfig::Route>::const_iterator it = routes.begin();
187       it != routes.end();
188       ++it) {
189    SLOG(Route, 3) << "Installing route:"
190                   << " Destination: " << it->host
191                   << " Netmask: " << it->netmask
192                   << " Gateway: " << it->gateway;
193    IPAddress destination_address(address_family);
194    IPAddress source_address(address_family);  // Left as default.
195    IPAddress gateway_address(address_family);
196    if (!destination_address.SetAddressFromString(it->host)) {
197      LOG(ERROR) << "Failed to parse host "
198                 << it->host;
199      ret = false;
200      continue;
201    }
202    if (!gateway_address.SetAddressFromString(it->gateway)) {
203      LOG(ERROR) << "Failed to parse gateway "
204                 << it->gateway;
205      ret = false;
206      continue;
207    }
208    destination_address.set_prefix(
209        IPAddress::GetPrefixLengthFromMask(address_family, it->netmask));
210    if (!AddRoute(interface_index,
211                  RoutingTableEntry(destination_address,
212                                    source_address,
213                                    gateway_address,
214                                    metric,
215                                    RT_SCOPE_UNIVERSE,
216                                    false))) {
217      ret = false;
218    }
219  }
220  return ret;
221}
222
223void RoutingTable::FlushRoutes(int interface_index) {
224  SLOG(Route, 2) << __func__;
225
226  base::hash_map<int, vector<RoutingTableEntry> >::iterator table =
227    tables_.find(interface_index);
228
229  if (table == tables_.end()) {
230    return;
231  }
232
233  vector<RoutingTableEntry>::iterator nent;
234
235  for (nent = table->second.begin(); nent != table->second.end(); ++nent) {
236    ApplyRoute(interface_index, *nent, RTNLMessage::kModeDelete, 0);
237  }
238  table->second.clear();
239}
240
241void RoutingTable::FlushRoutesWithTag(int tag) {
242  SLOG(Route, 2) << __func__;
243
244  base::hash_map<int, vector<RoutingTableEntry> >::iterator table;
245  for (table = tables_.begin(); table != tables_.end(); ++table) {
246    vector<RoutingTableEntry>::iterator nent;
247
248    for (nent = table->second.begin(); nent != table->second.end();) {
249      if (nent->tag == tag) {
250        ApplyRoute(table->first, *nent, RTNLMessage::kModeDelete, 0);
251        nent = table->second.erase(nent);
252      } else {
253        ++nent;
254      }
255    }
256  }
257}
258
259void RoutingTable::ResetTable(int interface_index) {
260  tables_.erase(interface_index);
261}
262
263void RoutingTable::SetDefaultMetric(int interface_index, uint32 metric) {
264  SLOG(Route, 2) << __func__ << " index " << interface_index
265                 << " metric " << metric;
266
267  RoutingTableEntry *entry;
268  if (GetDefaultRouteInternal(
269          interface_index, IPAddress::kFamilyIPv4, &entry) &&
270      entry->metric != metric) {
271    ReplaceMetric(interface_index, entry, metric);
272  }
273
274  if (GetDefaultRouteInternal(
275          interface_index, IPAddress::kFamilyIPv6, &entry) &&
276      entry->metric != metric) {
277    ReplaceMetric(interface_index, entry, metric);
278  }
279}
280
281// static
282bool RoutingTable::ParseRoutingTableMessage(const RTNLMessage &message,
283                                            int *interface_index,
284                                            RoutingTableEntry *entry) {
285  if (message.type() != RTNLMessage::kTypeRoute ||
286      message.family() == IPAddress::kFamilyUnknown ||
287      !message.HasAttribute(RTA_OIF)) {
288    return false;
289  }
290
291  const RTNLMessage::RouteStatus &route_status = message.route_status();
292
293  if (route_status.type != RTN_UNICAST ||
294      route_status.table != RT_TABLE_MAIN) {
295    return false;
296  }
297
298  uint32 interface_index_u32 = 0;
299  if (!message.GetAttribute(RTA_OIF).ConvertToCPUUInt32(&interface_index_u32)) {
300    return false;
301  }
302  *interface_index = interface_index_u32;
303
304  uint32 metric = 0;
305  if (message.HasAttribute(RTA_PRIORITY)) {
306    message.GetAttribute(RTA_PRIORITY).ConvertToCPUUInt32(&metric);
307  }
308
309  IPAddress default_addr(message.family());
310  default_addr.SetAddressToDefault();
311
312  ByteString dst_bytes(default_addr.address());
313  if (message.HasAttribute(RTA_DST)) {
314    dst_bytes = message.GetAttribute(RTA_DST);
315  }
316  ByteString src_bytes(default_addr.address());
317  if (message.HasAttribute(RTA_SRC)) {
318    src_bytes = message.GetAttribute(RTA_SRC);
319  }
320  ByteString gateway_bytes(default_addr.address());
321  if (message.HasAttribute(RTA_GATEWAY)) {
322    gateway_bytes = message.GetAttribute(RTA_GATEWAY);
323  }
324
325  entry->dst = IPAddress(message.family(), dst_bytes, route_status.dst_prefix);
326  entry->src = IPAddress(message.family(), src_bytes, route_status.src_prefix);
327  entry->gateway = IPAddress(message.family(), gateway_bytes);
328  entry->metric = metric;
329  entry->scope = route_status.scope;
330  entry->from_rtnl = true;
331
332  return true;
333}
334
335void RoutingTable::RouteMsgHandler(const RTNLMessage &message) {
336  int interface_index;
337  RoutingTableEntry entry;
338
339  if (!ParseRoutingTableMessage(message, &interface_index, &entry)) {
340    return;
341  }
342
343  if (!route_queries_.empty() &&
344      message.route_status().protocol == RTPROT_UNSPEC) {
345    SLOG(Route, 3) << __func__ << ": Message seq: " << message.seq()
346                   << " mode " << message.mode()
347                   << ", next query seq: " << route_queries_.front().sequence;
348
349    // Purge queries that have expired (sequence number of this message is
350    // greater than that of the head of the route query sequence).  Do the
351    // math in a way that's roll-over independent.
352    while (route_queries_.front().sequence - message.seq() > kuint32max / 2) {
353      LOG(ERROR) << __func__ << ": Purging un-replied route request sequence "
354                 << route_queries_.front().sequence
355                 << " (< " << message.seq() << ")";
356      route_queries_.pop_front();
357      if (route_queries_.empty())
358        return;
359    }
360
361    const Query &query = route_queries_.front();
362    if (query.sequence == message.seq()) {
363      RoutingTableEntry add_entry(entry);
364      add_entry.from_rtnl = false;
365      add_entry.tag = query.tag;
366      bool added = true;
367      if (add_entry.gateway.IsDefault()) {
368        SLOG(Route, 2) << __func__ << ": Ignoring route result with no gateway "
369                       << "since we don't need to plumb these.";
370      } else {
371        SLOG(Route, 2) << __func__ << ": Adding host route to "
372                       << add_entry.dst.ToString();
373        added = AddRoute(interface_index, add_entry);
374      }
375      if (added && !query.callback.is_null()) {
376        SLOG(Route, 2) << "Running query callback.";
377        query.callback.Run(interface_index, add_entry);
378      }
379      route_queries_.pop_front();
380    }
381    return;
382  } else if (message.route_status().protocol != RTPROT_BOOT) {
383    // Responses to route queries come back with a protocol of
384    // RTPROT_UNSPEC.  Otherwise, normal route updates that we are
385    // interested in come with a protocol of RTPROT_BOOT.
386    return;
387  }
388
389  vector<RoutingTableEntry> &table = tables_[interface_index];
390  vector<RoutingTableEntry>::iterator nent;
391  for (nent = table.begin(); nent != table.end(); ++nent) {
392    if (nent->dst.Equals(entry.dst) &&
393        nent->src.Equals(entry.src) &&
394        nent->gateway.Equals(entry.gateway) &&
395        nent->scope == entry.scope) {
396      if (message.mode() == RTNLMessage::kModeDelete &&
397          nent->metric == entry.metric) {
398        table.erase(nent);
399      } else if (message.mode() == RTNLMessage::kModeAdd) {
400        nent->from_rtnl = true;
401        nent->metric = entry.metric;
402      }
403      return;
404    }
405  }
406
407  if (message.mode() == RTNLMessage::kModeAdd) {
408    SLOG(Route, 2) << __func__ << " adding"
409                   << " destination " << entry.dst.ToString()
410                   << " index " << interface_index
411                   << " gateway " << entry.gateway.ToString()
412                   << " metric " << entry.metric;
413    table.push_back(entry);
414  }
415}
416
417bool RoutingTable::ApplyRoute(uint32 interface_index,
418                              const RoutingTableEntry &entry,
419                              RTNLMessage::Mode mode,
420                              unsigned int flags) {
421  SLOG(Route, 2) << base::StringPrintf(
422      "%s: dst %s/%d src %s/%d index %d mode %d flags 0x%x",
423      __func__, entry.dst.ToString().c_str(), entry.dst.prefix(),
424      entry.src.ToString().c_str(), entry.src.prefix(),
425      interface_index, mode, flags);
426
427  RTNLMessage message(
428      RTNLMessage::kTypeRoute,
429      mode,
430      NLM_F_REQUEST | flags,
431      0,
432      0,
433      0,
434      entry.dst.family());
435
436  message.set_route_status(RTNLMessage::RouteStatus(
437      entry.dst.prefix(),
438      entry.src.prefix(),
439      RT_TABLE_MAIN,
440      RTPROT_BOOT,
441      entry.scope,
442      RTN_UNICAST,
443      0));
444
445  message.SetAttribute(RTA_DST, entry.dst.address());
446  if (!entry.src.IsDefault()) {
447    message.SetAttribute(RTA_SRC, entry.src.address());
448  }
449  if (!entry.gateway.IsDefault()) {
450    message.SetAttribute(RTA_GATEWAY, entry.gateway.address());
451  }
452  message.SetAttribute(RTA_PRIORITY,
453                       ByteString::CreateFromCPUUInt32(entry.metric));
454  message.SetAttribute(RTA_OIF,
455                       ByteString::CreateFromCPUUInt32(interface_index));
456
457  return rtnl_handler_->SendMessage(&message);
458}
459
460// Somewhat surprisingly, the kernel allows you to create multiple routes
461// to the same destination through the same interface with different metrics.
462// Therefore, to change the metric on a route, we can't just use the
463// NLM_F_REPLACE flag by itself.  We have to explicitly remove the old route.
464// We do so after creating the route at a new metric so there is no traffic
465// disruption to existing network streams.
466void RoutingTable::ReplaceMetric(uint32 interface_index,
467                                 RoutingTableEntry *entry,
468                                 uint32 metric) {
469  SLOG(Route, 2) << __func__ << " index " << interface_index
470                 << " metric " << metric;
471  RoutingTableEntry new_entry = *entry;
472  new_entry.metric = metric;
473  // First create the route at the new metric.
474  ApplyRoute(interface_index, new_entry, RTNLMessage::kModeAdd,
475             NLM_F_CREATE | NLM_F_REPLACE);
476  // Then delete the route at the old metric.
477  ApplyRoute(interface_index, *entry, RTNLMessage::kModeDelete, 0);
478  // Now, update our routing table (via |*entry|) from |new_entry|.
479  *entry = new_entry;
480}
481
482bool RoutingTable::FlushCache() {
483  static const char *kPaths[2] = { kRouteFlushPath4, kRouteFlushPath6 };
484  bool ret = true;
485
486  SLOG(Route, 2) << __func__;
487
488  for (size_t i = 0; i < arraysize(kPaths); ++i) {
489    if (file_util::WriteFile(FilePath(kPaths[i]), "-1", 2) != 2) {
490      LOG(ERROR) << base::StringPrintf("Cannot write to route flush file %s",
491                                       kPaths[i]);
492      ret = false;
493    }
494  }
495
496  return ret;
497}
498
499bool RoutingTable::RequestRouteToHost(const IPAddress &address,
500                                      int interface_index,
501                                      int tag,
502                                      const Query::Callback &callback) {
503  RTNLMessage message(
504      RTNLMessage::kTypeRoute,
505      RTNLMessage::kModeQuery,
506      NLM_F_REQUEST,
507      0,
508      0,
509      interface_index,
510      address.family());
511
512  RTNLMessage::RouteStatus status;
513  status.dst_prefix = address.prefix();
514  message.set_route_status(status);
515  message.SetAttribute(RTA_DST, address.address());
516
517  if (interface_index != -1) {
518    message.SetAttribute(RTA_OIF,
519                         ByteString::CreateFromCPUUInt32(interface_index));
520  }
521
522  if (!rtnl_handler_->SendMessage(&message)) {
523    return false;
524  }
525
526  // Save the sequence number of the request so we can create a route for
527  // this host when we get a reply.
528  route_queries_.push_back(Query(message.seq(), tag, callback));
529
530  return true;
531}
532
533
534void RoutingTable::CancelQueryCallback(const Query::Callback &callback) {
535  SLOG(Route, 2) << __func__;
536  for (deque<Query>::iterator it = route_queries_.begin();
537       it != route_queries_.end(); ++it) {
538    if (it->callback.Equals(callback)) {
539      it->callback.Reset();
540    }
541  }
542}
543
544}  // namespace shill
545