Source code for caniusethat.shareable

import inspect
import logging
import pickle
from functools import wraps
from threading import Lock
from typing import Any, Callable, Dict, List

import zmq
from zmq.utils.win32 import allow_interrupt

from caniusethat._logging import getLogger
from caniusethat._thread import StoppableThread
from caniusethat._types import (
    RemoteProcedureCall,
    RemoteProcedureError,
    RemoteProcedureResponse,
    SharedMethodDescriptor,
    SharedObjectDescriptor,
)

_logger = getLogger(__name__)


def _is_shared_method(obj: Any) -> bool:
    return inspect.ismethod(obj) and hasattr(obj, "_you_can_use_this")


def _is_locking_method(obj: Any) -> bool:
    return inspect.ismethod(obj) and hasattr(obj, "_acquire_lock")


def _is_unlocking_method(obj: Any) -> bool:
    return inspect.ismethod(obj) and hasattr(obj, "_release_lock")


def _dealer_address(name: str) -> str:
    return f"inproc://{name}_worker"


def _force_remote_server_stop(server_address: str) -> Any:
    context = zmq.Context.instance()
    request_socket = context.socket(zmq.REQ)
    request_socket.connect(server_address)

    rpc_pickle = pickle.dumps(RemoteProcedureCall("_server", "stop"))
    request_socket.send(rpc_pickle)
    result = pickle.loads(request_socket.recv())
    request_socket.close(linger=10)
    return result


def _package_reply(reply: Any, error: RemoteProcedureError) -> bytes:
    return pickle.dumps(RemoteProcedureResponse(reply, error))


def _package_error(error: RemoteProcedureError) -> bytes:
    return _package_reply(None, error)


def _package_success_reply(reply: Any) -> bytes:
    return _package_reply(reply, RemoteProcedureError.NO_ERROR)


