Skip to content

Commit

Permalink
Merge pull request #223 from thodkatz/decouple-server-init
Browse files Browse the repository at this point in the history
Decouple server initialization when creating a new process
  • Loading branch information
thodkatz authored Dec 9, 2024
2 parents ef1629d + 706e49c commit 3c1477f
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 64 deletions.
16 changes: 16 additions & 0 deletions tests/test_server/test_grpc/test_inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,22 @@ def test_predict_call_fails_without_specifying_model_session_id(self, grpc_stub)
assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
assert "model-session-id has not been provided" in e.value.details()

def test_model_init_failed_close_session(self, bioimage_model_explicit_add_one_siso_v5, grpc_stub):
"""
If the model initialization fails, the session should be closed, so we can initialize a new one
"""

model_req = inference_pb2.CreateModelSessionRequest(
model_blob=inference_pb2.Blob(content=b""), deviceIds=["cpu"]
)

with pytest.raises(Exception):
grpc_stub.CreateModelSession(model_req)

model_bytes = bioimage_model_explicit_add_one_siso_v5
response = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
assert response.id is not None


class TestDeviceManagement:
def test_list_devices(self, grpc_stub):
Expand Down
23 changes: 18 additions & 5 deletions tiktorch/rpc/interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any, Callable, Dict

from tiktorch.rpc import Shutdown


class RPCInterfaceMeta(type):
def __new__(mcls, name, bases, namespace, **kwargs):
Expand All @@ -8,22 +10,33 @@ def __new__(mcls, name, bases, namespace, **kwargs):

for base in bases:
if issubclass(base, RPCInterface):
exposed ^= getattr(base, "__exposedmethods__", set())
exposed |= getattr(base, "__exposedmethods__", set())

cls.__exposedmethods__ = frozenset(exposed)
return cls


class RPCInterface(metaclass=RPCInterfaceMeta):
pass


def exposed(method: Callable[..., Any]) -> Callable[..., Any]:
"""decorator to mark method as exposed in the public API of the class"""
method.__exposed__ = True
return method


class RPCInterface(metaclass=RPCInterfaceMeta):
@exposed
def init(self, *args, **kwargs):
"""
Initialize server
Server initialization postponed so the client can handle errors occurring during server initialization.
"""
raise NotImplementedError

@exposed
def shutdown(self) -> Shutdown:
raise NotImplementedError


def get_exposed_methods(obj: RPCInterface) -> Dict[str, Callable[..., Any]]:
exposed = getattr(obj, "__exposedmethods__", None)

Expand Down
32 changes: 19 additions & 13 deletions tiktorch/server/grpc/inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@

from tiktorch.converters import pb_tensors_to_sample, sample_to_pb_tensors
from tiktorch.proto import inference_pb2, inference_pb2_grpc
from tiktorch.rpc.mp import BioModelClient
from tiktorch.server.data_store import IDataStore
from tiktorch.server.device_pool import DeviceStatus, IDevicePool
from tiktorch.server.session.process import InputSampleValidator, start_model_session_process
from tiktorch.server.session_manager import Session, SessionManager


class InferenceServicer(inference_pb2_grpc.InferenceServicer):
def __init__(self, device_pool: IDevicePool, session_manager: SessionManager, data_store: IDataStore) -> None:
def __init__(
self, device_pool: IDevicePool, session_manager: SessionManager[BioModelClient], data_store: IDataStore
) -> None:
self.__device_pool = device_pool
self.__session_manager = session_manager
self.__data_store = data_store
Expand All @@ -28,25 +31,28 @@ def CreateModelSession(
else:
content = request.model_blob.content

lease = self.__device_pool.lease(request.deviceIds)

try:
_, client = start_model_session_process(model_bytes=content, devices=[d.id for d in lease.devices])
except Exception:
lease.terminate()
raise
devices = list(request.deviceIds)

_, client = start_model_session_process(model_bytes=content)
session = self.__session_manager.create_session(client)
session.on_close(lease.terminate)
session.on_close(client.api.shutdown)

lease = self.__device_pool.lease(devices)
session.on_close(lease.terminate)

try:
client.api.init(model_bytes=content, devices=devices)
except Exception as e:
self.__session_manager.close_session(session.id)
raise e

return inference_pb2.ModelSession(id=session.id)

def CreateDatasetDescription(
self, request: inference_pb2.CreateDatasetDescriptionRequest, context
) -> inference_pb2.DatasetDescription:
session = self._getModelSession(context, request.modelSessionId)
id = session.bio_model_client.api.create_dataset_description(mean=request.mean, stddev=request.stddev)
id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev)
return inference_pb2.DatasetDescription(id=id)

