1// Copyright 2014 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <brillo/dbus/dbus_method_invoker.h>
6
7#include <string>
8
9#include <brillo/bind_lambda.h>
10#include <dbus/mock_bus.h>
11#include <dbus/mock_object_proxy.h>
12#include <dbus/scoped_dbus_error.h>
13#include <gmock/gmock.h>
14#include <gtest/gtest.h>
15
16#include "brillo/dbus/test.pb.h"
17
18using testing::AnyNumber;
19using testing::InSequence;
20using testing::Invoke;
21using testing::Return;
22using testing::_;
23
24using dbus::MessageReader;
25using dbus::MessageWriter;
26using dbus::Response;
27
28namespace brillo {
29namespace dbus_utils {
30
31const char kTestPath[] = "/test/path";
32const char kTestServiceName[] = "org.test.Object";
33const char kTestInterface[] = "org.test.Object.TestInterface";
34const char kTestMethod1[] = "TestMethod1";
35const char kTestMethod2[] = "TestMethod2";
36const char kTestMethod3[] = "TestMethod3";
37const char kTestMethod4[] = "TestMethod4";
38
39class DBusMethodInvokerTest : public testing::Test {
40 public:
41  void SetUp() override {
42    dbus::Bus::Options options;
43    options.bus_type = dbus::Bus::SYSTEM;
44    bus_ = new dbus::MockBus(options);
45    // By default, don't worry about threading assertions.
46    EXPECT_CALL(*bus_, AssertOnOriginThread()).Times(AnyNumber());
47    EXPECT_CALL(*bus_, AssertOnDBusThread()).Times(AnyNumber());
48    // Use a mock exported object.
49    mock_object_proxy_ = new dbus::MockObjectProxy(
50        bus_.get(), kTestServiceName, dbus::ObjectPath(kTestPath));
51    EXPECT_CALL(*bus_,
52                GetObjectProxy(kTestServiceName, dbus::ObjectPath(kTestPath)))
53        .WillRepeatedly(Return(mock_object_proxy_.get()));
54    int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT;
55    EXPECT_CALL(*mock_object_proxy_,
56                MockCallMethodAndBlockWithErrorDetails(_, def_timeout_ms, _))
57        .WillRepeatedly(Invoke(this, &DBusMethodInvokerTest::CreateResponse));
58  }
59
60  void TearDown() override { bus_ = nullptr; }
61
62  Response* CreateResponse(dbus::MethodCall* method_call,
63                           int /* timeout_ms */,
64                           dbus::ScopedDBusError* dbus_error) {
65    if (method_call->GetInterface() == kTestInterface) {
66      if (method_call->GetMember() == kTestMethod1) {
67        MessageReader reader(method_call);
68        int v1, v2;
69        // Input: two ints.
70        // Output: sum of the ints converted to string.
71        if (reader.PopInt32(&v1) && reader.PopInt32(&v2)) {
72          auto response = Response::CreateEmpty();
73          MessageWriter writer(response.get());
74          writer.AppendString(std::to_string(v1 + v2));
75          return response.release();
76        }
77      } else if (method_call->GetMember() == kTestMethod2) {
78        method_call->SetSerial(123);
79        dbus_set_error(dbus_error->get(), "org.MyError", "My error message");
80        return nullptr;
81      } else if (method_call->GetMember() == kTestMethod3) {
82        MessageReader reader(method_call);
83        dbus_utils_test::TestMessage msg;
84        if (PopValueFromReader(&reader, &msg)) {
85          auto response = Response::CreateEmpty();
86          MessageWriter writer(response.get());
87          AppendValueToWriter(&writer, msg);
88          return response.release();
89        }
90      } else if (method_call->GetMember() == kTestMethod4) {
91        method_call->SetSerial(123);
92        MessageReader reader(method_call);
93        dbus::FileDescriptor fd;
94        if (reader.PopFileDescriptor(&fd)) {
95          auto response = Response::CreateEmpty();
96          MessageWriter writer(response.get());
97          fd.CheckValidity();
98          writer.AppendFileDescriptor(fd);
99          return response.release();
100        }
101      }
102    }
103
104    LOG(ERROR) << "Unexpected method call: " << method_call->ToString();
105    return nullptr;
106  }
107
108  std::string CallTestMethod(int v1, int v2) {
109    std::unique_ptr<dbus::Response> response =
110        brillo::dbus_utils::CallMethodAndBlock(mock_object_proxy_.get(),
111                                               kTestInterface, kTestMethod1,
112                                               nullptr, v1, v2);
113    EXPECT_NE(nullptr, response.get());
114    std::string result;
115    using brillo::dbus_utils::ExtractMethodCallResults;
116    EXPECT_TRUE(ExtractMethodCallResults(response.get(), nullptr, &result));
117    return result;
118  }
119
120  dbus_utils_test::TestMessage CallProtobufTestMethod(
121      const dbus_utils_test::TestMessage& message) {
122    std::unique_ptr<dbus::Response> response =
123        brillo::dbus_utils::CallMethodAndBlock(mock_object_proxy_.get(),
124                                               kTestInterface, kTestMethod3,
125                                               nullptr, message);
126    EXPECT_NE(nullptr, response.get());
127    dbus_utils_test::TestMessage result;
128    using brillo::dbus_utils::ExtractMethodCallResults;
129    EXPECT_TRUE(ExtractMethodCallResults(response.get(), nullptr, &result));
130    return result;
131  }
132
133  // Sends a file descriptor received over D-Bus back to the caller.
134  dbus::FileDescriptor EchoFD(const dbus::FileDescriptor& fd_in) {
135    std::unique_ptr<dbus::Response> response =
136        brillo::dbus_utils::CallMethodAndBlock(mock_object_proxy_.get(),
137                                               kTestInterface, kTestMethod4,
138                                               nullptr, fd_in);
139    EXPECT_NE(nullptr, response.get());
140    dbus::FileDescriptor fd_out;
141    using brillo::dbus_utils::ExtractMethodCallResults;
142    EXPECT_TRUE(ExtractMethodCallResults(response.get(), nullptr, &fd_out));
143    return fd_out;
144  }
145
146  scoped_refptr<dbus::MockBus> bus_;
147  scoped_refptr<dbus::MockObjectProxy> mock_object_proxy_;
148};
149
150TEST_F(DBusMethodInvokerTest, TestSuccess) {
151  EXPECT_EQ("4", CallTestMethod(2, 2));
152  EXPECT_EQ("10", CallTestMethod(3, 7));
153  EXPECT_EQ("-4", CallTestMethod(13, -17));
154}
155
156TEST_F(DBusMethodInvokerTest, TestFailure) {
157  brillo::ErrorPtr error;
158  std::unique_ptr<dbus::Response> response =
159      brillo::dbus_utils::CallMethodAndBlock(
160          mock_object_proxy_.get(), kTestInterface, kTestMethod2, &error);
161  EXPECT_EQ(nullptr, response.get());
162  EXPECT_EQ(brillo::errors::dbus::kDomain, error->GetDomain());
163  EXPECT_EQ("org.MyError", error->GetCode());
164  EXPECT_EQ("My error message", error->GetMessage());
165}
166
167TEST_F(DBusMethodInvokerTest, TestProtobuf) {
168  dbus_utils_test::TestMessage test_message;
169  test_message.set_foo(123);
170  test_message.set_bar("bar");
171
172  dbus_utils_test::TestMessage resp = CallProtobufTestMethod(test_message);
173
174  EXPECT_EQ(123, resp.foo());
175  EXPECT_EQ("bar", resp.bar());
176}
177
178TEST_F(DBusMethodInvokerTest, TestFileDescriptors) {
179  // Passing a file descriptor over D-Bus would effectively duplicate the fd.
180  // So the resulting file descriptor value would be different but it still
181  // should be valid.
182  dbus::FileDescriptor fd_stdin(0);
183  fd_stdin.CheckValidity();
184  EXPECT_NE(fd_stdin.value(), EchoFD(fd_stdin).value());
185  dbus::FileDescriptor fd_stdout(1);
186  fd_stdout.CheckValidity();
187  EXPECT_NE(fd_stdout.value(), EchoFD(fd_stdout).value());
188  dbus::FileDescriptor fd_stderr(2);
189  fd_stderr.CheckValidity();
190  EXPECT_NE(fd_stderr.value(), EchoFD(fd_stderr).value());
191}
192
193//////////////////////////////////////////////////////////////////////////////
194// Asynchronous method invocation support
195
196class AsyncDBusMethodInvokerTest : public testing::Test {
197 public:
198  void SetUp() override {
199    dbus::Bus::Options options;
200    options.bus_type = dbus::Bus::SYSTEM;
201    bus_ = new dbus::MockBus(options);
202    // By default, don't worry about threading assertions.
203    EXPECT_CALL(*bus_, AssertOnOriginThread()).Times(AnyNumber());
204    EXPECT_CALL(*bus_, AssertOnDBusThread()).Times(AnyNumber());
205    // Use a mock exported object.
206    mock_object_proxy_ = new dbus::MockObjectProxy(
207        bus_.get(), kTestServiceName, dbus::ObjectPath(kTestPath));
208    EXPECT_CALL(*bus_,
209                GetObjectProxy(kTestServiceName, dbus::ObjectPath(kTestPath)))
210        .WillRepeatedly(Return(mock_object_proxy_.get()));
211    int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT;
212    EXPECT_CALL(*mock_object_proxy_,
213                CallMethodWithErrorCallback(_, def_timeout_ms, _, _))
214        .WillRepeatedly(Invoke(this, &AsyncDBusMethodInvokerTest::HandleCall));
215  }
216
217  void TearDown() override { bus_ = nullptr; }
218
219  void HandleCall(dbus::MethodCall* method_call,
220                  int /* timeout_ms */,
221                  dbus::ObjectProxy::ResponseCallback success_callback,
222                  dbus::ObjectProxy::ErrorCallback error_callback) {
223    if (method_call->GetInterface() == kTestInterface) {
224      if (method_call->GetMember() == kTestMethod1) {
225        MessageReader reader(method_call);
226        int v1, v2;
227        // Input: two ints.
228        // Output: sum of the ints converted to string.
229        if (reader.PopInt32(&v1) && reader.PopInt32(&v2)) {
230          auto response = Response::CreateEmpty();
231          MessageWriter writer(response.get());
232          writer.AppendString(std::to_string(v1 + v2));
233          success_callback.Run(response.get());
234        }
235        return;
236      } else if (method_call->GetMember() == kTestMethod2) {
237        method_call->SetSerial(123);
238        auto error_response = dbus::ErrorResponse::FromMethodCall(
239            method_call, "org.MyError", "My error message");
240        error_callback.Run(error_response.get());
241        return;
242      }
243    }
244
245    LOG(FATAL) << "Unexpected method call: " << method_call->ToString();
246  }
247
248  struct SuccessCallback {
249    SuccessCallback(const std::string& in_result, int* in_counter)
250        : result(in_result), counter(in_counter) {}
251
252    explicit SuccessCallback(int* in_counter) : counter(in_counter) {}
253
254    void operator()(const std::string& actual_result) {
255      (*counter)++;
256      EXPECT_EQ(result, actual_result);
257    }
258    std::string result;
259    int* counter;
260  };
261
262  struct ErrorCallback {
263    ErrorCallback(const std::string& in_domain,
264                  const std::string& in_code,
265                  const std::string& in_message,
266                  int* in_counter)
267        : domain(in_domain),
268          code(in_code),
269          message(in_message),
270          counter(in_counter) {}
271
272    explicit ErrorCallback(int* in_counter) : counter(in_counter) {}
273
274    void operator()(brillo::Error* error) {
275      (*counter)++;
276      EXPECT_NE(nullptr, error);
277      EXPECT_EQ(domain, error->GetDomain());
278      EXPECT_EQ(code, error->GetCode());
279      EXPECT_EQ(message, error->GetMessage());
280    }
281
282    std::string domain;
283    std::string code;
284    std::string message;
285    int* counter;
286  };
287
288  scoped_refptr<dbus::MockBus> bus_;
289  scoped_refptr<dbus::MockObjectProxy> mock_object_proxy_;
290};
291
292TEST_F(AsyncDBusMethodInvokerTest, TestSuccess) {
293  int error_count = 0;
294  int success_count = 0;
295  brillo::dbus_utils::CallMethod(
296      mock_object_proxy_.get(),
297      kTestInterface,
298      kTestMethod1,
299      base::Bind(SuccessCallback{"4", &success_count}),
300      base::Bind(ErrorCallback{&error_count}),
301      2, 2);
302  brillo::dbus_utils::CallMethod(
303      mock_object_proxy_.get(),
304      kTestInterface,
305      kTestMethod1,
306      base::Bind(SuccessCallback{"10", &success_count}),
307      base::Bind(ErrorCallback{&error_count}),
308      3, 7);
309  brillo::dbus_utils::CallMethod(
310      mock_object_proxy_.get(),
311      kTestInterface,
312      kTestMethod1,
313      base::Bind(SuccessCallback{"-4", &success_count}),
314      base::Bind(ErrorCallback{&error_count}),
315      13, -17);
316  EXPECT_EQ(0, error_count);
317  EXPECT_EQ(3, success_count);
318}
319
320TEST_F(AsyncDBusMethodInvokerTest, TestFailure) {
321  int error_count = 0;
322  int success_count = 0;
323  brillo::dbus_utils::CallMethod(
324      mock_object_proxy_.get(),
325      kTestInterface,
326      kTestMethod2,
327      base::Bind(SuccessCallback{&success_count}),
328      base::Bind(ErrorCallback{brillo::errors::dbus::kDomain,
329                               "org.MyError",
330                               "My error message",
331                               &error_count}),
332      2, 2);
333  EXPECT_EQ(1, error_count);
334  EXPECT_EQ(0, success_count);
335}
336
337}  // namespace dbus_utils
338}  // namespace brillo
339