1#!/usr/bin/python2.4
2# Copyright (c) 2010 The Chromium Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6"""Tests exercising the various classes in xmppserver.py."""
7
8import unittest
9
10import base64
11import xmppserver
12
13class XmlUtilsTest(unittest.TestCase):
14
15  def testParseXml(self):
16    xml_text = """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>"""
17    xml = xmppserver.ParseXml(xml_text)
18    self.assertEqual(xml.toxml(), xml_text)
19
20  def testCloneXml(self):
21    xml = xmppserver.ParseXml('<foo/>')
22    xml_clone = xmppserver.CloneXml(xml)
23    xml_clone.setAttribute('bar', 'baz')
24    self.assertEqual(xml, xml)
25    self.assertEqual(xml_clone, xml_clone)
26    self.assertNotEqual(xml, xml_clone)
27
28  def testCloneXmlUnlink(self):
29    xml_text = '<foo/>'
30    xml = xmppserver.ParseXml(xml_text)
31    xml_clone = xmppserver.CloneXml(xml)
32    xml.unlink()
33    self.assertEqual(xml.parentNode, None)
34    self.assertNotEqual(xml_clone.parentNode, None)
35    self.assertEqual(xml_clone.toxml(), xml_text)
36
37class StanzaParserTest(unittest.TestCase):
38
39  def setUp(self):
40    self.stanzas = []
41
42  def FeedStanza(self, stanza):
43    # We can't append stanza directly because it is unlinked after
44    # this callback.
45    self.stanzas.append(stanza.toxml())
46
47  def testBasic(self):
48    parser = xmppserver.StanzaParser(self)
49    parser.FeedString('<foo')
50    self.assertEqual(len(self.stanzas), 0)
51    parser.FeedString('/><bar></bar>')
52    self.assertEqual(self.stanzas[0], '<foo/>')
53    self.assertEqual(self.stanzas[1], '<bar/>')
54
55  def testStream(self):
56    parser = xmppserver.StanzaParser(self)
57    parser.FeedString('<stream')
58    self.assertEqual(len(self.stanzas), 0)
59    parser.FeedString(':stream foo="bar" xmlns:stream="baz">')
60    self.assertEqual(self.stanzas[0],
61                     '<stream:stream foo="bar" xmlns:stream="baz"/>')
62
63  def testNested(self):
64    parser = xmppserver.StanzaParser(self)
65    parser.FeedString('<foo')
66    self.assertEqual(len(self.stanzas), 0)
67    parser.FeedString(' bar="baz"')
68    parser.FeedString('><baz/><blah>meh</blah></foo>')
69    self.assertEqual(self.stanzas[0],
70                     '<foo bar="baz"><baz/><blah>meh</blah></foo>')
71
72
73class JidTest(unittest.TestCase):
74
75  def testBasic(self):
76    jid = xmppserver.Jid('foo', 'bar.com')
77    self.assertEqual(str(jid), 'foo@bar.com')
78
79  def testResource(self):
80    jid = xmppserver.Jid('foo', 'bar.com', 'resource')
81    self.assertEqual(str(jid), 'foo@bar.com/resource')
82
83  def testGetBareJid(self):
84    jid = xmppserver.Jid('foo', 'bar.com', 'resource')
85    self.assertEqual(str(jid.GetBareJid()), 'foo@bar.com')
86
87
88class IdGeneratorTest(unittest.TestCase):
89
90  def testBasic(self):
91    id_generator = xmppserver.IdGenerator('foo')
92    for i in xrange(0, 100):
93      self.assertEqual('foo.%d' % i, id_generator.GetNextId())
94
95
96class HandshakeTaskTest(unittest.TestCase):
97
98  def setUp(self):
99    self.data_received = 0
100
101  def SendData(self, _):
102    self.data_received += 1
103
104  def SendStanza(self, _, unused=True):
105    self.data_received += 1
106
107  def HandshakeDone(self, jid):
108    self.jid = jid
109
110  def DoHandshake(self, resource_prefix, resource, username,
111                  initial_stream_domain, auth_domain, auth_stream_domain):
112    self.data_received = 0
113    handshake_task = (
114      xmppserver.HandshakeTask(self, resource_prefix))
115    stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
116    stream_xml.setAttribute('to', initial_stream_domain)
117    self.assertEqual(self.data_received, 0)
118    handshake_task.FeedStanza(stream_xml)
119    self.assertEqual(self.data_received, 2)
120
121    if auth_domain:
122      username_domain = '%s@%s' % (username, auth_domain)
123    else:
124      username_domain = username
125    auth_string = base64.b64encode('\0%s\0bar' % username_domain)
126    auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
127    handshake_task.FeedStanza(auth_xml)
128    self.assertEqual(self.data_received, 3)
129
130    stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
131    stream_xml.setAttribute('to', auth_stream_domain)
132    handshake_task.FeedStanza(stream_xml)
133    self.assertEqual(self.data_received, 5)
134
135    bind_xml = xmppserver.ParseXml(
136      '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource)
137    handshake_task.FeedStanza(bind_xml)
138    self.assertEqual(self.data_received, 6)
139
140    session_xml = xmppserver.ParseXml(
141      '<iq type="set"><session></session></iq>')
142    handshake_task.FeedStanza(session_xml)
143    self.assertEqual(self.data_received, 7)
144
145    self.assertEqual(self.jid.username, username)
146    self.assertEqual(self.jid.domain,
147                     auth_stream_domain or auth_domain or
148                     initial_stream_domain)
149    self.assertEqual(self.jid.resource,
150                     '%s.%s' % (resource_prefix, resource))
151
152  def testBasic(self):
153    self.DoHandshake('resource_prefix', 'resource',
154                     'foo', 'bar.com', 'baz.com', 'quux.com')
155
156  def testDomainBehavior(self):
157    self.DoHandshake('resource_prefix', 'resource',
158                     'foo', 'bar.com', 'baz.com', 'quux.com')
159    self.DoHandshake('resource_prefix', 'resource',
160                     'foo', 'bar.com', 'baz.com', '')
161    self.DoHandshake('resource_prefix', 'resource',
162                     'foo', 'bar.com', '', '')
163    self.DoHandshake('resource_prefix', 'resource',
164                     'foo', '', '', '')
165
166
167class XmppConnectionTest(unittest.TestCase):
168
169  def setUp(self):
170    self.connections = set()
171    self.data = []
172
173  # socket-like methods.
174  def fileno(self):
175    return 0
176
177  def setblocking(self, int):
178    pass
179
180  def getpeername(self):
181    return ('', 0)
182
183  def send(self, data):
184    self.data.append(data)
185    pass
186
187  def close(self):
188    pass
189
190  # XmppConnection delegate methods.
191  def OnXmppHandshakeDone(self, xmpp_connection):
192    self.connections.add(xmpp_connection)
193
194  def OnXmppConnectionClosed(self, xmpp_connection):
195    self.connections.discard(xmpp_connection)
196
197  def ForwardNotification(self, unused_xmpp_connection, notification_stanza):
198    for connection in self.connections:
199      connection.ForwardNotification(notification_stanza)
200
201  def testBasic(self):
202    socket_map = {}
203    xmpp_connection = xmppserver.XmppConnection(
204      self, socket_map, self, ('', 0))
205    self.assertEqual(len(socket_map), 1)
206    self.assertEqual(len(self.connections), 0)
207    xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar'))
208    self.assertEqual(len(socket_map), 1)
209    self.assertEqual(len(self.connections), 1)
210
211    # Test subscription request.
212    self.assertEqual(len(self.data), 0)
213    xmpp_connection.collect_incoming_data(
214      '<iq><subscribe xmlns="google:push"></subscribe></iq>')
215    self.assertEqual(len(self.data), 1)
216
217    # Test acks.
218    xmpp_connection.collect_incoming_data('<iq type="result"/>')
219    self.assertEqual(len(self.data), 1)
220
221    # Test notification.
222    xmpp_connection.collect_incoming_data(
223      '<message><push xmlns="google:push"/></message>')
224    self.assertEqual(len(self.data), 2)
225
226    # Test unexpected stanza.
227    def SendUnexpectedStanza():
228      xmpp_connection.collect_incoming_data('<foo/>')
229    self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)
230
231    # Test unexpected notifier command.
232    def SendUnexpectedNotifierCommand():
233      xmpp_connection.collect_incoming_data(
234        '<iq><foo xmlns="google:notifier"/></iq>')
235    self.assertRaises(xmppserver.UnexpectedXml,
236                      SendUnexpectedNotifierCommand)
237
238    # Test close
239    xmpp_connection.close()
240    self.assertEqual(len(socket_map), 0)
241    self.assertEqual(len(self.connections), 0)
242
243class XmppServerTest(unittest.TestCase):
244
245  # socket-like methods.
246  def fileno(self):
247    return 0
248
249  def setblocking(self, int):
250    pass
251
252  def getpeername(self):
253    return ('', 0)
254
255  def close(self):
256    pass
257
258  def testBasic(self):
259    class FakeXmppServer(xmppserver.XmppServer):
260      def accept(self2):
261        return (self, ('', 0))
262
263    socket_map = {}
264    self.assertEqual(len(socket_map), 0)
265    xmpp_server = FakeXmppServer(socket_map, ('', 0))
266    self.assertEqual(len(socket_map), 1)
267    xmpp_server.handle_accept()
268    self.assertEqual(len(socket_map), 2)
269    xmpp_server.close()
270    self.assertEqual(len(socket_map), 0)
271
272
273if __name__ == '__main__':
274  unittest.main()
275