From 4798dbe25e9afe7c1215aa9cfc11a2c4065b9a62 Mon Sep 17 00:00:00 2001 From: Theodoros Katzalis Date: Sat, 21 Dec 2024 03:37:47 +0100 Subject: [PATCH] Move PredictRequest and ModelSession to utils.proto Since both inference and training servicers have common the concept of id, the training session id was replaced with the model session one used for inference. This model session protobuf interfaced moved to a separate utils proto file. The PredictRequest being common, can be leveraged for abstraction. --- proto/inference.proto | 14 +-- proto/training.proto | 35 +++---- proto/utils.proto | 13 +++ .../test_grpc/test_inference_servicer.py | 43 ++++---- .../test_grpc/test_training_servicer.py | 32 +++--- tiktorch/proto/inference_pb2.py | 46 ++++----- tiktorch/proto/inference_pb2_grpc.py | 24 ++--- tiktorch/proto/training_pb2.py | 50 +++++----- tiktorch/proto/training_pb2_grpc.py | 99 +++++++------------ tiktorch/proto/utils_pb2.py | 32 +++--- tiktorch/server/grpc/inference_servicer.py | 30 ++---- tiktorch/server/grpc/training_servicer.py | 48 ++++----- tiktorch/server/grpc/utils_servicer.py | 19 ++++ 13 files changed, 221 insertions(+), 264 deletions(-) diff --git a/proto/inference.proto b/proto/inference.proto index 3d10929b..f8530953 100644 --- a/proto/inference.proto +++ b/proto/inference.proto @@ -21,7 +21,7 @@ service Inference { message CreateDatasetDescriptionRequest { - string modelSessionId = 1; + ModelSession modelSessionId = 1; double mean = 3; double stddev = 4; } @@ -53,9 +53,6 @@ message NamedFloats { } -message ModelSession { - string id = 1; -} message LogEntry { enum Level { @@ -73,15 +70,6 @@ message LogEntry { } -message PredictRequest { - string modelSessionId = 1; - string datasetId = 2; - repeated Tensor tensors = 3; -} - -message PredictResponse { - repeated Tensor tensors = 1; -} service FlightControl { diff --git a/proto/training.proto b/proto/training.proto index 496a6eaa..1c85b7e3 100644 --- a/proto/training.proto +++ b/proto/training.proto @@ -9,33 +9,33 @@ import "utils.proto"; service Training { rpc ListDevices(Empty) returns (Devices) {} - rpc Init(TrainingConfig) returns (TrainingSessionId) {} + rpc Init(TrainingConfig) returns (ModelSession) {} - rpc Start(TrainingSessionId) returns (Empty) {} + rpc Start(ModelSession) returns (Empty) {} - rpc Resume(TrainingSessionId) returns (Empty) {} + rpc Resume(ModelSession) returns (Empty) {} - rpc Pause(TrainingSessionId) returns (Empty) {} + rpc Pause(ModelSession) returns (Empty) {} - rpc StreamUpdates(TrainingSessionId) returns (stream StreamUpdateResponse) {} + rpc StreamUpdates(ModelSession) returns (stream StreamUpdateResponse) {} - rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {} + rpc GetLogs(ModelSession) returns (GetLogsResponse) {} - rpc Save(TrainingSessionId) returns (Empty) {} - - rpc Export(TrainingSessionId) returns (Empty) {} + rpc Export(ModelSession) returns (Empty) {} rpc Predict(PredictRequest) returns (PredictResponse) {} - rpc GetStatus(TrainingSessionId) returns (GetStatusResponse) {} + rpc GetStatus(ModelSession) returns (GetStatusResponse) {} - rpc CloseTrainerSession(TrainingSessionId) returns (Empty) {} + rpc CloseTrainerSession(ModelSession) returns (Empty) {} } -message TrainingSessionId { + +message GetBestModelIdxResponse { string id = 1; } + message Logs { enum ModelPhase { Train = 0; @@ -59,17 +59,6 @@ message GetLogsResponse { } - -message PredictRequest { - repeated Tensor tensors = 1; - TrainingSessionId sessionId = 2; -} - - -message PredictResponse { - repeated Tensor tensors = 1; -} - message ValidationResponse { double validation_score_average = 1; } diff --git a/proto/utils.proto b/proto/utils.proto index cb24d3e3..86ab7744 100644 --- a/proto/utils.proto +++ b/proto/utils.proto @@ -2,6 +2,19 @@ syntax = "proto3"; message Empty {} +message ModelSession { + string id = 1; +} + +message PredictRequest { + ModelSession modelSessionId = 1; + repeated Tensor tensors = 2; +} + +message PredictResponse { + repeated Tensor tensors = 1; +} + message NamedInt { uint32 size = 1; string name = 2; diff --git a/tests/test_server/test_grpc/test_inference_servicer.py b/tests/test_server/test_grpc/test_inference_servicer.py index 52827193..363e03df 100644 --- a/tests/test_server/test_grpc/test_inference_servicer.py +++ b/tests/test_server/test_grpc/test_inference_servicer.py @@ -77,10 +77,10 @@ def test_model_session_creation_using_non_existent_upload(self, grpc_stub): def test_predict_call_fails_without_specifying_model_session_id(self, grpc_stub): with pytest.raises(grpc.RpcError) as e: - grpc_stub.Predict(inference_pb2.PredictRequest()) + grpc_stub.Predict(utils_pb2.PredictRequest()) assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code() - assert "model-session-id has not been provided" in e.value.details() + assert "model-session with id doesn't exist" in e.value.details() def test_model_init_failed_close_session(self, bioimage_model_explicit_add_one_siso_v5, grpc_stub): """ @@ -169,17 +169,18 @@ def test_returns_ack_message(self, bioimage_model_explicit_add_one_siso_v5, grpc class TestForwardPass: def test_call_fails_with_unknown_model_session_id(self, grpc_stub): with pytest.raises(grpc.RpcError) as e: - grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId="myid1")) + model_id = utils_pb2.ModelSession(id="myid") + grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id)) assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code() - assert "model-session with id myid1 doesn't exist" in e.value.details() + assert "model-session with id myid doesn't exist" in e.value.details() def test_call_predict_valid_explicit(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5): model_bytes = bioimage_model_explicit_add_one_siso_v5 - model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y")) input_tensor_id = "input" input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] - res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors)) assert len(res.tensors) == 1 pb_tensor = res.tensors[0] assert pb_tensor.tensorId == "output" @@ -188,11 +189,11 @@ def test_call_predict_valid_explicit(self, grpc_stub, bioimage_model_explicit_ad def test_call_predict_valid_explicit_v4(self, grpc_stub, bioimage_model_add_one_v4): model_bytes = bioimage_model_add_one_v4 - model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y")) input_tensor_id = "input" input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] - res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors)) assert len(res.tensors) == 1 pb_tensor = res.tensors[0] assert pb_tensor.tensorId == "output" @@ -201,16 +202,16 @@ def test_call_predict_valid_explicit_v4(self, grpc_stub, bioimage_model_add_one_ def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5): model_bytes = bioimage_model_explicit_add_one_siso_v5 - model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("batch", "channel", "x", "y")) input_tensors = [converters.xarray_to_pb_tensor("input", arr)] with pytest.raises(grpc.RpcError) as error: - grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors)) assert error.value.details().startswith("Exception calling application: Incompatible axis") def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_model_add_one_miso_v5): model_bytes = bioimage_model_add_one_miso_v5 - model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) arr1 = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y")) input_tensor_id1 = "input1" @@ -227,8 +228,8 @@ def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_m converters.xarray_to_pb_tensor(tensor_id, arr) for tensor_id, arr in zip(input_tensor_ids, tensors_arr) ] - res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) - grpc_stub.CloseModelSession(model) + res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors)) + grpc_stub.CloseModelSession(model_id) assert len(res.tensors) == 1 pb_tensor = res.tensors[0] assert pb_tensor.tensorId == "output" @@ -238,11 +239,11 @@ def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_m @pytest.mark.parametrize("shape", [(1, 2, 10, 20), (1, 2, 12, 20), (1, 2, 10, 23), (1, 2, 12, 23)]) def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimage_model_param_add_one_siso_v5): model_bytes = bioimage_model_param_add_one_siso_v5 - model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("batch", "channel", "x", "y")) input_tensor_id = "input" input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] - grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors)) @pytest.mark.parametrize( "shape", @@ -250,21 +251,21 @@ def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimage ) def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimage_model_param_add_one_siso_v5): model_bytes = bioimage_model_param_add_one_siso_v5 - model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("batch", "channel", "x", "y")) input_tensor_id = "input" input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] with pytest.raises(grpc.RpcError) as error: - grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors)) assert error.value.details().startswith("Exception calling application: Incompatible axis") def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5): model_bytes = bioimage_model_explicit_add_one_siso_v5 - model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y")) input_tensors = [converters.xarray_to_pb_tensor("invalidTensorName", arr)] with pytest.raises(grpc.RpcError) as error: - grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors)) assert error.value.details().startswith("Exception calling application: Spec 'invalidTensorName' doesn't exist") @pytest.mark.parametrize( @@ -278,10 +279,10 @@ def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimage_model_explici ) def test_call_predict_invalid_axes(self, grpc_stub, axes, bioimage_model_explicit_add_one_siso_v5): model_bytes = bioimage_model_explicit_add_one_siso_v5 - model = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) + model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes)) arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=axes) input_tensor_id = "input" input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)] with pytest.raises(grpc.RpcError) as error: - grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors)) + grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors)) assert error.value.details().startswith("Exception calling application: Incompatible axes names") diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index 0dcb9107..9dcf0c06 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -233,7 +233,7 @@ def test_init_failed_then_devices_are_released(self, grpc_stub): # attempt to init with the same device init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - response = training_pb2.TrainingSessionId(id=init_response.id) + response = utils_pb2.ModelSession(id=init_response.id) assert response.id is not None def test_start_training_success(self): @@ -271,7 +271,7 @@ def test_concurrent_state_transitions(self, grpc_stub): The test should exit gracefully without hanging processes or threads. """ init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) threads = [] for _ in range(2): @@ -290,7 +290,7 @@ def assert_state(state_to_check): self.assert_state(grpc_stub, training_session_id, state_to_check) init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) grpc_stub.Start(training_session_id) assert_state(TrainerState.RUNNING) @@ -304,7 +304,7 @@ def assert_state(state_to_check): def test_error_handling_on_invalid_state_transitions_after_training_started(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) # Attempt to start again while already running grpc_stub.Start(training_session_id) @@ -326,7 +326,7 @@ def test_error_handling_on_invalid_state_transitions_after_training_started(self def test_error_handling_on_invalid_state_transitions_before_training_started(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) # Attempt to resume before start with pytest.raises(grpc.RpcError) as excinfo: @@ -345,7 +345,7 @@ def test_start_training_without_init(self, grpc_stub): with pytest.raises(grpc.RpcError) as excinfo: grpc_stub.Start(utils_pb2.Empty()) assert excinfo.value.code() == grpc.StatusCode.FAILED_PRECONDITION - assert "trainer-session with id doesn't exist" in excinfo.value.details() + assert "model-session with id doesn't exist" in excinfo.value.details() def test_recover_training_failed(self): class MockedExceptionTrainer: @@ -439,25 +439,25 @@ def init(self, trainer_yaml_config: str = ""): def test_graceful_shutdown_after_init(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) grpc_stub.CloseTrainerSession(training_session_id) def test_graceful_shutdown_after_start(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) grpc_stub.Start(training_session_id) grpc_stub.CloseTrainerSession(training_session_id) def test_graceful_shutdown_after_pause(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) grpc_stub.Start(training_session_id) grpc_stub.Pause(training_session_id) grpc_stub.CloseTrainerSession(training_session_id) def test_graceful_shutdown_after_resume(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) grpc_stub.Start(training_session_id) grpc_stub.Pause(training_session_id) grpc_stub.Resume(training_session_id) @@ -466,7 +466,7 @@ def test_graceful_shutdown_after_resume(self, grpc_stub): def test_close_trainer_session_twice(self, grpc_stub): # Attempt to close the session twice init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) grpc_stub.CloseTrainerSession(training_session_id) # The second attempt should raise an error @@ -476,7 +476,7 @@ def test_close_trainer_session_twice(self, grpc_stub): def test_forward_while_running(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) grpc_stub.Start(training_session_id) @@ -487,7 +487,7 @@ def test_forward_while_running(self, grpc_stub): data = np.random.rand(*shape).astype(np.float32) xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x")) pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data) - predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor]) + predict_request = utils_pb2.PredictRequest(modelSessionId=training_session_id, tensors=[pb_tensor]) response = grpc_stub.Predict(predict_request) @@ -502,7 +502,7 @@ def test_forward_while_running(self, grpc_stub): def test_forward_while_paused(self, grpc_stub): init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) grpc_stub.Start(training_session_id) @@ -513,7 +513,7 @@ def test_forward_while_paused(self, grpc_stub): data = np.random.rand(*shape).astype(np.float32) xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x")) pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data) - predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor]) + predict_request = utils_pb2.PredictRequest(modelSessionId=training_session_id, tensors=[pb_tensor]) grpc_stub.Pause(training_session_id) @@ -533,7 +533,7 @@ def test_close_session(self, grpc_stub): Test closing a training session. """ init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) - training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + training_session_id = utils_pb2.ModelSession(id=init_response.id) grpc_stub.CloseTrainerSession(training_session_id) # attempt to perform an operation while session is closed diff --git a/tiktorch/proto/inference_pb2.py b/tiktorch/proto/inference_pb2.py index be9fbde2..5ae66345 100644 --- a/tiktorch/proto/inference_pb2.py +++ b/tiktorch/proto/inference_pb2.py @@ -14,7 +14,7 @@ from . import utils_pb2 as utils__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\x1a\x0butils.proto\"W\n\x1f\x43reateDatasetDescriptionRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"s\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12%\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x0f.inference.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"\xa8\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12(\n\x05level\x18\x02 \x01(\x0e\x32\x19.inference.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\"U\n\x0ePredictRequest\x12\x16\n\x0emodelSessionId\x18\x01 \x01(\t\x12\x11\n\tdatasetId\x18\x02 \x01(\t\x12\x18\n\x07tensors\x18\x03 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor2\x96\x03\n\tInference\x12U\n\x12\x43reateModelSession\x12$.inference.CreateModelSessionRequest\x1a\x17.inference.ModelSession\"\x00\x12\x36\n\x11\x43loseModelSession\x12\x17.inference.ModelSession\x1a\x06.Empty\"\x00\x12g\n\x18\x43reateDatasetDescription\x12*.inference.CreateDatasetDescriptionRequest\x1a\x1d.inference.DatasetDescription\"\x00\x12*\n\x07GetLogs\x12\x06.Empty\x1a\x13.inference.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x42\n\x07Predict\x12\x19.inference.PredictRequest\x1a\x1a.inference.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0finference.proto\x12\tinference\x1a\x0butils.proto\"f\n\x1f\x43reateDatasetDescriptionRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x0c\n\x04mean\x18\x03 \x01(\x01\x12\x0e\n\x06stddev\x18\x04 \x01(\x01\" \n\x12\x44\x61tasetDescription\x12\n\n\x02id\x18\x01 \x01(\t\"\'\n\x04\x42lob\x12\x0e\n\x06\x66ormat\x18\x01 \x01(\t\x12\x0f\n\x07\x63ontent\x18\x02 \x01(\x0c\"s\n\x19\x43reateModelSessionRequest\x12\x13\n\tmodel_uri\x18\x01 \x01(\tH\x00\x12%\n\nmodel_blob\x18\x02 \x01(\x0b\x32\x0f.inference.BlobH\x00\x12\x11\n\tdeviceIds\x18\x05 \x03(\tB\x07\n\x05model\")\n\tNamedInts\x12\x1c\n\tnamedInts\x18\x01 \x03(\x0b\x32\t.NamedInt\"/\n\x0bNamedFloats\x12 \n\x0bnamedFloats\x18\x01 \x03(\x0b\x32\x0b.NamedFloat\"\xa8\x01\n\x08LogEntry\x12\x11\n\ttimestamp\x18\x01 \x01(\r\x12(\n\x05level\x18\x02 \x01(\x0e\x32\x19.inference.LogEntry.Level\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\t\"N\n\x05Level\x12\n\n\x06NOTSET\x10\x00\x12\t\n\x05\x44\x45\x42UG\x10\x01\x12\x08\n\x04INFO\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\x0c\n\x08\x43RITICAL\x10\x05\x32\xee\x02\n\tInference\x12K\n\x12\x43reateModelSession\x12$.inference.CreateModelSessionRequest\x1a\r.ModelSession\"\x00\x12,\n\x11\x43loseModelSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12g\n\x18\x43reateDatasetDescription\x12*.inference.CreateDatasetDescriptionRequest\x1a\x1d.inference.DatasetDescription\"\x00\x12*\n\x07GetLogs\x12\x06.Empty\x1a\x13.inference.LogEntry\"\x00\x30\x01\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x32G\n\rFlightControl\x12\x18\n\x04Ping\x12\x06.Empty\x1a\x06.Empty\"\x00\x12\x1c\n\x08Shutdown\x12\x06.Empty\x1a\x06.Empty\"\x00\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'inference_pb2', globals()) @@ -22,29 +22,23 @@ DESCRIPTOR._options = None _CREATEDATASETDESCRIPTIONREQUEST._serialized_start=43 - _CREATEDATASETDESCRIPTIONREQUEST._serialized_end=130 - _DATASETDESCRIPTION._serialized_start=132 - _DATASETDESCRIPTION._serialized_end=164 - _BLOB._serialized_start=166 - _BLOB._serialized_end=205 - _CREATEMODELSESSIONREQUEST._serialized_start=207 - _CREATEMODELSESSIONREQUEST._serialized_end=322 - _NAMEDINTS._serialized_start=324 - _NAMEDINTS._serialized_end=365 - _NAMEDFLOATS._serialized_start=367 - _NAMEDFLOATS._serialized_end=414 - _MODELSESSION._serialized_start=416 - _MODELSESSION._serialized_end=442 - _LOGENTRY._serialized_start=445 - _LOGENTRY._serialized_end=613 - _LOGENTRY_LEVEL._serialized_start=535 - _LOGENTRY_LEVEL._serialized_end=613 - _PREDICTREQUEST._serialized_start=615 - _PREDICTREQUEST._serialized_end=700 - _PREDICTRESPONSE._serialized_start=702 - _PREDICTRESPONSE._serialized_end=745 - _INFERENCE._serialized_start=748 - _INFERENCE._serialized_end=1154 - _FLIGHTCONTROL._serialized_start=1156 - _FLIGHTCONTROL._serialized_end=1227 + _CREATEDATASETDESCRIPTIONREQUEST._serialized_end=145 + _DATASETDESCRIPTION._serialized_start=147 + _DATASETDESCRIPTION._serialized_end=179 + _BLOB._serialized_start=181 + _BLOB._serialized_end=220 + _CREATEMODELSESSIONREQUEST._serialized_start=222 + _CREATEMODELSESSIONREQUEST._serialized_end=337 + _NAMEDINTS._serialized_start=339 + _NAMEDINTS._serialized_end=380 + _NAMEDFLOATS._serialized_start=382 + _NAMEDFLOATS._serialized_end=429 + _LOGENTRY._serialized_start=432 + _LOGENTRY._serialized_end=600 + _LOGENTRY_LEVEL._serialized_start=522 + _LOGENTRY_LEVEL._serialized_end=600 + _INFERENCE._serialized_start=603 + _INFERENCE._serialized_end=969 + _FLIGHTCONTROL._serialized_start=971 + _FLIGHTCONTROL._serialized_end=1042 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/inference_pb2_grpc.py b/tiktorch/proto/inference_pb2_grpc.py index b49f2e42..22e28495 100644 --- a/tiktorch/proto/inference_pb2_grpc.py +++ b/tiktorch/proto/inference_pb2_grpc.py @@ -18,11 +18,11 @@ def __init__(self, channel): self.CreateModelSession = channel.unary_unary( '/inference.Inference/CreateModelSession', request_serializer=inference__pb2.CreateModelSessionRequest.SerializeToString, - response_deserializer=inference__pb2.ModelSession.FromString, + response_deserializer=utils__pb2.ModelSession.FromString, ) self.CloseModelSession = channel.unary_unary( '/inference.Inference/CloseModelSession', - request_serializer=inference__pb2.ModelSession.SerializeToString, + request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=utils__pb2.Empty.FromString, ) self.CreateDatasetDescription = channel.unary_unary( @@ -42,8 +42,8 @@ def __init__(self, channel): ) self.Predict = channel.unary_unary( '/inference.Inference/Predict', - request_serializer=inference__pb2.PredictRequest.SerializeToString, - response_deserializer=inference__pb2.PredictResponse.FromString, + request_serializer=utils__pb2.PredictRequest.SerializeToString, + response_deserializer=utils__pb2.PredictResponse.FromString, ) @@ -92,11 +92,11 @@ def add_InferenceServicer_to_server(servicer, server): 'CreateModelSession': grpc.unary_unary_rpc_method_handler( servicer.CreateModelSession, request_deserializer=inference__pb2.CreateModelSessionRequest.FromString, - response_serializer=inference__pb2.ModelSession.SerializeToString, + response_serializer=utils__pb2.ModelSession.SerializeToString, ), 'CloseModelSession': grpc.unary_unary_rpc_method_handler( servicer.CloseModelSession, - request_deserializer=inference__pb2.ModelSession.FromString, + request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=utils__pb2.Empty.SerializeToString, ), 'CreateDatasetDescription': grpc.unary_unary_rpc_method_handler( @@ -116,8 +116,8 @@ def add_InferenceServicer_to_server(servicer, server): ), 'Predict': grpc.unary_unary_rpc_method_handler( servicer.Predict, - request_deserializer=inference__pb2.PredictRequest.FromString, - response_serializer=inference__pb2.PredictResponse.SerializeToString, + request_deserializer=utils__pb2.PredictRequest.FromString, + response_serializer=utils__pb2.PredictResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -142,7 +142,7 @@ def CreateModelSession(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/inference.Inference/CreateModelSession', inference__pb2.CreateModelSessionRequest.SerializeToString, - inference__pb2.ModelSession.FromString, + utils__pb2.ModelSession.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -158,7 +158,7 @@ def CloseModelSession(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/inference.Inference/CloseModelSession', - inference__pb2.ModelSession.SerializeToString, + utils__pb2.ModelSession.SerializeToString, utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -226,8 +226,8 @@ def Predict(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/inference.Inference/Predict', - inference__pb2.PredictRequest.SerializeToString, - inference__pb2.PredictResponse.FromString, + utils__pb2.PredictRequest.SerializeToString, + utils__pb2.PredictResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tiktorch/proto/training_pb2.py b/tiktorch/proto/training_pb2.py index 70d93933..ee0a9d27 100644 --- a/tiktorch/proto/training_pb2.py +++ b/tiktorch/proto/training_pb2.py @@ -14,37 +14,33 @@ from . import utils_pb2 as utils__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"\x1f\n\x11TrainingSessionId\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"Z\n\x0ePredictRequest\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\x12.\n\tsessionId\x18\x02 \x01(\x0b\x32\x1b.training.TrainingSessionId\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\xbf\x05\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12?\n\x04Init\x12\x18.training.TrainingConfig\x1a\x1b.training.TrainingSessionId\"\x00\x12.\n\x05Start\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12/\n\x06Resume\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12.\n\x05Pause\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12P\n\rStreamUpdates\x12\x1b.training.TrainingSessionId\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x43\n\x07GetLogs\x12\x1b.training.TrainingSessionId\x1a\x19.training.GetLogsResponse\"\x00\x12-\n\x04Save\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12/\n\x06\x45xport\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x12@\n\x07Predict\x12\x18.training.PredictRequest\x1a\x19.training.PredictResponse\"\x00\x12G\n\tGetStatus\x12\x1b.training.TrainingSessionId\x1a\x1b.training.GetStatusResponse\"\x00\x12<\n\x13\x43loseTrainerSession\x12\x1b.training.TrainingSessionId\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a\x0butils.proto\"%\n\x17GetBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\t\"\x87\x01\n\x04Logs\x12\'\n\x04mode\x18\x01 \x01(\x0e\x32\x19.training.Logs.ModelPhase\x12\x12\n\neval_score\x18\x02 \x01(\x01\x12\x0c\n\x04loss\x18\x03 \x01(\x01\x12\x11\n\titeration\x18\x04 \x01(\r\"!\n\nModelPhase\x12\t\n\x05Train\x10\x00\x12\x08\n\x04\x45val\x10\x01\"L\n\x14StreamUpdateResponse\x12\x16\n\x0e\x62\x65st_model_idx\x18\x01 \x01(\r\x12\x1c\n\x04logs\x18\x02 \x01(\x0b\x32\x0e.training.Logs\"/\n\x0fGetLogsResponse\x12\x1c\n\x04logs\x18\x01 \x03(\x0b\x32\x0e.training.Logs\"6\n\x12ValidationResponse\x12 \n\x18validation_score_average\x18\x01 \x01(\x01\"\x8b\x01\n\x11GetStatusResponse\x12\x30\n\x05state\x18\x01 \x01(\x0e\x32!.training.GetStatusResponse.State\"D\n\x05State\x12\x08\n\x04Idle\x10\x00\x12\x0b\n\x07Running\x10\x01\x12\n\n\x06Paused\x10\x02\x12\n\n\x06\x46\x61iled\x10\x03\x12\x0c\n\x08\x46inished\x10\x04\",\n\x1eGetCurrentBestModelIdxResponse\x12\n\n\x02id\x18\x01 \x01(\r\"&\n\x0eTrainingConfig\x12\x14\n\x0cyaml_content\x18\x01 \x01(\t2\x80\x04\n\x08Training\x12!\n\x0bListDevices\x12\x06.Empty\x1a\x08.Devices\"\x00\x12\x31\n\x04Init\x12\x18.training.TrainingConfig\x1a\r.ModelSession\"\x00\x12 \n\x05Start\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12!\n\x06Resume\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12 \n\x05Pause\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12\x42\n\rStreamUpdates\x12\r.ModelSession\x1a\x1e.training.StreamUpdateResponse\"\x00\x30\x01\x12\x35\n\x07GetLogs\x12\r.ModelSession\x1a\x19.training.GetLogsResponse\"\x00\x12!\n\x06\x45xport\x12\r.ModelSession\x1a\x06.Empty\"\x00\x12.\n\x07Predict\x12\x0f.PredictRequest\x1a\x10.PredictResponse\"\x00\x12\x39\n\tGetStatus\x12\r.ModelSession\x1a\x1b.training.GetStatusResponse\"\x00\x12.\n\x13\x43loseTrainerSession\x12\r.ModelSession\x1a\x06.Empty\"\x00\x62\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'training_pb2', globals()) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _TRAININGSESSIONID._serialized_start=41 - _TRAININGSESSIONID._serialized_end=72 - _LOGS._serialized_start=75 - _LOGS._serialized_end=210 - _LOGS_MODELPHASE._serialized_start=177 - _LOGS_MODELPHASE._serialized_end=210 - _STREAMUPDATERESPONSE._serialized_start=212 - _STREAMUPDATERESPONSE._serialized_end=288 - _GETLOGSRESPONSE._serialized_start=290 - _GETLOGSRESPONSE._serialized_end=337 - _PREDICTREQUEST._serialized_start=339 - _PREDICTREQUEST._serialized_end=429 - _PREDICTRESPONSE._serialized_start=431 - _PREDICTRESPONSE._serialized_end=474 - _VALIDATIONRESPONSE._serialized_start=476 - _VALIDATIONRESPONSE._serialized_end=530 - _GETSTATUSRESPONSE._serialized_start=533 - _GETSTATUSRESPONSE._serialized_end=672 - _GETSTATUSRESPONSE_STATE._serialized_start=604 - _GETSTATUSRESPONSE_STATE._serialized_end=672 - _GETCURRENTBESTMODELIDXRESPONSE._serialized_start=674 - _GETCURRENTBESTMODELIDXRESPONSE._serialized_end=718 - _TRAININGCONFIG._serialized_start=720 - _TRAININGCONFIG._serialized_end=758 - _TRAINING._serialized_start=761 - _TRAINING._serialized_end=1464 + _GETBESTMODELIDXRESPONSE._serialized_start=41 + _GETBESTMODELIDXRESPONSE._serialized_end=78 + _LOGS._serialized_start=81 + _LOGS._serialized_end=216 + _LOGS_MODELPHASE._serialized_start=183 + _LOGS_MODELPHASE._serialized_end=216 + _STREAMUPDATERESPONSE._serialized_start=218 + _STREAMUPDATERESPONSE._serialized_end=294 + _GETLOGSRESPONSE._serialized_start=296 + _GETLOGSRESPONSE._serialized_end=343 + _VALIDATIONRESPONSE._serialized_start=345 + _VALIDATIONRESPONSE._serialized_end=399 + _GETSTATUSRESPONSE._serialized_start=402 + _GETSTATUSRESPONSE._serialized_end=541 + _GETSTATUSRESPONSE_STATE._serialized_start=473 + _GETSTATUSRESPONSE_STATE._serialized_end=541 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_start=543 + _GETCURRENTBESTMODELIDXRESPONSE._serialized_end=587 + _TRAININGCONFIG._serialized_start=589 + _TRAININGCONFIG._serialized_end=627 + _TRAINING._serialized_start=630 + _TRAINING._serialized_end=1142 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/proto/training_pb2_grpc.py b/tiktorch/proto/training_pb2_grpc.py index 79bf33df..5286471d 100644 --- a/tiktorch/proto/training_pb2_grpc.py +++ b/tiktorch/proto/training_pb2_grpc.py @@ -23,56 +23,51 @@ def __init__(self, channel): self.Init = channel.unary_unary( '/training.Training/Init', request_serializer=training__pb2.TrainingConfig.SerializeToString, - response_deserializer=training__pb2.TrainingSessionId.FromString, + response_deserializer=utils__pb2.ModelSession.FromString, ) self.Start = channel.unary_unary( '/training.Training/Start', - request_serializer=training__pb2.TrainingSessionId.SerializeToString, + request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=utils__pb2.Empty.FromString, ) self.Resume = channel.unary_unary( '/training.Training/Resume', - request_serializer=training__pb2.TrainingSessionId.SerializeToString, + request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=utils__pb2.Empty.FromString, ) self.Pause = channel.unary_unary( '/training.Training/Pause', - request_serializer=training__pb2.TrainingSessionId.SerializeToString, + request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=utils__pb2.Empty.FromString, ) self.StreamUpdates = channel.unary_stream( '/training.Training/StreamUpdates', - request_serializer=training__pb2.TrainingSessionId.SerializeToString, + request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=training__pb2.StreamUpdateResponse.FromString, ) self.GetLogs = channel.unary_unary( '/training.Training/GetLogs', - request_serializer=training__pb2.TrainingSessionId.SerializeToString, + request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=training__pb2.GetLogsResponse.FromString, ) - self.Save = channel.unary_unary( - '/training.Training/Save', - request_serializer=training__pb2.TrainingSessionId.SerializeToString, - response_deserializer=utils__pb2.Empty.FromString, - ) self.Export = channel.unary_unary( '/training.Training/Export', - request_serializer=training__pb2.TrainingSessionId.SerializeToString, + request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=utils__pb2.Empty.FromString, ) self.Predict = channel.unary_unary( '/training.Training/Predict', - request_serializer=training__pb2.PredictRequest.SerializeToString, - response_deserializer=training__pb2.PredictResponse.FromString, + request_serializer=utils__pb2.PredictRequest.SerializeToString, + response_deserializer=utils__pb2.PredictResponse.FromString, ) self.GetStatus = channel.unary_unary( '/training.Training/GetStatus', - request_serializer=training__pb2.TrainingSessionId.SerializeToString, + request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=training__pb2.GetStatusResponse.FromString, ) self.CloseTrainerSession = channel.unary_unary( '/training.Training/CloseTrainerSession', - request_serializer=training__pb2.TrainingSessionId.SerializeToString, + request_serializer=utils__pb2.ModelSession.SerializeToString, response_deserializer=utils__pb2.Empty.FromString, ) @@ -122,12 +117,6 @@ def GetLogs(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def Save(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - def Export(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -163,56 +152,51 @@ def add_TrainingServicer_to_server(servicer, server): 'Init': grpc.unary_unary_rpc_method_handler( servicer.Init, request_deserializer=training__pb2.TrainingConfig.FromString, - response_serializer=training__pb2.TrainingSessionId.SerializeToString, + response_serializer=utils__pb2.ModelSession.SerializeToString, ), 'Start': grpc.unary_unary_rpc_method_handler( servicer.Start, - request_deserializer=training__pb2.TrainingSessionId.FromString, + request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=utils__pb2.Empty.SerializeToString, ), 'Resume': grpc.unary_unary_rpc_method_handler( servicer.Resume, - request_deserializer=training__pb2.TrainingSessionId.FromString, + request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=utils__pb2.Empty.SerializeToString, ), 'Pause': grpc.unary_unary_rpc_method_handler( servicer.Pause, - request_deserializer=training__pb2.TrainingSessionId.FromString, + request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=utils__pb2.Empty.SerializeToString, ), 'StreamUpdates': grpc.unary_stream_rpc_method_handler( servicer.StreamUpdates, - request_deserializer=training__pb2.TrainingSessionId.FromString, + request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=training__pb2.StreamUpdateResponse.SerializeToString, ), 'GetLogs': grpc.unary_unary_rpc_method_handler( servicer.GetLogs, - request_deserializer=training__pb2.TrainingSessionId.FromString, + request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=training__pb2.GetLogsResponse.SerializeToString, ), - 'Save': grpc.unary_unary_rpc_method_handler( - servicer.Save, - request_deserializer=training__pb2.TrainingSessionId.FromString, - response_serializer=utils__pb2.Empty.SerializeToString, - ), 'Export': grpc.unary_unary_rpc_method_handler( servicer.Export, - request_deserializer=training__pb2.TrainingSessionId.FromString, + request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=utils__pb2.Empty.SerializeToString, ), 'Predict': grpc.unary_unary_rpc_method_handler( servicer.Predict, - request_deserializer=training__pb2.PredictRequest.FromString, - response_serializer=training__pb2.PredictResponse.SerializeToString, + request_deserializer=utils__pb2.PredictRequest.FromString, + response_serializer=utils__pb2.PredictResponse.SerializeToString, ), 'GetStatus': grpc.unary_unary_rpc_method_handler( servicer.GetStatus, - request_deserializer=training__pb2.TrainingSessionId.FromString, + request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=training__pb2.GetStatusResponse.SerializeToString, ), 'CloseTrainerSession': grpc.unary_unary_rpc_method_handler( servicer.CloseTrainerSession, - request_deserializer=training__pb2.TrainingSessionId.FromString, + request_deserializer=utils__pb2.ModelSession.FromString, response_serializer=utils__pb2.Empty.SerializeToString, ), } @@ -255,7 +239,7 @@ def Init(request, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Init', training__pb2.TrainingConfig.SerializeToString, - training__pb2.TrainingSessionId.FromString, + utils__pb2.ModelSession.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -271,7 +255,7 @@ def Start(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Start', - training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.ModelSession.SerializeToString, utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -288,7 +272,7 @@ def Resume(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Resume', - training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.ModelSession.SerializeToString, utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -305,7 +289,7 @@ def Pause(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Pause', - training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.ModelSession.SerializeToString, utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -322,7 +306,7 @@ def StreamUpdates(request, timeout=None, metadata=None): return grpc.experimental.unary_stream(request, target, '/training.Training/StreamUpdates', - training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.ModelSession.SerializeToString, training__pb2.StreamUpdateResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -339,28 +323,11 @@ def GetLogs(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/GetLogs', - training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.ModelSession.SerializeToString, training__pb2.GetLogsResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - @staticmethod - def Save(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary(request, target, '/training.Training/Save', - training__pb2.TrainingSessionId.SerializeToString, - utils__pb2.Empty.FromString, - options, channel_credentials, - insecure, call_credentials, compression, wait_for_ready, timeout, metadata) - @staticmethod def Export(request, target, @@ -373,7 +340,7 @@ def Export(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Export', - training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.ModelSession.SerializeToString, utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -390,8 +357,8 @@ def Predict(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/Predict', - training__pb2.PredictRequest.SerializeToString, - training__pb2.PredictResponse.FromString, + utils__pb2.PredictRequest.SerializeToString, + utils__pb2.PredictResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -407,7 +374,7 @@ def GetStatus(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/GetStatus', - training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.ModelSession.SerializeToString, training__pb2.GetStatusResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @@ -424,7 +391,7 @@ def CloseTrainerSession(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Training/CloseTrainerSession', - training__pb2.TrainingSessionId.SerializeToString, + utils__pb2.ModelSession.SerializeToString, utils__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/tiktorch/proto/utils_pb2.py b/tiktorch/proto/utils_pb2.py index c0709d97..8d025ce6 100644 --- a/tiktorch/proto/utils_pb2.py +++ b/tiktorch/proto/utils_pb2.py @@ -13,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0butils.proto\"\x07\n\x05\x45mpty\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"S\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x10\n\x08tensorId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Deviceb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0butils.proto\"\x07\n\x05\x45mpty\"\x1a\n\x0cModelSession\x12\n\n\x02id\x18\x01 \x01(\t\"Q\n\x0ePredictRequest\x12%\n\x0emodelSessionId\x18\x01 \x01(\x0b\x32\r.ModelSession\x12\x18\n\x07tensors\x18\x02 \x03(\x0b\x32\x07.Tensor\"+\n\x0fPredictResponse\x12\x18\n\x07tensors\x18\x01 \x03(\x0b\x32\x07.Tensor\"&\n\x08NamedInt\x12\x0c\n\x04size\x18\x01 \x01(\r\x12\x0c\n\x04name\x18\x02 \x01(\t\"(\n\nNamedFloat\x12\x0c\n\x04size\x18\x01 \x01(\x02\x12\x0c\n\x04name\x18\x02 \x01(\t\"S\n\x06Tensor\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\x12\r\n\x05\x64type\x18\x02 \x01(\t\x12\x10\n\x08tensorId\x18\x03 \x01(\t\x12\x18\n\x05shape\x18\x04 \x03(\x0b\x32\t.NamedInt\"Y\n\x06\x44\x65vice\x12\n\n\x02id\x18\x01 \x01(\t\x12\x1e\n\x06status\x18\x02 \x01(\x0e\x32\x0e.Device.Status\"#\n\x06Status\x12\r\n\tAVAILABLE\x10\x00\x12\n\n\x06IN_USE\x10\x01\"#\n\x07\x44\x65vices\x12\x18\n\x07\x64\x65vices\x18\x01 \x03(\x0b\x32\x07.Deviceb\x06proto3') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'utils_pb2', globals()) @@ -22,16 +22,22 @@ DESCRIPTOR._options = None _EMPTY._serialized_start=15 _EMPTY._serialized_end=22 - _NAMEDINT._serialized_start=24 - _NAMEDINT._serialized_end=62 - _NAMEDFLOAT._serialized_start=64 - _NAMEDFLOAT._serialized_end=104 - _TENSOR._serialized_start=106 - _TENSOR._serialized_end=189 - _DEVICE._serialized_start=191 - _DEVICE._serialized_end=280 - _DEVICE_STATUS._serialized_start=245 - _DEVICE_STATUS._serialized_end=280 - _DEVICES._serialized_start=282 - _DEVICES._serialized_end=317 + _MODELSESSION._serialized_start=24 + _MODELSESSION._serialized_end=50 + _PREDICTREQUEST._serialized_start=52 + _PREDICTREQUEST._serialized_end=133 + _PREDICTRESPONSE._serialized_start=135 + _PREDICTRESPONSE._serialized_end=178 + _NAMEDINT._serialized_start=180 + _NAMEDINT._serialized_end=218 + _NAMEDFLOAT._serialized_start=220 + _NAMEDFLOAT._serialized_end=260 + _TENSOR._serialized_start=262 + _TENSOR._serialized_end=345 + _DEVICE._serialized_start=347 + _DEVICE._serialized_end=436 + _DEVICE_STATUS._serialized_start=401 + _DEVICE_STATUS._serialized_end=436 + _DEVICES._serialized_start=438 + _DEVICES._serialized_end=473 # @@protoc_insertion_point(module_scope) diff --git a/tiktorch/server/grpc/inference_servicer.py b/tiktorch/server/grpc/inference_servicer.py index 4318fa50..58628b31 100644 --- a/tiktorch/server/grpc/inference_servicer.py +++ b/tiktorch/server/grpc/inference_servicer.py @@ -1,13 +1,11 @@ import time -import grpc - from tiktorch.converters import pb_tensors_to_sample, sample_to_pb_tensors from tiktorch.proto import inference_pb2, inference_pb2_grpc, utils_pb2 from tiktorch.rpc.mp import BioModelClient from tiktorch.server.data_store import IDataStore from tiktorch.server.device_pool import IDevicePool -from tiktorch.server.grpc.utils_servicer import list_devices +from tiktorch.server.grpc.utils_servicer import get_model_session, list_devices from tiktorch.server.session.process import InputSampleValidator, start_model_session_process from tiktorch.server.session_manager import Session, SessionManager @@ -20,9 +18,7 @@ def __init__( self.__session_manager = session_manager self.__data_store = data_store - def CreateModelSession( - self, request: inference_pb2.CreateModelSessionRequest, context - ) -> inference_pb2.ModelSession: + def CreateModelSession(self, request: inference_pb2.CreateModelSessionRequest, context) -> utils_pb2.ModelSession: if request.HasField("model_uri"): if not request.model_uri.startswith("upload://"): raise NotImplementedError("Only upload:// URI supported") @@ -47,7 +43,7 @@ def CreateModelSession( self.__session_manager.close_session(session.id) raise e - return inference_pb2.ModelSession(id=session.id) + return utils_pb2.ModelSession(id=session.id) def CreateDatasetDescription( self, request: inference_pb2.CreateDatasetDescriptionRequest, context @@ -56,7 +52,7 @@ def CreateDatasetDescription( id = session.client.api.create_dataset_description(mean=request.mean, stddev=request.stddev) return inference_pb2.DatasetDescription(id=id) - def CloseModelSession(self, request: inference_pb2.ModelSession, context) -> utils_pb2.Empty: + def CloseModelSession(self, request: utils_pb2.ModelSession, context) -> utils_pb2.Empty: self.__session_manager.close_session(request.id) return utils_pb2.Empty() @@ -77,21 +73,15 @@ def GetLogs(self, request: utils_pb2.Empty, context): def ListDevices(self, request: utils_pb2.Empty, context) -> utils_pb2.Devices: return list_devices(self.__device_pool) - def Predict(self, request: inference_pb2.PredictRequest, context) -> inference_pb2.PredictResponse: + def Predict(self, request: utils_pb2.PredictRequest, context) -> utils_pb2.PredictResponse: session = self._getModelSession(context, request.modelSessionId) input_sample = pb_tensors_to_sample(request.tensors) tensor_validator = InputSampleValidator(session.client.input_specs) tensor_validator.check_tensors(input_sample) res = session.client.api.forward(input_sample).result() - return inference_pb2.PredictResponse(tensors=sample_to_pb_tensors(res)) - - def _getModelSession(self, context, modelSessionId: str) -> Session[BioModelClient]: - if not modelSessionId: - context.abort(grpc.StatusCode.FAILED_PRECONDITION, "model-session-id has not been provided by client") - - session = self.__session_manager.get(modelSessionId) + return utils_pb2.PredictResponse(tensors=sample_to_pb_tensors(res)) - if session is None: - context.abort(grpc.StatusCode.FAILED_PRECONDITION, f"model-session with id {modelSessionId} doesn't exist") - - return session + def _getModelSession(self, context, modelSessionId: utils_pb2.ModelSession) -> Session[BioModelClient]: + return get_model_session( + session_manager=self.__session_manager, model_session_id=modelSessionId, context=context + ) diff --git a/tiktorch/server/grpc/training_servicer.py b/tiktorch/server/grpc/training_servicer.py index c44a0ad0..b3d677b9 100644 --- a/tiktorch/server/grpc/training_servicer.py +++ b/tiktorch/server/grpc/training_servicer.py @@ -5,13 +5,12 @@ from pathlib import Path from typing import Callable, List -import grpc import torch from tiktorch.converters import pb_tensor_to_numpy, trainer_state_to_pb from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2 from tiktorch.server.device_pool import IDevicePool -from tiktorch.server.grpc.utils_servicer import list_devices +from tiktorch.server.grpc.utils_servicer import get_model_session, list_devices from tiktorch.server.session.process import start_trainer_process from tiktorch.server.session.rpc_interface import IRPCTrainer from tiktorch.server.session_manager import Session, SessionManager @@ -51,39 +50,39 @@ def Init(self, request: training_pb2.TrainingConfig, context): self._session_manager.close_session(session.id) raise e - return training_pb2.TrainingSessionId(id=session.id) + return utils_pb2.ModelSession(id=session.id) - def Start(self, request: training_pb2.TrainingSessionId, context): - session = self._getTrainerSession(context, request.id) + def Start(self, request: utils_pb2.ModelSession, context): + session = self._getTrainerSession(context, request) session.client.start_training() return utils_pb2.Empty() def Resume(self, request, context): - session = self._getTrainerSession(context, request.id) + session = self._getTrainerSession(context, request) session.client.resume_training() return utils_pb2.Empty() - def Pause(self, request: training_pb2.TrainingSessionId, context): - session = self._getTrainerSession(context, request.id) + def Pause(self, request: utils_pb2.ModelSession, context): + session = self._getTrainerSession(context, request) session.client.pause_training() return utils_pb2.Empty() def Save(self, request: training_pb2.SaveRequest, context): - session = self._getTrainerSession(context, request.sessionId.id) + session = self._getTrainerSession(context, request.modelSessionId) session.client.save(Path(request.filePath)) return utils_pb2.Empty() def Export(self, request: training_pb2.ExportRequest, context): - session = self._getTrainerSession(context, request.sessionId.id) + session = self._getTrainerSession(context, request.modelSessionId) session.client.export(Path(request.filePath)) return utils_pb2.Empty() - def Predict(self, request: training_pb2.PredictRequest, context): - session = self._getTrainerSession(context, request.sessionId.id) + def Predict(self, request: utils_pb2.PredictRequest, context): + session = self._getTrainerSession(context, request.modelSessionId) tensors = [torch.tensor(pb_tensor_to_numpy(pb_tensor)) for pb_tensor in request.tensors] assert len(tensors) == 1, "We support models with one input" predictions = session.client.forward(tensors).result() - return training_pb2.PredictResponse(tensors=[self._tensor_to_pb(predictions)]) + return utils_pb2.PredictResponse(tensors=[self._tensor_to_pb(predictions)]) def _tensor_to_pb(self, tensor: torch.Tensor): dims = Trainer.get_axes_from_tensor(tensor) @@ -92,30 +91,25 @@ def _tensor_to_pb(self, tensor: torch.Tensor): proto_tensor = utils_pb2.Tensor(tensorId="", dtype=str(np_array.dtype), shape=shape, buffer=np_array.tobytes()) return proto_tensor - def StreamUpdates(self, request: training_pb2.TrainingSessionId, context): + def StreamUpdates(self, request: utils_pb2.ModelSession, context): raise NotImplementedError - def GetLogs(self, request: training_pb2.TrainingSessionId, context): + def GetLogs(self, request: utils_pb2.ModelSession, context): raise NotImplementedError - def GetStatus(self, request: training_pb2.TrainingSessionId, context): - session = self._getTrainerSession(context, request.id) + def GetStatus(self, request: utils_pb2.ModelSession, context): + session = self._getTrainerSession(context, request) state = session.client.get_state() return training_pb2.GetStatusResponse(state=trainer_state_to_pb[state]) - def CloseTrainerSession(self, request: training_pb2.TrainingSessionId, context) -> training_pb2.Empty: + def CloseTrainerSession(self, request: utils_pb2.ModelSession, context) -> utils_pb2.Empty: self._session_manager.close_session(request.id) return utils_pb2.Empty() def close_all_sessions(self): self._session_manager.close_all_sessions() - def _getTrainerSession(self, context, trainer_session_id: str) -> Session[IRPCTrainer]: - session = self._session_manager.get(trainer_session_id) - - if session is None: - context.abort( - grpc.StatusCode.FAILED_PRECONDITION, f"trainer-session with id {trainer_session_id} doesn't exist" - ) - - return session + def _getTrainerSession(self, context, model_session_id: utils_pb2.ModelSession) -> Session[IRPCTrainer]: + return get_model_session( + session_manager=self._session_manager, model_session_id=model_session_id, context=context + ) diff --git a/tiktorch/server/grpc/utils_servicer.py b/tiktorch/server/grpc/utils_servicer.py index bb23b40c..9e26d51c 100644 --- a/tiktorch/server/grpc/utils_servicer.py +++ b/tiktorch/server/grpc/utils_servicer.py @@ -1,5 +1,10 @@ +from typing import TypeVar + +import grpc + from tiktorch.proto import utils_pb2 from tiktorch.server.device_pool import DeviceStatus, IDevicePool +from tiktorch.server.session_manager import Session, SessionManager def list_devices(device_pool: IDevicePool) -> utils_pb2.Devices: @@ -16,3 +21,17 @@ def list_devices(device_pool: IDevicePool) -> utils_pb2.Devices: pb_devices.append(utils_pb2.Device(id=dev.id, status=pb_status)) return utils_pb2.Devices(devices=pb_devices) + + +T = TypeVar("T") + + +def get_model_session( + session_manager: SessionManager[T], model_session_id: utils_pb2.ModelSession, context +) -> Session[T]: + session = session_manager.get(model_session_id.id) + + if session is None: + context.abort(grpc.StatusCode.FAILED_PRECONDITION, f"model-session with id {model_session_id.id} doesn't exist") + + return session