def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> inference_pb2.Empty:
Expand Down Expand Up @@ -85,12 +91,12 @@ def ListDevices(self, request: inference_pb2.Empty, context) -> inference_pb2.De
def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse:
session = self._getModelSession(context, request.modelSessionId)
input_sample = pb_tensors_to_sample(request.tensors)
tensor_validator = InputSampleValidator(session.bio_model_client.input_specs)
tensor_validator = InputSampleValidator(session.client.input_specs)
tensor_validator.check_tensors(input_sample)
res = session.bio_model_client.api.forward(input_sample)
res = session.client.api.forward(input_sample).result()
return inference_pb2.PredictResponse(tensors=sample_to_pb_tensors(res))

def _getModelSession(self, context, modelSessionId: str) -> Session:
def _getModelSession(self, context, modelSessionId: str) -> Session[BioModelClient]:
if not modelSessionId:
context.abort(grpc.StatusCode.FAILED_PRECONDITION, "model-session-id has not been provided by client")

Expand Down
57 changes: 37 additions & 20 deletions tiktorch/server/session/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from concurrent.futures import Future
from multiprocessing.connection import Connection
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Type, TypeVar, Union

from bioimageio.core import PredictionPipeline, Tensor, create_prediction_pipeline
from bioimageio.spec import InvalidDescr, load_description
Expand All @@ -13,6 +13,7 @@
from tiktorch import log
from tiktorch.rpc import Shutdown
from tiktorch.rpc import mp as _mp_rpc
from tiktorch.rpc.interface import RPCInterface
from tiktorch.rpc.mp import BioModelClient, MPServer

from ...converters import Sample
Expand Down Expand Up @@ -72,14 +73,18 @@ def _get_spec(self, tensor_id: str) -> v0_5.InputTensorDescr:
return specs[0]


class ModelSessionProcess(IRPCModelSession[PredictionPipeline]):
def __init__(self, model: PredictionPipeline) -> None:
super().__init__(model)
class ModelSessionProcess(IRPCModelSession):
def __init__(self) -> None:
super().__init__()
self._datasets = {}
self._worker = base.SessionBackend(self._model)
self._worker: Optional[base.SessionBackend] = None

def init(self, model_bytes: bytes, devices: List[str]):
prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_bytes, devices)
self._worker = base.SessionBackend(prediction_pipeline)

def forward(self, sample: Sample) -> Future:
res = self._worker.forward(sample)
res = self.worker.forward(sample)
return res

def create_dataset(self, mean, stddev):
Expand All @@ -88,13 +93,19 @@ def create_dataset(self, mean, stddev):
return id_

def shutdown(self) -> Shutdown:
self._worker.shutdown()
if self._worker is None:
return Shutdown()
self.worker.shutdown()
return Shutdown()

@property
def worker(self) -> base.SessionBackend:
if self._worker is None:
raise ValueError("Server isn't initialized")
return self._worker