[docs]def you_can_use_this(f: Callable) -> Callable: """A decorator that marks a method as a shared method. Example: >>> @you_can_use_this ... def get_name(self) -> str: ... return self.name """ @wraps(f) def wrapper(*args, **kwds): return f(*args, **kwds) wrapper._you_can_use_this = True # type: ignore return wrapper
[docs]def acquire_lock(f: Callable) -> Callable: """A decorator that acquires the lock of the object before calling the method. Example: >>> @acquire_lock ... @you_can_use_this ... def start_phone_call(self, phone_number: str) -> None: ... self._make_phone_call(phone_number) """ @wraps(f) def wrapper(*args, **kwds): return f(*args, **kwds) wrapper._acquire_lock = True # type: ignore return wrapper
[docs]def release_lock(f: Callable) -> Callable: """A decorator that releases the lock of the object after calling the method. Example: >>> @release_lock ... @you_can_use_this ... def stop_phone_call(self) -> None: ... self._hang_up() """ @wraps(f) def wrapper(*args, **kwds): return f(*args, **kwds) wrapper._release_lock = True # type: ignore return wrapper
[docs]class Server(StoppableThread): """The Server takes care of sharing the objects on the network, handling the remote procedure calls from multiple users and their locks. Attributes: router_address (str): The address that the server will listen on. Example: >>> server = Server("tcp://127.0.0.1:6555") >>> server.start() >>> server.add_object("mobile_phone_interface", mobile_phone_interface) """ _LINGER_TIME = 1000 # milliseconds def __init__(self, router_address: str) -> None: super().__init__() self.router_address = router_address self.shared_objects: Dict[str, SharedObjectDescriptor] = {} self.shared_objects_queue: Dict[str, SharedObjectDescriptor] = {} self.new_object_lock = Lock() self.dealers: Dict[str, Any] = {} self.workers: Dict[str, _ObjectWorker] = {} self.worker_locks: Dict[str, bytes] = {} self.log_lock = Lock() def _safe_log(self, message: str, level: int = logging.INFO) -> None: with self.log_lock: _logger.log(level, message) def _task_setup(self): self._safe_log( f"Starting 👀 caniusethat server, listening on {self.router_address}." ) self.context = zmq.Context.instance() self.router_socket = self.context.socket(zmq.ROUTER) self.router_socket.bind(self.router_address) self.poller = zmq.Poller() self.poller.register(self.router_socket, zmq.POLLIN) def _task_cleanup(self): self._safe_log("Closing 👀 caniusethat server connections.") self.poller.unregister(self.router_socket) self.router_socket.close(linger=self._LINGER_TIME) for dealer_socket in self.dealers.values(): self.poller.unregister(dealer_socket) dealer_socket.close(linger=self._LINGER_TIME) for worker in self.workers.values(): worker.stop() def _task_cycle(self): # Add any new objects to the shared objects. with allow_interrupt(self.stop): with self.new_object_lock: self._process_new_object_queue() poll_sockets = dict(self.poller.poll(timeout=10)) # Check if there are new requests. if poll_sockets.get(self.router_socket) == zmq.POLLIN: address, _, message = self.router_socket.recv_multipart() self._process_incoming_rpc(address, message) # Check if there are any new replies for dealer_socket in self.dealers.values(): if poll_sockets.get(dealer_socket) == zmq.POLLIN: self._safe_log( f"Received reply from worker {dealer_socket}", logging.DEBUG ) message = dealer_socket.recv_multipart() # Send the reply back to the client self.router_socket.send_multipart(message) def _process_incoming_rpc(self, address: bytes, message: bytes) -> None: rpc = pickle.loads(message) # Check if the RPC is properly formatted. if not isinstance(rpc, RemoteProcedureCall): self._safe_log( f"Received invalid RemoteProcedureCall: {rpc}", logging.WARNING ) message = _package_error(RemoteProcedureError.INVALID_RPC) self.router_socket.send_multipart([address, b"", message]) return self._safe_log(f"Received RPC: {rpc}", logging.DEBUG) # Check if the RPC is asking for the list of shared methods. if rpc.name == "_server" and rpc.method == "get_object_methods": if rpc.args[0] not in self.shared_objects: self._safe_log(f"No such object: {rpc.args[0]}", logging.WARNING) message = _package_error(RemoteProcedureError.NO_SUCH_THING) self.router_socket.send_multipart([address, b"", message]) else: message = _package_success_reply( self.shared_objects[rpc.args[0]].shared_methods ) self.router_socket.send_multipart([address, b"", message]) return # Check if the RPC is asking for the list of shared list. if rpc.name == "_server" and rpc.method == "get_object_list": message = _package_success_reply(list(self.shared_objects.keys())) self.router_socket.send_multipart([address, b"", message]) return # Check if the RPC is asking for the server to terminate (useful in testing). if rpc.name == "_server" and rpc.method == "stop": message = _package_success_reply(None) self.router_socket.send_multipart([address, b"", message]) self.stop() return # Check if the RPC is asking for the server to release a lock. if rpc.name == "_server" and rpc.method == "release_lock_if_any": if self.worker_locks.get(rpc.args[0]) == address: self.worker_locks.pop(rpc.args[0]) self._safe_log(f"Released lock for {rpc.args[0]}", logging.DEBUG) message = _package_success_reply(None) self.router_socket.send_multipart([address, b"", message]) return # Check if the RPC is asking for the server to release a lock forcefully. if rpc.name == "_server" and rpc.method == "force_release_lock": if rpc.args[0] in self.worker_locks: self.worker_locks.pop(rpc.args[0]) self._safe_log( f"Forcefully released lock for {rpc.args[0]}", logging.WARNING ) message = _package_success_reply(None) self.router_socket.send_multipart([address, b"", message]) return # Check if the RPC object is in the server. if rpc.name not in self.shared_objects: self._safe_log( f"Received RPC for unknown object: {rpc.name}", logging.WARNING ) message = _package_error(RemoteProcedureError.NO_SUCH_THING) self.router_socket.send_multipart([address, b"", message]) return # Check if the RPC method is not one of the shared ones. if rpc.method not in [ method.name for method in self.shared_objects[rpc.name].shared_methods ]: self._safe_log( f"Received RPC for unknown method: {rpc.name}.{rpc.method}", logging.WARNING, ) message = _package_error(RemoteProcedureError.NO_SUCH_METHOD) self.router_socket.send_multipart([address, b"", message]) return # Check if the worker has a lock. if (rpc.name in self.worker_locks) and (self.worker_locks[rpc.name] != address): self._safe_log( f"Worker {rpc.name} is already locked by {str(self.worker_locks[rpc.name])}", logging.WARNING, ) message = _package_error(RemoteProcedureError.THING_IS_LOCKED) self.router_socket.send_multipart([address, b"", message]) return # Check if the worker needs to be locked. if (rpc.name not in self.worker_locks) and ( rpc.method in self.shared_objects[rpc.name].locking_methods ): self._safe_log( f"Locking worker {rpc.name} to {str(address)}", logging.DEBUG ) self.worker_locks[rpc.name] = address # Everything looks good so far, dispatch the RPC to the correct worker. self._safe_log(f"Dispatching RPC to worker {rpc.name}", logging.DEBUG) self.dealers[rpc.name].send_multipart([address, b"", message]) # Check if the worker needs to be unlocked. if (rpc.name in self.worker_locks) and ( rpc.method in self.shared_objects[rpc.name].unlocking_methods ): self._safe_log(f"Unlocking worker {rpc.name}", logging.DEBUG) self.worker_locks.pop(rpc.name)
[docs] def add_object(self, name: str, obj: Any): """Add an object to the server. Args: name: A unique name that will be used to refer to the object. obj: The object to add to the server. """ # Build the SharedObjectDescriptor shared_methods = [] for method_name, method in inspect.getmembers(obj, _is_shared_method): signature = str(inspect.signature(method)) docstring = inspect.getdoc(method) if docstring is None: docstring = "" shared_methods.append( SharedMethodDescriptor(method_name, signature, docstring) ) if len(shared_methods) == 0: raise RuntimeError(f"No shared methods found in {obj:!r}") locking_methods = [] for method_name, method in inspect.getmembers(obj, _is_locking_method): locking_methods.append(method_name) unlocking_methods = [] for method_name, method in inspect.getmembers(obj, _is_unlocking_method): unlocking_methods.append(method_name) if (locking_methods) and (not unlocking_methods): raise RuntimeError( f"Locking methods found in {obj:!r} but no unlocking methods." ) if (unlocking_methods) and (not locking_methods): raise RuntimeError( f"Unlocking methods found in {obj:!r} but no locking methods." ) descriptor = SharedObjectDescriptor( name, obj, shared_methods, locking_methods, unlocking_methods ) self._safe_log(f"Adding object {name} to server") with self.new_object_lock: self.shared_objects_queue[name] = descriptor
[docs] def get_object_methods(self, name: str) -> List[SharedMethodDescriptor]: """Returns a list of methods of the object with the given name. Args: name: The name of the object to get the methods of. Returns: A list of SharedMethodDescriptors.""" return self.shared_objects[name].shared_methods
[docs] def get_object_list(self) -> List[str]: """Returns a list of the names of the objects in the server. Returns: A list of strings. """ return list(self.shared_objects.keys())
def _process_new_object_queue(self): # First obtain a list of the names, we don't want to change the # dictionary while we're iterating over it. names = list(self.shared_objects_queue.keys()) for name in names: descriptor = self.shared_objects_queue.pop(name) if name in self.shared_objects: raise RuntimeError( f"Object {name} already exists, use a different name." ) self.shared_objects[name] = descriptor dealer_socket = self.context.socket(zmq.DEALER) dealer_socket.bind(_dealer_address(name)) self.poller.register(dealer_socket, zmq.POLLIN) self.dealers[name] = dealer_socket worker = _ObjectWorker(name, descriptor) worker.start() self.workers[name] = worker
class _ObjectWorker(StoppableThread): _LINGER_TIME = 1000 # milliseconds def __init__(self, worker_name: str, shared_object: SharedObjectDescriptor) -> None: super().__init__() self.worker_name = worker_name self.shared_object = shared_object def reply_address(self) -> str: return _dealer_address(self.worker_name) def _task_setup(self): self.context = zmq.Context.instance() self.reply_socket = self.context.socket(zmq.REP) self.reply_socket.connect(self.reply_address()) self.poller = zmq.Poller() self.poller.register(self.reply_socket, zmq.POLLIN) def _task_cleanup(self): self.reply_socket.close(linger=self._LINGER_TIME) def _task_cycle(self): with allow_interrupt(self.stop): # Wait for a request poll_sockets = dict(self.poller.poll(timeout=10)) # Check if there are new requests if poll_sockets.get(self.reply_socket) == zmq.POLLIN: message = self.reply_socket.recv() rpc: RemoteProcedureCall = pickle.loads(message) try: call_result = self.shared_object.obj.__getattribute__(rpc.method)( *rpc.args, **rpc.kwargs ) except Exception as e: call_result = e call_error = RemoteProcedureError.METHOD_EXCEPTION else: call_error = RemoteProcedureError.NO_ERROR response = RemoteProcedureResponse(call_result, call_error) # Send the result back to the client self.reply_socket.send(pickle.dumps(response))