1//
2// Copyright (C) 2013 The Android Open Source Project
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8//      http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15//
16
17#include "shill/net/netlink_socket.h"
18
19#include <linux/netlink.h>
20
21#include <algorithm>
22#include <string>
23
24#include <gmock/gmock.h>
25#include <gtest/gtest.h>
26
27#include "shill/net/byte_string.h"
28#include "shill/net/mock_sockets.h"
29#include "shill/net/netlink_message.h"
30
31using std::min;
32using std::string;
33using testing::_;
34using testing::Invoke;
35using testing::Return;
36using testing::Test;
37
38namespace shill {
39
40class NetlinkSocketTest;
41
42const int kFakeFd = 99;
43
44class NetlinkSocketTest : public Test {
45 public:
46  NetlinkSocketTest() {}
47  virtual ~NetlinkSocketTest() {}
48
49  virtual void SetUp() {
50    mock_sockets_ = new MockSockets();
51    netlink_socket_.sockets_.reset(mock_sockets_);
52  }
53
54  virtual void InitializeSocket(int fd) {
55    EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
56        .WillOnce(Return(fd));
57    EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(
58        fd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(0));
59    EXPECT_CALL(*mock_sockets_, Bind(fd, _, sizeof(struct sockaddr_nl)))
60        .WillOnce(Return(0));
61    EXPECT_TRUE(netlink_socket_.Init());
62  }
63
64 protected:
65  MockSockets* mock_sockets_;  // Owned by netlink_socket_.
66  NetlinkSocket netlink_socket_;
67};
68
69class FakeSocketRead {
70 public:
71  explicit FakeSocketRead(const ByteString& next_read_string) {
72    next_read_string_ = next_read_string;
73  }
74  // Copies |len| bytes of |next_read_string_| into |buf| and clears
75  // |next_read_string_|.
76  ssize_t FakeSuccessfulRead(int sockfd, void* buf, size_t len, int flags,
77                             struct sockaddr* src_addr, socklen_t* addrlen) {
78    if (!buf) {
79      return -1;
80    }
81    int read_bytes = min(len, next_read_string_.GetLength());
82    memcpy(buf, next_read_string_.GetConstData(), read_bytes);
83    next_read_string_.Clear();
84    return read_bytes;
85  }
86
87 private:
88  ByteString next_read_string_;
89};
90
91TEST_F(NetlinkSocketTest, InitWorkingTest) {
92  SetUp();
93  InitializeSocket(kFakeFd);
94  EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
95}
96
97TEST_F(NetlinkSocketTest, InitBrokenSocketTest) {
98  SetUp();
99
100  const int kBadFd = -1;
101  EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
102      .WillOnce(Return(kBadFd));
103  EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(_, _)).Times(0);
104  EXPECT_CALL(*mock_sockets_, Bind(_, _, _)).Times(0);
105  EXPECT_FALSE(netlink_socket_.Init());
106}
107
108TEST_F(NetlinkSocketTest, InitBrokenBufferTest) {
109  SetUp();
110
111  EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
112      .WillOnce(Return(kFakeFd));
113  EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(
114      kFakeFd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(-1));
115  EXPECT_CALL(*mock_sockets_, Bind(kFakeFd, _, sizeof(struct sockaddr_nl)))
116      .WillOnce(Return(0));
117  EXPECT_TRUE(netlink_socket_.Init());
118
119  // Destructor.
120  EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
121}
122
123TEST_F(NetlinkSocketTest, InitBrokenBindTest) {
124  SetUp();
125
126  EXPECT_CALL(*mock_sockets_, Socket(PF_NETLINK, SOCK_DGRAM, NETLINK_GENERIC))
127      .WillOnce(Return(kFakeFd));
128  EXPECT_CALL(*mock_sockets_, SetReceiveBuffer(
129      kFakeFd, NetlinkSocket::kReceiveBufferSize)).WillOnce(Return(0));
130  EXPECT_CALL(*mock_sockets_, Bind(kFakeFd, _, sizeof(struct sockaddr_nl)))
131      .WillOnce(Return(-1));
132  EXPECT_CALL(*mock_sockets_, Close(kFakeFd)).WillOnce(Return(0));
133  EXPECT_FALSE(netlink_socket_.Init());
134}
135
136TEST_F(NetlinkSocketTest, SendMessageTest) {
137  SetUp();
138  InitializeSocket(kFakeFd);
139
140  string message_string("This text is really arbitrary");
141  ByteString message(message_string.c_str(), message_string.size());
142
143  // Good Send.
144  EXPECT_CALL(*mock_sockets_,
145              Send(kFakeFd, message.GetConstData(), message.GetLength(), 0))
146      .WillOnce(Return(message.GetLength()));
147  EXPECT_TRUE(netlink_socket_.SendMessage(message));
148
149  // Short Send.
150  EXPECT_CALL(*mock_sockets_,
151              Send(kFakeFd, message.GetConstData(), message.GetLength(), 0))
152      .WillOnce(Return(message.GetLength() - 3));
153  EXPECT_FALSE(netlink_socket_.SendMessage(message));
154
155  // Bad Send.
156  EXPECT_CALL(*mock_sockets_,
157              Send(kFakeFd, message.GetConstData(), message.GetLength(), 0))
158      .WillOnce(Return(-1));
159  EXPECT_FALSE(netlink_socket_.SendMessage(message));
160
161  // Destructor.
162  EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
163}
164
165TEST_F(NetlinkSocketTest, SequenceNumberTest) {
166  SetUp();
167
168  // Just a sequence number.
169  const uint32_t arbitrary_number = 42;
170  netlink_socket_.sequence_number_ = arbitrary_number;
171  EXPECT_EQ(arbitrary_number+1, netlink_socket_.GetSequenceNumber());
172
173  // Make sure we don't go to |NetlinkMessage::kBroadcastSequenceNumber|.
174  netlink_socket_.sequence_number_ = NetlinkMessage::kBroadcastSequenceNumber;
175  EXPECT_NE(NetlinkMessage::kBroadcastSequenceNumber,
176            netlink_socket_.GetSequenceNumber());
177}
178
179TEST_F(NetlinkSocketTest, GoodRecvMessageTest) {
180  SetUp();
181  InitializeSocket(kFakeFd);
182
183  ByteString message;
184  static const string next_read_string(
185      "Random text may include things like 'freaking fracking foo'.");
186  static const size_t read_size = next_read_string.size();
187  ByteString expected_results(next_read_string.c_str(), read_size);
188  FakeSocketRead fake_socket_read(expected_results);
189
190  // Expect one call to get the size...
191  EXPECT_CALL(*mock_sockets_,
192              RecvFrom(kFakeFd, _, _, MSG_TRUNC | MSG_PEEK, _, _))
193      .WillOnce(Return(read_size));
194
195  // ...and expect a second call to get the data.
196  EXPECT_CALL(*mock_sockets_,
197              RecvFrom(kFakeFd, _, read_size, 0, _, _))
198      .WillOnce(Invoke(&fake_socket_read, &FakeSocketRead::FakeSuccessfulRead));
199
200  EXPECT_TRUE(netlink_socket_.RecvMessage(&message));
201  EXPECT_TRUE(message.Equals(expected_results));
202
203  // Destructor.
204  EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
205}
206
207TEST_F(NetlinkSocketTest, BadRecvMessageTest) {
208  SetUp();
209  InitializeSocket(kFakeFd);
210
211  ByteString message;
212  EXPECT_CALL(*mock_sockets_, RecvFrom(kFakeFd, _, _, _, _, _))
213      .WillOnce(Return(-1));
214  EXPECT_FALSE(netlink_socket_.RecvMessage(&message));
215
216  EXPECT_CALL(*mock_sockets_, Close(kFakeFd));
217}
218
219}  // namespace shill.
220