1//
2// Copyright (C) 2012 The Android Open Source Project
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16
17#include "shill/net/netlink_message.h"
18
19#include <limits.h>
20
21#include <algorithm>
22#include <map>
23#include <memory>
24#include <string>
25
26#include <base/format_macros.h>
27#include <base/logging.h>
28#include <base/stl_util.h>
29#include <base/strings/stringprintf.h>
30
31#include "shill/net/netlink_packet.h"
32
33using base::StringAppendF;
34using base::StringPrintf;
35using std::map;
36using std::min;
37using std::string;
38
39namespace shill {
40
41const uint32_t NetlinkMessage::kBroadcastSequenceNumber = 0;
42const uint16_t NetlinkMessage::kIllegalMessageType = UINT16_MAX;
43
44// NetlinkMessage
45
46ByteString NetlinkMessage::EncodeHeader(uint32_t sequence_number) {
47  ByteString result;
48  if (message_type_ == kIllegalMessageType) {
49    LOG(ERROR) << "Message type not set";
50    return result;
51  }
52  sequence_number_ = sequence_number;
53  if (sequence_number_ == kBroadcastSequenceNumber) {
54    LOG(ERROR) << "Couldn't get a legal sequence number";
55    return result;
56  }
57
58  // Build netlink header.
59  nlmsghdr header;
60  size_t nlmsghdr_with_pad = NLMSG_ALIGN(sizeof(header));
61  header.nlmsg_len = nlmsghdr_with_pad;
62  header.nlmsg_type = message_type_;
63  header.nlmsg_flags = NLM_F_REQUEST | flags_;
64  header.nlmsg_seq = sequence_number_;
65  header.nlmsg_pid = getpid();
66
67  // Netlink header + pad.
68  result.Append(ByteString(reinterpret_cast<unsigned char*>(&header),
69                           sizeof(header)));
70  result.Resize(nlmsghdr_with_pad);  // Zero-fill pad space (if any).
71  return result;
72}
73
74bool NetlinkMessage::InitAndStripHeader(NetlinkPacket* packet) {
75  const nlmsghdr& header = packet->GetNlMsgHeader();
76  message_type_ = header.nlmsg_type;
77  flags_ = header.nlmsg_flags;
78  sequence_number_ = header.nlmsg_seq;
79
80  return true;
81}
82
83bool NetlinkMessage::InitFromPacket(NetlinkPacket* packet,
84                                    NetlinkMessage::MessageContext context) {
85  if (!packet) {
86    LOG(ERROR) << "Null |packet| parameter";
87    return false;
88  }
89  if (!InitAndStripHeader(packet)) {
90    return false;
91  }
92  return true;
93}
94
95// static
96void NetlinkMessage::PrintBytes(int log_level, const unsigned char* buf,
97                                size_t num_bytes) {
98  VLOG(log_level) << "Netlink Message -- Examining Bytes";
99  if (!buf) {
100    VLOG(log_level) << "<NULL Buffer>";
101    return;
102  }
103
104  if (num_bytes >= sizeof(nlmsghdr)) {
105      PrintHeader(log_level, reinterpret_cast<const nlmsghdr*>(buf));
106      buf += sizeof(nlmsghdr);
107      num_bytes -= sizeof(nlmsghdr);
108  } else {
109    VLOG(log_level) << "Not enough bytes (" << num_bytes
110                    << ") for a complete nlmsghdr (requires "
111                    << sizeof(nlmsghdr) << ").";
112  }
113
114  PrintPayload(log_level, buf, num_bytes);
115}
116
117// static
118void NetlinkMessage::PrintPacket(int log_level, const NetlinkPacket& packet) {
119  VLOG(log_level) << "Netlink Message -- Examining Packet";
120  if (!packet.IsValid()) {
121    VLOG(log_level) << "<Invalid Buffer>";
122    return;
123  }
124
125  PrintHeader(log_level, &packet.GetNlMsgHeader());
126  const ByteString& payload = packet.GetPayload();
127  PrintPayload(log_level, payload.GetConstData(), payload.GetLength());
128}
129
130// static
131void NetlinkMessage::PrintHeader(int log_level, const nlmsghdr* header) {
132  const unsigned char* buf = reinterpret_cast<const unsigned char*>(header);
133  VLOG(log_level) << StringPrintf(
134      "len:          %02x %02x %02x %02x = %u bytes",
135      buf[0], buf[1], buf[2], buf[3], header->nlmsg_len);
136
137  VLOG(log_level) << StringPrintf(
138      "type | flags: %02x %02x %02x %02x - type:%u flags:%s%s%s%s%s",
139      buf[4], buf[5], buf[6], buf[7], header->nlmsg_type,
140      ((header->nlmsg_flags & NLM_F_REQUEST) ? " REQUEST" : ""),
141      ((header->nlmsg_flags & NLM_F_MULTI) ? " MULTI" : ""),
142      ((header->nlmsg_flags & NLM_F_ACK) ? " ACK" : ""),
143      ((header->nlmsg_flags & NLM_F_ECHO) ? " ECHO" : ""),
144      ((header->nlmsg_flags & NLM_F_DUMP_INTR) ? " BAD-SEQ" : ""));
145
146  VLOG(log_level) << StringPrintf(
147      "sequence:     %02x %02x %02x %02x = %u",
148      buf[8], buf[9], buf[10], buf[11], header->nlmsg_seq);
149  VLOG(log_level) << StringPrintf(
150      "pid:          %02x %02x %02x %02x = %u",
151      buf[12], buf[13], buf[14], buf[15], header->nlmsg_pid);
152}
153
154// static
155void NetlinkMessage::PrintPayload(int log_level, const unsigned char* buf,
156                                  size_t num_bytes) {
157  while (num_bytes) {
158    string output;
159    size_t bytes_this_row = min(num_bytes, static_cast<size_t>(32));
160    for (size_t i = 0; i < bytes_this_row; ++i) {
161      StringAppendF(&output, " %02x", *buf++);
162    }
163    VLOG(log_level) << output;
164    num_bytes -= bytes_this_row;
165  }
166}
167
168// ErrorAckMessage.
169
170const uint16_t ErrorAckMessage::kMessageType = NLMSG_ERROR;
171
172bool ErrorAckMessage::InitFromPacket(NetlinkPacket* packet,
173                                     NetlinkMessage::MessageContext context) {
174  if (!packet) {
175    LOG(ERROR) << "Null |const_msg| parameter";
176    return false;
177  }
178  if (!InitAndStripHeader(packet)) {
179    return false;
180  }
181
182  // Get the error code from the payload.
183  return packet->ConsumeData(sizeof(error_), &error_);
184}
185
186ByteString ErrorAckMessage::Encode(uint32_t sequence_number) {
187  LOG(ERROR) << "We're not supposed to send errors or Acks to the kernel";
188  return ByteString();
189}
190
191string ErrorAckMessage::ToString() const {
192  string output;
193  if (error()) {
194    StringAppendF(&output, "NETLINK_ERROR 0x%" PRIx32 ": %s",
195                  -error_, strerror(-error_));
196  } else {
197    StringAppendF(&output, "ACK");
198  }
199  return output;
200}
201
202void ErrorAckMessage::Print(int header_log_level,
203                            int /*detail_log_level*/) const {
204  VLOG(header_log_level) << ToString();
205}
206
207// NoopMessage.
208
209const uint16_t NoopMessage::kMessageType = NLMSG_NOOP;
210
211ByteString NoopMessage::Encode(uint32_t sequence_number) {
212  LOG(ERROR) << "We're not supposed to send NOOP to the kernel";
213  return ByteString();
214}
215
216void NoopMessage::Print(int header_log_level, int /*detail_log_level*/) const {
217  VLOG(header_log_level) << ToString();
218}
219
220// DoneMessage.
221
222const uint16_t DoneMessage::kMessageType = NLMSG_DONE;
223
224ByteString DoneMessage::Encode(uint32_t sequence_number) {
225  return EncodeHeader(sequence_number);
226}
227
228void DoneMessage::Print(int header_log_level, int /*detail_log_level*/) const {
229  VLOG(header_log_level) << ToString();
230}
231
232// OverrunMessage.
233
234const uint16_t OverrunMessage::kMessageType = NLMSG_OVERRUN;
235
236ByteString OverrunMessage::Encode(uint32_t sequence_number) {
237  LOG(ERROR) << "We're not supposed to send Overruns to the kernel";
238  return ByteString();
239}
240
241void OverrunMessage::Print(int header_log_level,
242                           int /*detail_log_level*/) const {
243  VLOG(header_log_level) << ToString();
244}
245
246// UnknownMessage.
247
248ByteString UnknownMessage::Encode(uint32_t sequence_number) {
249  LOG(ERROR) << "We're not supposed to send UNKNOWN messages to the kernel";
250  return ByteString();
251}
252
253void UnknownMessage::Print(int header_log_level,
254                           int /*detail_log_level*/) const {
255  int total_bytes = message_body_.GetLength();
256  const uint8_t* const_data = message_body_.GetConstData();
257
258  string output = StringPrintf("%d bytes:", total_bytes);
259  for (int i = 0; i < total_bytes; ++i) {
260    StringAppendF(&output, " 0x%02x", const_data[i]);
261  }
262  VLOG(header_log_level) << output;
263}
264
265//
266// Factory class.
267//
268
269bool NetlinkMessageFactory::AddFactoryMethod(uint16_t message_type,
270                                             FactoryMethod factory) {
271  if (ContainsKey(factories_, message_type)) {
272    LOG(WARNING) << "Message type " << message_type << " already exists.";
273    return false;
274  }
275  if (message_type == NetlinkMessage::kIllegalMessageType) {
276    LOG(ERROR) << "Not installing factory for illegal message type.";
277    return false;
278  }
279  factories_[message_type] = factory;
280  return true;
281}
282
283NetlinkMessage* NetlinkMessageFactory::CreateMessage(
284    NetlinkPacket* packet, NetlinkMessage::MessageContext context) const {
285  std::unique_ptr<NetlinkMessage> message;
286
287  auto message_type = packet->GetMessageType();
288  if (message_type == NoopMessage::kMessageType) {
289    message.reset(new NoopMessage());
290  } else if (message_type == DoneMessage::kMessageType) {
291    message.reset(new DoneMessage());
292  } else if (message_type == OverrunMessage::kMessageType) {
293    message.reset(new OverrunMessage());
294  } else if (message_type == ErrorAckMessage::kMessageType) {
295    message.reset(new ErrorAckMessage());
296  } else if (ContainsKey(factories_, message_type)) {
297    map<uint16_t, FactoryMethod>::const_iterator factory;
298    factory = factories_.find(message_type);
299    message.reset(factory->second.Run(*packet));
300  }
301
302  // If no factory exists for this message _or_ if a factory exists but it
303  // failed, there'll be no message.  Handle either of those cases, by
304  // creating an |UnknownMessage|.
305  if (!message) {
306    message.reset(new UnknownMessage(message_type, packet->GetPayload()));
307  }
308
309  if (!message->InitFromPacket(packet, context)) {
310    LOG(ERROR) << "Message did not initialize properly";
311    return nullptr;
312  }
313
314  return message.release();
315}
316
317}  // namespace shill.
318