def _run_model_session_process(
conn: Connection, model_bytes: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
):
def _run_server(api: RPCInterface, conn: Connection, log_queue: Optional[_mp.Queue] = None):
try:
# from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
import resource
Expand All @@ -107,24 +118,30 @@ def _run_model_session_process(
if log_queue:
log.configure(log_queue)

prediction_pipeline = _get_prediction_pipeline_from_model_bytes(model_bytes, devices)
session_proc = ModelSessionProcess(prediction_pipeline)
srv = MPServer(session_proc, conn)
srv = MPServer(api, conn)
srv.listen()


def start_model_session_process(
model_bytes: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
) -> Tuple[_mp.Process, BioModelClient]:
T = TypeVar("T", bound=RPCInterface)


def start_process(interface_class: Type[T], log_queue: Optional[_mp.Queue] = None) -> Tuple[_mp.Process, T]:
client_conn, server_conn = _mp.Pipe()
proc = _mp.Process(
target=_run_model_session_process,
name="ModelSessionProcess",
kwargs={"conn": server_conn, "log_queue": log_queue, "devices": devices, "model_bytes": model_bytes},
target=_run_server,
name="TiktorchProcess",
kwargs={"conn": server_conn, "log_queue": log_queue, "api": interface_class()},
)
proc.start()
api = _mp_rpc.create_client_api(iface_cls=IRPCModelSession, conn=client_conn)
api: T = _mp_rpc.create_client_api(iface_cls=interface_class, conn=client_conn)
return proc, api


def start_model_session_process(
model_bytes: bytes, log_queue: Optional[_mp.Queue] = None
) -> Tuple[_mp.Process, BioModelClient]:
model_descr = _get_model_descr_from_model_bytes(model_bytes)
proc, api = start_process(interface_class=ModelSessionProcess, log_queue=log_queue)
return proc, BioModelClient(
input_specs=model_descr.inputs,
output_specs=model_descr.outputs,
Expand Down
18 changes: 4 additions & 14 deletions tiktorch/server/session/rpc_interface.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,14 @@
from typing import Generic, List, TypeVar
from typing import List

from tiktorch.converters import Sample
from tiktorch.rpc import RPCInterface, Shutdown, exposed
from tiktorch.rpc import RPCInterface, exposed
from tiktorch.tiktypes import TikTensorBatch
from tiktorch.types import ModelState

ModelType = TypeVar("ModelType")


class IRPCModelSession(RPCInterface, Generic[ModelType]):
def __init__(self, model: ModelType):
super().__init__()
self._model = model

@property
def model(self):
return self._model

class IRPCModelSession(RPCInterface):
@exposed
def shutdown(self) -> Shutdown:
def init(self, model_bytes: bytes, devices: List[str]):
raise NotImplementedError

@exposed
Expand Down
24 changes: 12 additions & 12 deletions tiktorch/server/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@
import threading
from collections import defaultdict
from logging import getLogger
from typing import Callable, Dict, List, Optional
from typing import Callable, Dict, Generic, List, Optional, TypeVar
from uuid import uuid4

from tiktorch.rpc.mp import BioModelClient

logger = getLogger(__name__)

CloseCallback = Callable[[], None]

SessionClient = TypeVar("SessionClient")


class Session:
class Session(Generic[SessionClient]):
"""
session object has unique id
Used for resource managent
"""

def __init__(self, id_: str, bio_model_client: BioModelClient, manager: SessionManager) -> None:
def __init__(self, id_: str, client: SessionClient, manager: SessionManager) -> None:
self.__id = id_
self.__manager = manager
self.__bio_model_client = bio_model_client
self.__client = client

@property
def bio_model_client(self) -> BioModelClient:
return self.__bio_model_client
def client(self) -> SessionClient:
return self.__client

@property
def id(self) -> str:
Expand All @@ -42,23 +42,23 @@ def on_close(self, handler: CloseCallback) -> None:
self.__manager._on_close(self, handler)


class SessionManager:
class SessionManager(Generic[SessionClient]):
"""
Manages session lifecycle (create/close)
"""

def create_session(self, bio_model_client: BioModelClient) -> Session:
def create_session(self, client: SessionClient) -> Session[SessionClient]:
"""
Creates new session with unique id
"""
with self.__lock:
session_id = uuid4().hex
session = Session(session_id, bio_model_client=bio_model_client, manager=self)
session = Session(session_id, client=client, manager=self)
self.__session_by_id[session_id] = session
logger.info("Created session %s", session.id)
return session

def get(self, session_id: str) -> Optional[Session]:
def get(self, session_id: str) -> Optional[Session[SessionClient]]:
"""
Returns existing session with given id if it exists
"""
Expand Down

0 comments on commit 3c1477f

Please sign in to comment.