diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index 62a93a3c..b1de0213 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -82,6 +82,22 @@ def test_predict_call_fails_without_specifying_model_session_id(self, grpc_stub) assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code() assert "model-session-id has not been provided" in e.value.details() + def test_model_init_failed_close_session(self, bioimage_model_explicit_add_one_siso_v5, grpc_stub): + """ + If the model initialization fails, the session should be closed, so we can initialize a new one + """ + + model_req = inference_pb2.CreateModelSessionRequest( + model_blob=inference_pb2.Blob(content=b""), deviceIds=["cpu"] + ) + + with pytest.raises(Exception): + grpc_stub.CreateModelSession(model_req) + + model_bytes = bioimage_model_explicit_add_one_siso_v5 + response = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + assert response.id is not None + class TestDeviceManagement: def test_list_devices(self, grpc_stub): diff --git a/tiktorch/rpc/interface.py b/tiktorch/rpc/interface.py index 7f3d477b..d47543a5 100644 --- a/tiktorch/rpc/interface.py +++ b/tiktorch/rpc/interface.py @@ -1,5 +1,7 @@ from typing import Any, Callable, Dict +from tiktorch.rpc import Shutdown + class RPCInterfaceMeta(type): def __new__(mcls, name, bases, namespace, **kwargs): @@ -8,22 +10,33 @@ def __new__(mcls, name, bases, namespace, **kwargs): for base in bases: if issubclass(base, RPCInterface): - exposed ^= getattr(base, "__exposedmethods__", set()) + exposed |= getattr(base, "__exposedmethods__", set()) cls.__exposedmethods__ = frozenset(exposed) return cls -class RPCInterface(metaclass=RPCInterfaceMeta): - pass - - def exposed(method: Callable[..., Any]) -> Callable[..., Any]: """decorator to mark method as exposed in the public API of the class""" method.__exposed__ = True return method +class RPCInterface(metaclass=RPCInterfaceMeta): + @exposed + def init(self, *args, **kwargs): + """ + Initialize server + + Server initialization postponed so the client can handle errors occurring during server initialization. + """ + raise NotImplementedError + + @exposed + def shutdown(self) -> Shutdown: + raise NotImplementedError + + def get_exposed_methods(obj: RPCInterface) -> Dict[str, Callable[..., Any]]: exposed = getattr(obj, "__exposedmethods__", None) diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index a710a963..230674ee 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -4,6 +4,7 @@ from tiktorch.converters import pb_tensors_to_sample, sample_to_pb_tensors from tiktorch.proto import inference_pb2, inference_pb2_grpc +from tiktorch.rpc.mp import BioModelClient from tiktorch.server.data_store import IDataStore from tiktorch.server.device_pool import DeviceStatus, IDevicePool from tiktorch.server.session.process import InputSampleValidator, start_model_session_process @@ -11,7 +12,9 @@ class InferenceServicer(inference_pb2_grpc.InferenceServicer): - def __init__(self, device_pool: IDevicePool, session_manager: SessionManager, data_store: IDataStore) -> None: + def __init__( + self, device_pool: IDevicePool, session_manager: SessionManager[BioModelClient], data_store: IDataStore + ) -> None: self.__device_pool = device_pool self.__session_manager = session_manager self.__data_store = data_store @@ -28,25 +31,28 @@ def CreateModelSession( else: content = request.model_blob.content - lease = self.__device_pool.lease(request.deviceIds) - - try: - _, client = start_model_session_process(model_bytes=content, devices=[d.id for d in lease.devices]) - except Exception: - lease.terminate() - raise + devices = list(request.deviceIds) + _, client = start_model_session_process(model_bytes=content) session = self.__session_manager.create_session(client) - session.on_close(lease.terminate) session.on_close(client.api.shutdown) + lease = self.__device_pool.lease(devices) + session.on_close(lease.terminate) + + try: + client.api.init(model_bytes=content, devices=devices) + except Exception as e: + self.__session_manager.close_session(session.id) + raise e + return inference_pb2.ModelSession(id=session.id) def CreateDatasetDescription( self, request: inference_pb2.CreateDatasetDescriptionRequest, context ) -> inference_pb2.DatasetDescription: session = self._getModelSession(context, request.modelSessionId) - id = session.bio_model_client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) + id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) return inference_pb2.DatasetDescription(id=id) def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty: @@ -85,12 +91,12 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) input_sample = pb_tensors_to_sample(request.tensors) - tensor_validator = InputSampleValidator(session.bio_model_client.input_specs) + tensor_validator = InputSampleValidator(session.client.input_specs) tensor_validator.check_tensors(input_sample) - res = session.bio_model_client.api.forward(input_sample) + res = session.client.api.forward(input_sample).result() return inference_pb2.PredictResponse(tensors=sample_to_pb_tensors(res)) - def _getModelSession(self, context, modelSessionId: str) -> Session: + def _getModelSession(self, context, modelSessionId: str) -> Session[BioModelClient]: if not modelSessionId: context.abort(grpc.StatusCode.FAILED_PRECONDITION, "model-session-id has not been provided by client") diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 6224e23a..f289697e 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -4,7 +4,7 @@ import uuid from concurrent.futures import Future from multiprocessing.connection import Connection -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Type, TypeVar, Union from bioimageio.core import PredictionPipeline, Tensor, create_prediction_pipeline from bioimageio.spec import InvalidDescr, load_description @@ -13,6 +13,7 @@ from tiktorch import log from tiktorch.rpc import Shutdown from tiktorch.rpc import mp as _mp_rpc +from tiktorch.rpc.interface import RPCInterface from tiktorch.rpc.mp import BioModelClient, MPServer from ...converters import Sample @@ -72,14 +73,18 @@ def _get_spec(self, tensor_id: str) -> v0_5.InputTensorDescr: return specs[0] -class ModelSessionProcess(IRPCModelSession[PredictionPipeline]): - def __init__(self, model: PredictionPipeline) -> None: - super().__init__(model) +class ModelSessionProcess(IRPCModelSession): + def __init__(self) -> None: + super().__init__() self._datasets = {} - self._worker = base.SessionBackend(self._model) + self._worker: Optional[base.SessionBackend] = None + + def init(self, model_bytes: bytes, devices: List[str]): + prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_bytes, devices) + self._worker = base.SessionBackend(prediction_pipeline) def forward(self, sample: Sample) -> Future: - res = self._worker.forward(sample) + res = self.worker.forward(sample) return res def create_dataset(self, mean, stddev): @@ -88,13 +93,19 @@ def create_dataset(self, mean, stddev): return id_ def shutdown(self) -> Shutdown: - self._worker.shutdown() + if self._worker is None: + return Shutdown() + self.worker.shutdown() return Shutdown() + @property + def worker(self) -> base.SessionBackend: + if self._worker is None: + raise ValueError("Server isn't initialized") + return self._worker + -def _run_model_session_process( - conn: Connection, model_bytes: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None -): +def _run_server(api: RPCInterface, conn: Connection, log_queue: Optional[_mp.Queue] = None): try: # from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667 import resource @@ -107,24 +118,30 @@ def _run_model_session_process( if log_queue: log.configure(log_queue) - prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_bytes, devices) - session_proc = ModelSessionProcess(prediction_pipeline) - srv = MPServer(session_proc, conn) + srv = MPServer(api, conn) srv.listen() -def start_model_session_process( - model_bytes: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None -) -> Tuple[_mp.Process, BioModelClient]: +T = TypeVar("T", bound=RPCInterface) + + +def start_process(interface_class: Type[T], log_queue: Optional[_mp.Queue] = None) -> Tuple[_mp.Process, T]: client_conn, server_conn = _mp.Pipe() proc = _mp.Process( - target=_run_model_session_process, - name="ModelSessionProcess", - kwargs={"conn": server_conn, "log_queue": log_queue, "devices": devices, "model_bytes": model_bytes}, + target=_run_server, + name="TiktorchProcess", + kwargs={"conn": server_conn, "log_queue": log_queue, "api": interface_class()}, ) proc.start() - api = _mp_rpc.create_client_api(iface_cls=IRPCModelSession, conn=client_conn) + api: T = _mp_rpc.create_client_api(iface_cls=interface_class, conn=client_conn) + return proc, api + + +def start_model_session_process( + model_bytes: bytes, log_queue: Optional[_mp.Queue] = None +) -> Tuple[_mp.Process, BioModelClient]: model_descr = _get_model_descr_from_model_bytes(model_bytes) + proc, api = start_process(interface_class=ModelSessionProcess, log_queue=log_queue) return proc, BioModelClient( input_specs=model_descr.inputs, output_specs=model_descr.outputs, diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index bb414138..4efface8 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -1,24 +1,14 @@ -from typing import Generic, List, TypeVar +from typing import List from tiktorch.converters import Sample -from tiktorch.rpc import RPCInterface, Shutdown, exposed +from tiktorch.rpc import RPCInterface, exposed from tiktorch.tiktypes import TikTensorBatch from tiktorch.types import ModelState -ModelType = TypeVar("ModelType") - - -class IRPCModelSession(RPCInterface, Generic[ModelType]): - def __init__(self, model: ModelType): - super().__init__() - self._model = model - - @property - def model(self): - return self._model +class IRPCModelSession(RPCInterface): @exposed - def shutdown(self) -> Shutdown: + def init(self, model_bytes: bytes, devices: List[str]): raise NotImplementedError @exposed diff --git a/tiktorch/server/session_manager.py b/tiktorch/server/session_manager.py index 6def4a4a..90a3e1c9 100644 --- a/tiktorch/server/session_manager.py +++ b/tiktorch/server/session_manager.py @@ -3,30 +3,30 @@ import threading from collections import defaultdict from logging import getLogger -from typing import Callable, Dict, List, Optional +from typing import Callable, Dict, Generic, List, Optional, TypeVar from uuid import uuid4 -from tiktorch.rpc.mp import BioModelClient - logger = getLogger(__name__) CloseCallback = Callable[[], None] +SessionClient = TypeVar("SessionClient") + -class Session: +class Session(Generic[SessionClient]): """ session object has unique id Used for resource managent """ - def __init__(self, id_: str, bio_model_client: BioModelClient, manager: SessionManager) -> None: + def __init__(self, id_: str, client: SessionClient, manager: SessionManager) -> None: self.__id = id_ self.__manager = manager - self.__bio_model_client = bio_model_client + self.__client = client @property - def bio_model_client(self) -> BioModelClient: - return self.__bio_model_client + def client(self) -> SessionClient: + return self.__client @property def id(self) -> str: @@ -42,23 +42,23 @@ def on_close(self, handler: CloseCallback) -> None: self.__manager._on_close(self, handler) -class SessionManager: +class SessionManager(Generic[SessionClient]): """ Manages session lifecycle (create/close) """ - def create_session(self, bio_model_client: BioModelClient) -> Session: + def create_session(self, client: SessionClient) -> Session[SessionClient]: """ Creates new session with unique id """ with self.__lock: session_id = uuid4().hex - session = Session(session_id, bio_model_client=bio_model_client, manager=self) + session = Session(session_id, client=client, manager=self) self.__session_by_id[session_id] = session logger.info("Created session %s", session.id) return session - def get(self, session_id: str) -> Optional[Session]: + def get(self, session_id: str) -> Optional[Session[SessionClient]]: """ Returns existing session with given id if it exists """