1# -*- coding: iso-8859-15 -*-
2"""Immutable integer set type.
3
4Integer set class.
5
6Copyright (C) 2006, Heiko Wundram.
7Released under the MIT license.
8"""
9import six
10
11# Version information
12# -------------------
13
14__author__ = "Heiko Wundram <me@modelnine.org>"
15__version__ = "0.2"
16__revision__ = "6"
17__date__ = "2006-01-20"
18
19
20# Utility classes
21# ---------------
22
23class _Infinity(object):
24    """Internal type used to represent infinity values."""
25
26    __slots__ = ["_neg"]
27
28    def __init__(self,neg):
29        self._neg = neg
30
31    def __lt__(self,value):
32        if not isinstance(value, _VALID_TYPES):
33            return NotImplemented
34        return ( self._neg and
35                 not ( isinstance(value,_Infinity) and value._neg ) )
36
37    def __le__(self,value):
38        if not isinstance(value, _VALID_TYPES):
39            return NotImplemented
40        return self._neg
41
42    def __gt__(self,value):
43        if not isinstance(value, _VALID_TYPES):
44            return NotImplemented
45        return not ( self._neg or
46                     ( isinstance(value,_Infinity) and not value._neg ) )
47
48    def __ge__(self,value):
49        if not isinstance(value, _VALID_TYPES):
50            return NotImplemented
51        return not self._neg
52
53    def __eq__(self,value):
54        if not isinstance(value, _VALID_TYPES):
55            return NotImplemented
56        return isinstance(value,_Infinity) and self._neg == value._neg
57
58    def __ne__(self,value):
59        if not isinstance(value, _VALID_TYPES):
60            return NotImplemented
61        return not isinstance(value,_Infinity) or self._neg != value._neg
62
63    def __repr__(self):
64        return "None"
65
66_VALID_TYPES = six.integer_types + (_Infinity,)
67
68
69
70# Constants
71# ---------
72
73_MININF = _Infinity(True)
74_MAXINF = _Infinity(False)
75
76
77# Integer set class
78# -----------------
79
80class IntSet(object):
81    """Integer set class with efficient storage in a RLE format of ranges.
82    Supports minus and plus infinity in the range."""
83
84    __slots__ = ["_ranges","_min","_max","_hash"]
85
86    def __init__(self,*args,**kwargs):
87        """Initialize an integer set. The constructor accepts an unlimited
88        number of arguments that may either be tuples in the form of
89        (start,stop) where either start or stop may be a number or None to
90        represent maximum/minimum in that direction. The range specified by
91        (start,stop) is always inclusive (differing from the builtin range
92        operator).
93
94        Keyword arguments that can be passed to an integer set are min and
95        max, which specify the minimum and maximum number in the set,
96        respectively. You can also pass None here to represent minus or plus
97        infinity, which is also the default.
98        """
99
100        # Special case copy constructor.
101        if len(args) == 1 and isinstance(args[0],IntSet):
102            if kwargs:
103                raise ValueError("No keyword arguments for copy constructor.")
104            self._min = args[0]._min
105            self._max = args[0]._max
106            self._ranges = args[0]._ranges
107            self._hash = args[0]._hash
108            return
109
110        # Initialize set.
111        self._ranges = []
112
113        # Process keyword arguments.
114        self._min = kwargs.pop("min",_MININF)
115        self._max = kwargs.pop("max",_MAXINF)
116        if self._min is None:
117            self._min = _MININF
118        if self._max is None:
119            self._max = _MAXINF
120
121        # Check keyword arguments.
122        if kwargs:
123            raise ValueError("Invalid keyword argument.")
124        if not ( isinstance(self._min, six.integer_types) or self._min is _MININF ):
125            raise TypeError("Invalid type of min argument.")
126        if not ( isinstance(self._max, six.integer_types) or self._max is _MAXINF ):
127            raise TypeError("Invalid type of max argument.")
128        if ( self._min is not _MININF and self._max is not _MAXINF and
129             self._min > self._max ):
130            raise ValueError("Minimum is not smaller than maximum.")
131        if isinstance(self._max, six.integer_types):
132            self._max += 1
133
134        # Process arguments.
135        for arg in args:
136            if isinstance(arg, six.integer_types):
137                start, stop = arg, arg+1
138            elif isinstance(arg,tuple):
139                if len(arg) != 2:
140                    raise ValueError("Invalid tuple, must be (start,stop).")
141
142                # Process argument.
143                start, stop = arg
144                if start is None:
145                    start = self._min
146                if stop is None:
147                    stop = self._max
148
149                # Check arguments.
150                if not ( isinstance(start, six.integer_types) or start is _MININF ):
151                    raise TypeError("Invalid type of tuple start.")
152                if not ( isinstance(stop, six.integer_types) or stop is _MAXINF ):
153                    raise TypeError("Invalid type of tuple stop.")
154                if ( start is not _MININF and stop is not _MAXINF and
155                     start > stop ):
156                    continue
157                if isinstance(stop, six.integer_types):
158                    stop += 1
159            else:
160                raise TypeError("Invalid argument.")
161
162            if start > self._max:
163                continue
164            elif start < self._min:
165                start = self._min
166            if stop < self._min:
167                continue
168            elif stop > self._max:
169                stop = self._max
170            self._ranges.append((start,stop))
171
172        # Normalize set.
173        self._normalize()
174
175    # Utility functions for set operations
176    # ------------------------------------
177
178    def _iterranges(self,r1,r2,minval=_MININF,maxval=_MAXINF):
179        curval = minval
180        curstates = {"r1":False,"r2":False}
181        imax, jmax = 2*len(r1), 2*len(r2)
182        i, j = 0, 0
183        while i < imax or j < jmax:
184            if i < imax and ( ( j < jmax and
185                                r1[i>>1][i&1] < r2[j>>1][j&1] ) or
186                              j == jmax ):
187                cur_r, newname, newstate = r1[i>>1][i&1], "r1", not (i&1)
188                i += 1
189            else:
190                cur_r, newname, newstate = r2[j>>1][j&1], "r2", not (j&1)
191                j += 1
192            if curval < cur_r:
193                if cur_r > maxval:
194                    break
195                yield curstates, (curval,cur_r)
196                curval = cur_r
197            curstates[newname] = newstate
198        if curval < maxval:
199            yield curstates, (curval,maxval)
200
201    def _normalize(self):
202        self._ranges.sort()
203        i = 1
204        while i < len(self._ranges):
205            if self._ranges[i][0] < self._ranges[i-1][1]:
206                self._ranges[i-1] = (self._ranges[i-1][0],
207                                     max(self._ranges[i-1][1],
208                                         self._ranges[i][1]))
209                del self._ranges[i]
210            else:
211                i += 1
212        self._ranges = tuple(self._ranges)
213        self._hash = hash(self._ranges)
214
215    def __coerce__(self,other):
216        if isinstance(other,IntSet):
217            return self, other
218        elif isinstance(other, six.integer_types + (tuple,)):
219            try:
220                return self, self.__class__(other)
221            except TypeError:
222                # Catch a type error, in that case the structure specified by
223                # other is something we can't coerce, return NotImplemented.
224                # ValueErrors are not caught, they signal that the data was
225                # invalid for the constructor. This is appropriate to signal
226                # as a ValueError to the caller.
227                return NotImplemented
228        elif isinstance(other,list):
229            try:
230                return self, self.__class__(*other)
231            except TypeError:
232                # See above.
233                return NotImplemented
234        return NotImplemented
235
236    # Set function definitions
237    # ------------------------
238
239    def _make_function(name,type,doc,pall,pany=None):
240        """Makes a function to match two ranges. Accepts two types: either
241        'set', which defines a function which returns a set with all ranges
242        matching pall (pany is ignored), or 'bool', which returns True if pall
243        matches for all ranges and pany matches for any one range. doc is the
244        dostring to give this function. pany may be none to ignore the any
245        match.
246
247        The predicates get a dict with two keys, 'r1', 'r2', which denote
248        whether the current range is present in range1 (self) and/or range2
249        (other) or none of the two, respectively."""
250
251        if type == "set":
252            def f(self,other):
253                coerced = self.__coerce__(other)
254                if coerced is NotImplemented:
255                    return NotImplemented
256                other = coerced[1]
257                newset = self.__class__.__new__(self.__class__)
258                newset._min = min(self._min,other._min)
259                newset._max = max(self._max,other._max)
260                newset._ranges = []
261                for states, (start,stop) in \
262                        self._iterranges(self._ranges,other._ranges,
263                                         newset._min,newset._max):
264                    if pall(states):
265                        if newset._ranges and newset._ranges[-1][1] == start:
266                            newset._ranges[-1] = (newset._ranges[-1][0],stop)
267                        else:
268                            newset._ranges.append((start,stop))
269                newset._ranges = tuple(newset._ranges)
270                newset._hash = hash(self._ranges)
271                return newset
272        elif type == "bool":
273            def f(self,other):
274                coerced = self.__coerce__(other)
275                if coerced is NotImplemented:
276                    return NotImplemented
277                other = coerced[1]
278                _min = min(self._min,other._min)
279                _max = max(self._max,other._max)
280                found = not pany
281                for states, (start,stop) in \
282                        self._iterranges(self._ranges,other._ranges,_min,_max):
283                    if not pall(states):
284                        return False
285                    found = found or pany(states)
286                return found
287        else:
288            raise ValueError("Invalid type of function to create.")
289        try:
290            f.func_name = name
291        except TypeError:
292            pass
293        f.func_doc = doc
294        return f
295
296    # Intersection.
297    __and__ = _make_function("__and__","set",
298                             "Intersection of two sets as a new set.",
299                             lambda s: s["r1"] and s["r2"])
300    __rand__ = _make_function("__rand__","set",
301                              "Intersection of two sets as a new set.",
302                              lambda s: s["r1"] and s["r2"])
303    intersection = _make_function("intersection","set",
304                                  "Intersection of two sets as a new set.",
305                                  lambda s: s["r1"] and s["r2"])
306
307    # Union.
308    __or__ = _make_function("__or__","set",
309                            "Union of two sets as a new set.",
310                            lambda s: s["r1"] or s["r2"])
311    __ror__ = _make_function("__ror__","set",
312                             "Union of two sets as a new set.",
313                             lambda s: s["r1"] or s["r2"])
314    union = _make_function("union","set",
315                           "Union of two sets as a new set.",
316                           lambda s: s["r1"] or s["r2"])
317
318    # Difference.
319    __sub__ = _make_function("__sub__","set",
320                             "Difference of two sets as a new set.",
321                             lambda s: s["r1"] and not s["r2"])
322    __rsub__ = _make_function("__rsub__","set",
323                              "Difference of two sets as a new set.",
324                              lambda s: s["r2"] and not s["r1"])
325    difference = _make_function("difference","set",
326                                "Difference of two sets as a new set.",
327                                lambda s: s["r1"] and not s["r2"])
328
329    # Symmetric difference.
330    __xor__ = _make_function("__xor__","set",
331                             "Symmetric difference of two sets as a new set.",
332                             lambda s: s["r1"] ^ s["r2"])
333    __rxor__ = _make_function("__rxor__","set",
334                              "Symmetric difference of two sets as a new set.",
335                              lambda s: s["r1"] ^ s["r2"])
336    symmetric_difference = _make_function("symmetric_difference","set",
337                                          "Symmetric difference of two sets as a new set.",
338                                          lambda s: s["r1"] ^ s["r2"])
339
340    # Containership testing.
341    __contains__ = _make_function("__contains__","bool",
342                                  "Returns true if self is superset of other.",
343                                  lambda s: s["r1"] or not s["r2"])
344    issubset = _make_function("issubset","bool",
345                              "Returns true if self is subset of other.",
346                              lambda s: s["r2"] or not s["r1"])
347    istruesubset = _make_function("istruesubset","bool",
348                                  "Returns true if self is true subset of other.",
349                                  lambda s: s["r2"] or not s["r1"],
350                                  lambda s: s["r2"] and not s["r1"])
351    issuperset = _make_function("issuperset","bool",
352                                "Returns true if self is superset of other.",
353                                lambda s: s["r1"] or not s["r2"])
354    istruesuperset = _make_function("istruesuperset","bool",
355                                    "Returns true if self is true superset of other.",
356                                    lambda s: s["r1"] or not s["r2"],
357                                    lambda s: s["r1"] and not s["r2"])
358    overlaps = _make_function("overlaps","bool",
359                              "Returns true if self overlaps with other.",
360                              lambda s: True,
361                              lambda s: s["r1"] and s["r2"])
362
363    # Comparison.
364    __eq__ = _make_function("__eq__","bool",
365                            "Returns true if self is equal to other.",
366                            lambda s: not ( s["r1"] ^ s["r2"] ))
367    __ne__ = _make_function("__ne__","bool",
368                            "Returns true if self is different to other.",
369                            lambda s: True,
370                            lambda s: s["r1"] ^ s["r2"])
371
372    # Clean up namespace.
373    del _make_function
374
375    # Define other functions.
376    def inverse(self):
377        """Inverse of set as a new set."""
378
379        newset = self.__class__.__new__(self.__class__)
380        newset._min = self._min
381        newset._max = self._max
382        newset._ranges = []
383        laststop = self._min
384        for r in self._ranges:
385            if laststop < r[0]:
386                newset._ranges.append((laststop,r[0]))
387                laststop = r[1]
388        if laststop < self._max:
389            newset._ranges.append((laststop,self._max))
390        return newset
391
392    __invert__ = inverse
393
394    # Hashing
395    # -------
396
397    def __hash__(self):
398        """Returns a hash value representing this integer set. As the set is
399        always stored normalized, the hash value is guaranteed to match for
400        matching ranges."""
401
402        return self._hash
403
404    # Iterating
405    # ---------
406
407    def __len__(self):
408        """Get length of this integer set. In case the length is larger than
409        2**31 (including infinitely sized integer sets), it raises an
410        OverflowError. This is due to len() restricting the size to
411        0 <= len < 2**31."""
412
413        if not self._ranges:
414            return 0
415        if self._ranges[0][0] is _MININF or self._ranges[-1][1] is _MAXINF:
416            raise OverflowError("Infinitely sized integer set.")
417        rlen = 0
418        for r in self._ranges:
419            rlen += r[1]-r[0]
420        if rlen >= 2**31:
421            raise OverflowError("Integer set bigger than 2**31.")
422        return rlen
423
424    def len(self):
425        """Returns the length of this integer set as an integer. In case the
426        length is infinite, returns -1. This function exists because of a
427        limitation of the builtin len() function which expects values in
428        the range 0 <= len < 2**31. Use this function in case your integer
429        set might be larger."""
430
431        if not self._ranges:
432            return 0
433        if self._ranges[0][0] is _MININF or self._ranges[-1][1] is _MAXINF:
434            return -1
435        rlen = 0
436        for r in self._ranges:
437            rlen += r[1]-r[0]
438        return rlen
439
440    def __nonzero__(self):
441        """Returns true if this integer set contains at least one item."""
442
443        return bool(self._ranges)
444
445    def __iter__(self):
446        """Iterate over all values in this integer set. Iteration always starts
447        by iterating from lowest to highest over the ranges that are bounded.
448        After processing these, all ranges that are unbounded (maximum 2) are
449        yielded intermixed."""
450
451        ubranges = []
452        for r in self._ranges:
453            if r[0] is _MININF:
454                if r[1] is _MAXINF:
455                    ubranges.extend(([0,1],[-1,-1]))
456                else:
457                    ubranges.append([r[1]-1,-1])
458            elif r[1] is _MAXINF:
459                ubranges.append([r[0],1])
460            else:
461                for val in xrange(r[0],r[1]):
462                    yield val
463        if ubranges:
464            while True:
465                for ubrange in ubranges:
466                    yield ubrange[0]
467                    ubrange[0] += ubrange[1]
468
469    # Printing
470    # --------
471
472    def __repr__(self):
473        """Return a representation of this integer set. The representation is
474        executable to get an equal integer set."""
475
476        rv = []
477        for start, stop in self._ranges:
478            if ( isinstance(start, six.integer_types) and isinstance(stop, six.integer_types)
479                 and stop-start == 1 ):
480                rv.append("%r" % start)
481            elif isinstance(stop, six.integer_types):
482                rv.append("(%r,%r)" % (start,stop-1))
483            else:
484                rv.append("(%r,%r)" % (start,stop))
485        if self._min is not _MININF:
486            rv.append("min=%r" % self._min)
487        if self._max is not _MAXINF:
488            rv.append("max=%r" % self._max)
489        return "%s(%s)" % (self.__class__.__name__,",".join(rv))
490
491if __name__ == "__main__":
492    # Little test script demonstrating functionality.
493    x = IntSet((10,20),30)
494    y = IntSet((10,20))
495    z = IntSet((10,20),30,(15,19),min=0,max=40)
496    print(x)
497    print(x&110)
498    print(x|110)
499    print(x^(15,25))
500    print(x-12)
501    print(12 in x)
502    print(x.issubset(x))
503    print(y.issubset(x))
504    print(x.istruesubset(x))
505    print(y.istruesubset(x))
506    for val in x:
507        print(val)
508    print(x.inverse())
509    print(x == z)
510    print(x == y)
511    print(x != y)
512    print(hash(x))
513    print(hash(z))
514    print(len(x))
515    print(x.len())
516