routing_table.cc revision c1dec4d5cad7c6ee2cd8dbc4f47e4d30403dcca1
1// Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "shill/routing_table.h"
6
7#include <arpa/inet.h>
8#include <fcntl.h>
9#include <linux/netlink.h>
10#include <linux/rtnetlink.h>
11#include <netinet/ether.h>
12#include <net/if.h>
13#include <net/if_arp.h>
14#include <string.h>
15#include <sys/socket.h>
16#include <time.h>
17#include <unistd.h>
18
19#include <string>
20
21#include <base/callback_old.h>
22#include <base/file_path.h>
23#include <base/file_util.h>
24#include <base/hash_tables.h>
25#include <base/logging.h>
26#include <base/memory/scoped_ptr.h>
27#include <base/stl_util-inl.h>
28#include <base/stringprintf.h>
29
30#include "shill/byte_string.h"
31#include "shill/routing_table_entry.h"
32#include "shill/rtnl_handler.h"
33#include "shill/rtnl_listener.h"
34#include "shill/rtnl_message.h"
35
36using std::string;
37using std::vector;
38
39namespace shill {
40
41static base::LazyInstance<RoutingTable> g_routing_table(
42    base::LINKER_INITIALIZED);
43
44// static
45const char RoutingTable::kRouteFlushPath4[] = "/proc/sys/net/ipv4/route/flush";
46// static
47const char RoutingTable::kRouteFlushPath6[] = "/proc/sys/net/ipv6/route/flush";
48
49RoutingTable::RoutingTable()
50    : route_callback_(NewCallback(this, &RoutingTable::RouteMsgHandler)),
51      route_listener_(NULL) {
52  VLOG(2) << __func__;
53}
54
55RoutingTable::~RoutingTable() {}
56
57RoutingTable* RoutingTable::GetInstance() {
58  return g_routing_table.Pointer();
59}
60
61void RoutingTable::Start() {
62  VLOG(2) << __func__;
63
64  route_listener_.reset(
65      new RTNLListener(RTNLHandler::kRequestRoute, route_callback_.get()));
66  RTNLHandler::GetInstance()->RequestDump(
67      RTNLHandler::kRequestRoute);
68}
69
70void RoutingTable::Stop() {
71  VLOG(2) << __func__;
72
73  route_listener_.reset();
74}
75
76bool RoutingTable::AddRoute(int interface_index,
77                            const RoutingTableEntry &entry) {
78  VLOG(2) << __func__;
79
80  CHECK(!entry.from_rtnl);
81  if (!ApplyRoute(interface_index,
82                  entry,
83                  RTNLMessage::kModeAdd,
84                  NLM_F_CREATE | NLM_F_EXCL)) {
85    return false;
86  }
87  tables_[interface_index].push_back(entry);
88  return true;
89}
90
91bool RoutingTable::GetDefaultRoute(int interface_index,
92                                   IPAddress::Family family,
93                                   RoutingTableEntry *entry) {
94  VLOG(2) << __func__;
95
96  base::hash_map<int, vector<RoutingTableEntry> >::iterator table =
97    tables_.find(interface_index);
98
99  if (table == tables_.end()) {
100    return false;
101  }
102
103  vector<RoutingTableEntry>::iterator nent;
104
105  for (nent = table->second.begin(); nent != table->second.end(); ++nent) {
106    if (nent->dst.IsDefault() && nent->dst.family() == family) {
107      *entry = *nent;
108      return true;
109    }
110  }
111
112  return false;
113}
114
115bool RoutingTable::SetDefaultRoute(int interface_index,
116                                   const IPConfigRefPtr &ipconfig,
117                                   uint32 metric) {
118  const IPConfig::Properties &ipconfig_props = ipconfig->properties();
119  RoutingTableEntry old_entry;
120
121  VLOG(2) << __func__;
122
123  IPAddress gateway_address(ipconfig_props.address_family);
124  if (!gateway_address.SetAddressFromString(ipconfig_props.gateway)) {
125    return false;
126  }
127
128  if (GetDefaultRoute(interface_index,
129                      ipconfig_props.address_family,
130                      &old_entry)) {
131    if (old_entry.gateway.Equals(gateway_address)) {
132      if (old_entry.metric != metric) {
133        ReplaceMetric(interface_index, old_entry, metric);
134      }
135      return true;
136    } else {
137      ApplyRoute(interface_index,
138                 old_entry,
139                 RTNLMessage::kModeDelete,
140                 0);
141    }
142  }
143
144  IPAddress default_address(ipconfig_props.address_family);
145  default_address.SetAddressToDefault();
146
147  return AddRoute(interface_index,
148                  RoutingTableEntry(default_address,
149                                    default_address,
150                                    gateway_address,
151                                    metric,
152                                    RT_SCOPE_UNIVERSE,
153                                    false));
154}
155
156void RoutingTable::FlushRoutes(int interface_index) {
157  VLOG(2) << __func__;
158
159  base::hash_map<int, vector<RoutingTableEntry> >::iterator table =
160    tables_.find(interface_index);
161
162  if (table == tables_.end()) {
163    return;
164  }
165
166  vector<RoutingTableEntry>::iterator nent;
167
168  for (nent = table->second.begin(); nent != table->second.end(); ++nent) {
169    ApplyRoute(interface_index, *nent, RTNLMessage::kModeDelete, 0);
170  }
171}
172
173void RoutingTable::ResetTable(int interface_index) {
174  tables_.erase(interface_index);
175}
176
177void RoutingTable::SetDefaultMetric(int interface_index, uint32 metric) {
178  RoutingTableEntry entry;
179
180  VLOG(2) << __func__;
181
182  if (GetDefaultRoute(interface_index, IPAddress::kFamilyIPv4, &entry) &&
183      entry.metric != metric) {
184    ReplaceMetric(interface_index, entry, metric);
185  }
186
187  if (GetDefaultRoute(interface_index, IPAddress::kFamilyIPv6, &entry) &&
188      entry.metric != metric) {
189    ReplaceMetric(interface_index, entry, metric);
190  }
191}
192
193void RoutingTable::RouteMsgHandler(const RTNLMessage &msg) {
194  if (msg.type() != RTNLMessage::kTypeRoute ||
195      msg.family() == IPAddress::kFamilyUnknown ||
196      !msg.HasAttribute(RTA_OIF)) {
197    return;
198  }
199
200  const RTNLMessage::RouteStatus &route_status = msg.route_status();
201
202  if (route_status.type != RTN_UNICAST ||
203      route_status.protocol != RTPROT_BOOT ||
204      route_status.table != RT_TABLE_MAIN) {
205    return;
206  }
207
208  uint32 interface_index = 0;
209  if (!msg.GetAttribute(RTA_OIF).ConvertToCPUUInt32(&interface_index)) {
210    return;
211  }
212
213  uint32 metric = 0;
214  if (msg.HasAttribute(RTA_PRIORITY)) {
215    msg.GetAttribute(RTA_PRIORITY).ConvertToCPUUInt32(&metric);
216  }
217
218  IPAddress default_addr(msg.family());
219  default_addr.SetAddressToDefault();
220
221  ByteString dst_bytes(default_addr.address());
222  if (msg.HasAttribute(RTA_DST)) {
223    dst_bytes = msg.GetAttribute(RTA_DST);
224  }
225  ByteString src_bytes(default_addr.address());
226  if (msg.HasAttribute(RTA_SRC)) {
227    src_bytes = msg.GetAttribute(RTA_SRC);
228  }
229  ByteString gateway_bytes(default_addr.address());
230  if (msg.HasAttribute(RTA_GATEWAY)) {
231    gateway_bytes = msg.GetAttribute(RTA_GATEWAY);
232  }
233
234  RoutingTableEntry entry(
235      IPAddress(msg.family(), dst_bytes, route_status.dst_prefix),
236      IPAddress(msg.family(), src_bytes, route_status.src_prefix),
237      IPAddress(msg.family(), gateway_bytes),
238      metric,
239      route_status.scope,
240      true);
241
242  vector<RoutingTableEntry> &table = tables_[interface_index];
243  vector<RoutingTableEntry>::iterator nent;
244  for (nent = table.begin(); nent != table.end(); ++nent) {
245    if (nent->dst.Equals(entry.dst) &&
246        nent->src.Equals(entry.src) &&
247        nent->gateway.Equals(entry.gateway) &&
248        nent->scope == entry.scope) {
249      if (msg.mode() == RTNLMessage::kModeDelete &&
250          nent->metric == entry.metric) {
251        table.erase(nent);
252      } else if (msg.mode() == RTNLMessage::kModeAdd) {
253        nent->from_rtnl = true;
254        nent->metric = entry.metric;
255      }
256      return;
257    }
258  }
259
260  if (msg.mode() == RTNLMessage::kModeAdd) {
261    table.push_back(entry);
262  }
263}
264
265bool RoutingTable::ApplyRoute(uint32 interface_index,
266                              const RoutingTableEntry &entry,
267                              RTNLMessage::Mode mode,
268                              unsigned int flags) {
269  VLOG(2) << base::StringPrintf("%s: index %d mode %d flags 0x%x",
270                                __func__, interface_index, mode, flags);
271
272  RTNLMessage msg(
273      RTNLMessage::kTypeRoute,
274      mode,
275      NLM_F_REQUEST | flags,
276      0,
277      0,
278      0,
279      entry.dst.family());
280
281  msg.set_route_status(RTNLMessage::RouteStatus(
282      entry.dst.prefix(),
283      entry.src.prefix(),
284      RT_TABLE_MAIN,
285      RTPROT_BOOT,
286      entry.scope,
287      RTN_UNICAST,
288      0));
289
290  msg.SetAttribute(RTA_DST, entry.dst.address());
291  if (!entry.src.IsDefault()) {
292    msg.SetAttribute(RTA_SRC, entry.src.address());
293  }
294  if (!entry.gateway.IsDefault()) {
295    msg.SetAttribute(RTA_GATEWAY, entry.gateway.address());
296  }
297  msg.SetAttribute(RTA_PRIORITY, ByteString::CreateFromCPUUInt32(entry.metric));
298  msg.SetAttribute(RTA_OIF, ByteString::CreateFromCPUUInt32(interface_index));
299
300  return RTNLHandler::GetInstance()->SendMessage(&msg);
301}
302
303// Somewhat surprisingly, the kernel allows you to create multiple routes
304// to the same destination through the same interface with different metrics.
305// Therefore, to change the metric on a route, we can't just use the
306// NLM_F_REPLACE flag by itself.  We have to explicitly remove the old route.
307// We do so after creating the route at a new metric so there is no traffic
308// disruption to existing network streams.
309void RoutingTable::ReplaceMetric(uint32 interface_index,
310                                 const RoutingTableEntry &entry,
311                                 uint32 metric) {
312  RoutingTableEntry new_entry = entry;
313  new_entry.metric = metric;
314  // First create the route at the new metric.
315  ApplyRoute(interface_index, new_entry, RTNLMessage::kModeAdd,
316             NLM_F_CREATE | NLM_F_REPLACE);
317  // Then delete the route at the old metric.
318  ApplyRoute(interface_index, entry, RTNLMessage::kModeDelete, 0);
319}
320
321bool RoutingTable::FlushCache() {
322  static const char *kPaths[2] = { kRouteFlushPath4, kRouteFlushPath6 };
323  bool ret = true;
324
325  VLOG(2) << __func__;
326
327  for (size_t i = 0; i < arraysize(kPaths); ++i) {
328    if (file_util::WriteFile(FilePath(kPaths[i]), "-1", 2) != 2) {
329      LOG(ERROR) << base::StringPrintf("Cannot write to route flush file %s",
330                                       kPaths[i]);
331      ret = false;
332    }
333  }
334
335  return ret;
336}
337
338}  // namespace shill
339