efficient vim config
[dotfiles/.git] / .local / lib / python2.7 / site-packages / trollius / test_utils.py
1 """Utilities shared by tests."""
2
3 import collections
4 import contextlib
5 import io
6 import logging
7 import os
8 import re
9 import socket
10 import sys
11 import tempfile
12 import threading
13 import time
14
15 from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
16
17 import six
18
19 try:
20     import socketserver
21     from http.server import HTTPServer
22 except ImportError:
23     # Python 2
24     import SocketServer as socketserver
25     from BaseHTTPServer import HTTPServer
26
27 try:
28     from unittest import mock
29 except ImportError:
30     # Python < 3.3
31     import mock
32
33 try:
34     import ssl
35     from .py3_ssl import SSLContext, wrap_socket
36 except ImportError:  # pragma: no cover
37     # SSL support disabled in Python
38     ssl = None
39
40 from . import base_events
41 from . import compat
42 from . import events
43 from . import futures
44 from . import selectors
45 from . import tasks
46 from .coroutines import coroutine
47 from .log import logger
48
49
50 if sys.platform == 'win32':  # pragma: no cover
51     from .windows_utils import socketpair
52 else:
53     from socket import socketpair  # pragma: no cover
54
55 try:
56     # Prefer unittest2 if available (on Python 2)
57     import unittest2 as unittest
58 except ImportError:
59     import unittest
60
61 skipIf = unittest.skipIf
62 skipUnless = unittest.skipUnless
63 SkipTest = unittest.SkipTest
64
65
66 if not hasattr(unittest.TestCase, 'assertRaisesRegex'):
67     class _BaseTestCaseContext:
68
69         def __init__(self, test_case):
70             self.test_case = test_case
71
72         def _raiseFailure(self, standardMsg):
73             msg = self.test_case._formatMessage(self.msg, standardMsg)
74             raise self.test_case.failureException(msg)
75
76
77     class _AssertRaisesBaseContext(_BaseTestCaseContext):
78
79         def __init__(self, expected, test_case, callable_obj=None,
80                      expected_regex=None):
81             _BaseTestCaseContext.__init__(self, test_case)
82             self.expected = expected
83             self.test_case = test_case
84             if callable_obj is not None:
85                 try:
86                     self.obj_name = callable_obj.__name__
87                 except AttributeError:
88                     self.obj_name = str(callable_obj)
89             else:
90                 self.obj_name = None
91             if isinstance(expected_regex, (bytes, str)):
92                 expected_regex = re.compile(expected_regex)
93             self.expected_regex = expected_regex
94             self.msg = None
95
96         def handle(self, name, callable_obj, args, kwargs):
97             """
98             If callable_obj is None, assertRaises/Warns is being used as a
99             context manager, so check for a 'msg' kwarg and return self.
100             If callable_obj is not None, call it passing args and kwargs.
101             """
102             if callable_obj is None:
103                 self.msg = kwargs.pop('msg', None)
104                 return self
105             with self:
106                 callable_obj(*args, **kwargs)
107
108
109     class _AssertRaisesContext(_AssertRaisesBaseContext):
110         """A context manager used to implement TestCase.assertRaises* methods."""
111
112         def __enter__(self):
113             return self
114
115         def __exit__(self, exc_type, exc_value, tb):
116             if exc_type is None:
117                 try:
118                     exc_name = self.expected.__name__
119                 except AttributeError:
120                     exc_name = str(self.expected)
121                 if self.obj_name:
122                     self._raiseFailure("{0} not raised by {1}".format(exc_name,
123                                                                     self.obj_name))
124                 else:
125                     self._raiseFailure("{0} not raised".format(exc_name))
126             if not issubclass(exc_type, self.expected):
127                 # let unexpected exceptions pass through
128                 return False
129             self.exception = exc_value
130             if self.expected_regex is None:
131                 return True
132
133             expected_regex = self.expected_regex
134             if not expected_regex.search(str(exc_value)):
135                 self._raiseFailure('"{0}" does not match "{1}"'.format(
136                          expected_regex.pattern, str(exc_value)))
137             return True
138
139
140 def dummy_ssl_context():
141     if ssl is None:
142         return None
143     else:
144         return SSLContext(ssl.PROTOCOL_SSLv23)
145
146
147 def run_briefly(loop, steps=1):
148     @coroutine
149     def once():
150         pass
151     for step in range(steps):
152         gen = once()
153         t = loop.create_task(gen)
154         # Don't log a warning if the task is not done after run_until_complete().
155         # It occurs if the loop is stopped or if a task raises a BaseException.
156         t._log_destroy_pending = False
157         try:
158             loop.run_until_complete(t)
159         finally:
160             gen.close()
161
162
163 def run_until(loop, pred, timeout=30):
164     deadline = time.time() + timeout
165     while not pred():
166         if timeout is not None:
167             timeout = deadline - time.time()
168             if timeout <= 0:
169                 raise futures.TimeoutError()
170         loop.run_until_complete(tasks.sleep(0.001, loop=loop))
171
172
173 def run_once(loop):
174     """loop.stop() schedules _raise_stop_error()
175     and run_forever() runs until _raise_stop_error() callback.
176     this wont work if test waits for some IO events, because
177     _raise_stop_error() runs before any of io events callbacks.
178     """
179     loop.stop()
180     loop.run_forever()
181
182
183 class SilentWSGIRequestHandler(WSGIRequestHandler):
184
185     def get_stderr(self):
186         return io.StringIO()
187
188     def log_message(self, format, *args):
189         pass
190
191
192 class SilentWSGIServer(WSGIServer, object):
193
194     request_timeout = 2
195
196     def get_request(self):
197         request, client_addr = super(SilentWSGIServer, self).get_request()
198         request.settimeout(self.request_timeout)
199         return request, client_addr
200
201     def handle_error(self, request, client_address):
202         pass
203
204
205 class SSLWSGIServerMixin:
206
207     def finish_request(self, request, client_address):
208         # The relative location of our test directory (which
209         # contains the ssl key and certificate files) differs
210         # between the stdlib and stand-alone asyncio.
211         # Prefer our own if we can find it.
212         here = os.path.join(os.path.dirname(__file__), '..', 'tests')
213         if not os.path.isdir(here):
214             here = os.path.join(os.path.dirname(os.__file__),
215                                 'test', 'test_asyncio')
216         keyfile = os.path.join(here, 'ssl_key.pem')
217         certfile = os.path.join(here, 'ssl_cert.pem')
218         ssock = wrap_socket(request,
219                             keyfile=keyfile,
220                             certfile=certfile,
221                             server_side=True)
222         try:
223             self.RequestHandlerClass(ssock, client_address, self)
224             ssock.close()
225         except OSError:
226             # maybe socket has been closed by peer
227             pass
228
229
230 class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
231     pass
232
233
234 def _run_test_server(address, use_ssl, server_cls, server_ssl_cls):
235
236     def app(environ, start_response):
237         status = '200 OK'
238         headers = [('Content-type', 'text/plain')]
239         start_response(status, headers)
240         return [b'Test message']
241
242     # Run the test WSGI server in a separate thread in order not to
243     # interfere with event handling in the main thread
244     server_class = server_ssl_cls if use_ssl else server_cls
245     httpd = server_class(address, SilentWSGIRequestHandler)
246     httpd.set_app(app)
247     httpd.address = httpd.server_address
248     server_thread = threading.Thread(
249         target=lambda: httpd.serve_forever(poll_interval=0.05))
250     server_thread.start()
251     try:
252         yield httpd
253     finally:
254         httpd.shutdown()
255         httpd.server_close()
256         server_thread.join()
257
258
259 if hasattr(socket, 'AF_UNIX'):
260
261     class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer, object):
262
263         def server_bind(self):
264             socketserver.UnixStreamServer.server_bind(self)
265             self.server_name = '127.0.0.1'
266             self.server_port = 80
267
268
269     class UnixWSGIServer(UnixHTTPServer, WSGIServer, object):
270
271         request_timeout = 2
272
273         def server_bind(self):
274             UnixHTTPServer.server_bind(self)
275             self.setup_environ()
276
277         def get_request(self):
278             request, client_addr = super(UnixWSGIServer, self).get_request()
279             request.settimeout(self.request_timeout)
280             # Code in the stdlib expects that get_request
281             # will return a socket and a tuple (host, port).
282             # However, this isn't true for UNIX sockets,
283             # as the second return value will be a path;
284             # hence we return some fake data sufficient
285             # to get the tests going
286             return request, ('127.0.0.1', '')
287
288
289     class SilentUnixWSGIServer(UnixWSGIServer):
290
291         def handle_error(self, request, client_address):
292             pass
293
294
295     class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
296         pass
297
298
299     def gen_unix_socket_path():
300         with tempfile.NamedTemporaryFile() as file:
301             return file.name
302
303
304     @contextlib.contextmanager
305     def unix_socket_path():
306         path = gen_unix_socket_path()
307         try:
308             yield path
309         finally:
310             try:
311                 os.unlink(path)
312             except OSError:
313                 pass
314
315
316     @contextlib.contextmanager
317     def run_test_unix_server(use_ssl=False):
318         with unix_socket_path() as path:
319             for item in _run_test_server(address=path, use_ssl=use_ssl,
320                                          server_cls=SilentUnixWSGIServer,
321                                          server_ssl_cls=UnixSSLWSGIServer):
322                 yield item
323
324
325 @contextlib.contextmanager
326 def run_test_server(host='127.0.0.1', port=0, use_ssl=False):
327     for item in _run_test_server(address=(host, port), use_ssl=use_ssl,
328                                  server_cls=SilentWSGIServer,
329                                  server_ssl_cls=SSLWSGIServer):
330         yield item
331
332
333 def make_test_protocol(base):
334     dct = {}
335     for name in dir(base):
336         if name.startswith('__') and name.endswith('__'):
337             # skip magic names
338             continue
339         dct[name] = MockCallback(return_value=None)
340     return type('TestProtocol', (base,) + base.__bases__, dct)()
341
342
343 class TestSelector(selectors.BaseSelector):
344
345     def __init__(self):
346         self.keys = {}
347
348     def register(self, fileobj, events, data=None):
349         key = selectors.SelectorKey(fileobj, 0, events, data)
350         self.keys[fileobj] = key
351         return key
352
353     def unregister(self, fileobj):
354         return self.keys.pop(fileobj)
355
356     def select(self, timeout):
357         return []
358
359     def get_map(self):
360         return self.keys
361
362
363 class TestLoop(base_events.BaseEventLoop):
364     """Loop for unittests.
365
366     It manages self time directly.
367     If something scheduled to be executed later then
368     on next loop iteration after all ready handlers done
369     generator passed to __init__ is calling.
370
371     Generator should be like this:
372
373         def gen():
374             ...
375             when = yield ...
376             ... = yield time_advance
377
378     Value returned by yield is absolute time of next scheduled handler.
379     Value passed to yield is time advance to move loop's time forward.
380     """
381
382     def __init__(self, gen=None):
383         super(TestLoop, self).__init__()
384
385         if gen is None:
386             def gen():
387                 yield
388             self._check_on_close = False
389         else:
390             self._check_on_close = True
391
392         self._gen = gen()
393         next(self._gen)
394         self._time = 0
395         self._clock_resolution = 1e-9
396         self._timers = []
397         self._selector = TestSelector()
398
399         self.readers = {}
400         self.writers = {}
401         self.reset_counters()
402
403     def time(self):
404         return self._time
405
406     def advance_time(self, advance):
407         """Move test time forward."""
408         if advance:
409             self._time += advance
410
411     def close(self):
412         super(TestLoop, self).close()
413         if self._check_on_close:
414             try:
415                 self._gen.send(0)
416             except StopIteration:
417                 pass
418             else:  # pragma: no cover
419                 raise AssertionError("Time generator is not finished")
420
421     def add_reader(self, fd, callback, *args):
422         self.readers[fd] = events.Handle(callback, args, self)
423
424     def remove_reader(self, fd):
425         self.remove_reader_count[fd] += 1
426         if fd in self.readers:
427             del self.readers[fd]
428             return True
429         else:
430             return False
431
432     def assert_reader(self, fd, callback, *args):
433         assert fd in self.readers, 'fd {0} is not registered'.format(fd)
434         handle = self.readers[fd]
435         assert handle._callback == callback, '{0!r} != {1!r}'.format(
436             handle._callback, callback)
437         assert handle._args == args, '{0!r} != {1!r}'.format(
438             handle._args, args)
439
440     def add_writer(self, fd, callback, *args):
441         self.writers[fd] = events.Handle(callback, args, self)
442
443     def remove_writer(self, fd):
444         self.remove_writer_count[fd] += 1
445         if fd in self.writers:
446             del self.writers[fd]
447             return True
448         else:
449             return False
450
451     def assert_writer(self, fd, callback, *args):
452         assert fd in self.writers, 'fd {0} is not registered'.format(fd)
453         handle = self.writers[fd]
454         assert handle._callback == callback, '{0!r} != {1!r}'.format(
455             handle._callback, callback)
456         assert handle._args == args, '{0!r} != {1!r}'.format(
457             handle._args, args)
458
459     def reset_counters(self):
460         self.remove_reader_count = collections.defaultdict(int)
461         self.remove_writer_count = collections.defaultdict(int)
462
463     def _run_once(self):
464         super(TestLoop, self)._run_once()
465         for when in self._timers:
466             advance = self._gen.send(when)
467             self.advance_time(advance)
468         self._timers = []
469
470     def call_at(self, when, callback, *args):
471         self._timers.append(when)
472         return super(TestLoop, self).call_at(when, callback, *args)
473
474     def _process_events(self, event_list):
475         return
476
477     def _write_to_self(self):
478         pass
479
480
481 def MockCallback(**kwargs):
482     return mock.Mock(spec=['__call__'], **kwargs)
483
484
485 class MockPattern(str):
486     """A regex based str with a fuzzy __eq__.
487
488     Use this helper with 'mock.assert_called_with', or anywhere
489     where a regex comparison between strings is needed.
490
491     For instance:
492        mock_call.assert_called_with(MockPattern('spam.*ham'))
493     """
494     def __eq__(self, other):
495         return bool(re.search(str(self), other, re.S))
496
497
498 def get_function_source(func):
499     source = events._get_function_source(func)
500     if source is None:
501         raise ValueError("unable to get the source of %r" % (func,))
502     return source
503
504
505 class TestCase(unittest.TestCase):
506     def set_event_loop(self, loop, cleanup=True):
507         assert loop is not None
508         # ensure that the event loop is passed explicitly in asyncio
509         events.set_event_loop(None)
510         if cleanup:
511             self.addCleanup(loop.close)
512
513     def new_test_loop(self, gen=None):
514         loop = TestLoop(gen)
515         self.set_event_loop(loop)
516         return loop
517
518     def tearDown(self):
519         events.set_event_loop(None)
520
521         # Detect CPython bug #23353: ensure that yield/yield-from is not used
522         # in an except block of a generator
523         if sys.exc_info()[0] == SkipTest:
524             if six.PY2:
525                 sys.exc_clear()
526         else:
527             pass #self.assertEqual(sys.exc_info(), (None, None, None))
528
529     def check_soure_traceback(self, source_traceback, lineno_delta):
530         frame = sys._getframe(1)
531         filename = frame.f_code.co_filename
532         lineno = frame.f_lineno + lineno_delta
533         name = frame.f_code.co_name
534         self.assertIsInstance(source_traceback, list)
535         self.assertEqual(source_traceback[-1][:3],
536                          (filename,
537                           lineno,
538                           name))
539
540
541 @contextlib.contextmanager
542 def disable_logger():
543     """Context manager to disable asyncio logger.
544
545     For example, it can be used to ignore warnings in debug mode.
546     """
547     old_level = logger.level
548     try:
549         logger.setLevel(logging.CRITICAL+1)
550         yield
551     finally:
552         logger.setLevel(old_level)
553
554 def mock_nonblocking_socket():
555     """Create a mock of a non-blocking socket."""
556     sock = mock.Mock(socket.socket)
557     sock.gettimeout.return_value = 0.0
558     return sock
559
560
561 def force_legacy_ssl_support():
562     return mock.patch('trollius.sslproto._is_sslproto_available',
563                       return_value=False)