1 """Utilities shared by tests."""
15 from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
21 from http.server import HTTPServer
24 import SocketServer as socketserver
25 from BaseHTTPServer import HTTPServer
28 from unittest import mock
35 from .py3_ssl import SSLContext, wrap_socket
36 except ImportError: # pragma: no cover
37 # SSL support disabled in Python
40 from . import base_events
44 from . import selectors
46 from .coroutines import coroutine
47 from .log import logger
50 if sys.platform == 'win32': # pragma: no cover
51 from .windows_utils import socketpair
53 from socket import socketpair # pragma: no cover
56 # Prefer unittest2 if available (on Python 2)
57 import unittest2 as unittest
61 skipIf = unittest.skipIf
62 skipUnless = unittest.skipUnless
63 SkipTest = unittest.SkipTest
66 if not hasattr(unittest.TestCase, 'assertRaisesRegex'):
67 class _BaseTestCaseContext:
69 def __init__(self, test_case):
70 self.test_case = test_case
72 def _raiseFailure(self, standardMsg):
73 msg = self.test_case._formatMessage(self.msg, standardMsg)
74 raise self.test_case.failureException(msg)
77 class _AssertRaisesBaseContext(_BaseTestCaseContext):
79 def __init__(self, expected, test_case, callable_obj=None,
81 _BaseTestCaseContext.__init__(self, test_case)
82 self.expected = expected
83 self.test_case = test_case
84 if callable_obj is not None:
86 self.obj_name = callable_obj.__name__
87 except AttributeError:
88 self.obj_name = str(callable_obj)
91 if isinstance(expected_regex, (bytes, str)):
92 expected_regex = re.compile(expected_regex)
93 self.expected_regex = expected_regex
96 def handle(self, name, callable_obj, args, kwargs):
98 If callable_obj is None, assertRaises/Warns is being used as a
99 context manager, so check for a 'msg' kwarg and return self.
100 If callable_obj is not None, call it passing args and kwargs.
102 if callable_obj is None:
103 self.msg = kwargs.pop('msg', None)
106 callable_obj(*args, **kwargs)
109 class _AssertRaisesContext(_AssertRaisesBaseContext):
110 """A context manager used to implement TestCase.assertRaises* methods."""
115 def __exit__(self, exc_type, exc_value, tb):
118 exc_name = self.expected.__name__
119 except AttributeError:
120 exc_name = str(self.expected)
122 self._raiseFailure("{0} not raised by {1}".format(exc_name,
125 self._raiseFailure("{0} not raised".format(exc_name))
126 if not issubclass(exc_type, self.expected):
127 # let unexpected exceptions pass through
129 self.exception = exc_value
130 if self.expected_regex is None:
133 expected_regex = self.expected_regex
134 if not expected_regex.search(str(exc_value)):
135 self._raiseFailure('"{0}" does not match "{1}"'.format(
136 expected_regex.pattern, str(exc_value)))
140 def dummy_ssl_context():
144 return SSLContext(ssl.PROTOCOL_SSLv23)
147 def run_briefly(loop, steps=1):
151 for step in range(steps):
153 t = loop.create_task(gen)
154 # Don't log a warning if the task is not done after run_until_complete().
155 # It occurs if the loop is stopped or if a task raises a BaseException.
156 t._log_destroy_pending = False
158 loop.run_until_complete(t)
163 def run_until(loop, pred, timeout=30):
164 deadline = time.time() + timeout
166 if timeout is not None:
167 timeout = deadline - time.time()
169 raise futures.TimeoutError()
170 loop.run_until_complete(tasks.sleep(0.001, loop=loop))
174 """loop.stop() schedules _raise_stop_error()
175 and run_forever() runs until _raise_stop_error() callback.
176 this wont work if test waits for some IO events, because
177 _raise_stop_error() runs before any of io events callbacks.
183 class SilentWSGIRequestHandler(WSGIRequestHandler):
185 def get_stderr(self):
188 def log_message(self, format, *args):
192 class SilentWSGIServer(WSGIServer, object):
196 def get_request(self):
197 request, client_addr = super(SilentWSGIServer, self).get_request()
198 request.settimeout(self.request_timeout)
199 return request, client_addr
201 def handle_error(self, request, client_address):
205 class SSLWSGIServerMixin:
207 def finish_request(self, request, client_address):
208 # The relative location of our test directory (which
209 # contains the ssl key and certificate files) differs
210 # between the stdlib and stand-alone asyncio.
211 # Prefer our own if we can find it.
212 here = os.path.join(os.path.dirname(__file__), '..', 'tests')
213 if not os.path.isdir(here):
214 here = os.path.join(os.path.dirname(os.__file__),
215 'test', 'test_asyncio')
216 keyfile = os.path.join(here, 'ssl_key.pem')
217 certfile = os.path.join(here, 'ssl_cert.pem')
218 ssock = wrap_socket(request,
223 self.RequestHandlerClass(ssock, client_address, self)
226 # maybe socket has been closed by peer
230 class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
234 def _run_test_server(address, use_ssl, server_cls, server_ssl_cls):
236 def app(environ, start_response):
238 headers = [('Content-type', 'text/plain')]
239 start_response(status, headers)
240 return [b'Test message']
242 # Run the test WSGI server in a separate thread in order not to
243 # interfere with event handling in the main thread
244 server_class = server_ssl_cls if use_ssl else server_cls
245 httpd = server_class(address, SilentWSGIRequestHandler)
247 httpd.address = httpd.server_address
248 server_thread = threading.Thread(
249 target=lambda: httpd.serve_forever(poll_interval=0.05))
250 server_thread.start()
259 if hasattr(socket, 'AF_UNIX'):
261 class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer, object):
263 def server_bind(self):
264 socketserver.UnixStreamServer.server_bind(self)
265 self.server_name = '127.0.0.1'
266 self.server_port = 80
269 class UnixWSGIServer(UnixHTTPServer, WSGIServer, object):
273 def server_bind(self):
274 UnixHTTPServer.server_bind(self)
277 def get_request(self):
278 request, client_addr = super(UnixWSGIServer, self).get_request()
279 request.settimeout(self.request_timeout)
280 # Code in the stdlib expects that get_request
281 # will return a socket and a tuple (host, port).
282 # However, this isn't true for UNIX sockets,
283 # as the second return value will be a path;
284 # hence we return some fake data sufficient
285 # to get the tests going
286 return request, ('127.0.0.1', '')
289 class SilentUnixWSGIServer(UnixWSGIServer):
291 def handle_error(self, request, client_address):
295 class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
299 def gen_unix_socket_path():
300 with tempfile.NamedTemporaryFile() as file:
304 @contextlib.contextmanager
305 def unix_socket_path():
306 path = gen_unix_socket_path()
316 @contextlib.contextmanager
317 def run_test_unix_server(use_ssl=False):
318 with unix_socket_path() as path:
319 for item in _run_test_server(address=path, use_ssl=use_ssl,
320 server_cls=SilentUnixWSGIServer,
321 server_ssl_cls=UnixSSLWSGIServer):
325 @contextlib.contextmanager
326 def run_test_server(host='127.0.0.1', port=0, use_ssl=False):
327 for item in _run_test_server(address=(host, port), use_ssl=use_ssl,
328 server_cls=SilentWSGIServer,
329 server_ssl_cls=SSLWSGIServer):
333 def make_test_protocol(base):
335 for name in dir(base):
336 if name.startswith('__') and name.endswith('__'):
339 dct[name] = MockCallback(return_value=None)
340 return type('TestProtocol', (base,) + base.__bases__, dct)()
343 class TestSelector(selectors.BaseSelector):
348 def register(self, fileobj, events, data=None):
349 key = selectors.SelectorKey(fileobj, 0, events, data)
350 self.keys[fileobj] = key
353 def unregister(self, fileobj):
354 return self.keys.pop(fileobj)
356 def select(self, timeout):
363 class TestLoop(base_events.BaseEventLoop):
364 """Loop for unittests.
366 It manages self time directly.
367 If something scheduled to be executed later then
368 on next loop iteration after all ready handlers done
369 generator passed to __init__ is calling.
371 Generator should be like this:
376 ... = yield time_advance
378 Value returned by yield is absolute time of next scheduled handler.
379 Value passed to yield is time advance to move loop's time forward.
382 def __init__(self, gen=None):
383 super(TestLoop, self).__init__()
388 self._check_on_close = False
390 self._check_on_close = True
395 self._clock_resolution = 1e-9
397 self._selector = TestSelector()
401 self.reset_counters()
406 def advance_time(self, advance):
407 """Move test time forward."""
409 self._time += advance
412 super(TestLoop, self).close()
413 if self._check_on_close:
416 except StopIteration:
418 else: # pragma: no cover
419 raise AssertionError("Time generator is not finished")
421 def add_reader(self, fd, callback, *args):
422 self.readers[fd] = events.Handle(callback, args, self)
424 def remove_reader(self, fd):
425 self.remove_reader_count[fd] += 1
426 if fd in self.readers:
432 def assert_reader(self, fd, callback, *args):
433 assert fd in self.readers, 'fd {0} is not registered'.format(fd)
434 handle = self.readers[fd]
435 assert handle._callback == callback, '{0!r} != {1!r}'.format(
436 handle._callback, callback)
437 assert handle._args == args, '{0!r} != {1!r}'.format(
440 def add_writer(self, fd, callback, *args):
441 self.writers[fd] = events.Handle(callback, args, self)
443 def remove_writer(self, fd):
444 self.remove_writer_count[fd] += 1
445 if fd in self.writers:
451 def assert_writer(self, fd, callback, *args):
452 assert fd in self.writers, 'fd {0} is not registered'.format(fd)
453 handle = self.writers[fd]
454 assert handle._callback == callback, '{0!r} != {1!r}'.format(
455 handle._callback, callback)
456 assert handle._args == args, '{0!r} != {1!r}'.format(
459 def reset_counters(self):
460 self.remove_reader_count = collections.defaultdict(int)
461 self.remove_writer_count = collections.defaultdict(int)
464 super(TestLoop, self)._run_once()
465 for when in self._timers:
466 advance = self._gen.send(when)
467 self.advance_time(advance)
470 def call_at(self, when, callback, *args):
471 self._timers.append(when)
472 return super(TestLoop, self).call_at(when, callback, *args)
474 def _process_events(self, event_list):
477 def _write_to_self(self):
481 def MockCallback(**kwargs):
482 return mock.Mock(spec=['__call__'], **kwargs)
485 class MockPattern(str):
486 """A regex based str with a fuzzy __eq__.
488 Use this helper with 'mock.assert_called_with', or anywhere
489 where a regex comparison between strings is needed.
492 mock_call.assert_called_with(MockPattern('spam.*ham'))
494 def __eq__(self, other):
495 return bool(re.search(str(self), other, re.S))
498 def get_function_source(func):
499 source = events._get_function_source(func)
501 raise ValueError("unable to get the source of %r" % (func,))
505 class TestCase(unittest.TestCase):
506 def set_event_loop(self, loop, cleanup=True):
507 assert loop is not None
508 # ensure that the event loop is passed explicitly in asyncio
509 events.set_event_loop(None)
511 self.addCleanup(loop.close)
513 def new_test_loop(self, gen=None):
515 self.set_event_loop(loop)
519 events.set_event_loop(None)
521 # Detect CPython bug #23353: ensure that yield/yield-from is not used
522 # in an except block of a generator
523 if sys.exc_info()[0] == SkipTest:
527 pass #self.assertEqual(sys.exc_info(), (None, None, None))
529 def check_soure_traceback(self, source_traceback, lineno_delta):
530 frame = sys._getframe(1)
531 filename = frame.f_code.co_filename
532 lineno = frame.f_lineno + lineno_delta
533 name = frame.f_code.co_name
534 self.assertIsInstance(source_traceback, list)
535 self.assertEqual(source_traceback[-1][:3],
541 @contextlib.contextmanager
542 def disable_logger():
543 """Context manager to disable asyncio logger.
545 For example, it can be used to ignore warnings in debug mode.
547 old_level = logger.level
549 logger.setLevel(logging.CRITICAL+1)
552 logger.setLevel(old_level)
554 def mock_nonblocking_socket():
555 """Create a mock of a non-blocking socket."""
556 sock = mock.Mock(socket.socket)
557 sock.gettimeout.return_value = 0.0
561 def force_legacy_ssl_support():
562 return mock.patch('trollius.sslproto._is_sslproto_available',