Source code for mpservice.socket

"""The module ``mpservice.socket`` provides tools to use sockets to communicate between
two Python processes on the same machine.
"""

from __future__ import annotations

__all__ = [
    'SocketApplication',
    'SocketServer',
    'SocketClient',
    'make_server',
]

import asyncio
import concurrent.futures
import functools
import inspect
import logging
import os
import queue
import stat
import subprocess
import threading
import time
from collections.abc import Iterable
from pickle import dumps as pickle_dumps
from pickle import loads as pickle_loads
from time import perf_counter
from typing import Any, Awaitable, Callable

from ._queues import SingleLane
from .concurrent.futures import ThreadPoolExecutor
from .multiprocessing.remote_exception import RemoteException

logger = logging.getLogger(__name__)


def is_async(func):
    while isinstance(func, functools.partial):
        func = func.func
    return inspect.iscoroutinefunction(func) or (
        not inspect.isfunction(func)
        and hasattr(func, '__call__')
        and inspect.iscoroutinefunction(func.__call__)
    )


def get_docker_host_ip():
    """
    From within a Docker container, this function finds the IP address
    of the host machine.
    """
    # INTERNAL_HOST_IP=$(ip route show default | awk '/default/ {print $3})')
    # another idea:
    # ip -4 route list match 0/0 | cut -d' ' -f3
    #
    # Usually the result is '172.17.0.1'
    #
    # The command `ip` is provided by the Linux package `iproute2`.

    z = subprocess.check_output(
        ['ip', '-4', 'route', 'list', 'match', '0/0']  # noqa: S603, S607
    )
    z = z.decode()[len('default via ') :]
    return z[: z.find(' ')]


def put_in_queue(q, x, stop_event, timeout=0.1):
    """
    ``q`` is either a ``threading.Queue`` or a ``multiprocessing.queues.Queue``,
    but not an ``asyncio.Queue``, because the latter does not take
    the ``timeout`` argument.

    This is used in a blocking mode to put ``x`` in the queue
    till success. It checks for any request of early-stop indicated by
    ``stop_event``, which is either ``threading.Event`` or ``multiprocessing.synchronize.Event``.
    Usually there is not need to customize the value of ``timeout``,
    because in this use case it's OK to try a little long before checking
    early-stop.

    Return ``True`` if successfully put in queue; ``False`` if early-stop is detected.
    """
    while True:
        try:
            q.put(x, timeout=timeout)
            return True
        except queue.Full:
            if stop_event.is_set():
                return False


# References about Docker networking:
#
#  Docker Containers: IPC using Sockets
#  https://medium.com/technanic/docker-containers-ipc-using-sockets-part-1-2ee90885602c
#  https://medium.com/technanic/docker-containers-ipc-using-sockets-part-2-834e8ea00768
#
#  Connection refused? Docker networking and how it impacts your image
#  https://pythonspeed.com/articles/docker-connection-refused/

# Unix domain sockets
#   https://pymotw.com/2/socket/uds.html

# Python socket programming
#  https://realpython.com/python-sockets/
#  https://docs.python.org/3/howto/sockets.html


def encode(data, encoder):
    if encoder == 'pickle':
        return pickle_dumps(data)
    if encoder == 'utf8':
        return data.encode('utf8')
    if encoder != 'none':
        raise ValueError(f"expecting 'none' but got: {encoder}")
    return data  # `data` must be bytes


def decode(data, encoder):
    if encoder == 'pickle':
        return pickle_loads(data)
    if encoder == 'utf8':
        return data.decode('utf')
    assert encoder == 'none'
    return data  # bytes unchanged


# Our design of a record is laid out this way:
#
#    b'adf23d 24 pickle\naskadfka23kdkda'
#
# The 'header' part contains
#    - request id (whatever string the client decides to use, no space)
#    - lengths of payload in bytes,
#    - encoder
# ending with '\n'; after that comes the said number
# of bytes, which should be decoded by `decode` according to the 'encoder'.


