Source code for mpservice.mpserver._worker

import logging
import multiprocessing
import multiprocessing.queues
import os
import queue
import threading
from collections.abc import Iterable, Iterator
from queue import Empty
from time import perf_counter
from typing import Any, Callable

from mpservice._queues import SingleLane
from mpservice.multiprocessing import MP_SPAWN_CTX
from mpservice.multiprocessing.remote_exception import RemoteException
from mpservice.streamer import Parmapper
from mpservice.threading import Thread

logger = logging.getLogger(__name__)


class _SimpleProcessQueue(multiprocessing.queues.SimpleQueue):
    """
    A customization of `multiprocessing.queue._SimpleThreadQueue <https://docs.python.org/3/library/multiprocessing.html#multiprocessing._SimpleThreadQueue>`_,
    this class reduces some overhead in a particular use-case in this module,
    where one consumer of the queue greedily grabs elements out of the queue
    towards a batch-size limit.

    This queue is meant to be used between two processes or between a process
    and a thread.

    The main use case of this class is in :meth:`Worker._build_input_batches`.

    Check out
    `os.read <https://docs.python.org/3/library/os.html#os.read>`_,
    `os.write <https://docs.python.org/3/library/os.html#os.write>`_, and
    `os.close <https://docs.python.org/3/library/os.html#os.close>`_ with file-descriptor args.
    """

    def __init__(self, *, ctx=None):
        if ctx is None:
            ctx = MP_SPAWN_CTX
        super().__init__(ctx=ctx)
        # Replace Lock by RLock to facilitate batching via greedy `get_many`.
        self._rlock = ctx.RLock()


class _SimpleThreadQueue(queue.SimpleQueue):
    """
    A customization of
    `queue._SimpleThreadQueue <https://docs.python.org/3/library/queue.html#queue._SimpleThreadQueue>`_,
    this class is analogous to :class:`_SimpleProcessQueue` but is designed to be used between two threads.
    """

    def __init__(self):
        super().__init__()
        self._rlock = threading.RLock()


