1# Authors: Karl MacMillan <kmacmillan@mentalrootkit.com>
2#
3# Copyright (C) 2006 Red Hat
4# see file 'COPYING' for use and warranty information
5#
6# This program is free software; you can redistribute it and/or
7# modify it under the terms of the GNU General Public License as
8# published by the Free Software Foundation; version 2 only
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program; if not, write to the Free Software
17# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18#
19
20import unittest
21import sepolgen.refpolicy as refpolicy
22import selinux
23
24class TestIdSet(unittest.TestCase):
25    def test_set_to_str(self):
26        s = refpolicy.IdSet(["read", "write", "getattr"])
27        s = s.to_space_str().split(' ')
28        s.sort()
29        expected = "{ read write getattr }".split(' ')
30        expected.sort()
31        self.assertEqual(s, expected)
32        s = refpolicy.IdSet()
33        s.add("read")
34        self.assertEqual(s.to_space_str(), "read")
35
36class TestSecurityContext(unittest.TestCase):
37    def test_init(self):
38        sc = refpolicy.SecurityContext()
39        sc = refpolicy.SecurityContext("user_u:object_r:foo_t")
40
41    def test_from_string(self):
42        context = "user_u:object_r:foo_t"
43        sc = refpolicy.SecurityContext()
44        sc.from_string(context)
45        self.assertEqual(sc.user, "user_u")
46        self.assertEqual(sc.role, "object_r")
47        self.assertEqual(sc.type, "foo_t")
48        self.assertEqual(sc.level, None)
49        if selinux.is_selinux_mls_enabled():
50            self.assertEqual(str(sc), context + ":s0")
51        else:
52            self.assertEqual(str(sc), context)
53        self.assertEqual(sc.to_string(default_level="s1"), context + ":s1")
54
55        context = "user_u:object_r:foo_t:s0-s0:c0-c255"
56        sc = refpolicy.SecurityContext()
57        sc.from_string(context)
58        self.assertEqual(sc.user, "user_u")
59        self.assertEqual(sc.role, "object_r")
60        self.assertEqual(sc.type, "foo_t")
61        self.assertEqual(sc.level, "s0-s0:c0-c255")
62        self.assertEqual(str(sc), context)
63        self.assertEqual(sc.to_string(), context)
64
65        sc = refpolicy.SecurityContext()
66        self.assertRaises(ValueError, sc.from_string, "abc")
67
68    def test_equal(self):
69        sc1 = refpolicy.SecurityContext("user_u:object_r:foo_t")
70        sc2 = refpolicy.SecurityContext("user_u:object_r:foo_t")
71        sc3 = refpolicy.SecurityContext("user_u:object_r:foo_t:s0")
72        sc4 = refpolicy.SecurityContext("user_u:object_r:bar_t")
73
74        self.assertEqual(sc1, sc2)
75        self.assertNotEqual(sc1, sc3)
76        self.assertNotEqual(sc1, sc4)
77
78class TestObjecClass(unittest.TestCase):
79    def test_init(self):
80        o = refpolicy.ObjectClass(name="file")
81        self.assertEqual(o.name, "file")
82        self.assertTrue(isinstance(o.perms, set))
83
84class TestAVRule(unittest.TestCase):
85    def test_init(self):
86        a = refpolicy.AVRule()
87        self.assertEqual(a.rule_type, a.ALLOW)
88        self.assertTrue(isinstance(a.src_types, set))
89        self.assertTrue(isinstance(a.tgt_types, set))
90        self.assertTrue(isinstance(a.obj_classes, set))
91        self.assertTrue(isinstance(a.perms, set))
92
93    def test_to_string(self):
94        a = refpolicy.AVRule()
95        a.src_types.add("foo_t")
96        a.tgt_types.add("bar_t")
97        a.obj_classes.add("file")
98        a.perms.add("read")
99        self.assertEqual(a.to_string(), "allow foo_t bar_t:file read;")
100
101        a.rule_type = a.DONTAUDIT
102        a.src_types.add("user_t")
103        a.tgt_types.add("user_home_t")
104        a.obj_classes.add("lnk_file")
105        a.perms.add("write")
106        # This test might need to go because set ordering is not guaranteed
107        a = a.to_string().split(' ')
108        a.sort()
109        b = "dontaudit { foo_t user_t } { user_home_t bar_t }:{ lnk_file file } { read write };".split(' ')
110        b.sort()
111        self.assertEqual(a, b)
112
113class TestTypeRule(unittest.TestCase):
114    def test_init(self):
115        a = refpolicy.TypeRule()
116        self.assertEqual(a.rule_type, a.TYPE_TRANSITION)
117        self.assertTrue(isinstance(a.src_types, set))
118        self.assertTrue(isinstance(a.tgt_types, set))
119        self.assertTrue(isinstance(a.obj_classes, set))
120        self.assertEqual(a.dest_type, "")
121
122    def test_to_string(self):
123        a = refpolicy.TypeRule()
124        a.src_types.add("foo_t")
125        a.tgt_types.add("bar_exec_t")
126        a.obj_classes.add("process")
127        a.dest_type = "bar_t"
128        self.assertEqual(a.to_string(), "type_transition foo_t bar_exec_t:process bar_t;")
129
130
131class TestParseNode(unittest.TestCase):
132    def test_walktree(self):
133        # Construct a small tree
134        h = refpolicy.Headers()
135        a = refpolicy.AVRule()
136        a.src_types.add("foo_t")
137        a.tgt_types.add("bar_t")
138        a.obj_classes.add("file")
139        a.perms.add("read")
140
141        ifcall = refpolicy.InterfaceCall(ifname="allow_foobar")
142        ifcall.args.append("foo_t")
143        ifcall.args.append("{ file dir }")
144
145        i = refpolicy.Interface(name="foo")
146        i.children.append(a)
147        i.children.append(ifcall)
148        h.children.append(i)
149
150        a = refpolicy.AVRule()
151        a.rule_type = a.DONTAUDIT
152        a.src_types.add("user_t")
153        a.tgt_types.add("user_home_t")
154        a.obj_classes.add("lnk_file")
155        a.perms.add("write")
156        i = refpolicy.Interface(name="bar")
157        i.children.append(a)
158        h.children.append(i)
159
160class TestHeaders(unittest.TestCase):
161    def test_iter(self):
162        h = refpolicy.Headers()
163        h.children.append(refpolicy.Interface(name="foo"))
164        h.children.append(refpolicy.Interface(name="bar"))
165        h.children.append(refpolicy.ClassMap("file", "read write"))
166        i = 0
167        for node in h:
168            i += 1
169        self.assertEqual(i, 3)
170
171        i = 0
172        for node in h.interfaces():
173            i += 1
174        self.assertEqual(i, 2)
175
176