1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "wificond/net/nl80211_packet.h"
18
19#include <android-base/logging.h>
20
21using std::vector;
22
23namespace android {
24namespace wificond {
25
26NL80211Packet::NL80211Packet(const vector<uint8_t>& data)
27    : data_(data) {
28  data_ = data;
29}
30
31NL80211Packet::NL80211Packet(const NL80211Packet& packet) {
32  data_ = packet.data_;
33  LOG(WARNING) << "Copy constructor is only used for unit tests";
34}
35
36NL80211Packet::NL80211Packet(uint16_t type,
37                             uint8_t command,
38                             uint32_t sequence,
39                             uint32_t pid) {
40  // Initialize the netlink header and generic netlink header.
41  // NLMSG_HDRLEN and GENL_HDRLEN already include the padding size.
42  data_.resize(NLMSG_HDRLEN + GENL_HDRLEN, 0);
43  // Initialize length field.
44  nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
45  nl_header->nlmsg_len = data_.size();
46  // Add NLM_F_REQUEST flag.
47  nl_header->nlmsg_flags = nl_header->nlmsg_flags | NLM_F_REQUEST;
48  nl_header->nlmsg_type = type;
49  nl_header->nlmsg_seq = sequence;
50  nl_header->nlmsg_pid = pid;
51
52  genlmsghdr* genl_header =
53      reinterpret_cast<genlmsghdr*>(data_.data() + NLMSG_HDRLEN);
54  genl_header->version = 1;
55  genl_header->cmd = command;
56  // genl_header->reserved is aready 0.
57}
58
59bool NL80211Packet::IsValid() const {
60  // Verify the size of packet.
61  if (data_.size() < NLMSG_HDRLEN) {
62    LOG(ERROR) << "Cannot retrieve netlink header.";
63    return false;
64  }
65
66  const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
67
68  // If type < NLMSG_MIN_TYPE, this should be a reserved control message,
69  // which doesn't carry a generic netlink header.
70  if (GetMessageType() >= NLMSG_MIN_TYPE) {
71    if (data_.size() < NLMSG_HDRLEN + GENL_HDRLEN ||
72        nl_header->nlmsg_len < NLMSG_HDRLEN + GENL_HDRLEN) {
73      LOG(ERROR) << "Cannot retrieve generic netlink header.";
74      return false;
75    }
76  }
77  // If it is an ERROR message, it should be long enough to carry an extra error
78  // code field.
79  // Kernel uses int for this field.
80  if (GetMessageType() == NLMSG_ERROR) {
81    if (data_.size() < NLMSG_HDRLEN + sizeof(int) ||
82        nl_header->nlmsg_len < NLMSG_HDRLEN + sizeof(int)) {
83     LOG(ERROR) << "Broken error message.";
84     return false;
85    }
86  }
87
88  // Verify the netlink header.
89  if (data_.size() < nl_header->nlmsg_len ||
90      nl_header->nlmsg_len < sizeof(nlmsghdr)) {
91    LOG(ERROR) << "Discarding incomplete / invalid message.";
92    return false;
93  }
94  return true;
95}
96
97bool NL80211Packet::IsDump() const {
98  return GetFlags() & NLM_F_DUMP;
99}
100
101bool NL80211Packet::IsMulti() const {
102  return GetFlags() & NLM_F_MULTI;
103}
104
105uint8_t NL80211Packet::GetCommand() const {
106  const genlmsghdr* genl_header = reinterpret_cast<const genlmsghdr*>(
107      data_.data() + NLMSG_HDRLEN);
108  return genl_header->cmd;
109}
110
111uint16_t NL80211Packet::GetFlags() const {
112  const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
113  return nl_header->nlmsg_flags;
114}
115
116uint16_t NL80211Packet::GetMessageType() const {
117  const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
118  return nl_header->nlmsg_type;
119}
120
121uint32_t NL80211Packet::GetMessageSequence() const {
122  const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
123  return nl_header->nlmsg_seq;
124}
125
126uint32_t NL80211Packet::GetPortId() const {
127  const nlmsghdr* nl_header = reinterpret_cast<const nlmsghdr*>(data_.data());
128  return nl_header->nlmsg_pid;
129}
130
131int NL80211Packet::GetErrorCode() const {
132  return -*reinterpret_cast<const int*>(data_.data() + NLMSG_HDRLEN);
133}
134
135const vector<uint8_t>& NL80211Packet::GetConstData() const {
136  return data_;
137}
138
139void NL80211Packet::SetCommand(uint8_t command) {
140  genlmsghdr* genl_header = reinterpret_cast<genlmsghdr*>(
141      data_.data() + NLMSG_HDRLEN);
142  genl_header->cmd = command;
143}
144
145void NL80211Packet::AddFlag(uint16_t flag) {
146  nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
147  nl_header->nlmsg_flags |= flag;
148}
149
150void NL80211Packet::SetFlags(uint16_t flags) {
151  nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
152  nl_header->nlmsg_flags = flags;
153}
154
155void NL80211Packet::SetMessageType(uint16_t message_type) {
156  nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
157  nl_header->nlmsg_type = message_type;
158}
159
160void NL80211Packet::SetMessageSequence(uint32_t message_sequence) {
161  nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
162  nl_header->nlmsg_seq = message_sequence;
163}
164
165void NL80211Packet::SetPortId(uint32_t pid) {
166  nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
167  nl_header->nlmsg_pid = pid;
168}
169
170void NL80211Packet::AddAttribute(const BaseNL80211Attr& attribute) {
171  const vector<uint8_t>& append_data = attribute.GetConstData();
172  // Append the data of |attribute| to |this|.
173  data_.insert(data_.end(), append_data.begin(), append_data.end());
174  nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
175  // We don't need to worry about padding for a nl80211 packet.
176  // Because as long as all sub attributes have padding, the payload is aligned.
177  nl_header->nlmsg_len += append_data.size();
178}
179
180void NL80211Packet::AddFlagAttribute(int attribute_id) {
181  // We only need to append a header for flag attribute.
182  // Make space for the new attribute.
183  data_.resize(data_.size() + NLA_HDRLEN, 0);
184  nlattr* flag_header =
185      reinterpret_cast<nlattr*>(data_.data() + data_.size() - NLA_HDRLEN);
186  flag_header->nla_type = attribute_id;
187  flag_header->nla_len = NLA_HDRLEN;
188  nlmsghdr* nl_header = reinterpret_cast<nlmsghdr*>(data_.data());
189  nl_header->nlmsg_len += NLA_HDRLEN;
190}
191
192bool NL80211Packet::HasAttribute(int id) const {
193  return BaseNL80211Attr::GetAttributeImpl(
194      data_.data() + NLMSG_HDRLEN + GENL_HDRLEN,
195      data_.size() - NLMSG_HDRLEN - GENL_HDRLEN,
196      id, nullptr, nullptr);
197}
198
199bool NL80211Packet::GetAttribute(int id,
200    NL80211NestedAttr* attribute) const {
201  uint8_t* start = nullptr;
202  uint8_t* end = nullptr;
203  if (!BaseNL80211Attr::GetAttributeImpl(
204          data_.data() + NLMSG_HDRLEN + GENL_HDRLEN,
205          data_.size() - NLMSG_HDRLEN - GENL_HDRLEN,
206          id, &start, &end) ||
207      start == nullptr ||
208      end == nullptr) {
209    return false;
210  }
211  *attribute = NL80211NestedAttr(vector<uint8_t>(start, end));
212  if (!attribute->IsValid()) {
213    return false;
214  }
215  return true;
216}
217
218void NL80211Packet::DebugLog() const {
219  const uint8_t* ptr = data_.data() + NLMSG_HDRLEN + GENL_HDRLEN;
220  const uint8_t* end_ptr = data_.data() + data_.size();
221  while (ptr + NLA_HDRLEN <= end_ptr) {
222    const nlattr* header = reinterpret_cast<const nlattr*>(ptr);
223    if (ptr + NLA_ALIGN(header->nla_len) > end_ptr) {
224      LOG(ERROR) << "broken nl80211 atrribute.";
225      return;
226    }
227    LOG(INFO) << "Have attribute with nla_type=" << header->nla_type
228              << " and nla_len=" << header->nla_len;
229    if (header->nla_len == 0) {
230      LOG(ERROR) << "0 is a bad nla_len";
231      return;
232    }
233    ptr += NLA_ALIGN(header->nla_len);
234  }
235}
236
237}  // namespace wificond
238}  // namespace android
239