[docs] class Worker: """ ``Worker`` defines operations on a single input item or a batch of items in usual synchronous code. This is supposed to run in its own process (or thread) and use that single process (or thread) only. Typically a subclass needs to enhance :meth:`__init__` and implement :meth:`call`, and leave the other methods intact. A ``Worker`` object is not created and used by itself. It is always started by a :class:`ProcessServlet` or a :class:`ThreadServlet`. """
[docs] @classmethod def run( cls, *, q_in: _SimpleProcessQueue | _SimpleThreadQueue, q_out: _SimpleProcessQueue | _SimpleThreadQueue, **init_kwargs, ): """ A :class:`Servlet` object will arrange to start a :class:`Worker` object in a thread or process. This classmethod will be the ``target`` argument to `Thread`_ or `Process`_. This method creates a :class:`Worker` object and calls its :meth:`~Worker.start` method to kick off the work. Parameters ---------- q_in A queue that carries input elements to be processed. If used in a :class:`ProcessServlet`, ``q_in`` is a :class:`_SimpleProcessQueue`. If used in a :class:`ThreadServlet`, ``q_in`` is either a :class:`_SimpleProcessQueue` or a :class:`_SimpleThreadQueue`. q_out A queue that carries output values. If used in a :class:`ProcessServlet`, ``q_out`` is a :class:`_SimpleProcessQueue`. If used in a :class:`ThreadServlet`, ``q_out`` is either a :class:`_SimpleProcessQueue` or a :class:`_SimpleThreadQueue`. The elements in ``q_out`` are results for each individual element in ``q_in``. "Batching" is an internal optimization for speed; ``q_out`` does not contain result batches, but rather results of individuals. **init_kwargs Passed on to :meth:`__init__`. If the worker is going to run in a child process, then elements in ``**kwargs`` go through pickling, hence they should consist mainly of small, Python builtin types such as string, number, small dict's, etc. Be careful about passing custom class objects in ``**kwargs``. """ try: obj = cls(**init_kwargs) except Exception: q_out.put(None) raise q_out.put(obj.name) # This sends a signal to the caller (or "coordinator") # indicating completion of init. obj.start(q_in=q_in, q_out=q_out)
[docs] def __init__( self, *, worker_index: int, batch_size: int | None = None, batch_wait_time: float | None = None, cpu_affinity: int | list[int] | None = None, ): """ The main concern here is to set up controls for "batching" via the two parameters ``batch_size`` and ``batch_wait_time``. If the algorithm can not vectorize the computation, then there is no advantage in enabling batching. In that case, the subclass should simply fix ``batch_size`` to 0 in their ``__init__`` and invoke ``super().__init__`` accordingly. The ``__init__`` of a subclass may define additional input parameters; they can be passed in through :meth:`run`. Parameters ---------- worker_index 0-based sequential number of the worker in a "servlet". A subclass may use this to distinguish the worker processes/threads in the same Servlet and give them some different treatments, although they do essentially the same thing. For example, let each worker use one particular GPU. This argument is provided in :meth:`Servlet.start` when starting the worker. A subclass does not worry about providing this argument; it simply uses it if needed. The parameter has a proper value in ``__init__`` and continues to be available as an instance attribute. batch_size Max batch size; see :meth:`call`. Remember to pass in ``batch_size`` in accordance with the implementation of :meth:`call`. In other words, if ``batch_size > 0``, then :meth:`call` must handle a list input that contains a batch of elements. On the other hand, if ``batch_size`` is 0, then the input to :meth:`call` is a single element. If ``None``, then 0 is used, meaning no batching. If ``batch_size=1``, then processing is batched in form without speed benefits of batching. batch_wait_time Seconds, may be 0; the total duration to wait for one batch after the first item has arrived. For example, suppose ``batch_size`` is 100 and ``batch_wait_time`` is 1. After the first item has arrived, if at least 99 items arrive within 1 second, then a batch of 100 elements will be produced; if less than 99 elements arrive within 1 second, then the wait will stop at 1 second, hence a batch of less than 100 elements will be produced; the batch could have only one element. If 0, then there's no wait. After the first element is obtained, if there are more elements in ``q_in`` "right there right now", they will be retrieved until a batch of ``batch_size`` elements is produced. Any moment when ``q_in`` is empty, the collection will stop, and the elements collected so far (less than ``batch_size`` count of them) will make a batch. In other words, batching happens only for items that are already "piled up" in ``q_in`` at the moment. To leverage batching, it is recommended to set ``batch_wait_time`` to a small positive value. Small, so that there is not much futile waiting. Positive (as opposed to 0), so that it always waits a little bit just in case more elements are coming in. When ``batch_wait_time > 0``, it will hurt performance during sequential calls (i.e. send a request with a single element, wait for the result, then send the next, and so on), because this worker will always wait for this long for additional items to come and form a batch, yet additional items will never come during sequential calls. However, when batching is enabled, sequential calls are not the intended use case. Beware of this factor in benchmarking. If ``batch_size`` is 0 or 1, then ``batch_wait_time`` should be left unspecified, otherwise the only valid value is 0. If ``batch_size > 1``, then ``batch_wait_time`` is 0.01 by default. cpu_affinity Which CPUs this worker process is going to be "pinned" to. If `None`, no pinning. If this worker is used in a `ThreadServlet`, this parameter is not specified in the call, and its value is `None` and is unused. """ if batch_size is None or batch_size == 0: batch_size = 0 if batch_wait_time is None: batch_wait_time = 0 else: assert batch_wait_time == 0 elif batch_size == 1: if batch_wait_time is None: batch_wait_time = 0 else: assert batch_wait_time == 0 else: if batch_wait_time is None: batch_wait_time = 0.01 self.worker_index = worker_index self.batch_size = batch_size self.batch_size_log_cadence = 1_000_000 # Log batch size statistics every this many batches. If ``None``, this log is turned off. # This log is for debugging and development purposes. # This is ignored if ``batch_size=0``. self.batch_wait_time = batch_wait_time self.name = f'{multiprocessing.current_process().name}-{threading.current_thread().name}' if cpu_affinity is not None: if isinstance(cpu_affinity, int): cpu_affinity = [cpu_affinity] else: cpu_affinity = sorted(set(cpu_affinity)) os.sched_setaffinity(0, cpu_affinity) self.cpu_affinity = cpu_affinity # If `None`, `os.sched_getaffinity` will return all CPUs. self.num_stream_threads: int = 0 # See :meth:`stream`. self.preprocess: Callable[[Any], Any] """ If a subclass has a method ``preprocess`` or an attribute ``preprocess`` that is a free-standing function, this method or function must take one data element (not a batch) as the sole, positional argument. This processes/transforms the data, and the output is used in :meth:`call`. If this function raises an exception, this element is not sent to :meth:`call`; instead, the exception object is short-circuited to the output queue. When ``self.batch_size > 1``, if :meth:`call` needs to take care of an element of the batch that might fail a pre-condition, it is tedious to properly assemble the "good" and "bad" elements to further processing or to output in right order. This ``preprocess`` mechanism helps to deal with that situation. When a subclass is designed to do non-batching work, this attribute is not necessary, because the same concern can be handled in :meth:`call` directly. When ``self.preprocess`` is defined, it is used in :meth:`_start_single` and :meth:`_build_input_batches`. If a ``Server`` contains a single servlet, which uses this ``Worker``, then the functionalities of this ``self.preprocess`` can be largely provided by the parameter ``preprocessor`` to ``Server.stream``. In those case, there is no need for this ``self.preprocess``. """
[docs] def call(self, x): """ Private methods of this class wait on the input queue to gather "work orders", send them to :meth:`call` for processing, collect the outputs of :meth:`call`, and put them in the output queue. If ``self.batch_size == 0``, then ``x`` is a single element, and this method returns result for ``x``. `x` is not an ``Exception`` or ``RemoteException`` object; such a value would have been routed to the outgoing pipe and not passed to this method. The same is true for elements of `x` when `self.batch_size > 0`. If ``self.batch_size > 0`` (including 1), then ``x`` is a list of input data elements, and this method returns a list (or `Sequence`_) of results corresponding to the elements in ``x``. However, this output, when received by private methods of this class, will be split and individually put in the output queue, so that the elements in the output queue (``q_out``) correspond to the elements in the input queue (``q_in``), although *vectorized* computation, or *batching*, has happened internally. When batching is enabled (i.e. when ``self.batch_size > 0``), the number of elements in ``x`` varies between calls depending on the supply in the input queue. The list ``x`` does not have a fixed length. Be sure to distinguish the case with batching (``batch_size > 0``) and the case w/o batching (``batch_size = 0``) where a single input is a list. In the latter case, the output of this method is the result corresponding to the single input ``x``. The result could be anything---it may or may not be a list. If a subclass fixes ``batch_size`` in its ``__init__`` to be 0 or nonzero, make sure this method is implemented accordingly. If ``__init__`` does not fix the value of ``batch_size``, then a particular instance may have been created with or without batching. In this case, this method needs to check ``self.batch_size`` and act accordingly, If this method raises exceptions, unless the user has specific things to do, do not handle them; just let them happen. They will be handled in private methods of this class that call this method. Usually this is the only method a subclass needs to customize. In rare cases, a subclass may want to customize :meth:`stream` instead of or in addition to :meth:`call`. """ raise NotImplementedError
[docs] def stream(self, xx: Iterable) -> Iterator: """ `xx` is an iterable of input `x` to :meth:`call`. (If `self.batch_size > 0, then `xx` is an iterable of batches.) This function yields the results of :meth:`call` for the elements of `xx`, in the right order. If any invocation of :meth:`call` raises an exception, the exception object is yielded. The elements of `xx` (or elements of the elements of ``x`` when `self.batch_size > 0`) are not instances of `Exception` or `RemoteException`. Such values would have been routed to the outgoing pipe and not passed to this method. The background loop in :meth:`start` calls this method and does not call :meth:`call` directly. This method is provided mainly for the special use cases where a subclass wants to set `self.num_stream_threads` to a positive number, thereby use threading concurrency in this method. If a subclass re-implements this method without calling :meth:`call`, then :meth:`call` does not need to be implemented, because this method is the only place of this class that calls :meth:`call`. """ if self.num_stream_threads < 1: for x in xx: try: y = self.call(x) except Exception as e: y = e yield y else: # This branch runs :meth:`call` in threads under these condictions: # # - the main operations in :meth:`call` is IO bound that releases the GIL # - the threads share some common context that is set up in this worker's :meth:`__init__` # # If the second condition is not met, one could just use a :class:`ThreadServlet` with multiple # thread workers that run independently of each other. # # To use this branch, a subclass needs to set `self.num_stream_threads` appropriately # after calling `super().__init__`. See tests for an example. yield from Parmapper( xx, self.call, executor='thread', concurrency=self.num_stream_threads, return_exceptions=True, parmapper_name=f'{self.__class__.__name__}.stream', )
[docs] def cleanup(self, exc=None): """ This method is called when the object exits its service loop and stops. This is the place for cleanup code, e.g. releasing resources, exiting context managers (that have been entered in :meth:`_init_`), etc. """ pass
[docs] def start(self, *, q_in, q_out): """ This is called by :meth:`run` to kick off the processing loop. To stop the processing, pass in the constant ``None`` through ``q_in``. """ try: if self.batch_size > 1: self._start_batch(q_in=q_in, q_out=q_out) else: self._start_single(q_in=q_in, q_out=q_out) except KeyboardInterrupt as e: q_in.put(None) # broadcast to one fellow worker q_out.put(None) print(self.name, 'stopped by KeyboardInterrupt') # The process or thread will exit. Don't print the usual # exception stuff as that's not needed when user # pressed Ctrl-C. self.cleanup(e) # TODO: do we need to `raise` here? except BaseException as e: q_in.put(None) q_out.put(None) self.cleanup(e) raise else: self.cleanup()
def _start_single(self, *, q_in, q_out): def get_input(q_in, q_out, q_uid): batched = self.batch_size > 0 preprocess = getattr(self, 'preprocess', None) while True: z = q_in.get() if z is None: q_in.put(z) # broadcast to one fellow worker q_out.put(z) break uid, x = z if preprocess is not None: if not isinstance(x, (Exception, RemoteException)): try: x = preprocess(x) except Exception as e: x = e # If it's an exception, short-circuit to output. if isinstance( x, Exception ): # `RemteException` is not a subclass of `Exception`. x = RemoteException(x) if isinstance(x, RemoteException): q_out.put((uid, x)) continue q_uid.put(uid) if batched: yield [x] else: yield x q_uid = queue.SimpleQueue() batched = self.batch_size > 0 for y in self.stream(get_input(q_in, q_out, q_uid)): if isinstance(y, Exception): y = RemoteException(y) # There are opportunities to print traceback # and details later. Be brief on the logging here. else: if batched: y = y[0] uid = q_uid.get() q_out.put((uid, y)) # Element in the output queue is always a 2-tuple, that is, (ID, value). def _start_batch(self, *, q_in, q_out): def print_batching_info(): logger.info( '%s: %d batches with sizes %d--%d, mean %.1f', self.name, n_batches, batch_size_min, batch_size_max, batch_size_mean, ) self._batch_buffer = SingleLane(self.batch_size + 10) self._batch_get_called = threading.Event() collector_thread = Thread( target=self._build_input_batches, args=(q_in, q_out), name=f'{self.name}._build_input_batches', ) collector_thread.start() def get_input(q_in, q_out, q_uids): while True: batch = self._get_input_batch() if batch is None: q_in.put(batch) # broadcast to fellow workers. q_out.put(batch) break # The batch is a list of (ID, value) tuples. us = [v[0] for v in batch] batch = [v[1] for v in batch] q_uids.put(us) yield batch n_batches = 0 batch_size_log_cadence = self.batch_size_log_cadence q_uids = queue.SimpleQueue() try: for yy in self.stream(get_input(q_in, q_out, q_uids)): if batch_size_log_cadence and n_batches == 0: batch_size_max = -1 batch_size_min = 1000000 batch_size_mean = 0.0 uids = q_uids.get() if isinstance(yy, Exception): err = RemoteException(yy) for u in uids: q_out.put((u, err)) else: for z in zip(uids, yy): q_out.put(z) # Each element in the output queue is a (ID, value) tuple. if batch_size_log_cadence: n = len(uids) n_batches += 1 batch_size_max = max(batch_size_max, n) batch_size_min = min(batch_size_min, n) batch_size_mean = ( batch_size_mean * (n_batches - 1) + n ) / n_batches if n_batches >= batch_size_log_cadence: print_batching_info() n_batches = 0 finally: if batch_size_log_cadence and n_batches: # Finally, log this if `batch_size_log_cadence` is "truthy" # and there has been any unlogged batch. print_batching_info() collector_thread.join() def _build_input_batches(self, q_in, q_out): # This background thread get elements from `q_in` # and put them in `self._batch_buffer`. # Exceptions taken out of `q_in` will be short-circuited # to `q_out`. buffer = self._batch_buffer batchsize = self.batch_size preprocess = getattr(self, 'preprocess', None) while True: if buffer.full(): with buffer._not_full: buffer._not_full.wait() # Multiple workers in separate processes may be competing # to get data out of this `q_in`. with q_in._rlock: # Now we've got hold of the read lock. # In order to facilitate batching, # we hold on to the lock and keep getting # data from `q_in` even though other readers are waiting. # We let go the lock when certain conditions are met. while True: z = q_in.get() # wait as long as it takes to get one item. while True: if z is None: buffer.put(z) q_in.put(z) # broadcast to fellow workers. q_out.put(z) return uid, x = z if preprocess is not None: if not isinstance(x, (Exception, RemoteException)): try: x = preprocess(x) except Exception as e: x = e if isinstance(x, Exception): q_out.put((uid, RemoteException(x))) elif isinstance(x, RemoteException): q_out.put((uid, x)) else: buffer.put((uid, x)) # If `q_in` currently has more data right there # and `buffer` has not reached `batchsize` yet, # keep grabbing more data. if not q_in.empty() and buffer.qsize() < batchsize: z = q_in.get() else: break # Now, either `q_in` is empty or `buffer` already has # a batch-ful of items, and we have retrieved at least one # item during this holding of the lock. if self._batch_get_called.is_set(): # `_get_input_batch` has been called in this round; # that is, `self` has already take a (partial) batch # of data away to process. Even though that might have # made `buffer` low at this time, we should let go # the lock to give others a chance to read data. self._batch_get_called.clear() break if buffer.qsize() >= batchsize: # `buffer` has reached `batchsize`, which is the most # that `_get_input_batch` will take in one call. # Even if `buffer` is not full, we no longer have priority # for more data. Release the lock to give others # a chance. break def _get_input_batch(self): # This function gets a batch from `self._batch_buffer`. extra_timeout = self.batch_wait_time batchsize = self.batch_size buffer = self._batch_buffer out = buffer.get() if out is None: return out out = [out] n = 1 deadline = perf_counter() + extra_timeout # Timeout starts after the first item is obtained. while n < batchsize: t = deadline - perf_counter() # `t` is the remaining time to wait. # If `extra_timeout == 0`, then `t <= 0`. # If `t <= 0`, will still get an item if it is already # in the buffer. try: z = buffer.get(timeout=max(0, t)) # If `extra_timeout == 0`, then `timeout=0`, # hence will get an item w/o wait. except Empty: break if z is None: # Return the batch so far. # Put this indicator back in the buffer. # Next call to this method will get # the indicator. buffer.put(z) break out.append(z) n += 1 self._batch_get_called.set() return out
[docs] def make_worker(func: Callable[[Any], Any]) -> type[Worker]: """ This function defines and returns a simple :class:`Worker` subclass for quick, "on-the-fly" use. This can be useful when we want to introduce simple servlets for pre-processing and post-processing. Parameters ---------- func This function is what happens in the method :meth:`~Worker.call`. """ class MyWorker(Worker): def call(self, x): return func(x) MyWorker.__name__ = f'Worker-{func.__name__}' return MyWorker
[docs] class PassThrough(Worker): """ Example use of this class:: def combine(x): ''' Combine the ensemble elements depending on the results as well as the original input. ''' x, *y = x assert len(y) == 3 if x < 100: return sum(y) / len(y) else: return max(y) s = EnsembleServlet( ThreadServlet(PassThrough), ProcessServlet(W1), ProcessServlet(W2) ProcessServlet(W3), ) ss = SequentialServlet(s, ThreadServlet(make_worker(combine))) """ def call(self, x): return x