1//===----------------------------------------------------------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is dual licensed under the MIT and the University of Illinois Open
6// Source Licenses. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#ifndef COUNT_NEW_HPP
11#define COUNT_NEW_HPP
12
13# include <cstdlib>
14# include <cassert>
15# include <new>
16
17#include "test_macros.h"
18
19#if defined(TEST_HAS_SANITIZERS)
20#define DISABLE_NEW_COUNT
21#endif
22
23namespace detail
24{
25   TEST_NORETURN
26   inline void throw_bad_alloc_helper() {
27#ifndef TEST_HAS_NO_EXCEPTIONS
28       throw std::bad_alloc();
29#else
30       std::abort();
31#endif
32   }
33}
34
35class MemCounter
36{
37public:
38    // Make MemCounter super hard to accidentally construct or copy.
39    class MemCounterCtorArg_ {};
40    explicit MemCounter(MemCounterCtorArg_) { reset(); }
41
42private:
43    MemCounter(MemCounter const &);
44    MemCounter & operator=(MemCounter const &);
45
46public:
47    // All checks return true when disable_checking is enabled.
48    static const bool disable_checking;
49
50    // Disallow any allocations from occurring. Useful for testing that
51    // code doesn't perform any allocations.
52    bool disable_allocations;
53
54    // number of allocations to throw after. Default (unsigned)-1. If
55    // throw_after has the default value it will never be decremented.
56    static const unsigned never_throw_value = static_cast<unsigned>(-1);
57    unsigned throw_after;
58
59    int outstanding_new;
60    int new_called;
61    int delete_called;
62    std::size_t last_new_size;
63
64    int outstanding_array_new;
65    int new_array_called;
66    int delete_array_called;
67    std::size_t last_new_array_size;
68
69public:
70    void newCalled(std::size_t s)
71    {
72        assert(disable_allocations == false);
73        assert(s);
74        if (throw_after == 0) {
75            throw_after = never_throw_value;
76            detail::throw_bad_alloc_helper();
77        } else if (throw_after != never_throw_value) {
78            --throw_after;
79        }
80        ++new_called;
81        ++outstanding_new;
82        last_new_size = s;
83    }
84
85    void deleteCalled(void * p)
86    {
87        assert(p);
88        --outstanding_new;
89        ++delete_called;
90    }
91
92    void newArrayCalled(std::size_t s)
93    {
94        assert(disable_allocations == false);
95        assert(s);
96        if (throw_after == 0) {
97            throw_after = never_throw_value;
98            detail::throw_bad_alloc_helper();
99        } else {
100            // don't decrement throw_after here. newCalled will end up doing that.
101        }
102        ++outstanding_array_new;
103        ++new_array_called;
104        last_new_array_size = s;
105    }
106
107    void deleteArrayCalled(void * p)
108    {
109        assert(p);
110        --outstanding_array_new;
111        ++delete_array_called;
112    }
113
114    void disableAllocations()
115    {
116        disable_allocations = true;
117    }
118
119    void enableAllocations()
120    {
121        disable_allocations = false;
122    }
123
124
125    void reset()
126    {
127        disable_allocations = false;
128        throw_after = never_throw_value;
129
130        outstanding_new = 0;
131        new_called = 0;
132        delete_called = 0;
133        last_new_size = 0;
134
135        outstanding_array_new = 0;
136        new_array_called = 0;
137        delete_array_called = 0;
138        last_new_array_size = 0;
139    }
140
141public:
142    bool checkOutstandingNewEq(int n) const
143    {
144        return disable_checking || n == outstanding_new;
145    }
146
147    bool checkOutstandingNewNotEq(int n) const
148    {
149        return disable_checking || n != outstanding_new;
150    }
151
152    bool checkNewCalledEq(int n) const
153    {
154        return disable_checking || n == new_called;
155    }
156
157    bool checkNewCalledNotEq(int n) const
158    {
159        return disable_checking || n != new_called;
160    }
161
162    bool checkNewCalledGreaterThan(int n) const
163    {
164        return disable_checking || new_called > n;
165    }
166
167    bool checkDeleteCalledEq(int n) const
168    {
169        return disable_checking || n == delete_called;
170    }
171
172    bool checkDeleteCalledNotEq(int n) const
173    {
174        return disable_checking || n != delete_called;
175    }
176
177    bool checkLastNewSizeEq(std::size_t n) const
178    {
179        return disable_checking || n == last_new_size;
180    }
181
182    bool checkLastNewSizeNotEq(std::size_t n) const
183    {
184        return disable_checking || n != last_new_size;
185    }
186
187    bool checkOutstandingArrayNewEq(int n) const
188    {
189        return disable_checking || n == outstanding_array_new;
190    }
191
192    bool checkOutstandingArrayNewNotEq(int n) const
193    {
194        return disable_checking || n != outstanding_array_new;
195    }
196
197    bool checkNewArrayCalledEq(int n) const
198    {
199        return disable_checking || n == new_array_called;
200    }
201
202    bool checkNewArrayCalledNotEq(int n) const
203    {
204        return disable_checking || n != new_array_called;
205    }
206
207    bool checkDeleteArrayCalledEq(int n) const
208    {
209        return disable_checking || n == delete_array_called;
210    }
211
212    bool checkDeleteArrayCalledNotEq(int n) const
213    {
214        return disable_checking || n != delete_array_called;
215    }
216
217    bool checkLastNewArraySizeEq(std::size_t n) const
218    {
219        return disable_checking || n == last_new_array_size;
220    }
221
222    bool checkLastNewArraySizeNotEq(std::size_t n) const
223    {
224        return disable_checking || n != last_new_array_size;
225    }
226};
227
228#ifdef DISABLE_NEW_COUNT
229  const bool MemCounter::disable_checking = true;
230#else
231  const bool MemCounter::disable_checking = false;
232#endif
233
234inline MemCounter* getGlobalMemCounter() {
235  static MemCounter counter((MemCounter::MemCounterCtorArg_()));
236  return &counter;
237}
238
239MemCounter &globalMemCounter = *getGlobalMemCounter();
240
241#ifndef DISABLE_NEW_COUNT
242void* operator new(std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
243{
244    getGlobalMemCounter()->newCalled(s);
245    void* ret = std::malloc(s);
246    if (ret == nullptr)
247        detail::throw_bad_alloc_helper();
248    return ret;
249}
250
251void  operator delete(void* p) TEST_NOEXCEPT
252{
253    getGlobalMemCounter()->deleteCalled(p);
254    std::free(p);
255}
256
257
258void* operator new[](std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
259{
260    getGlobalMemCounter()->newArrayCalled(s);
261    return operator new(s);
262}
263
264
265void operator delete[](void* p) TEST_NOEXCEPT
266{
267    getGlobalMemCounter()->deleteArrayCalled(p);
268    operator delete(p);
269}
270
271#endif // DISABLE_NEW_COUNT
272
273
274struct DisableAllocationGuard {
275    explicit DisableAllocationGuard(bool disable = true) : m_disabled(disable)
276    {
277        // Don't re-disable if already disabled.
278        if (globalMemCounter.disable_allocations == true) m_disabled = false;
279        if (m_disabled) globalMemCounter.disableAllocations();
280    }
281
282    void release() {
283        if (m_disabled) globalMemCounter.enableAllocations();
284        m_disabled = false;
285    }
286
287    ~DisableAllocationGuard() {
288        release();
289    }
290
291private:
292    bool m_disabled;
293
294    DisableAllocationGuard(DisableAllocationGuard const&);
295    DisableAllocationGuard& operator=(DisableAllocationGuard const&);
296};
297
298
299struct RequireAllocationGuard {
300    explicit RequireAllocationGuard(std::size_t RequireAtLeast = 1)
301            : m_req_alloc(RequireAtLeast),
302              m_new_count_on_init(globalMemCounter.new_called),
303              m_outstanding_new_on_init(globalMemCounter.outstanding_new),
304              m_exactly(false)
305    {
306    }
307
308    void requireAtLeast(std::size_t N) { m_req_alloc = N; m_exactly = false; }
309    void requireExactly(std::size_t N) { m_req_alloc = N; m_exactly = true; }
310
311    ~RequireAllocationGuard() {
312        assert(globalMemCounter.checkOutstandingNewEq(static_cast<int>(m_outstanding_new_on_init)));
313        std::size_t Expect = m_new_count_on_init + m_req_alloc;
314        assert(globalMemCounter.checkNewCalledEq(static_cast<int>(Expect)) ||
315               (!m_exactly && globalMemCounter.checkNewCalledGreaterThan(static_cast<int>(Expect))));
316    }
317
318private:
319    std::size_t m_req_alloc;
320    const std::size_t m_new_count_on_init;
321    const std::size_t m_outstanding_new_on_init;
322    bool m_exactly;
323    RequireAllocationGuard(RequireAllocationGuard const&);
324    RequireAllocationGuard& operator=(RequireAllocationGuard const&);
325};
326
327#endif /* COUNT_NEW_HPP */
328