async def write_record(writer, request_id, data, *, encoder: str = 'pickle'):
    data_bytes = encode(data, encoder)
    writer.write(f'{request_id} {len(data_bytes)} {encoder}\n'.encode())
    writer.write(data_bytes)
    await writer.drain()
    # TODO: add timeout?


async def read_record(reader, *, timeout=None):
    # `timeout` should be `None` or `> 0`.
    # This may raise `asyncio.TimeoutError` (nothing to read at the moment)
    # or `asyncio.IncompleteReadError` (connection has been closed by
    # the other side).
    data = await asyncio.wait_for(reader.readuntil(b'\n'), timeout)
    request_id, num_bytes, encoder = data[:-1].decode().split()
    data = await reader.readexactly(int(num_bytes))
    return request_id, decode(data, encoder)
    # If this function is called on the server side, it should include
    # `request_id` as is in the response.
    # If this function is called on the client side, after getting
    # `request_id` in the server response, it may need to transform on it
    # as per its design, e.g. if the ID is an int, now it's a string in
    # the response, hence client needs to convert it to an int.


async def run_tcp_server(conn_handler: Callable, host: str, port: int):
    server = await asyncio.start_server(conn_handler, host, port)
    async with server:
        addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets)
        logger.info('serving on %s', addrs)
        await server.serve_forever()


async def run_unix_server(conn_handler: Callable, path: str):
    """
    Parameters
    ----------
    path
        A file path accessible by both the server and the client,
        which run on the same machine. This is not a regular file,
        and it (including parent directories) does not need to exist.
        In fact, if the file exists, it will be removed first.
    """
    if os.path.exists(path):
        assert stat.S_ISSOCK(
            os.stat(path).st_mode
        ), f"file '{path}' exists but is not a socket file"
    else:
        os.makedirs(os.path.dirname(path), exist_ok=True)
    server = await asyncio.start_unix_server(conn_handler, path)
    async with server:
        addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets)
        logger.info('serving on %s', addrs)
        await server.serve_forever()


async def open_tcp_connection(host, port, *, timeout=None):
    t0 = perf_counter()
    while True:
        try:
            reader, writer = await asyncio.open_connection(host, port)
            return reader, writer
        except ConnectionRefusedError:
            if timeout is not None:
                if perf_counter() - t0 > timeout:
                    raise
            await asyncio.sleep(0.1)


async def open_unix_connection(path: str, *, timeout=None):
    """
    Parameters
    ----------
    path
        The same string that has been used by ``run_unix_server``.
    """
    t0 = perf_counter()
    while True:
        try:
            reader, writer = await asyncio.open_unix_connection(path)
            return reader, writer
        except (ConnectionRefusedError, FileNotFoundError):
            if timeout is not None:
                if perf_counter() - t0 > timeout:
                    raise
            await asyncio.sleep(0.1)


