netlink_message.cc revision e67a78539a05ea7fc68ed5ca18f6d1de333a3086
1// Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "shill/net/netlink_message.h"
6
7#include <limits.h>
8#include <netlink/msg.h>
9#include <netlink/netlink.h>
10
11#include <algorithm>
12#include <map>
13#include <memory>
14#include <string>
15
16#include <base/format_macros.h>
17#include <base/logging.h>
18#include <base/stl_util.h>
19#include <base/strings/stringprintf.h>
20
21using base::StringAppendF;
22using base::StringPrintf;
23using std::map;
24using std::min;
25using std::string;
26
27namespace shill {
28
29const uint32_t NetlinkMessage::kBroadcastSequenceNumber = 0;
30const uint16_t NetlinkMessage::kIllegalMessageType = UINT16_MAX;
31
32// NetlinkMessage
33
34ByteString NetlinkMessage::EncodeHeader(uint32_t sequence_number) {
35  ByteString result;
36  if (message_type_ == kIllegalMessageType) {
37    LOG(ERROR) << "Message type not set";
38    return result;
39  }
40  sequence_number_ = sequence_number;
41  if (sequence_number_ == kBroadcastSequenceNumber) {
42    LOG(ERROR) << "Couldn't get a legal sequence number";
43    return result;
44  }
45
46  // Build netlink header.
47  nlmsghdr header;
48  size_t nlmsghdr_with_pad = NLMSG_ALIGN(sizeof(header));
49  header.nlmsg_len = nlmsghdr_with_pad;
50  header.nlmsg_type = message_type_;
51  header.nlmsg_flags = NLM_F_REQUEST | flags_;
52  header.nlmsg_seq = sequence_number_;
53  header.nlmsg_pid = getpid();
54
55  // Netlink header + pad.
56  result.Append(ByteString(reinterpret_cast<unsigned char*>(&header),
57                           sizeof(header)));
58  result.Resize(nlmsghdr_with_pad);  // Zero-fill pad space (if any).
59  return result;
60}
61
62bool NetlinkMessage::InitAndStripHeader(ByteString* input) {
63  if (!input) {
64    LOG(ERROR) << "NULL input";
65    return false;
66  }
67  if (input->GetLength() < sizeof(nlmsghdr)) {
68    LOG(ERROR) << "Insufficient input to extract nlmsghdr";
69    return false;
70  }
71
72  // Read the nlmsghdr.
73  nlmsghdr* header = reinterpret_cast<nlmsghdr*>(input->GetData());
74  message_type_ = header->nlmsg_type;
75  flags_ = header->nlmsg_flags;
76  sequence_number_ = header->nlmsg_seq;
77
78  // Strip the nlmsghdr.
79  input->RemovePrefix(NLMSG_ALIGN(sizeof(struct nlmsghdr)));
80  return true;
81}
82
83bool NetlinkMessage::InitFromNlmsg(const nlmsghdr* const_msg,
84                                   NetlinkMessage::MessageContext context) {
85  if (!const_msg) {
86    LOG(ERROR) << "Null |const_msg| parameter";
87    return false;
88  }
89  ByteString message(reinterpret_cast<const unsigned char*>(const_msg),
90                     const_msg->nlmsg_len);
91  if (!InitAndStripHeader(&message)) {
92    return false;
93  }
94  return true;
95}
96
97// static
98void NetlinkMessage::PrintBytes(int log_level, const unsigned char* buf,
99                                size_t num_bytes) {
100  VLOG(log_level) << "Netlink Message -- Examining Bytes";
101  if (!buf) {
102    VLOG(log_level) << "<NULL Buffer>";
103    return;
104  }
105
106  if (num_bytes >= sizeof(nlmsghdr)) {
107      const nlmsghdr* header = reinterpret_cast<const nlmsghdr*>(buf);
108      VLOG(log_level) << StringPrintf(
109          "len:          %02x %02x %02x %02x = %u bytes",
110          buf[0], buf[1], buf[2], buf[3], header->nlmsg_len);
111
112      VLOG(log_level) << StringPrintf(
113          "type | flags: %02x %02x %02x %02x - type:%u flags:%s%s%s%s%s",
114          buf[4], buf[5], buf[6], buf[7], header->nlmsg_type,
115          ((header->nlmsg_flags & NLM_F_REQUEST) ? " REQUEST" : ""),
116          ((header->nlmsg_flags & NLM_F_MULTI) ? " MULTI" : ""),
117          ((header->nlmsg_flags & NLM_F_ACK) ? " ACK" : ""),
118          ((header->nlmsg_flags & NLM_F_ECHO) ? " ECHO" : ""),
119          ((header->nlmsg_flags & NLM_F_DUMP_INTR) ? " BAD-SEQ" : ""));
120
121      VLOG(log_level) << StringPrintf(
122          "sequence:     %02x %02x %02x %02x = %u",
123          buf[8], buf[9], buf[10], buf[11], header->nlmsg_seq);
124      VLOG(log_level) << StringPrintf(
125          "pid:          %02x %02x %02x %02x = %u",
126          buf[12], buf[13], buf[14], buf[15], header->nlmsg_pid);
127      buf += sizeof(nlmsghdr);
128      num_bytes -= sizeof(nlmsghdr);
129  } else {
130    VLOG(log_level) << "Not enough bytes (" << num_bytes
131                    << ") for a complete nlmsghdr (requires "
132                    << sizeof(nlmsghdr) << ").";
133  }
134
135  while (num_bytes) {
136    string output;
137    size_t bytes_this_row = min(num_bytes, static_cast<size_t>(32));
138    for (size_t i = 0; i < bytes_this_row; ++i) {
139      StringAppendF(&output, " %02x", *buf++);
140    }
141    VLOG(log_level) << output;
142    num_bytes -= bytes_this_row;
143  }
144}
145
146// ErrorAckMessage.
147
148const uint16_t ErrorAckMessage::kMessageType = NLMSG_ERROR;
149
150bool ErrorAckMessage::InitFromNlmsg(const nlmsghdr* const_msg,
151                                    NetlinkMessage::MessageContext context) {
152  if (!const_msg) {
153    LOG(ERROR) << "Null |const_msg| parameter";
154    return false;
155  }
156  ByteString message(reinterpret_cast<const unsigned char*>(const_msg),
157                     const_msg->nlmsg_len);
158  if (!InitAndStripHeader(&message)) {
159    return false;
160  }
161
162  // Get the error code from the payload.
163  error_ = *(reinterpret_cast<const uint32_t*>(message.GetConstData()));
164  return true;
165}
166
167ByteString ErrorAckMessage::Encode(uint32_t sequence_number) {
168  LOG(ERROR) << "We're not supposed to send errors or Acks to the kernel";
169  return ByteString();
170}
171
172string ErrorAckMessage::ToString() const {
173  string output;
174  if (error()) {
175    StringAppendF(&output, "NETLINK_ERROR 0x%" PRIx32 ": %s",
176                  -error_, strerror(-error_));
177  } else {
178    StringAppendF(&output, "ACK");
179  }
180  return output;
181}
182
183void ErrorAckMessage::Print(int header_log_level,
184                            int /*detail_log_level*/) const {
185  VLOG(header_log_level) << ToString();
186}
187
188// NoopMessage.
189
190const uint16_t NoopMessage::kMessageType = NLMSG_NOOP;
191
192ByteString NoopMessage::Encode(uint32_t sequence_number) {
193  LOG(ERROR) << "We're not supposed to send NOOP to the kernel";
194  return ByteString();
195}
196
197void NoopMessage::Print(int header_log_level, int /*detail_log_level*/) const {
198  VLOG(header_log_level) << ToString();
199}
200
201// DoneMessage.
202
203const uint16_t DoneMessage::kMessageType = NLMSG_DONE;
204
205ByteString DoneMessage::Encode(uint32_t sequence_number) {
206  return EncodeHeader(sequence_number);
207}
208
209void DoneMessage::Print(int header_log_level, int /*detail_log_level*/) const {
210  VLOG(header_log_level) << ToString();
211}
212
213// OverrunMessage.
214
215const uint16_t OverrunMessage::kMessageType = NLMSG_OVERRUN;
216
217ByteString OverrunMessage::Encode(uint32_t sequence_number) {
218  LOG(ERROR) << "We're not supposed to send Overruns to the kernel";
219  return ByteString();
220}
221
222void OverrunMessage::Print(int header_log_level,
223                           int /*detail_log_level*/) const {
224  VLOG(header_log_level) << ToString();
225}
226
227// UnknownMessage.
228
229ByteString UnknownMessage::Encode(uint32_t sequence_number) {
230  LOG(ERROR) << "We're not supposed to send UNKNOWN messages to the kernel";
231  return ByteString();
232}
233
234void UnknownMessage::Print(int header_log_level,
235                           int /*detail_log_level*/) const {
236  int total_bytes = message_body_.GetLength();
237  const uint8_t* const_data = message_body_.GetConstData();
238
239  string output = StringPrintf("%d bytes:", total_bytes);
240  for (int i = 0; i < total_bytes; ++i) {
241    StringAppendF(&output, " 0x%02x", const_data[i]);
242  }
243  VLOG(header_log_level) << output;
244}
245
246//
247// Factory class.
248//
249
250bool NetlinkMessageFactory::AddFactoryMethod(uint16_t message_type,
251                                             FactoryMethod factory) {
252  if (ContainsKey(factories_, message_type)) {
253    LOG(WARNING) << "Message type " << message_type << " already exists.";
254    return false;
255  }
256  if (message_type == NetlinkMessage::kIllegalMessageType) {
257    LOG(ERROR) << "Not installing factory for illegal message type.";
258    return false;
259  }
260  factories_[message_type] = factory;
261  return true;
262}
263
264NetlinkMessage* NetlinkMessageFactory::CreateMessage(
265    const nlmsghdr* const_msg, NetlinkMessage::MessageContext context) const {
266  if (!const_msg) {
267    LOG(ERROR) << "NULL |const_msg| parameter";
268    return nullptr;
269  }
270
271  std::unique_ptr<NetlinkMessage> message;
272
273  if (const_msg->nlmsg_type == NoopMessage::kMessageType) {
274    message.reset(new NoopMessage());
275  } else if (const_msg->nlmsg_type == DoneMessage::kMessageType) {
276    message.reset(new DoneMessage());
277  } else if (const_msg->nlmsg_type == OverrunMessage::kMessageType) {
278    message.reset(new OverrunMessage());
279  } else if (const_msg->nlmsg_type == ErrorAckMessage::kMessageType) {
280    message.reset(new ErrorAckMessage());
281  } else if (ContainsKey(factories_, const_msg->nlmsg_type)) {
282    map<uint16_t, FactoryMethod>::const_iterator factory;
283    factory = factories_.find(const_msg->nlmsg_type);
284    message.reset(factory->second.Run(const_msg));
285  }
286
287  // If no factory exists for this message _or_ if a factory exists but it
288  // failed, there'll be no message.  Handle either of those cases, by
289  // creating an |UnknownMessage|.
290  if (!message) {
291    // Casting away constness since, while nlmsg_data doesn't change its
292    // parameter, it also doesn't declare its paramenter as const.
293    nlmsghdr* msg = const_cast<nlmsghdr*>(const_msg);
294    ByteString payload(reinterpret_cast<char*>(nlmsg_data(msg)),
295                       nlmsg_datalen(msg));
296    message.reset(new UnknownMessage(msg->nlmsg_type, payload));
297  }
298
299  if (!message->InitFromNlmsg(const_msg, context)) {
300    LOG(ERROR) << "Message did not initialize properly";
301    return nullptr;
302  }
303
304  return message.release();
305}
306
307}  // namespace shill.
308