+++ /dev/null
-"""Stream-related things."""
-
-__all__ = ['StreamReader', 'StreamWriter', 'StreamReaderProtocol',
- 'open_connection', 'start_server',
- 'IncompleteReadError',
- ]
-
-import socket
-
-if hasattr(socket, 'AF_UNIX'):
- __all__.extend(['open_unix_connection', 'start_unix_server'])
-
-from . import coroutines
-from . import compat
-from . import events
-from . import futures
-from . import protocols
-from .coroutines import coroutine, From, Return
-from .py33_exceptions import ConnectionResetError
-from .log import logger
-
-
-_DEFAULT_LIMIT = 2**16
-
-
-class IncompleteReadError(EOFError):
- """
- Incomplete read error. Attributes:
-
- - partial: read bytes string before the end of stream was reached
- - expected: total number of expected bytes
- """
- def __init__(self, partial, expected):
- EOFError.__init__(self, "%s bytes read on a total of %s expected bytes"
- % (len(partial), expected))
- self.partial = partial
- self.expected = expected
-
-
-@coroutine
-def open_connection(host=None, port=None,
- loop=None, limit=_DEFAULT_LIMIT, **kwds):
- """A wrapper for create_connection() returning a (reader, writer) pair.
-
- The reader returned is a StreamReader instance; the writer is a
- StreamWriter instance.
-
- The arguments are all the usual arguments to create_connection()
- except protocol_factory; most common are positional host and port,
- with various optional keyword arguments following.
-
- Additional optional keyword arguments are loop (to set the event loop
- instance to use) and limit (to set the buffer limit passed to the
- StreamReader).
-
- (If you want to customize the StreamReader and/or
- StreamReaderProtocol classes, just copy the code -- there's
- really nothing special here except some convenience.)
- """
- if loop is None:
- loop = events.get_event_loop()
- reader = StreamReader(limit=limit, loop=loop)
- protocol = StreamReaderProtocol(reader, loop=loop)
- transport, _ = yield From(loop.create_connection(
- lambda: protocol, host, port, **kwds))
- writer = StreamWriter(transport, protocol, reader, loop)
- raise Return(reader, writer)
-
-
-@coroutine
-def start_server(client_connected_cb, host=None, port=None,
- loop=None, limit=_DEFAULT_LIMIT, **kwds):
- """Start a socket server, call back for each client connected.
-
- The first parameter, `client_connected_cb`, takes two parameters:
- client_reader, client_writer. client_reader is a StreamReader
- object, while client_writer is a StreamWriter object. This
- parameter can either be a plain callback function or a coroutine;
- if it is a coroutine, it will be automatically converted into a
- Task.
-
- The rest of the arguments are all the usual arguments to
- loop.create_server() except protocol_factory; most common are
- positional host and port, with various optional keyword arguments
- following. The return value is the same as loop.create_server().
-
- Additional optional keyword arguments are loop (to set the event loop
- instance to use) and limit (to set the buffer limit passed to the
- StreamReader).
-
- The return value is the same as loop.create_server(), i.e. a
- Server object which can be used to stop the service.
- """
- if loop is None:
- loop = events.get_event_loop()
-
- def factory():
- reader = StreamReader(limit=limit, loop=loop)
- protocol = StreamReaderProtocol(reader, client_connected_cb,
- loop=loop)
- return protocol
-
- server = yield From(loop.create_server(factory, host, port, **kwds))
- raise Return(server)
-
-
-if hasattr(socket, 'AF_UNIX'):
- # UNIX Domain Sockets are supported on this platform
-
- @coroutine
- def open_unix_connection(path=None,
- loop=None, limit=_DEFAULT_LIMIT, **kwds):
- """Similar to `open_connection` but works with UNIX Domain Sockets."""
- if loop is None:
- loop = events.get_event_loop()
- reader = StreamReader(limit=limit, loop=loop)
- protocol = StreamReaderProtocol(reader, loop=loop)
- transport, _ = yield From(loop.create_unix_connection(
- lambda: protocol, path, **kwds))
- writer = StreamWriter(transport, protocol, reader, loop)
- raise Return(reader, writer)
-
-
- @coroutine
- def start_unix_server(client_connected_cb, path=None,
- loop=None, limit=_DEFAULT_LIMIT, **kwds):
- """Similar to `start_server` but works with UNIX Domain Sockets."""
- if loop is None:
- loop = events.get_event_loop()
-
- def factory():
- reader = StreamReader(limit=limit, loop=loop)
- protocol = StreamReaderProtocol(reader, client_connected_cb,
- loop=loop)
- return protocol
-
- server = (yield From(loop.create_unix_server(factory, path, **kwds)))
- raise Return(server)
-
-
-class FlowControlMixin(protocols.Protocol):
- """Reusable flow control logic for StreamWriter.drain().
-
- This implements the protocol methods pause_writing(),
- resume_reading() and connection_lost(). If the subclass overrides
- these it must call the super methods.
-
- StreamWriter.drain() must wait for _drain_helper() coroutine.
- """
-
- def __init__(self, loop=None):
- if loop is None:
- self._loop = events.get_event_loop()
- else:
- self._loop = loop
- self._paused = False
- self._drain_waiter = None
- self._connection_lost = False
-
- def pause_writing(self):
- assert not self._paused
- self._paused = True
- if self._loop.get_debug():
- logger.debug("%r pauses writing", self)
-
- def resume_writing(self):
- assert self._paused
- self._paused = False
- if self._loop.get_debug():
- logger.debug("%r resumes writing", self)
-
- waiter = self._drain_waiter
- if waiter is not None:
- self._drain_waiter = None
- if not waiter.done():
- waiter.set_result(None)
-
- def connection_lost(self, exc):
- self._connection_lost = True
- # Wake up the writer if currently paused.
- if not self._paused:
- return
- waiter = self._drain_waiter
- if waiter is None:
- return
- self._drain_waiter = None
- if waiter.done():
- return
- if exc is None:
- waiter.set_result(None)
- else:
- waiter.set_exception(exc)
-
- @coroutine
- def _drain_helper(self):
- if self._connection_lost:
- raise ConnectionResetError('Connection lost')
- if not self._paused:
- return
- waiter = self._drain_waiter
- assert waiter is None or waiter.cancelled()
- waiter = futures.Future(loop=self._loop)
- self._drain_waiter = waiter
- yield From(waiter)
-
-
-class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
- """Helper class to adapt between Protocol and StreamReader.
-
- (This is a helper class instead of making StreamReader itself a
- Protocol subclass, because the StreamReader has other potential
- uses, and to prevent the user of the StreamReader to accidentally
- call inappropriate methods of the protocol.)
- """
-
- def __init__(self, stream_reader, client_connected_cb=None, loop=None):
- super(StreamReaderProtocol, self).__init__(loop=loop)
- self._stream_reader = stream_reader
- self._stream_writer = None
- self._client_connected_cb = client_connected_cb
-
- def connection_made(self, transport):
- self._stream_reader.set_transport(transport)
- if self._client_connected_cb is not None:
- self._stream_writer = StreamWriter(transport, self,
- self._stream_reader,
- self._loop)
- res = self._client_connected_cb(self._stream_reader,
- self._stream_writer)
- if coroutines.iscoroutine(res):
- self._loop.create_task(res)
-
- def connection_lost(self, exc):
- if exc is None:
- self._stream_reader.feed_eof()
- else:
- self._stream_reader.set_exception(exc)
- super(StreamReaderProtocol, self).connection_lost(exc)
-
- def data_received(self, data):
- self._stream_reader.feed_data(data)
-
- def eof_received(self):
- self._stream_reader.feed_eof()
- return True
-
-
-class StreamWriter(object):
- """Wraps a Transport.
-
- This exposes write(), writelines(), [can_]write_eof(),
- get_extra_info() and close(). It adds drain() which returns an
- optional Future on which you can wait for flow control. It also
- adds a transport property which references the Transport
- directly.
- """
-
- def __init__(self, transport, protocol, reader, loop):
- self._transport = transport
- self._protocol = protocol
- # drain() expects that the reader has a exception() method
- assert reader is None or isinstance(reader, StreamReader)
- self._reader = reader
- self._loop = loop
-
- def __repr__(self):
- info = [self.__class__.__name__, 'transport=%r' % self._transport]
- if self._reader is not None:
- info.append('reader=%r' % self._reader)
- return '<%s>' % ' '.join(info)
-
- @property
- def transport(self):
- return self._transport
-
- def write(self, data):
- self._transport.write(data)
-
- def writelines(self, data):
- self._transport.writelines(data)
-
- def write_eof(self):
- return self._transport.write_eof()
-
- def can_write_eof(self):
- return self._transport.can_write_eof()
-
- def close(self):
- return self._transport.close()
-
- def get_extra_info(self, name, default=None):
- return self._transport.get_extra_info(name, default)
-
- @coroutine
- def drain(self):
- """Flush the write buffer.
-
- The intended use is to write
-
- w.write(data)
- yield From(w.drain())
- """
- if self._reader is not None:
- exc = self._reader.exception()
- if exc is not None:
- raise exc
- yield From(self._protocol._drain_helper())
-
-
-class StreamReader(object):
-
- def __init__(self, limit=_DEFAULT_LIMIT, loop=None):
- # The line length limit is a security feature;
- # it also doubles as half the buffer limit.
- self._limit = limit
- if loop is None:
- self._loop = events.get_event_loop()
- else:
- self._loop = loop
- self._buffer = bytearray()
- self._eof = False # Whether we're done.
- self._waiter = None # A future used by _wait_for_data()
- self._exception = None
- self._transport = None
- self._paused = False
-
- def __repr__(self):
- info = ['StreamReader']
- if self._buffer:
- info.append('%d bytes' % len(info))
- if self._eof:
- info.append('eof')
- if self._limit != _DEFAULT_LIMIT:
- info.append('l=%d' % self._limit)
- if self._waiter:
- info.append('w=%r' % self._waiter)
- if self._exception:
- info.append('e=%r' % self._exception)
- if self._transport:
- info.append('t=%r' % self._transport)
- if self._paused:
- info.append('paused')
- return '<%s>' % ' '.join(info)
-
- def exception(self):
- return self._exception
-
- def set_exception(self, exc):
- self._exception = exc
-
- waiter = self._waiter
- if waiter is not None:
- self._waiter = None
- if not waiter.cancelled():
- waiter.set_exception(exc)
-
- def _wakeup_waiter(self):
- """Wakeup read() or readline() function waiting for data or EOF."""
- waiter = self._waiter
- if waiter is not None:
- self._waiter = None
- if not waiter.cancelled():
- waiter.set_result(None)
-
- def set_transport(self, transport):
- assert self._transport is None, 'Transport already set'
- self._transport = transport
-
- def _maybe_resume_transport(self):
- if self._paused and len(self._buffer) <= self._limit:
- self._paused = False
- self._transport.resume_reading()
-
- def feed_eof(self):
- self._eof = True
- self._wakeup_waiter()
-
- def at_eof(self):
- """Return True if the buffer is empty and 'feed_eof' was called."""
- return self._eof and not self._buffer
-
- def feed_data(self, data):
- assert not self._eof, 'feed_data after feed_eof'
-
- if not data:
- return
-
- self._buffer.extend(data)
- self._wakeup_waiter()
-
- if (self._transport is not None and
- not self._paused and
- len(self._buffer) > 2*self._limit):
- try:
- self._transport.pause_reading()
- except NotImplementedError:
- # The transport can't be paused.
- # We'll just have to buffer all data.
- # Forget the transport so we don't keep trying.
- self._transport = None
- else:
- self._paused = True
-
- @coroutine
- def _wait_for_data(self, func_name):
- """Wait until feed_data() or feed_eof() is called."""
- # StreamReader uses a future to link the protocol feed_data() method
- # to a read coroutine. Running two read coroutines at the same time
- # would have an unexpected behaviour. It would not possible to know
- # which coroutine would get the next data.
- if self._waiter is not None:
- raise RuntimeError('%s() called while another coroutine is '
- 'already waiting for incoming data' % func_name)
-
- # In asyncio, there is no need to recheck if we got data or EOF thanks
- # to "yield from". In trollius, a StreamReader method can be called
- # after the _wait_for_data() coroutine is scheduled and before it is
- # really executed.
- if self._buffer or self._eof:
- return
-
- self._waiter = futures.Future(loop=self._loop)
- try:
- yield From(self._waiter)
- finally:
- self._waiter = None
-
- @coroutine
- def readline(self):
- if self._exception is not None:
- raise self._exception
-
- line = bytearray()
- not_enough = True
-
- while not_enough:
- while self._buffer and not_enough:
- ichar = self._buffer.find(b'\n')
- if ichar < 0:
- line.extend(self._buffer)
- del self._buffer[:]
- else:
- ichar += 1
- line.extend(self._buffer[:ichar])
- del self._buffer[:ichar]
- not_enough = False
-
- if len(line) > self._limit:
- self._maybe_resume_transport()
- raise ValueError('Line is too long')
-
- if self._eof:
- break
-
- if not_enough:
- yield From(self._wait_for_data('readline'))
-
- self._maybe_resume_transport()
- raise Return(bytes(line))
-
- @coroutine
- def read(self, n=-1):
- if self._exception is not None:
- raise self._exception
-
- if not n:
- raise Return(b'')
-
- if n < 0:
- # This used to just loop creating a new waiter hoping to
- # collect everything in self._buffer, but that would
- # deadlock if the subprocess sends more than self.limit
- # bytes. So just call self.read(self._limit) until EOF.
- blocks = []
- while True:
- block = yield From(self.read(self._limit))
- if not block:
- break
- blocks.append(block)
- raise Return(b''.join(blocks))
- else:
- if not self._buffer and not self._eof:
- yield From(self._wait_for_data('read'))
-
- if n < 0 or len(self._buffer) <= n:
- data = bytes(self._buffer)
- del self._buffer[:]
- else:
- # n > 0 and len(self._buffer) > n
- data = bytes(self._buffer[:n])
- del self._buffer[:n]
-
- self._maybe_resume_transport()
- raise Return(data)
-
- @coroutine
- def readexactly(self, n):
- if self._exception is not None:
- raise self._exception
-
- # There used to be "optimized" code here. It created its own
- # Future and waited until self._buffer had at least the n
- # bytes, then called read(n). Unfortunately, this could pause
- # the transport if the argument was larger than the pause
- # limit (which is twice self._limit). So now we just read()
- # into a local buffer.
-
- blocks = []
- while n > 0:
- block = yield From(self.read(n))
- if not block:
- partial = b''.join(blocks)
- raise IncompleteReadError(partial, len(partial) + n)
- blocks.append(block)
- n -= len(block)
-
- raise Return(b''.join(blocks))
-
- # FIXME: should we support __aiter__ and __anext__ in Trollius?
- #if compat.PY35:
- # @coroutine
- # def __aiter__(self):
- # return self
- #
- # @coroutine
- # def __anext__(self):
- # val = yield from self.readline()
- # if val == b'':
- # raise StopAsyncIteration
- # return val