[docs] class SocketApplication: """ ``SocketApplication`` is designed to to used similar to the "application" in a HTTP framework. The main API is to register "endpoint" functions by the method :meth:`add_route`. This allows to back the socket service by multiple functions for different purposes. Usually there is only one main function, which involves transmitting substantial amount of data between the server and the client. For simplicity, one may use ``'/'`` for the ``path`` of this route. The other routes are usually supportive, for example, getting server info or setting options. For example:: app.add_route('/', make_prediction) app.add_route('/server-info', get_server_info) app.add_route('/set-option', set_server_option) This class is the intended interface between a socket server and a particular application (functions). Usually, user should not customize the class ``SocketServer``. """
[docs] def __init__( self, ): self._routes = {}
[docs] def add_route( self, path: str, route: Callable[[Any], Awaitable[Any]] | Callable[[], Awaitable[Any]], ): """ ``route`` is an *async* function that takes a single positional arg, and returns a response (which could be ``None`` if so desired). The response should be serializable by the encoder. To be safe, return a object of Python native types. If exception is raised in this method, appropriate :class:`~mpservice.multiprocessing.remote_exception.RemoteException` object will be sent in the response. The method could also proactively return a :class:`RemoteException` object. ``path`` is any string. The route is identified by this string. For familiarity, it may be a good idea to start the string with ``'/'``, although this is in no way necessary. There is no GET/POST distinction like in the case of HTTP. """ # TODO: support sync "endpoint" functions, if they are I/O bound. self._routes[path] = route
[docs] async def handle_request(self, path: str, data: Any = None): """ Dispatch the request to a registered route function. """ if data is None: return await self._routes[path]() return await self._routes[path](data)
[docs] class SocketServer:
[docs] def __init__( self, app: SocketApplication, *, path: str | None = None, host: str | None = None, port: int | None = None, backlog: int | None = None, shutdown_path: str = '/shutdown', ): """ ``backlog`` is the max concurrent in-progress requests per connection. (Note, the client may open many connections.) This "concurrency" is in terms of concurrent calls to :meth:`SocketApplication.handle_request`. The type of the service, between 'tcp' and 'unix', is determined by the parameters ``path``, ``host``, and ``port``. See code for details. """ self.app = app if path: assert not host assert not port self._path = os.path.abspath(path) self._host = None self._port = None else: # If the server runs within a Docker container, # `host` should be '0.0.0.0'. Outside of Docker, # it should be '127.0.0.1' (I think but did not verify). assert port if not host: host = '0.0.0.0' # in Docker self._path = None self._host = host self._port = int(port) self._backlog = backlog or 256 self._encoder = 'pickle' # encoder when sending responses self._n_connections = 0 self._shutdown_path = shutdown_path self.to_shutdown = False
def __repr__(self): if self._path: return f"{self.__class__.__name__}('{self._path}')" return f"{self.__class__.__name__}('{self._host}:{self._port}')" def __str__(self): return self.__repr__()
[docs] async def serve(self): """ Start the server and let it stay up until shutdown conditions are met. """ if self._path: server_task = asyncio.create_task( run_unix_server(self._handle_connection, self._path) ) else: server_task = asyncio.create_task( run_tcp_server(self._handle_connection, self._host, self._port) ) logger.info('server %s is ready', self) try: while True: if self.to_shutdown and not self._n_connections: raise Exception('shutdown requested') await asyncio.sleep(0.1) except BaseException as e: # This should take care of keyboard interrupt and such. # To be verified. server_task.cancel() if self._path: os.unlink(self._path) logger.info('server %s is stopped', self) if str(e) != 'shutdown requested': raise
async def _handle_connection(self, reader, writer): """ This is called upon a new connection that is openned at the request from a client to the server. This method handles requests in that connection. """ if self._path: addr = writer.get_extra_info('sockname') else: addr = writer.get_extra_info('peername') self._n_connections += 1 logger.info( 'connection %d is openned from client %r', self._n_connections, addr ) reqs = asyncio.Queue(self._backlog) async def _keep_receiving(): # Infinite loop to handle requests on this connection # until the connection is closed by the client. loop = asyncio.get_running_loop() while True: try: req_id, data = await read_record(reader, timeout=0.1) except asyncio.TimeoutError: if self.to_shutdown: return continue # If `asyncio.IncompleteReadError` is raised, let it propage. path = data[0] data = data[1] if path == self._shutdown_path: t = loop.create_future() await reqs.put((req_id, t)) self.to_shutdown = True t.set_result(None) break f = self.app.handle_request(path, data) t = asyncio.create_task(f) await reqs.put((req_id, t)) # The queue size will restrict how many concurrent calls # to `handle_request` can be in progress. # `write_record` needs to be called sequentially because it's not atomic; # that's why we don't use `add_done_callback` on the Futures to do # response writing. async def _keep_responding(): while True: try: req_id, t = await asyncio.wait_for(reqs.get(), 0.1) except asyncio.TimeoutError: if self.to_shutdown: return continue try: z = await t except Exception as e: z = RemoteException(e) await write_record(writer, req_id, z, encoder=self._encoder) # TODO: what if client has closed the connection? trec = asyncio.create_task(_keep_receiving()) tres = asyncio.create_task(_keep_responding()) try: await trec except asyncio.IncompleteReadError: tres.cancel() else: await tres writer.close() await writer.wait_closed() logger.info('connection %d from client %r is closed', self._n_connections, addr) self._n_connections -= 1
[docs] def make_server(app: SocketApplication, **kwargs): """ Example:: async def double(data): await asyncio.sleep(0.01) return data * 2 app = SocketApplication() app.add_route('/', double) server = make_server(app, path='/tmp/sock_abc') asyncio.run(server.serve()) """ return SocketServer(app, **kwargs)
[docs] class SocketClient:
[docs] def __init__( self, *, path: None | str = None, host: None | str = None, port: None | int = None, num_connections: None | int = None, connection_timeout: int = 60, backlog: int = 2048, ): """ Parameters ---------- path, host, port Either ``path`` is given (for Unix socket), or ``port`` (plus optionally ``host``) is given (for Tcp socket). These values should, of course, be consistent with the corresponding server. num_connections This is expected to have a direct impact on the performance, hence needs experimentation. connection_timeout How many seconds to wait while connecting to the server. This is meant for waiting for server to be ready, rather than for the action of "connecting" itself (which should be fast). backlog Size of the queue for in-progress requests. """ # Experiments showed `max_connections` can be up to 200. # This needs to be improved. if path: # `path` must be consistent with that passed to `run_unix_server`. assert not host assert not port self._path = path self._host = None self._port = None else: # If both client and server run on the same machine # in separate Docker containers, `host` should be # `mpservice._util.get_docker_host_ip()`. assert port if not host: host = get_docker_host_ip() # in Docker self._path = None self._host = host self._port = int(port) self._num_connections = num_connections or 32 self._connection_timeout = connection_timeout self._backlog = backlog self._encoder = 'pickle' # encoder when sending requests. self._prepare_shutdown = threading.Event() self._to_shutdown = threading.Event() self._executor = ThreadPoolExecutor() self._tasks = [] self._pending_requests: queue.Queue | None = None self._active_requests = {} self._shutdown_timeout = 60
def __repr__(self): if self._path: return f"{self.__class__.__name__}('{self._path}')" return f"{self.__class__.__name__}('{self._host}:{self._port}')" def __str__(self): return self.__repr__()
[docs] def __enter__(self): self._pending_requests = SingleLane(self._backlog) q = queue.Queue() self._tasks.append(self._executor.submit(self._open_connections, q)) n = 0 while n < self._num_connections: z = q.get() if z == 'OK': n += 1 if isinstance(z, BaseException): self._to_shutdown.set() concurrent.futures.wait(self._tasks) self._executor.shutdown() raise z logger.info('client %s is ready', self) return self
[docs] def __exit__(self, *args, **kwargs): self._prepare_shutdown.set() t0 = perf_counter() while not self._pending_requests.empty(): if perf_counter() - t0 > self._shutdown_timeout: break time.sleep(0.1) while self._active_requests: if perf_counter() - t0 > self._shutdown_timeout: break time.sleep(0.1) self._to_shutdown.set() concurrent.futures.wait(self._tasks) self._executor.shutdown() self._pending_requests = None
def _open_connections(self, q: queue.Queue): async def _keep_sending(writer): pending = self._pending_requests # This queue is populated by `_enqueue`. active = self._active_requests encoder = self._encoder to_shutdown = self._to_shutdown while True: try: x, fut = pending.get_nowait() except queue.Empty: if to_shutdown.is_set(): return await asyncio.sleep(0.0012) # This sleep should be short, but the length # is not critical, because this happens only when # the queue is empty. In a busy application, # the queue should be rarely empty. continue req_id = id(fut) await write_record(writer, req_id, x, encoder=encoder) active[req_id] = fut async def _keep_receiving(reader): active = self._active_requests to_shutdown = self._to_shutdown while True: try: req_id, data = await read_record(reader, timeout=0.1) req_id = int(req_id) except asyncio.TimeoutError: if to_shutdown.is_set(): return continue # Do not capture `asyncio.IncompleteReadError`; # let it stop this function. fut = active.pop(req_id) if isinstance(data, BaseException): fut.set_exception(data) else: fut.set_result(data) async def _open_connection(k): try: if self._path: reader, writer = await open_unix_connection( self._path, timeout=self._connection_timeout ) else: reader, writer = await open_tcp_connection( self._host, self._port, timeout=self._connection_timeout ) except Exception as e: q.put(e) return q.put('OK') addr = writer.get_extra_info('peername') logger.info('connection %d to server %r is openned', k + 1, addr) tw = asyncio.create_task(_keep_sending(writer)) tr = asyncio.create_task(_keep_receiving(reader)) try: await tr except asyncio.IncompleteReadError: # Propagated from `tr` indicating server has closed the connection. tw.cancel() else: await tw writer.close() await writer.wait_closed() logger.info('connection %d to server %r is closed', k + 1, addr) async def _main(): tasks = [_open_connection(i) for i in range(self._num_connections)] await asyncio.gather(*tasks) # If any of the tasks raises exception, it will be propagated here. asyncio.run(_main()) # If `_main` raises exception, it will be propagated here, # hence stopping this thread. def _enqueue(self, path: str, data, *, timeout=None): if not self._pending_requests: raise Exception( 'Client is not yet started. Please use it in a context manager' ) if self._prepare_shutdown.is_set() or self._to_shutdown.is_set(): raise Exception('Client is closed') fut = concurrent.futures.Future() # `data` is the payload to send. # `fut` will hold the corresponding response to be received. # Every request will get a response from the server. self._pending_requests.put(((path, data), fut), timeout=timeout) return fut
[docs] def request( self, path: str, data=None, *, enqueue_timeout=None, response_timeout=None ): """ This could raise ``concurrent.futures.TimeoutError``. That means result is not available in the specified time, but the request may have well been sent to the server. However the user handles the exception, it will not affect the server's response to the request. The user will not be able to resume the wait for the result. If caller does not need the response, use ``response_timeout=0``. In some cases, the request does not need to send data, e.g. if the request if for certain info query. In such situations, the corresponding function on the server side takes no argument, and in this call to ``request``, ``data`` should be ``None``. Example:: request('/shutdown', response_timeout=0) """ fut = self._enqueue(path, data, timeout=enqueue_timeout) if response_timeout is not None and response_timeout <= 0: return None return fut.result(timeout=response_timeout)
[docs] def stream( self, path: str, data: Iterable, *, return_x: bool = False, return_exceptions: bool = False, enqueue_timeout=60, response_timeout=60, ): """ If ``return_x`` is ``True``, return a stream of ``(x, y)`` tuples, where ``x`` is the input data, and ``y`` is a dict with element 'data'. If ``return_x`` is ``False``, return a stream of ``y``. If ``return_exceptions`` is ``True``, ``y`` could be an Exception object. """ # Refer to `mpserver.Server.stream` for in-code documentation. tasks = SingleLane(self._backlog) nomore = object() def _enqueue(): en = self._enqueue et = enqueue_timeout tt = tasks Future = concurrent.futures.Future to_shutdown = self._to_shutdown for x in data: try: fut = en(path, x, timeout=et) except Exception as e: if return_exceptions: fut = Future() fut.set_exception(e) else: # logger.error("exception '%r' happened for input '%s'", e, x) logger.error(repr(e)) raise t0 = perf_counter() if not put_in_queue(tt, (x, fut, t0), to_shutdown): return put_in_queue(tt, nomore, to_shutdown) t = self._executor.submit(_enqueue) self._tasks.append(t) while True: try: z = tasks.get(timeout=0.1) except queue.Empty: if t.done(): if t.exception(): raise t.exception() if not self._to_shutdown.is_set(): raise ValueError( f'expecting `self._to_shutdown.is_set()` to be True but got: {self._to_shutdown.is_set()}' ) break if self._to_shutdown.is_set(): break continue if z is nomore: break x, fut, t0 = z try: y = fut.result(timeout=response_timeout - (perf_counter() - t0)) except Exception as e: if return_exceptions: if return_x: yield x, e else: yield e else: # logger.error("exception '%r' happened for input '%s'", e, x) logger.error(repr(e)) raise else: if return_x: yield x, y else: yield y