1# Authors: Karl MacMillan <kmacmillan@mentalrootkit.com>
2#
3# Copyright (C) 2006 Red Hat
4# see file 'COPYING' for use and warranty information
5#
6# This program is free software; you can redistribute it and/or
7# modify it under the terms of the GNU General Public License as
8# published by the Free Software Foundation; version 2 only
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program; if not, write to the Free Software
17# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
18#
19
20import unittest
21import sepolgen.access as access
22import sepolgen.interfaces as interfaces
23import sepolgen.policygen as policygen
24import sepolgen.refparser as refparser
25import sepolgen.refpolicy as refpolicy
26
27class TestParam(unittest.TestCase):
28    def test(self):
29        p = interfaces.Param()
30        p.name = "$1"
31        self.assertEqual(p.name, "$1")
32        self.assertRaises(ValueError, p.set_name, "$N")
33        self.assertEqual(p.num, 1)
34        self.assertEqual(p.type, refpolicy.SRC_TYPE)
35
36class TestAVExtractPerms(unittest.TestCase):
37    def test(self):
38        av = access.AccessVector(['foo', 'bar', 'file', 'read'])
39        params = { }
40        ret = interfaces.av_extract_params(av, params)
41        self.assertEqual(ret, 0)
42        self.assertEqual(params, { })
43
44        av.src_type = "$1"
45        ret = interfaces.av_extract_params(av, params)
46        self.assertEqual(ret, 0)
47        p = params["$1"]
48        self.assertEqual(p.name, "$1")
49        self.assertEqual(p.type, refpolicy.SRC_TYPE)
50        self.assertEqual(p.obj_classes, refpolicy.IdSet(["file"]))
51
52        params = { }
53        av.tgt_type = "$1"
54        av.obj_class = "process"
55        ret = interfaces.av_extract_params(av, params)
56        self.assertEqual(ret, 0)
57        p = params["$1"]
58        self.assertEqual(p.name, "$1")
59        self.assertEqual(p.type, refpolicy.SRC_TYPE)
60        self.assertEqual(p.obj_classes, refpolicy.IdSet(["process"]))
61
62        params = { }
63        av.tgt_type = "$1"
64        av.obj_class = "dir"
65        ret = interfaces.av_extract_params(av, params)
66        self.assertEqual(ret, 1)
67        p = params["$1"]
68        self.assertEqual(p.name, "$1")
69        self.assertEqual(p.type, refpolicy.SRC_TYPE)
70        self.assertEqual(p.obj_classes, refpolicy.IdSet(["dir"]))
71
72        av.src_type = "bar"
73        av.tgt_type = "$2"
74        av.obj_class = "dir"
75        ret = interfaces.av_extract_params(av, params)
76        self.assertEqual(ret, 0)
77        p = params["$2"]
78        self.assertEqual(p.name, "$2")
79        self.assertEqual(p.type, refpolicy.TGT_TYPE)
80        self.assertEqual(p.obj_classes, refpolicy.IdSet(["dir"]))
81
82interface_example = """
83interface(`files_search_usr',`
84	gen_require(`
85		type usr_t;
86	')
87
88	allow $1 usr_t:dir search;
89        allow { domain $1 } { usr_t usr_home_t }:{ file dir } { read write getattr };
90        typeattribute $1 file_type;
91
92        if (foo) {
93           allow $1 foo:bar baz;
94        }
95
96        if (bar) {
97           allow $1 foo:bar baz;
98        } else {
99           allow $1 foo:bar baz;
100        }
101')
102
103interface(`files_list_usr',`
104	gen_require(`
105		type usr_t;
106	')
107
108	allow $1 usr_t:dir { read getattr };
109
110        optional_policy(`
111            search_usr($1)
112        ')
113
114        tunable_policy(`foo',`
115            whatever($1)
116        ')
117
118')
119
120interface(`files_exec_usr_files',`
121	gen_require(`
122		type usr_t;
123	')
124
125	allow $1 usr_t:dir read;
126	allow $1 usr_t:lnk_file { read getattr };
127	can_exec($1,usr_t)
128        can_foo($1)
129
130')
131"""
132
133simple_interface = """
134interface(`foo',`
135   gen_require(`
136       type usr_t;
137   ')
138   allow $1 usr_t:dir { create add_name };
139   allow $1 usr_t:file { read write };
140')
141"""
142
143test_expansion = """
144interface(`foo',`
145   gen_require(`
146       type usr_t;
147   ')
148   allow $1 usr_t:dir { create add_name };
149   allow $1 usr_t:file { read write };
150')
151
152interface(`map', `
153   gen_require(`
154       type bar_t;
155   ')
156   allow $1 bar_t:file read;
157   allow $2 bar_t:file write;
158
159   foo($2)
160')
161
162interface(`hard_map', `
163   gen_require(`
164      type baz_t;
165   ')
166   allow $1 baz_t:file getattr;
167   allow $2 baz_t:file read;
168   allow $3 baz_t:file write;
169
170   map($1, $2)
171   map($2, $3)
172
173   # This should have no effect
174   foo($2)
175')
176"""
177
178def compare_avsets(l, avs_b):
179    avs_a = access.AccessVectorSet()
180    avs_a.from_list(l)
181
182    a = list(avs_a)
183    b = list(avs_b)
184
185    a.sort()
186    b.sort()
187
188    if len(a) != len(b):
189        return False
190
191
192    for av_a, av_b in zip(a, b):
193        if av_a != av_b:
194            return False
195
196    return True
197
198
199class TestInterfaceSet(unittest.TestCase):
200    def test_simple(self):
201        h = refparser.parse(simple_interface)
202        i = interfaces.InterfaceSet()
203        i.add_headers(h)
204
205        self.assertEquals(len(i.interfaces), 1)
206        for key, interface in i.interfaces.items():
207            self.assertEquals(key, interface.name)
208            self.assertEquals(key, "foo")
209            self.assertEquals(len(interface.access), 2)
210
211            # Check the access vectors
212            comp_avs = [["$1", "usr_t", "dir", "create", "add_name"],
213                        ["$1", "usr_t", "file", "read", "write"]]
214            ret = compare_avsets(comp_avs, interface.access)
215            self.assertTrue(ret)
216
217            # Check the params
218            self.assertEquals(len(interface.params), 1)
219            for param in interface.params.values():
220                self.assertEquals(param.type, refpolicy.SRC_TYPE)
221                self.assertEquals(param.name, "$1")
222                self.assertEquals(param.num, 1)
223                self.assertEquals(param.required, True)
224
225    def test_expansion(self):
226        h = refparser.parse(test_expansion)
227        i = interfaces.InterfaceSet()
228        i.add_headers(h)
229
230        self.assertEquals(len(i.interfaces), 3)
231        for key, interface in i.interfaces.items():
232            self.assertEquals(key, interface.name)
233            if key == "foo":
234                comp_avs = [["$1", "usr_t", "dir", "create", "add_name"],
235                            ["$1", "usr_t", "file", "read", "write"]]
236                self.assertTrue(compare_avsets(comp_avs, interface.access))
237            elif key == "map":
238                comp_avs = [["$2", "usr_t", "dir", "create", "add_name"],
239                            ["$2", "usr_t", "file", "read", "write"],
240                            ["$1", "bar_t", "file", "read"],
241                            ["$2", "bar_t", "file", "write"]]
242                self.assertTrue(compare_avsets(comp_avs, interface.access))
243            elif key == "hard_map":
244                comp_avs = [["$1", "baz_t", "file", "getattr"],
245                            ["$2", "baz_t", "file", "read"],
246                            ["$3", "baz_t", "file", "write"],
247
248                            ["$2", "usr_t", "dir", "create", "add_name"],
249                            ["$2", "usr_t", "file", "read", "write"],
250                            ["$1", "bar_t", "file", "read"],
251                            ["$2", "bar_t", "file", "write"],
252
253                            ["$3", "usr_t", "dir", "create", "add_name"],
254                            ["$3", "usr_t", "file", "read", "write"],
255                            ["$2", "bar_t", "file", "read"],
256                            ["$3", "bar_t", "file", "write"]]
257                self.assertTrue(compare_avsets(comp_avs, interface.access))
258
259
260    def test_export(self):
261        h = refparser.parse(interface_example)
262        i = interfaces.InterfaceSet()
263        i.add_headers(h)
264        f = open("output", "w")
265        i.to_file(f)
266        f.close()
267
268        i2 = interfaces.InterfaceSet()
269        f = open("output")
270        i2.from_file(f)
271        if_status = [False, False, False]
272        for ifv in i2.interfaces.values():
273            if ifv.name == "files_search_usr":
274                if_status[0] = True
275            if ifv.name == "files_list_usr":
276                if_status[1] = True
277            if ifv.name == "files_exec_usr_files":
278                if_status[2] = True
279
280        self.assertEquals(if_status[0], True)
281        self.assertEquals(if_status[1], True)
282        self.assertEquals(if_status[2], True)
283