routing_table.cc revision bf14e94cbd47d6320eec846f1ca4def026840e14
1// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "shill/routing_table.h"
6
7#include <arpa/inet.h>
8#include <fcntl.h>
9#include <linux/netlink.h>
10#include <linux/rtnetlink.h>
11#include <netinet/ether.h>
12#include <net/if.h>
13#include <net/if_arp.h>
14#include <string.h>
15#include <sys/socket.h>
16#include <time.h>
17#include <unistd.h>
18
19#include <string>
20
21#include <base/callback_old.h>
22#include <base/file_path.h>
23#include <base/file_util.h>
24#include <base/hash_tables.h>
25#include <base/logging.h>
26#include <base/memory/scoped_ptr.h>
27#include <base/stl_util-inl.h>
28#include <base/stringprintf.h>
29
30#include "shill/byte_string.h"
31#include "shill/routing_table_entry.h"
32#include "shill/rtnl_handler.h"
33#include "shill/rtnl_listener.h"
34#include "shill/rtnl_message.h"
35
36using std::string;
37using std::vector;
38
39namespace shill {
40
41static base::LazyInstance<RoutingTable> g_routing_table(
42    base::LINKER_INITIALIZED);
43
44// static
45const char RoutingTable::kRouteFlushPath4[] = "/proc/sys/net/ipv4/route/flush";
46// static
47const char RoutingTable::kRouteFlushPath6[] = "/proc/sys/net/ipv6/route/flush";
48
49RoutingTable::RoutingTable()
50    : route_callback_(NewCallback(this, &RoutingTable::RouteMsgHandler)),
51      route_listener_(NULL) {
52  VLOG(2) << __func__;
53}
54
55RoutingTable::~RoutingTable() {}
56
57RoutingTable* RoutingTable::GetInstance() {
58  return g_routing_table.Pointer();
59}
60
61void RoutingTable::Start() {
62  VLOG(2) << __func__;
63
64  route_listener_.reset(
65      new RTNLListener(RTNLHandler::kRequestRoute, route_callback_.get()));
66  RTNLHandler::GetInstance()->RequestDump(
67      RTNLHandler::kRequestRoute);
68}
69
70void RoutingTable::Stop() {
71  VLOG(2) << __func__;
72
73  route_listener_.reset();
74}
75
76bool RoutingTable::AddRoute(int interface_index,
77                            const RoutingTableEntry &entry) {
78  VLOG(2) << __func__ << " "
79          << "index " << interface_index << " "
80          << "gateway " << entry.gateway.ToString() << " "
81          << "metric " << entry.metric;
82
83  CHECK(!entry.from_rtnl);
84  if (!ApplyRoute(interface_index,
85                  entry,
86                  RTNLMessage::kModeAdd,
87                  NLM_F_CREATE | NLM_F_EXCL)) {
88    return false;
89  }
90  tables_[interface_index].push_back(entry);
91  return true;
92}
93
94bool RoutingTable::GetDefaultRoute(int interface_index,
95                                   IPAddress::Family family,
96                                   RoutingTableEntry *entry) {
97  RoutingTableEntry *found_entry;
98  bool ret = GetDefaultRouteInternal(interface_index, family, &found_entry);
99  if (ret) {
100    *entry = *found_entry;
101  }
102  return ret;
103}
104
105bool RoutingTable::GetDefaultRouteInternal(int interface_index,
106                                           IPAddress::Family family,
107                                           RoutingTableEntry **entry) {
108  VLOG(2) << __func__ << " index " << interface_index << " family " << family;
109
110  base::hash_map<int, vector<RoutingTableEntry> >::iterator table =
111    tables_.find(interface_index);
112
113  if (table == tables_.end()) {
114    VLOG(2) << __func__ << " no table";
115    return false;
116  }
117
118  vector<RoutingTableEntry>::iterator nent;
119
120  for (nent = table->second.begin(); nent != table->second.end(); ++nent) {
121    if (nent->dst.IsDefault() && nent->dst.family() == family) {
122      *entry = &(*nent);
123      VLOG(2) << __func__ << " found "
124              << "gateway " << nent->gateway.ToString() << " "
125              << "metric " << nent->metric;
126      return true;
127    }
128  }
129
130  VLOG(2) << __func__ << " no route";
131  return false;
132}
133
134bool RoutingTable::SetDefaultRoute(int interface_index,
135                                   const IPConfigRefPtr &ipconfig,
136                                   uint32 metric) {
137  VLOG(2) << __func__ << " index " << interface_index << " metric " << metric;
138
139  const IPConfig::Properties &ipconfig_props = ipconfig->properties();
140  RoutingTableEntry *old_entry;
141  IPAddress gateway_address(ipconfig_props.address_family);
142  if (!gateway_address.SetAddressFromString(ipconfig_props.gateway)) {
143    return false;
144  }
145
146  if (GetDefaultRouteInternal(interface_index,
147                              ipconfig_props.address_family,
148                              &old_entry)) {
149    if (old_entry->gateway.Equals(gateway_address)) {
150      if (old_entry->metric != metric) {
151        ReplaceMetric(interface_index, old_entry, metric);
152      }
153      return true;
154    } else {
155      // TODO(quiche): Update internal state as well?
156      ApplyRoute(interface_index,
157                 *old_entry,
158                 RTNLMessage::kModeDelete,
159                 0);
160    }
161  }
162
163  IPAddress default_address(ipconfig_props.address_family);
164  default_address.SetAddressToDefault();
165
166  return AddRoute(interface_index,
167                  RoutingTableEntry(default_address,
168                                    default_address,
169                                    gateway_address,
170                                    metric,
171                                    RT_SCOPE_UNIVERSE,
172                                    false));
173}
174
175void RoutingTable::FlushRoutes(int interface_index) {
176  VLOG(2) << __func__;
177
178  base::hash_map<int, vector<RoutingTableEntry> >::iterator table =
179    tables_.find(interface_index);
180
181  if (table == tables_.end()) {
182    return;
183  }
184
185  vector<RoutingTableEntry>::iterator nent;
186
187  for (nent = table->second.begin(); nent != table->second.end(); ++nent) {
188      ApplyRoute(interface_index, *nent, RTNLMessage::kModeDelete, 0);
189  }
190  table->second.clear();
191}
192
193void RoutingTable::ResetTable(int interface_index) {
194  tables_.erase(interface_index);
195}
196
197void RoutingTable::SetDefaultMetric(int interface_index, uint32 metric) {
198  VLOG(2) << __func__ << " "
199          << "index " << interface_index << " metric " << metric;
200
201  RoutingTableEntry *entry;
202  if (GetDefaultRouteInternal(
203          interface_index, IPAddress::kFamilyIPv4, &entry) &&
204      entry->metric != metric) {
205    ReplaceMetric(interface_index, entry, metric);
206  }
207
208  if (GetDefaultRouteInternal(
209          interface_index, IPAddress::kFamilyIPv6, &entry) &&
210      entry->metric != metric) {
211    ReplaceMetric(interface_index, entry, metric);
212  }
213}
214
215void RoutingTable::RouteMsgHandler(const RTNLMessage &msg) {
216  if (msg.type() != RTNLMessage::kTypeRoute ||
217      msg.family() == IPAddress::kFamilyUnknown ||
218      !msg.HasAttribute(RTA_OIF)) {
219    return;
220  }
221
222  const RTNLMessage::RouteStatus &route_status = msg.route_status();
223
224  if (route_status.type != RTN_UNICAST ||
225      route_status.protocol != RTPROT_BOOT ||
226      route_status.table != RT_TABLE_MAIN) {
227    return;
228  }
229
230  uint32 interface_index = 0;
231  if (!msg.GetAttribute(RTA_OIF).ConvertToCPUUInt32(&interface_index)) {
232    return;
233  }
234
235  uint32 metric = 0;
236  if (msg.HasAttribute(RTA_PRIORITY)) {
237    msg.GetAttribute(RTA_PRIORITY).ConvertToCPUUInt32(&metric);
238  }
239
240  IPAddress default_addr(msg.family());
241  default_addr.SetAddressToDefault();
242
243  ByteString dst_bytes(default_addr.address());
244  if (msg.HasAttribute(RTA_DST)) {
245    dst_bytes = msg.GetAttribute(RTA_DST);
246  }
247  ByteString src_bytes(default_addr.address());
248  if (msg.HasAttribute(RTA_SRC)) {
249    src_bytes = msg.GetAttribute(RTA_SRC);
250  }
251  ByteString gateway_bytes(default_addr.address());
252  if (msg.HasAttribute(RTA_GATEWAY)) {
253    gateway_bytes = msg.GetAttribute(RTA_GATEWAY);
254  }
255
256  RoutingTableEntry entry(
257      IPAddress(msg.family(), dst_bytes, route_status.dst_prefix),
258      IPAddress(msg.family(), src_bytes, route_status.src_prefix),
259      IPAddress(msg.family(), gateway_bytes),
260      metric,
261      route_status.scope,
262      true);
263
264  vector<RoutingTableEntry> &table = tables_[interface_index];
265  vector<RoutingTableEntry>::iterator nent;
266  for (nent = table.begin(); nent != table.end(); ++nent) {
267    if (nent->dst.Equals(entry.dst) &&
268        nent->src.Equals(entry.src) &&
269        nent->gateway.Equals(entry.gateway) &&
270        nent->scope == entry.scope) {
271      if (msg.mode() == RTNLMessage::kModeDelete &&
272          nent->metric == entry.metric) {
273        table.erase(nent);
274      } else if (msg.mode() == RTNLMessage::kModeAdd) {
275        nent->from_rtnl = true;
276        nent->metric = entry.metric;
277      }
278      return;
279    }
280  }
281
282  if (msg.mode() == RTNLMessage::kModeAdd) {
283    VLOG(2) << __func__ << " adding "
284            << "index " << interface_index
285            << "gateway " << entry.gateway.ToString() << " "
286            << "metric " << entry.metric;
287    table.push_back(entry);
288  }
289}
290
291bool RoutingTable::ApplyRoute(uint32 interface_index,
292                              const RoutingTableEntry &entry,
293                              RTNLMessage::Mode mode,
294                              unsigned int flags) {
295  VLOG(2) << base::StringPrintf("%s: dst %s index %d mode %d flags 0x%x",
296                                __func__, entry.dst.ToString().c_str(),
297                                interface_index, mode, flags);
298
299  RTNLMessage msg(
300      RTNLMessage::kTypeRoute,
301      mode,
302      NLM_F_REQUEST | flags,
303      0,
304      0,
305      0,
306      entry.dst.family());
307
308  msg.set_route_status(RTNLMessage::RouteStatus(
309      entry.dst.prefix(),
310      entry.src.prefix(),
311      RT_TABLE_MAIN,
312      RTPROT_BOOT,
313      entry.scope,
314      RTN_UNICAST,
315      0));
316
317  msg.SetAttribute(RTA_DST, entry.dst.address());
318  if (!entry.src.IsDefault()) {
319    msg.SetAttribute(RTA_SRC, entry.src.address());
320  }
321  if (!entry.gateway.IsDefault()) {
322    msg.SetAttribute(RTA_GATEWAY, entry.gateway.address());
323  }
324  msg.SetAttribute(RTA_PRIORITY, ByteString::CreateFromCPUUInt32(entry.metric));
325  msg.SetAttribute(RTA_OIF, ByteString::CreateFromCPUUInt32(interface_index));
326
327  return RTNLHandler::GetInstance()->SendMessage(&msg);
328}
329
330// Somewhat surprisingly, the kernel allows you to create multiple routes
331// to the same destination through the same interface with different metrics.
332// Therefore, to change the metric on a route, we can't just use the
333// NLM_F_REPLACE flag by itself.  We have to explicitly remove the old route.
334// We do so after creating the route at a new metric so there is no traffic
335// disruption to existing network streams.
336void RoutingTable::ReplaceMetric(uint32 interface_index,
337                                 RoutingTableEntry *entry,
338                                 uint32 metric) {
339  VLOG(2) << __func__ << " "
340          << "index " << interface_index << " metric " << metric;
341  RoutingTableEntry new_entry = *entry;
342  new_entry.metric = metric;
343  // First create the route at the new metric.
344  ApplyRoute(interface_index, new_entry, RTNLMessage::kModeAdd,
345             NLM_F_CREATE | NLM_F_REPLACE);
346  // Then delete the route at the old metric.
347  ApplyRoute(interface_index, *entry, RTNLMessage::kModeDelete, 0);
348  // Now, update our routing table (via |*entry|) from |new_entry|.
349  *entry = new_entry;
350}
351
352bool RoutingTable::FlushCache() {
353  static const char *kPaths[2] = { kRouteFlushPath4, kRouteFlushPath6 };
354  bool ret = true;
355
356  VLOG(2) << __func__;
357
358  for (size_t i = 0; i < arraysize(kPaths); ++i) {
359    if (file_util::WriteFile(FilePath(kPaths[i]), "-1", 2) != 2) {
360      LOG(ERROR) << base::StringPrintf("Cannot write to route flush file %s",
361                                       kPaths[i]);
362      ret = false;
363    }
364  }
365
366  return ret;
367}
368
369}  // namespace shill
370