netlink_message.cc revision e67a78539a05ea7fc68ed5ca18f6d1de333a3086
1// Copyright (c) 2012 The Chromium OS Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "shill/net/netlink_message.h" 6 7#include <limits.h> 8#include <netlink/msg.h> 9#include <netlink/netlink.h> 10 11#include <algorithm> 12#include <map> 13#include <memory> 14#include <string> 15 16#include <base/format_macros.h> 17#include <base/logging.h> 18#include <base/stl_util.h> 19#include <base/strings/stringprintf.h> 20 21using base::StringAppendF; 22using base::StringPrintf; 23using std::map; 24using std::min; 25using std::string; 26 27namespace shill { 28 29const uint32_t NetlinkMessage::kBroadcastSequenceNumber = 0; 30const uint16_t NetlinkMessage::kIllegalMessageType = UINT16_MAX; 31 32// NetlinkMessage 33 34ByteString NetlinkMessage::EncodeHeader(uint32_t sequence_number) { 35 ByteString result; 36 if (message_type_ == kIllegalMessageType) { 37 LOG(ERROR) << "Message type not set"; 38 return result; 39 } 40 sequence_number_ = sequence_number; 41 if (sequence_number_ == kBroadcastSequenceNumber) { 42 LOG(ERROR) << "Couldn't get a legal sequence number"; 43 return result; 44 } 45 46 // Build netlink header. 47 nlmsghdr header; 48 size_t nlmsghdr_with_pad = NLMSG_ALIGN(sizeof(header)); 49 header.nlmsg_len = nlmsghdr_with_pad; 50 header.nlmsg_type = message_type_; 51 header.nlmsg_flags = NLM_F_REQUEST | flags_; 52 header.nlmsg_seq = sequence_number_; 53 header.nlmsg_pid = getpid(); 54 55 // Netlink header + pad. 56 result.Append(ByteString(reinterpret_cast<unsigned char*>(&header), 57 sizeof(header))); 58 result.Resize(nlmsghdr_with_pad); // Zero-fill pad space (if any). 59 return result; 60} 61 62bool NetlinkMessage::InitAndStripHeader(ByteString* input) { 63 if (!input) { 64 LOG(ERROR) << "NULL input"; 65 return false; 66 } 67 if (input->GetLength() < sizeof(nlmsghdr)) { 68 LOG(ERROR) << "Insufficient input to extract nlmsghdr"; 69 return false; 70 } 71 72 // Read the nlmsghdr. 73 nlmsghdr* header = reinterpret_cast<nlmsghdr*>(input->GetData()); 74 message_type_ = header->nlmsg_type; 75 flags_ = header->nlmsg_flags; 76 sequence_number_ = header->nlmsg_seq; 77 78 // Strip the nlmsghdr. 79 input->RemovePrefix(NLMSG_ALIGN(sizeof(struct nlmsghdr))); 80 return true; 81} 82 83bool NetlinkMessage::InitFromNlmsg(const nlmsghdr* const_msg, 84 NetlinkMessage::MessageContext context) { 85 if (!const_msg) { 86 LOG(ERROR) << "Null |const_msg| parameter"; 87 return false; 88 } 89 ByteString message(reinterpret_cast<const unsigned char*>(const_msg), 90 const_msg->nlmsg_len); 91 if (!InitAndStripHeader(&message)) { 92 return false; 93 } 94 return true; 95} 96 97// static 98void NetlinkMessage::PrintBytes(int log_level, const unsigned char* buf, 99 size_t num_bytes) { 100 VLOG(log_level) << "Netlink Message -- Examining Bytes"; 101 if (!buf) { 102 VLOG(log_level) << "<NULL Buffer>"; 103 return; 104 } 105 106 if (num_bytes >= sizeof(nlmsghdr)) { 107 const nlmsghdr* header = reinterpret_cast<const nlmsghdr*>(buf); 108 VLOG(log_level) << StringPrintf( 109 "len: %02x %02x %02x %02x = %u bytes", 110 buf[0], buf[1], buf[2], buf[3], header->nlmsg_len); 111 112 VLOG(log_level) << StringPrintf( 113 "type | flags: %02x %02x %02x %02x - type:%u flags:%s%s%s%s%s", 114 buf[4], buf[5], buf[6], buf[7], header->nlmsg_type, 115 ((header->nlmsg_flags & NLM_F_REQUEST) ? " REQUEST" : ""), 116 ((header->nlmsg_flags & NLM_F_MULTI) ? " MULTI" : ""), 117 ((header->nlmsg_flags & NLM_F_ACK) ? " ACK" : ""), 118 ((header->nlmsg_flags & NLM_F_ECHO) ? " ECHO" : ""), 119 ((header->nlmsg_flags & NLM_F_DUMP_INTR) ? " BAD-SEQ" : "")); 120 121 VLOG(log_level) << StringPrintf( 122 "sequence: %02x %02x %02x %02x = %u", 123 buf[8], buf[9], buf[10], buf[11], header->nlmsg_seq); 124 VLOG(log_level) << StringPrintf( 125 "pid: %02x %02x %02x %02x = %u", 126 buf[12], buf[13], buf[14], buf[15], header->nlmsg_pid); 127 buf += sizeof(nlmsghdr); 128 num_bytes -= sizeof(nlmsghdr); 129 } else { 130 VLOG(log_level) << "Not enough bytes (" << num_bytes 131 << ") for a complete nlmsghdr (requires " 132 << sizeof(nlmsghdr) << ")."; 133 } 134 135 while (num_bytes) { 136 string output; 137 size_t bytes_this_row = min(num_bytes, static_cast<size_t>(32)); 138 for (size_t i = 0; i < bytes_this_row; ++i) { 139 StringAppendF(&output, " %02x", *buf++); 140 } 141 VLOG(log_level) << output; 142 num_bytes -= bytes_this_row; 143 } 144} 145 146// ErrorAckMessage. 147 148const uint16_t ErrorAckMessage::kMessageType = NLMSG_ERROR; 149 150bool ErrorAckMessage::InitFromNlmsg(const nlmsghdr* const_msg, 151 NetlinkMessage::MessageContext context) { 152 if (!const_msg) { 153 LOG(ERROR) << "Null |const_msg| parameter"; 154 return false; 155 } 156 ByteString message(reinterpret_cast<const unsigned char*>(const_msg), 157 const_msg->nlmsg_len); 158 if (!InitAndStripHeader(&message)) { 159 return false; 160 } 161 162 // Get the error code from the payload. 163 error_ = *(reinterpret_cast<const uint32_t*>(message.GetConstData())); 164 return true; 165} 166 167ByteString ErrorAckMessage::Encode(uint32_t sequence_number) { 168 LOG(ERROR) << "We're not supposed to send errors or Acks to the kernel"; 169 return ByteString(); 170} 171 172string ErrorAckMessage::ToString() const { 173 string output; 174 if (error()) { 175 StringAppendF(&output, "NETLINK_ERROR 0x%" PRIx32 ": %s", 176 -error_, strerror(-error_)); 177 } else { 178 StringAppendF(&output, "ACK"); 179 } 180 return output; 181} 182 183void ErrorAckMessage::Print(int header_log_level, 184 int /*detail_log_level*/) const { 185 VLOG(header_log_level) << ToString(); 186} 187 188// NoopMessage. 189 190const uint16_t NoopMessage::kMessageType = NLMSG_NOOP; 191 192ByteString NoopMessage::Encode(uint32_t sequence_number) { 193 LOG(ERROR) << "We're not supposed to send NOOP to the kernel"; 194 return ByteString(); 195} 196 197void NoopMessage::Print(int header_log_level, int /*detail_log_level*/) const { 198 VLOG(header_log_level) << ToString(); 199} 200 201// DoneMessage. 202 203const uint16_t DoneMessage::kMessageType = NLMSG_DONE; 204 205ByteString DoneMessage::Encode(uint32_t sequence_number) { 206 return EncodeHeader(sequence_number); 207} 208 209void DoneMessage::Print(int header_log_level, int /*detail_log_level*/) const { 210 VLOG(header_log_level) << ToString(); 211} 212 213// OverrunMessage. 214 215const uint16_t OverrunMessage::kMessageType = NLMSG_OVERRUN; 216 217ByteString OverrunMessage::Encode(uint32_t sequence_number) { 218 LOG(ERROR) << "We're not supposed to send Overruns to the kernel"; 219 return ByteString(); 220} 221 222void OverrunMessage::Print(int header_log_level, 223 int /*detail_log_level*/) const { 224 VLOG(header_log_level) << ToString(); 225} 226 227// UnknownMessage. 228 229ByteString UnknownMessage::Encode(uint32_t sequence_number) { 230 LOG(ERROR) << "We're not supposed to send UNKNOWN messages to the kernel"; 231 return ByteString(); 232} 233 234void UnknownMessage::Print(int header_log_level, 235 int /*detail_log_level*/) const { 236 int total_bytes = message_body_.GetLength(); 237 const uint8_t* const_data = message_body_.GetConstData(); 238 239 string output = StringPrintf("%d bytes:", total_bytes); 240 for (int i = 0; i < total_bytes; ++i) { 241 StringAppendF(&output, " 0x%02x", const_data[i]); 242 } 243 VLOG(header_log_level) << output; 244} 245 246// 247// Factory class. 248// 249 250bool NetlinkMessageFactory::AddFactoryMethod(uint16_t message_type, 251 FactoryMethod factory) { 252 if (ContainsKey(factories_, message_type)) { 253 LOG(WARNING) << "Message type " << message_type << " already exists."; 254 return false; 255 } 256 if (message_type == NetlinkMessage::kIllegalMessageType) { 257 LOG(ERROR) << "Not installing factory for illegal message type."; 258 return false; 259 } 260 factories_[message_type] = factory; 261 return true; 262} 263 264NetlinkMessage* NetlinkMessageFactory::CreateMessage( 265 const nlmsghdr* const_msg, NetlinkMessage::MessageContext context) const { 266 if (!const_msg) { 267 LOG(ERROR) << "NULL |const_msg| parameter"; 268 return nullptr; 269 } 270 271 std::unique_ptr<NetlinkMessage> message; 272 273 if (const_msg->nlmsg_type == NoopMessage::kMessageType) { 274 message.reset(new NoopMessage()); 275 } else if (const_msg->nlmsg_type == DoneMessage::kMessageType) { 276 message.reset(new DoneMessage()); 277 } else if (const_msg->nlmsg_type == OverrunMessage::kMessageType) { 278 message.reset(new OverrunMessage()); 279 } else if (const_msg->nlmsg_type == ErrorAckMessage::kMessageType) { 280 message.reset(new ErrorAckMessage()); 281 } else if (ContainsKey(factories_, const_msg->nlmsg_type)) { 282 map<uint16_t, FactoryMethod>::const_iterator factory; 283 factory = factories_.find(const_msg->nlmsg_type); 284 message.reset(factory->second.Run(const_msg)); 285 } 286 287 // If no factory exists for this message _or_ if a factory exists but it 288 // failed, there'll be no message. Handle either of those cases, by 289 // creating an |UnknownMessage|. 290 if (!message) { 291 // Casting away constness since, while nlmsg_data doesn't change its 292 // parameter, it also doesn't declare its paramenter as const. 293 nlmsghdr* msg = const_cast<nlmsghdr*>(const_msg); 294 ByteString payload(reinterpret_cast<char*>(nlmsg_data(msg)), 295 nlmsg_datalen(msg)); 296 message.reset(new UnknownMessage(msg->nlmsg_type, payload)); 297 } 298 299 if (!message->InitFromNlmsg(const_msg, context)) { 300 LOG(ERROR) << "Message did not initialize properly"; 301 return nullptr; 302 } 303 304 return message.release(); 305} 306 307} // namespace shill. 308