1/*
2 * Copyright 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <binder/SafeInterface.h>
18
19#include <binder/IInterface.h>
20#include <binder/IPCThreadState.h>
21#include <binder/IServiceManager.h>
22#include <binder/Parcel.h>
23#include <binder/Parcelable.h>
24#include <binder/ProcessState.h>
25
26#pragma clang diagnostic push
27#pragma clang diagnostic ignored "-Weverything"
28#include <gtest/gtest.h>
29#pragma clang diagnostic pop
30
31#include <utils/LightRefBase.h>
32#include <utils/NativeHandle.h>
33
34#include <cutils/native_handle.h>
35
36#include <optional>
37
38#include <sys/eventfd.h>
39
40using namespace std::chrono_literals; // NOLINT - google-build-using-namespace
41
42namespace android {
43namespace tests {
44
45enum class TestEnum : uint32_t {
46    INVALID = 0,
47    INITIAL = 1,
48    FINAL = 2,
49};
50
51// This class serves two purposes:
52//   1) It ensures that the implementation doesn't require copying or moving the data (for
53//      efficiency purposes)
54//   2) It tests that Parcelables can be passed correctly
55class NoCopyNoMove : public Parcelable {
56public:
57    NoCopyNoMove() = default;
58    explicit NoCopyNoMove(int32_t value) : mValue(value) {}
59    ~NoCopyNoMove() override = default;
60
61    // Not copyable
62    NoCopyNoMove(const NoCopyNoMove&) = delete;
63    NoCopyNoMove& operator=(const NoCopyNoMove&) = delete;
64
65    // Not movable
66    NoCopyNoMove(NoCopyNoMove&&) = delete;
67    NoCopyNoMove& operator=(NoCopyNoMove&&) = delete;
68
69    // Parcelable interface
70    status_t writeToParcel(Parcel* parcel) const override { return parcel->writeInt32(mValue); }
71    status_t readFromParcel(const Parcel* parcel) override { return parcel->readInt32(&mValue); }
72
73    int32_t getValue() const { return mValue; }
74    void setValue(int32_t value) { mValue = value; }
75
76private:
77    int32_t mValue = 0;
78    uint8_t mPadding[4] = {}; // Avoids a warning from -Wpadded
79};
80
81struct TestFlattenable : Flattenable<TestFlattenable> {
82    TestFlattenable() = default;
83    explicit TestFlattenable(int32_t v) : value(v) {}
84
85    // Flattenable protocol
86    size_t getFlattenedSize() const { return sizeof(value); }
87    size_t getFdCount() const { return 0; }
88    status_t flatten(void*& buffer, size_t& size, int*& /*fds*/, size_t& /*count*/) const {
89        FlattenableUtils::write(buffer, size, value);
90        return NO_ERROR;
91    }
92    status_t unflatten(void const*& buffer, size_t& size, int const*& /*fds*/, size_t& /*count*/) {
93        FlattenableUtils::read(buffer, size, value);
94        return NO_ERROR;
95    }
96
97    int32_t value = 0;
98};
99
100struct TestLightFlattenable : LightFlattenablePod<TestLightFlattenable> {
101    TestLightFlattenable() = default;
102    explicit TestLightFlattenable(int32_t v) : value(v) {}
103    int32_t value = 0;
104};
105
106// It seems like this should be able to inherit from TestFlattenable (to avoid duplicating code),
107// but the SafeInterface logic can't easily be extended to find an indirect Flattenable<T>
108// base class
109class TestLightRefBaseFlattenable : public Flattenable<TestLightRefBaseFlattenable>,
110                                    public LightRefBase<TestLightRefBaseFlattenable> {
111public:
112    TestLightRefBaseFlattenable() = default;
113    explicit TestLightRefBaseFlattenable(int32_t v) : value(v) {}
114
115    // Flattenable protocol
116    size_t getFlattenedSize() const { return sizeof(value); }
117    size_t getFdCount() const { return 0; }
118    status_t flatten(void*& buffer, size_t& size, int*& /*fds*/, size_t& /*count*/) const {
119        FlattenableUtils::write(buffer, size, value);
120        return NO_ERROR;
121    }
122    status_t unflatten(void const*& buffer, size_t& size, int const*& /*fds*/, size_t& /*count*/) {
123        FlattenableUtils::read(buffer, size, value);
124        return NO_ERROR;
125    }
126
127    int32_t value = 0;
128};
129
130class TestParcelable : public Parcelable {
131public:
132    TestParcelable() = default;
133    explicit TestParcelable(int32_t value) : mValue(value) {}
134    TestParcelable(const TestParcelable& other) : TestParcelable(other.mValue) {}
135    TestParcelable(TestParcelable&& other) : TestParcelable(other.mValue) {}
136
137    // Parcelable interface
138    status_t writeToParcel(Parcel* parcel) const override { return parcel->writeInt32(mValue); }
139    status_t readFromParcel(const Parcel* parcel) override { return parcel->readInt32(&mValue); }
140
141    int32_t getValue() const { return mValue; }
142    void setValue(int32_t value) { mValue = value; }
143
144private:
145    int32_t mValue = 0;
146};
147
148class ExitOnDeath : public IBinder::DeathRecipient {
149public:
150    ~ExitOnDeath() override = default;
151
152    void binderDied(const wp<IBinder>& /*who*/) override {
153        ALOG(LOG_INFO, "ExitOnDeath", "Exiting");
154        exit(0);
155    }
156};
157
158// This callback class is used to test both one-way transactions and that sp<IInterface> can be
159// passed correctly
160class ICallback : public IInterface {
161public:
162    DECLARE_META_INTERFACE(Callback)
163
164    enum class Tag : uint32_t {
165        OnCallback = IBinder::FIRST_CALL_TRANSACTION,
166        Last,
167    };
168
169    virtual void onCallback(int32_t aPlusOne) = 0;
170};
171
172class BpCallback : public SafeBpInterface<ICallback> {
173public:
174    explicit BpCallback(const sp<IBinder>& impl) : SafeBpInterface<ICallback>(impl, getLogTag()) {}
175
176    void onCallback(int32_t aPlusOne) override {
177        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
178        return callRemoteAsync<decltype(&ICallback::onCallback)>(Tag::OnCallback, aPlusOne);
179    }
180
181private:
182    static constexpr const char* getLogTag() { return "BpCallback"; }
183};
184
185#pragma clang diagnostic push
186#pragma clang diagnostic ignored "-Wexit-time-destructors"
187IMPLEMENT_META_INTERFACE(Callback, "android.gfx.tests.ICallback");
188#pragma clang diagnostic pop
189
190class BnCallback : public SafeBnInterface<ICallback> {
191public:
192    BnCallback() : SafeBnInterface("BnCallback") {}
193
194    status_t onTransact(uint32_t code, const Parcel& data, Parcel* reply,
195                        uint32_t /*flags*/) override {
196        EXPECT_GE(code, IBinder::FIRST_CALL_TRANSACTION);
197        EXPECT_LT(code, static_cast<uint32_t>(ICallback::Tag::Last));
198        ICallback::Tag tag = static_cast<ICallback::Tag>(code);
199        switch (tag) {
200            case ICallback::Tag::OnCallback: {
201                return callLocalAsync(data, reply, &ICallback::onCallback);
202            }
203            case ICallback::Tag::Last:
204                // Should not be possible because of the asserts at the beginning of the method
205                [&]() { FAIL(); }();
206                return UNKNOWN_ERROR;
207        }
208    }
209};
210
211class ISafeInterfaceTest : public IInterface {
212public:
213    DECLARE_META_INTERFACE(SafeInterfaceTest)
214
215    enum class Tag : uint32_t {
216        SetDeathToken = IBinder::FIRST_CALL_TRANSACTION,
217        ReturnsNoMemory,
218        LogicalNot,
219        ModifyEnum,
220        IncrementFlattenable,
221        IncrementLightFlattenable,
222        IncrementLightRefBaseFlattenable,
223        IncrementNativeHandle,
224        IncrementNoCopyNoMove,
225        IncrementParcelableVector,
226        ToUpper,
227        CallMeBack,
228        IncrementInt32,
229        IncrementUint32,
230        IncrementInt64,
231        IncrementUint64,
232        IncrementTwo,
233        Last,
234    };
235
236    // This is primarily so that the remote service dies when the test does, but it also serves to
237    // test the handling of sp<IBinder> and non-const methods
238    virtual status_t setDeathToken(const sp<IBinder>& token) = 0;
239
240    // This is the most basic test since it doesn't require parceling any arguments
241    virtual status_t returnsNoMemory() const = 0;
242
243    // These are ordered according to their corresponding methods in SafeInterface::ParcelHandler
244    virtual status_t logicalNot(bool a, bool* notA) const = 0;
245    virtual status_t modifyEnum(TestEnum a, TestEnum* b) const = 0;
246    virtual status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const = 0;
247    virtual status_t increment(const TestLightFlattenable& a,
248                               TestLightFlattenable* aPlusOne) const = 0;
249    virtual status_t increment(const sp<TestLightRefBaseFlattenable>& a,
250                               sp<TestLightRefBaseFlattenable>* aPlusOne) const = 0;
251    virtual status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const = 0;
252    virtual status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const = 0;
253    virtual status_t increment(const std::vector<TestParcelable>& a,
254                               std::vector<TestParcelable>* aPlusOne) const = 0;
255    virtual status_t toUpper(const String8& str, String8* upperStr) const = 0;
256    // As mentioned above, sp<IBinder> is already tested by setDeathToken
257    virtual void callMeBack(const sp<ICallback>& callback, int32_t a) const = 0;
258    virtual status_t increment(int32_t a, int32_t* aPlusOne) const = 0;
259    virtual status_t increment(uint32_t a, uint32_t* aPlusOne) const = 0;
260    virtual status_t increment(int64_t a, int64_t* aPlusOne) const = 0;
261    virtual status_t increment(uint64_t a, uint64_t* aPlusOne) const = 0;
262
263    // This tests that input/output parameter interleaving works correctly
264    virtual status_t increment(int32_t a, int32_t* aPlusOne, int32_t b,
265                               int32_t* bPlusOne) const = 0;
266};
267
268class BpSafeInterfaceTest : public SafeBpInterface<ISafeInterfaceTest> {
269public:
270    explicit BpSafeInterfaceTest(const sp<IBinder>& impl)
271          : SafeBpInterface<ISafeInterfaceTest>(impl, getLogTag()) {}
272
273    status_t setDeathToken(const sp<IBinder>& token) override {
274        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
275        return callRemote<decltype(&ISafeInterfaceTest::setDeathToken)>(Tag::SetDeathToken, token);
276    }
277    status_t returnsNoMemory() const override {
278        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
279        return callRemote<decltype(&ISafeInterfaceTest::returnsNoMemory)>(Tag::ReturnsNoMemory);
280    }
281    status_t logicalNot(bool a, bool* notA) const override {
282        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
283        return callRemote<decltype(&ISafeInterfaceTest::logicalNot)>(Tag::LogicalNot, a, notA);
284    }
285    status_t modifyEnum(TestEnum a, TestEnum* b) const override {
286        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
287        return callRemote<decltype(&ISafeInterfaceTest::modifyEnum)>(Tag::ModifyEnum, a, b);
288    }
289    status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const override {
290        using Signature =
291                status_t (ISafeInterfaceTest::*)(const TestFlattenable&, TestFlattenable*) const;
292        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
293        return callRemote<Signature>(Tag::IncrementFlattenable, a, aPlusOne);
294    }
295    status_t increment(const TestLightFlattenable& a,
296                       TestLightFlattenable* aPlusOne) const override {
297        using Signature = status_t (ISafeInterfaceTest::*)(const TestLightFlattenable&,
298                                                           TestLightFlattenable*) const;
299        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
300        return callRemote<Signature>(Tag::IncrementLightFlattenable, a, aPlusOne);
301    }
302    status_t increment(const sp<TestLightRefBaseFlattenable>& a,
303                       sp<TestLightRefBaseFlattenable>* aPlusOne) const override {
304        using Signature = status_t (ISafeInterfaceTest::*)(const sp<TestLightRefBaseFlattenable>&,
305                                                           sp<TestLightRefBaseFlattenable>*) const;
306        return callRemote<Signature>(Tag::IncrementLightRefBaseFlattenable, a, aPlusOne);
307    }
308    status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const override {
309        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
310        using Signature =
311                status_t (ISafeInterfaceTest::*)(const sp<NativeHandle>&, sp<NativeHandle>*) const;
312        return callRemote<Signature>(Tag::IncrementNativeHandle, a, aPlusOne);
313    }
314    status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const override {
315        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
316        using Signature = status_t (ISafeInterfaceTest::*)(const NoCopyNoMove& a,
317                                                           NoCopyNoMove* aPlusOne) const;
318        return callRemote<Signature>(Tag::IncrementNoCopyNoMove, a, aPlusOne);
319    }
320    status_t increment(const std::vector<TestParcelable>& a,
321                       std::vector<TestParcelable>* aPlusOne) const override {
322        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
323        using Signature = status_t (ISafeInterfaceTest::*)(const std::vector<TestParcelable>&,
324                                                           std::vector<TestParcelable>*);
325        return callRemote<Signature>(Tag::IncrementParcelableVector, a, aPlusOne);
326    }
327    status_t toUpper(const String8& str, String8* upperStr) const override {
328        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
329        return callRemote<decltype(&ISafeInterfaceTest::toUpper)>(Tag::ToUpper, str, upperStr);
330    }
331    void callMeBack(const sp<ICallback>& callback, int32_t a) const override {
332        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
333        return callRemoteAsync<decltype(&ISafeInterfaceTest::callMeBack)>(Tag::CallMeBack, callback,
334                                                                          a);
335    }
336    status_t increment(int32_t a, int32_t* aPlusOne) const override {
337        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
338        using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*) const;
339        return callRemote<Signature>(Tag::IncrementInt32, a, aPlusOne);
340    }
341    status_t increment(uint32_t a, uint32_t* aPlusOne) const override {
342        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
343        using Signature = status_t (ISafeInterfaceTest::*)(uint32_t, uint32_t*) const;
344        return callRemote<Signature>(Tag::IncrementUint32, a, aPlusOne);
345    }
346    status_t increment(int64_t a, int64_t* aPlusOne) const override {
347        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
348        using Signature = status_t (ISafeInterfaceTest::*)(int64_t, int64_t*) const;
349        return callRemote<Signature>(Tag::IncrementInt64, a, aPlusOne);
350    }
351    status_t increment(uint64_t a, uint64_t* aPlusOne) const override {
352        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
353        using Signature = status_t (ISafeInterfaceTest::*)(uint64_t, uint64_t*) const;
354        return callRemote<Signature>(Tag::IncrementUint64, a, aPlusOne);
355    }
356    status_t increment(int32_t a, int32_t* aPlusOne, int32_t b, int32_t* bPlusOne) const override {
357        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
358        using Signature =
359                status_t (ISafeInterfaceTest::*)(int32_t, int32_t*, int32_t, int32_t*) const;
360        return callRemote<Signature>(Tag::IncrementTwo, a, aPlusOne, b, bPlusOne);
361    }
362
363private:
364    static constexpr const char* getLogTag() { return "BpSafeInterfaceTest"; }
365};
366
367#pragma clang diagnostic push
368#pragma clang diagnostic ignored "-Wexit-time-destructors"
369IMPLEMENT_META_INTERFACE(SafeInterfaceTest, "android.gfx.tests.ISafeInterfaceTest");
370
371static sp<IBinder::DeathRecipient> getDeathRecipient() {
372    static sp<IBinder::DeathRecipient> recipient = new ExitOnDeath;
373    return recipient;
374}
375#pragma clang diagnostic pop
376
377class BnSafeInterfaceTest : public SafeBnInterface<ISafeInterfaceTest> {
378public:
379    BnSafeInterfaceTest() : SafeBnInterface(getLogTag()) {}
380
381    status_t setDeathToken(const sp<IBinder>& token) override {
382        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
383        token->linkToDeath(getDeathRecipient());
384        return NO_ERROR;
385    }
386    status_t returnsNoMemory() const override {
387        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
388        return NO_MEMORY;
389    }
390    status_t logicalNot(bool a, bool* notA) const override {
391        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
392        *notA = !a;
393        return NO_ERROR;
394    }
395    status_t modifyEnum(TestEnum a, TestEnum* b) const override {
396        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
397        *b = (a == TestEnum::INITIAL) ? TestEnum::FINAL : TestEnum::INVALID;
398        return NO_ERROR;
399    }
400    status_t increment(const TestFlattenable& a, TestFlattenable* aPlusOne) const override {
401        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
402        aPlusOne->value = a.value + 1;
403        return NO_ERROR;
404    }
405    status_t increment(const TestLightFlattenable& a,
406                       TestLightFlattenable* aPlusOne) const override {
407        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
408        aPlusOne->value = a.value + 1;
409        return NO_ERROR;
410    }
411    status_t increment(const sp<TestLightRefBaseFlattenable>& a,
412                       sp<TestLightRefBaseFlattenable>* aPlusOne) const override {
413        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
414        *aPlusOne = new TestLightRefBaseFlattenable(a->value + 1);
415        return NO_ERROR;
416    }
417    status_t increment(const sp<NativeHandle>& a, sp<NativeHandle>* aPlusOne) const override {
418        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
419        native_handle* rawHandle = native_handle_create(1 /*numFds*/, 1 /*numInts*/);
420        if (rawHandle == nullptr) return NO_MEMORY;
421
422        // Copy the fd over directly
423        rawHandle->data[0] = dup(a->handle()->data[0]);
424
425        // Increment the int
426        rawHandle->data[1] = a->handle()->data[1] + 1;
427
428        // This cannot fail, as it is just the sp<NativeHandle> taking responsibility for closing
429        // the native_handle when it goes out of scope
430        *aPlusOne = NativeHandle::create(rawHandle, true);
431        return NO_ERROR;
432    }
433    status_t increment(const NoCopyNoMove& a, NoCopyNoMove* aPlusOne) const override {
434        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
435        aPlusOne->setValue(a.getValue() + 1);
436        return NO_ERROR;
437    }
438    status_t increment(const std::vector<TestParcelable>& a,
439                       std::vector<TestParcelable>* aPlusOne) const override {
440        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
441        aPlusOne->resize(a.size());
442        for (size_t i = 0; i < a.size(); ++i) {
443            (*aPlusOne)[i].setValue(a[i].getValue() + 1);
444        }
445        return NO_ERROR;
446    }
447    status_t toUpper(const String8& str, String8* upperStr) const override {
448        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
449        *upperStr = str;
450        upperStr->toUpper();
451        return NO_ERROR;
452    }
453    void callMeBack(const sp<ICallback>& callback, int32_t a) const override {
454        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
455        callback->onCallback(a + 1);
456    }
457    status_t increment(int32_t a, int32_t* aPlusOne) const override {
458        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
459        *aPlusOne = a + 1;
460        return NO_ERROR;
461    }
462    status_t increment(uint32_t a, uint32_t* aPlusOne) const override {
463        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
464        *aPlusOne = a + 1;
465        return NO_ERROR;
466    }
467    status_t increment(int64_t a, int64_t* aPlusOne) const override {
468        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
469        *aPlusOne = a + 1;
470        return NO_ERROR;
471    }
472    status_t increment(uint64_t a, uint64_t* aPlusOne) const override {
473        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
474        *aPlusOne = a + 1;
475        return NO_ERROR;
476    }
477    status_t increment(int32_t a, int32_t* aPlusOne, int32_t b, int32_t* bPlusOne) const override {
478        ALOG(LOG_INFO, getLogTag(), "%s", __PRETTY_FUNCTION__);
479        *aPlusOne = a + 1;
480        *bPlusOne = b + 1;
481        return NO_ERROR;
482    }
483
484    // BnInterface
485    status_t onTransact(uint32_t code, const Parcel& data, Parcel* reply,
486                        uint32_t /*flags*/) override {
487        EXPECT_GE(code, IBinder::FIRST_CALL_TRANSACTION);
488        EXPECT_LT(code, static_cast<uint32_t>(Tag::Last));
489        ISafeInterfaceTest::Tag tag = static_cast<ISafeInterfaceTest::Tag>(code);
490        switch (tag) {
491            case ISafeInterfaceTest::Tag::SetDeathToken: {
492                return callLocal(data, reply, &ISafeInterfaceTest::setDeathToken);
493            }
494            case ISafeInterfaceTest::Tag::ReturnsNoMemory: {
495                return callLocal(data, reply, &ISafeInterfaceTest::returnsNoMemory);
496            }
497            case ISafeInterfaceTest::Tag::LogicalNot: {
498                return callLocal(data, reply, &ISafeInterfaceTest::logicalNot);
499            }
500            case ISafeInterfaceTest::Tag::ModifyEnum: {
501                return callLocal(data, reply, &ISafeInterfaceTest::modifyEnum);
502            }
503            case ISafeInterfaceTest::Tag::IncrementFlattenable: {
504                using Signature = status_t (ISafeInterfaceTest::*)(const TestFlattenable& a,
505                                                                   TestFlattenable* aPlusOne) const;
506                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
507            }
508            case ISafeInterfaceTest::Tag::IncrementLightFlattenable: {
509                using Signature =
510                        status_t (ISafeInterfaceTest::*)(const TestLightFlattenable& a,
511                                                         TestLightFlattenable* aPlusOne) const;
512                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
513            }
514            case ISafeInterfaceTest::Tag::IncrementLightRefBaseFlattenable: {
515                using Signature =
516                        status_t (ISafeInterfaceTest::*)(const sp<TestLightRefBaseFlattenable>&,
517                                                         sp<TestLightRefBaseFlattenable>*) const;
518                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
519            }
520            case ISafeInterfaceTest::Tag::IncrementNativeHandle: {
521                using Signature = status_t (ISafeInterfaceTest::*)(const sp<NativeHandle>&,
522                                                                   sp<NativeHandle>*) const;
523                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
524            }
525            case ISafeInterfaceTest::Tag::IncrementNoCopyNoMove: {
526                using Signature = status_t (ISafeInterfaceTest::*)(const NoCopyNoMove& a,
527                                                                   NoCopyNoMove* aPlusOne) const;
528                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
529            }
530            case ISafeInterfaceTest::Tag::IncrementParcelableVector: {
531                using Signature =
532                        status_t (ISafeInterfaceTest::*)(const std::vector<TestParcelable>&,
533                                                         std::vector<TestParcelable>*) const;
534                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
535            }
536            case ISafeInterfaceTest::Tag::ToUpper: {
537                return callLocal(data, reply, &ISafeInterfaceTest::toUpper);
538            }
539            case ISafeInterfaceTest::Tag::CallMeBack: {
540                return callLocalAsync(data, reply, &ISafeInterfaceTest::callMeBack);
541            }
542            case ISafeInterfaceTest::Tag::IncrementInt32: {
543                using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*) const;
544                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
545            }
546            case ISafeInterfaceTest::Tag::IncrementUint32: {
547                using Signature = status_t (ISafeInterfaceTest::*)(uint32_t, uint32_t*) const;
548                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
549            }
550            case ISafeInterfaceTest::Tag::IncrementInt64: {
551                using Signature = status_t (ISafeInterfaceTest::*)(int64_t, int64_t*) const;
552                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
553            }
554            case ISafeInterfaceTest::Tag::IncrementUint64: {
555                using Signature = status_t (ISafeInterfaceTest::*)(uint64_t, uint64_t*) const;
556                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
557            }
558            case ISafeInterfaceTest::Tag::IncrementTwo: {
559                using Signature = status_t (ISafeInterfaceTest::*)(int32_t, int32_t*, int32_t,
560                                                                   int32_t*) const;
561                return callLocal<Signature>(data, reply, &ISafeInterfaceTest::increment);
562            }
563            case ISafeInterfaceTest::Tag::Last:
564                // Should not be possible because of the asserts at the beginning of the method
565                [&]() { FAIL(); }();
566                return UNKNOWN_ERROR;
567        }
568    }
569
570private:
571    static constexpr const char* getLogTag() { return "BnSafeInterfaceTest"; }
572};
573
574class SafeInterfaceTest : public ::testing::Test {
575public:
576    SafeInterfaceTest() : mSafeInterfaceTest(getRemoteService()) {
577        ProcessState::self()->startThreadPool();
578    }
579    ~SafeInterfaceTest() override = default;
580
581protected:
582    sp<ISafeInterfaceTest> mSafeInterfaceTest;
583
584private:
585    static constexpr const char* getLogTag() { return "SafeInterfaceTest"; }
586
587    sp<ISafeInterfaceTest> getRemoteService() {
588#pragma clang diagnostic push
589#pragma clang diagnostic ignored "-Wexit-time-destructors"
590        static std::mutex sMutex;
591        static sp<ISafeInterfaceTest> sService;
592        static sp<IBinder> sDeathToken = new BBinder;
593#pragma clang diagnostic pop
594
595        std::unique_lock<decltype(sMutex)> lock;
596        if (sService == nullptr) {
597            ALOG(LOG_INFO, getLogTag(), "Forking remote process");
598            pid_t forkPid = fork();
599            EXPECT_NE(forkPid, -1);
600
601            const String16 serviceName("SafeInterfaceTest");
602
603            if (forkPid == 0) {
604                ALOG(LOG_INFO, getLogTag(), "Remote process checking in");
605                sp<ISafeInterfaceTest> nativeService = new BnSafeInterfaceTest;
606                defaultServiceManager()->addService(serviceName,
607                                                    IInterface::asBinder(nativeService));
608                ProcessState::self()->startThreadPool();
609                IPCThreadState::self()->joinThreadPool();
610                // We shouldn't get to this point
611                [&]() { FAIL(); }();
612            }
613
614            sp<IBinder> binder = defaultServiceManager()->getService(serviceName);
615            sService = interface_cast<ISafeInterfaceTest>(binder);
616            EXPECT_TRUE(sService != nullptr);
617
618            sService->setDeathToken(sDeathToken);
619        }
620
621        return sService;
622    }
623};
624
625TEST_F(SafeInterfaceTest, TestReturnsNoMemory) {
626    status_t result = mSafeInterfaceTest->returnsNoMemory();
627    ASSERT_EQ(NO_MEMORY, result);
628}
629
630TEST_F(SafeInterfaceTest, TestLogicalNot) {
631    const bool a = true;
632    bool notA = true;
633    status_t result = mSafeInterfaceTest->logicalNot(a, &notA);
634    ASSERT_EQ(NO_ERROR, result);
635    ASSERT_EQ(!a, notA);
636    // Test both since we don't want to accidentally catch a default false somewhere
637    const bool b = false;
638    bool notB = false;
639    result = mSafeInterfaceTest->logicalNot(b, &notB);
640    ASSERT_EQ(NO_ERROR, result);
641    ASSERT_EQ(!b, notB);
642}
643
644TEST_F(SafeInterfaceTest, TestModifyEnum) {
645    const TestEnum a = TestEnum::INITIAL;
646    TestEnum b = TestEnum::INVALID;
647    status_t result = mSafeInterfaceTest->modifyEnum(a, &b);
648    ASSERT_EQ(NO_ERROR, result);
649    ASSERT_EQ(TestEnum::FINAL, b);
650}
651
652TEST_F(SafeInterfaceTest, TestIncrementFlattenable) {
653    const TestFlattenable a{1};
654    TestFlattenable aPlusOne{0};
655    status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
656    ASSERT_EQ(NO_ERROR, result);
657    ASSERT_EQ(a.value + 1, aPlusOne.value);
658}
659
660TEST_F(SafeInterfaceTest, TestIncrementLightFlattenable) {
661    const TestLightFlattenable a{1};
662    TestLightFlattenable aPlusOne{0};
663    status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
664    ASSERT_EQ(NO_ERROR, result);
665    ASSERT_EQ(a.value + 1, aPlusOne.value);
666}
667
668TEST_F(SafeInterfaceTest, TestIncrementLightRefBaseFlattenable) {
669    sp<TestLightRefBaseFlattenable> a = new TestLightRefBaseFlattenable{1};
670    sp<TestLightRefBaseFlattenable> aPlusOne;
671    status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
672    ASSERT_EQ(NO_ERROR, result);
673    ASSERT_NE(nullptr, aPlusOne.get());
674    ASSERT_EQ(a->value + 1, aPlusOne->value);
675}
676
677namespace { // Anonymous namespace
678
679bool fdsAreEquivalent(int a, int b) {
680    struct stat statA {};
681    struct stat statB {};
682    if (fstat(a, &statA) != 0) return false;
683    if (fstat(b, &statB) != 0) return false;
684    return (statA.st_dev == statB.st_dev) && (statA.st_ino == statB.st_ino);
685}
686
687} // Anonymous namespace
688
689TEST_F(SafeInterfaceTest, TestIncrementNativeHandle) {
690    // Create an fd we can use to send and receive from the remote process
691    base::unique_fd eventFd{eventfd(0 /*initval*/, 0 /*flags*/)};
692    ASSERT_NE(-1, eventFd);
693
694    // Determine the maximum number of fds this process can have open
695    struct rlimit limit {};
696    ASSERT_EQ(0, getrlimit(RLIMIT_NOFILE, &limit));
697    uint32_t maxFds = static_cast<uint32_t>(limit.rlim_cur);
698
699    // Perform this test enough times to rule out fd leaks
700    for (uint32_t iter = 0; iter < (2 * maxFds); ++iter) {
701        native_handle* handle = native_handle_create(1 /*numFds*/, 1 /*numInts*/);
702        ASSERT_NE(nullptr, handle);
703        handle->data[0] = dup(eventFd.get());
704        handle->data[1] = 1;
705
706        // This cannot fail, as it is just the sp<NativeHandle> taking responsibility for closing
707        // the native_handle when it goes out of scope
708        sp<NativeHandle> a = NativeHandle::create(handle, true);
709
710        sp<NativeHandle> aPlusOne;
711        status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
712        ASSERT_EQ(NO_ERROR, result);
713        ASSERT_TRUE(fdsAreEquivalent(a->handle()->data[0], aPlusOne->handle()->data[0]));
714        ASSERT_EQ(a->handle()->data[1] + 1, aPlusOne->handle()->data[1]);
715    }
716}
717
718TEST_F(SafeInterfaceTest, TestIncrementNoCopyNoMove) {
719    const NoCopyNoMove a{1};
720    NoCopyNoMove aPlusOne{0};
721    status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
722    ASSERT_EQ(NO_ERROR, result);
723    ASSERT_EQ(a.getValue() + 1, aPlusOne.getValue());
724}
725
726TEST_F(SafeInterfaceTest, TestIncremementParcelableVector) {
727    const std::vector<TestParcelable> a{TestParcelable{1}, TestParcelable{2}};
728    std::vector<TestParcelable> aPlusOne;
729    status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
730    ASSERT_EQ(a.size(), aPlusOne.size());
731    for (size_t i = 0; i < a.size(); ++i) {
732        ASSERT_EQ(a[i].getValue() + 1, aPlusOne[i].getValue());
733    }
734}
735
736TEST_F(SafeInterfaceTest, TestToUpper) {
737    const String8 str{"Hello, world!"};
738    String8 upperStr;
739    status_t result = mSafeInterfaceTest->toUpper(str, &upperStr);
740    ASSERT_EQ(NO_ERROR, result);
741    ASSERT_TRUE(upperStr == String8{"HELLO, WORLD!"});
742}
743
744TEST_F(SafeInterfaceTest, TestCallMeBack) {
745    class CallbackReceiver : public BnCallback {
746    public:
747        void onCallback(int32_t aPlusOne) override {
748            ALOG(LOG_INFO, "CallbackReceiver", "%s", __PRETTY_FUNCTION__);
749            std::unique_lock<decltype(mMutex)> lock(mMutex);
750            mValue = aPlusOne;
751            mCondition.notify_one();
752        }
753
754        std::optional<int32_t> waitForCallback() {
755            std::unique_lock<decltype(mMutex)> lock(mMutex);
756            bool success =
757                    mCondition.wait_for(lock, 100ms, [&]() { return static_cast<bool>(mValue); });
758            return success ? mValue : std::nullopt;
759        }
760
761    private:
762        std::mutex mMutex;
763        std::condition_variable mCondition;
764        std::optional<int32_t> mValue;
765    };
766
767    sp<CallbackReceiver> receiver = new CallbackReceiver;
768    const int32_t a = 1;
769    mSafeInterfaceTest->callMeBack(receiver, a);
770    auto result = receiver->waitForCallback();
771    ASSERT_TRUE(result);
772    ASSERT_EQ(a + 1, *result);
773}
774
775TEST_F(SafeInterfaceTest, TestIncrementInt32) {
776    const int32_t a = 1;
777    int32_t aPlusOne = 0;
778    status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
779    ASSERT_EQ(NO_ERROR, result);
780    ASSERT_EQ(a + 1, aPlusOne);
781}
782
783TEST_F(SafeInterfaceTest, TestIncrementUint32) {
784    const uint32_t a = 1;
785    uint32_t aPlusOne = 0;
786    status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
787    ASSERT_EQ(NO_ERROR, result);
788    ASSERT_EQ(a + 1, aPlusOne);
789}
790
791TEST_F(SafeInterfaceTest, TestIncrementInt64) {
792    const int64_t a = 1;
793    int64_t aPlusOne = 0;
794    status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
795    ASSERT_EQ(NO_ERROR, result);
796    ASSERT_EQ(a + 1, aPlusOne);
797}
798
799TEST_F(SafeInterfaceTest, TestIncrementUint64) {
800    const uint64_t a = 1;
801    uint64_t aPlusOne = 0;
802    status_t result = mSafeInterfaceTest->increment(a, &aPlusOne);
803    ASSERT_EQ(NO_ERROR, result);
804    ASSERT_EQ(a + 1, aPlusOne);
805}
806
807TEST_F(SafeInterfaceTest, TestIncrementTwo) {
808    const int32_t a = 1;
809    int32_t aPlusOne = 0;
810    const int32_t b = 2;
811    int32_t bPlusOne = 0;
812    status_t result = mSafeInterfaceTest->increment(1, &aPlusOne, 2, &bPlusOne);
813    ASSERT_EQ(NO_ERROR, result);
814    ASSERT_EQ(a + 1, aPlusOne);
815    ASSERT_EQ(b + 1, bPlusOne);
816}
817
818} // namespace tests
819} // namespace android
820