1# Copyright 2016, Tresys Technology, LLC
2#
3# This file is part of SETools.
4#
5# SETools is free software: you can redistribute it and/or modify
6# it under the terms of the GNU Lesser General Public License as
7# published by the Free Software Foundation, either version 2.1 of
8# the License, or (at your option) any later version.
9#
10# SETools is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU Lesser General Public License for more details.
14#
15# You should have received a copy of the GNU Lesser General Public
16# License along with SETools.  If not, see
17# <http://www.gnu.org/licenses/>.
18#
19from collections import namedtuple
20
21from .descriptors import DiffResultDescriptor
22from .difference import Difference, SymbolWrapper, Wrapper
23
24
25modified_default_record = namedtuple("modified_default", ["rule",
26                                                          "added_default",
27                                                          "removed_default",
28                                                          "added_default_range",
29                                                          "removed_default_range"])
30
31
32class DefaultsDifference(Difference):
33
34    """Determine the difference in default_* between two policies."""
35
36    added_defaults = DiffResultDescriptor("diff_defaults")
37    removed_defaults = DiffResultDescriptor("diff_defaults")
38    modified_defaults = DiffResultDescriptor("diff_defaults")
39
40    def diff_defaults(self):
41        """Generate the difference in type defaults between the policies."""
42
43        self.log.info(
44            "Generating default_* differences from {0.left_policy} to {0.right_policy}".
45            format(self))
46
47        self.added_defaults, self.removed_defaults, matched_defaults = self._set_diff(
48            (DefaultWrapper(d) for d in self.left_policy.defaults()),
49            (DefaultWrapper(d) for d in self.right_policy.defaults()))
50
51        self.modified_defaults = []
52
53        for left_default, right_default in matched_defaults:
54            # Criteria for modified defaults
55            # 1. change to default setting
56            # 2. change to default range
57
58            if left_default.default != right_default.default:
59                removed_default = left_default.default
60                added_default = right_default.default
61            else:
62                removed_default = None
63                added_default = None
64
65            try:
66                if left_default.default_range != right_default.default_range:
67                    removed_default_range = left_default.default_range
68                    added_default_range = right_default.default_range
69                else:
70                    removed_default_range = None
71                    added_default_range = None
72            except AttributeError:
73                removed_default_range = None
74                added_default_range = None
75
76            if removed_default or removed_default_range:
77                self.modified_defaults.append(
78                    modified_default_record(left_default,
79                                            added_default,
80                                            removed_default,
81                                            added_default_range,
82                                            removed_default_range))
83
84    #
85    # Internal functions
86    #
87    def _reset_diff(self):
88        """Reset diff results on policy changes."""
89        self.log.debug("Resetting default_* differences")
90        self.added_defaults = None
91        self.removed_defaults = None
92        self.modified_defaults = None
93
94
95class DefaultWrapper(Wrapper):
96
97    """Wrap default_* to allow comparisons."""
98
99    def __init__(self, default):
100        self.origin = default
101        self.ruletype = default.ruletype
102        self.tclass = SymbolWrapper(default.tclass)
103        self.key = hash(default)
104
105    def __hash__(self):
106        return self.key
107
108    def __lt__(self, other):
109        return self.key < other.key
110
111    def __eq__(self, other):
112        return self.ruletype == other.ruletype and \
113               self.tclass == other.tclass
114