1// 2// Copyright (C) 2012 The Android Open Source Project 3// 4// Licensed under the Apache License, Version 2.0 (the "License"); 5// you may not use this file except in compliance with the License. 6// You may obtain a copy of the License at 7// 8// http://www.apache.org/licenses/LICENSE-2.0 9// 10// Unless required by applicable law or agreed to in writing, software 11// distributed under the License is distributed on an "AS IS" BASIS, 12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13// See the License for the specific language governing permissions and 14// limitations under the License. 15// 16 17#include "shill/net/netlink_message.h" 18 19#include <limits.h> 20 21#include <algorithm> 22#include <map> 23#include <memory> 24#include <string> 25 26#include <base/format_macros.h> 27#include <base/logging.h> 28#include <base/stl_util.h> 29#include <base/strings/stringprintf.h> 30 31#include "shill/net/netlink_packet.h" 32 33using base::StringAppendF; 34using base::StringPrintf; 35using std::map; 36using std::min; 37using std::string; 38 39namespace shill { 40 41const uint32_t NetlinkMessage::kBroadcastSequenceNumber = 0; 42const uint16_t NetlinkMessage::kIllegalMessageType = UINT16_MAX; 43 44// NetlinkMessage 45 46ByteString NetlinkMessage::EncodeHeader(uint32_t sequence_number) { 47 ByteString result; 48 if (message_type_ == kIllegalMessageType) { 49 LOG(ERROR) << "Message type not set"; 50 return result; 51 } 52 sequence_number_ = sequence_number; 53 if (sequence_number_ == kBroadcastSequenceNumber) { 54 LOG(ERROR) << "Couldn't get a legal sequence number"; 55 return result; 56 } 57 58 // Build netlink header. 59 nlmsghdr header; 60 size_t nlmsghdr_with_pad = NLMSG_ALIGN(sizeof(header)); 61 header.nlmsg_len = nlmsghdr_with_pad; 62 header.nlmsg_type = message_type_; 63 header.nlmsg_flags = NLM_F_REQUEST | flags_; 64 header.nlmsg_seq = sequence_number_; 65 header.nlmsg_pid = getpid(); 66 67 // Netlink header + pad. 68 result.Append(ByteString(reinterpret_cast<unsigned char*>(&header), 69 sizeof(header))); 70 result.Resize(nlmsghdr_with_pad); // Zero-fill pad space (if any). 71 return result; 72} 73 74bool NetlinkMessage::InitAndStripHeader(NetlinkPacket* packet) { 75 const nlmsghdr& header = packet->GetNlMsgHeader(); 76 message_type_ = header.nlmsg_type; 77 flags_ = header.nlmsg_flags; 78 sequence_number_ = header.nlmsg_seq; 79 80 return true; 81} 82 83bool NetlinkMessage::InitFromPacket(NetlinkPacket* packet, 84 NetlinkMessage::MessageContext context) { 85 if (!packet) { 86 LOG(ERROR) << "Null |packet| parameter"; 87 return false; 88 } 89 if (!InitAndStripHeader(packet)) { 90 return false; 91 } 92 return true; 93} 94 95// static 96void NetlinkMessage::PrintBytes(int log_level, const unsigned char* buf, 97 size_t num_bytes) { 98 VLOG(log_level) << "Netlink Message -- Examining Bytes"; 99 if (!buf) { 100 VLOG(log_level) << "<NULL Buffer>"; 101 return; 102 } 103 104 if (num_bytes >= sizeof(nlmsghdr)) { 105 PrintHeader(log_level, reinterpret_cast<const nlmsghdr*>(buf)); 106 buf += sizeof(nlmsghdr); 107 num_bytes -= sizeof(nlmsghdr); 108 } else { 109 VLOG(log_level) << "Not enough bytes (" << num_bytes 110 << ") for a complete nlmsghdr (requires " 111 << sizeof(nlmsghdr) << ")."; 112 } 113 114 PrintPayload(log_level, buf, num_bytes); 115} 116 117// static 118void NetlinkMessage::PrintPacket(int log_level, const NetlinkPacket& packet) { 119 VLOG(log_level) << "Netlink Message -- Examining Packet"; 120 if (!packet.IsValid()) { 121 VLOG(log_level) << "<Invalid Buffer>"; 122 return; 123 } 124 125 PrintHeader(log_level, &packet.GetNlMsgHeader()); 126 const ByteString& payload = packet.GetPayload(); 127 PrintPayload(log_level, payload.GetConstData(), payload.GetLength()); 128} 129 130// static 131void NetlinkMessage::PrintHeader(int log_level, const nlmsghdr* header) { 132 const unsigned char* buf = reinterpret_cast<const unsigned char*>(header); 133 VLOG(log_level) << StringPrintf( 134 "len: %02x %02x %02x %02x = %u bytes", 135 buf[0], buf[1], buf[2], buf[3], header->nlmsg_len); 136 137 VLOG(log_level) << StringPrintf( 138 "type | flags: %02x %02x %02x %02x - type:%u flags:%s%s%s%s%s", 139 buf[4], buf[5], buf[6], buf[7], header->nlmsg_type, 140 ((header->nlmsg_flags & NLM_F_REQUEST) ? " REQUEST" : ""), 141 ((header->nlmsg_flags & NLM_F_MULTI) ? " MULTI" : ""), 142 ((header->nlmsg_flags & NLM_F_ACK) ? " ACK" : ""), 143 ((header->nlmsg_flags & NLM_F_ECHO) ? " ECHO" : ""), 144 ((header->nlmsg_flags & NLM_F_DUMP_INTR) ? " BAD-SEQ" : "")); 145 146 VLOG(log_level) << StringPrintf( 147 "sequence: %02x %02x %02x %02x = %u", 148 buf[8], buf[9], buf[10], buf[11], header->nlmsg_seq); 149 VLOG(log_level) << StringPrintf( 150 "pid: %02x %02x %02x %02x = %u", 151 buf[12], buf[13], buf[14], buf[15], header->nlmsg_pid); 152} 153 154// static 155void NetlinkMessage::PrintPayload(int log_level, const unsigned char* buf, 156 size_t num_bytes) { 157 while (num_bytes) { 158 string output; 159 size_t bytes_this_row = min(num_bytes, static_cast<size_t>(32)); 160 for (size_t i = 0; i < bytes_this_row; ++i) { 161 StringAppendF(&output, " %02x", *buf++); 162 } 163 VLOG(log_level) << output; 164 num_bytes -= bytes_this_row; 165 } 166} 167 168// ErrorAckMessage. 169 170const uint16_t ErrorAckMessage::kMessageType = NLMSG_ERROR; 171 172bool ErrorAckMessage::InitFromPacket(NetlinkPacket* packet, 173 NetlinkMessage::MessageContext context) { 174 if (!packet) { 175 LOG(ERROR) << "Null |const_msg| parameter"; 176 return false; 177 } 178 if (!InitAndStripHeader(packet)) { 179 return false; 180 } 181 182 // Get the error code from the payload. 183 return packet->ConsumeData(sizeof(error_), &error_); 184} 185 186ByteString ErrorAckMessage::Encode(uint32_t sequence_number) { 187 LOG(ERROR) << "We're not supposed to send errors or Acks to the kernel"; 188 return ByteString(); 189} 190 191string ErrorAckMessage::ToString() const { 192 string output; 193 if (error()) { 194 StringAppendF(&output, "NETLINK_ERROR 0x%" PRIx32 ": %s", 195 -error_, strerror(-error_)); 196 } else { 197 StringAppendF(&output, "ACK"); 198 } 199 return output; 200} 201 202void ErrorAckMessage::Print(int header_log_level, 203 int /*detail_log_level*/) const { 204 VLOG(header_log_level) << ToString(); 205} 206 207// NoopMessage. 208 209const uint16_t NoopMessage::kMessageType = NLMSG_NOOP; 210 211ByteString NoopMessage::Encode(uint32_t sequence_number) { 212 LOG(ERROR) << "We're not supposed to send NOOP to the kernel"; 213 return ByteString(); 214} 215 216void NoopMessage::Print(int header_log_level, int /*detail_log_level*/) const { 217 VLOG(header_log_level) << ToString(); 218} 219 220// DoneMessage. 221 222const uint16_t DoneMessage::kMessageType = NLMSG_DONE; 223 224ByteString DoneMessage::Encode(uint32_t sequence_number) { 225 return EncodeHeader(sequence_number); 226} 227 228void DoneMessage::Print(int header_log_level, int /*detail_log_level*/) const { 229 VLOG(header_log_level) << ToString(); 230} 231 232// OverrunMessage. 233 234const uint16_t OverrunMessage::kMessageType = NLMSG_OVERRUN; 235 236ByteString OverrunMessage::Encode(uint32_t sequence_number) { 237 LOG(ERROR) << "We're not supposed to send Overruns to the kernel"; 238 return ByteString(); 239} 240 241void OverrunMessage::Print(int header_log_level, 242 int /*detail_log_level*/) const { 243 VLOG(header_log_level) << ToString(); 244} 245 246// UnknownMessage. 247 248ByteString UnknownMessage::Encode(uint32_t sequence_number) { 249 LOG(ERROR) << "We're not supposed to send UNKNOWN messages to the kernel"; 250 return ByteString(); 251} 252 253void UnknownMessage::Print(int header_log_level, 254 int /*detail_log_level*/) const { 255 int total_bytes = message_body_.GetLength(); 256 const uint8_t* const_data = message_body_.GetConstData(); 257 258 string output = StringPrintf("%d bytes:", total_bytes); 259 for (int i = 0; i < total_bytes; ++i) { 260 StringAppendF(&output, " 0x%02x", const_data[i]); 261 } 262 VLOG(header_log_level) << output; 263} 264 265// 266// Factory class. 267// 268 269bool NetlinkMessageFactory::AddFactoryMethod(uint16_t message_type, 270 FactoryMethod factory) { 271 if (ContainsKey(factories_, message_type)) { 272 LOG(WARNING) << "Message type " << message_type << " already exists."; 273 return false; 274 } 275 if (message_type == NetlinkMessage::kIllegalMessageType) { 276 LOG(ERROR) << "Not installing factory for illegal message type."; 277 return false; 278 } 279 factories_[message_type] = factory; 280 return true; 281} 282 283NetlinkMessage* NetlinkMessageFactory::CreateMessage( 284 NetlinkPacket* packet, NetlinkMessage::MessageContext context) const { 285 std::unique_ptr<NetlinkMessage> message; 286 287 auto message_type = packet->GetMessageType(); 288 if (message_type == NoopMessage::kMessageType) { 289 message.reset(new NoopMessage()); 290 } else if (message_type == DoneMessage::kMessageType) { 291 message.reset(new DoneMessage()); 292 } else if (message_type == OverrunMessage::kMessageType) { 293 message.reset(new OverrunMessage()); 294 } else if (message_type == ErrorAckMessage::kMessageType) { 295 message.reset(new ErrorAckMessage()); 296 } else if (ContainsKey(factories_, message_type)) { 297 map<uint16_t, FactoryMethod>::const_iterator factory; 298 factory = factories_.find(message_type); 299 message.reset(factory->second.Run(*packet)); 300 } 301 302 // If no factory exists for this message _or_ if a factory exists but it 303 // failed, there'll be no message. Handle either of those cases, by 304 // creating an |UnknownMessage|. 305 if (!message) { 306 message.reset(new UnknownMessage(message_type, packet->GetPayload())); 307 } 308 309 if (!message->InitFromPacket(packet, context)) { 310 LOG(ERROR) << "Message did not initialize properly"; 311 return nullptr; 312 } 313 314 return message.release(); 315} 316 317} // namespace shill. 318