Source code for mpservice.concurrent.futures

from __future__ import annotations

__all__ = [
    'ThreadPoolExecutor',
    'ProcessPoolExecutor',
    'wait',
    'as_completed',
    'ALL_COMPLETED',
    'FIRST_COMPLETED',
    'FIRST_EXCEPTION',
]

import concurrent.futures
import multiprocessing
import os
import sys
import threading
import traceback
import weakref

from mpservice.multiprocessing import MP_SPAWN_CTX

wait = concurrent.futures.wait
as_completed = concurrent.futures.as_completed
ALL_COMPLETED = concurrent.futures.ALL_COMPLETED
FIRST_COMPLETED = concurrent.futures.FIRST_COMPLETED
FIRST_EXCEPTION = concurrent.futures.FIRST_EXCEPTION


def _loud_thread_function(fn, *args, **kwargs):
    try:
        return fn(*args, **kwargs)
    except Exception:
        print(
            f"Exception in process '{multiprocessing.current_process().name}' thread '{threading.current_thread().name}':",
            file=sys.stderr,
        )
        traceback.print_exception(*sys.exc_info())
        raise
        # https://stackoverflow.com/a/54295910


[docs] class ThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): """ This class is a drop-in replacement of the standard `concurrent.futures.ThreadPoolExecutor <https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor>`_. The parameter ``loud_exception`` controls whether to print out exception info if the submitted worker task fails with an exception. The default is ``True``, whereas ``False`` has the behavior of the standard library, which does not print exception info in the worker thread. """
[docs] def submit(self, fn, /, *args, loud_exception: bool = True, **kwargs): if loud_exception: return super().submit(_loud_thread_function, fn, *args, **kwargs) return super().submit(fn, *args, **kwargs)
def _loud_process_function(fn, *args, **kwargs): try: return fn(*args, **kwargs) except Exception: print( f"Exception in process '{multiprocessing.current_process().name}':", file=sys.stderr, ) traceback.print_exception(*sys.exc_info()) raise
[docs] class ProcessPoolExecutor(concurrent.futures.ProcessPoolExecutor): """ This class is a drop-in replacement of the standard `concurrent.futures.ProcessPoolExecutor <https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ProcessPoolExecutor>`_. By default, it uses a "spawn" context and the process class :class:`SpawnProcess`. In addition, the parameter ``loud_exception`` controls whether to print out exception info if the submitted worker task fails with an exception. The default is ``True``, whereas ``False`` has the behavior of the standard library, which does not print exception info in the worker process. Although exception info can be obtained in the caller process via the `Future <https://docs.python.org/3/library/concurrent.futures.html#future-objects>`_ object returned from the method :meth:`submit`, the printing in the worker process is handy for debugging in cases where the user fails to check the Future object in a timely manner. """ # The loud-ness of this executor is different from the loud-ness of # ``SpawnProcess``. In this executor, loudness refers to each submitted function. # A process may stay on and execute many submitted functions. # The loudness of SpawnProcess plays a role only when that process crashes.
[docs] def __init__(self, max_workers=None, mp_context=None, **kwargs): if mp_context is None: mp_context = MP_SPAWN_CTX super().__init__(max_workers=max_workers, mp_context=mp_context, **kwargs)
[docs] def submit(self, fn, /, *args, loud_exception: bool = True, **kwargs): if loud_exception: return super().submit(_loud_process_function, fn, *args, **kwargs) return super().submit(fn, *args, **kwargs)
# References # https://thorstenball.com/blog/2014/10/13/why-threads-cant-fork/ _global_thread_pools_: dict[str, ThreadPoolExecutor] = weakref.WeakValueDictionary() _global_thread_pools_lock: threading.Lock = threading.Lock() _global_process_pools_: dict[str, ProcessPoolExecutor] = weakref.WeakValueDictionary() _global_process_pools_lock: threading.Lock = threading.Lock()
[docs] def get_shared_thread_pool( name: str, max_workers: int | None = None ) -> ThreadPoolExecutor: """ Get a globally shared "thread pool", that is, `concurrent.futures.ThreadPoolExecutor <https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor>`_. If an executor with the requested ``name`` does not exist, it will be created with the specified ``max_workers`` argument (or using default if not specified). If the named executor exists, it will be returned. However, if ``max_workers`` is specified but mismatches the "max workers" of the existing executor, ``ValueError`` is raised. User should assign the returned executor to a variable and keep the variable in scope as long as the executor is needed. Once all user references to a named executor have been garbage collected, the executor is gone. When it is requested again, it will be created again. User should not call ``shutdown`` on the returned executor, because it is *shared* with other users. This function is thread-safe, meaning it can be called safely in multiple threads with different or the same ``name``. Example use case: an instance of a class needs to start and use a ThreadPoolExecutor; user may have many such instances live at the same time although they are unlikely to use the ThreadPoolExecutor at the same time; to avoid having too many threads open, the class may choose to use a "shared" thread pool between the instances. """ assert name with _global_thread_pools_lock: executor = _global_thread_pools_.get(name) if executor is None or executor._shutdown: # `executor._shutdown` is True if user inadvertently called `shutdown` on the executor. executor = ThreadPoolExecutor(max_workers) _global_thread_pools_[name] = executor else: if max_workers is not None and max_workers != executor._max_workers: raise ValueError( f'`max_workers`, {max_workers}, mismatches the existing value, {executor._max_workers}' ) return executor
[docs] def get_shared_process_pool(name: str, max_workers: int = None) -> ProcessPoolExecutor: """ Get a globally shared "process pool", that is, `concurrent.futures.ProcessPoolExecutor <https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ProcessPoolExecutor>`_. Analogous to :func:`get_shared_thread_pool`. """ assert name with _global_process_pools_lock: executor = _global_process_pools_.get(name) if executor is None or executor._processes is None: # `executor._processes` is None if user inadvertently called `shutdown` on the executor. executor = ProcessPoolExecutor(max_workers) _global_process_pools_[name] = executor else: if max_workers is not None and max_workers != executor._max_workers: raise ValueError( f'`max_workers`, {max_workers}, mismatches the existing value, {executor._max_workers}' ) return executor
if hasattr(os, 'register_at_fork'): # not available on Windows def _clear_global_state(): for box in (_global_thread_pools_, _global_process_pools_): for name in list(box.keys()): pool = box.get(name) if pool is not None: # TODO: if `pool` has locks, are there problems? pool.shutdown(wait=False) box.pop(name, None) global _global_thread_pools_lock try: _global_thread_pools_lock.release() except RuntimeError: # 'release unlocked lock' pass _global_thread_pools_lock = threading.Lock() global _global_process_pools_lock try: _global_process_pools_lock.release() except RuntimeError: pass _global_process_pools_lock = threading.Lock() os.register_at_fork(after_in_child=_clear_global_state)