Skip to content

Commit

Permalink
Move PredictRequest and ModelSession to utils.proto
Browse files Browse the repository at this point in the history
Since both inference and training servicers have common the concept of
id, the training session id was replaced with the model session one used
for inference. This model session protobuf interfaced moved to a
separate utils proto file.

The PredictRequest being common, can be leveraged for abstraction.
  • Loading branch information
thodkatz committed Jan 16, 2025
1 parent 941db36 commit d3d6702
Show file tree
Hide file tree
Showing 13 changed files with 211 additions and 253 deletions.
14 changes: 1 addition & 13 deletions proto/inference.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ service Inference {


message CreateDatasetDescriptionRequest {
string modelSessionId = 1;
ModelSession modelSessionId = 1;
double mean = 3;
double stddev = 4;
}
Expand Down Expand Up @@ -53,9 +53,6 @@ message NamedFloats {
}


message ModelSession {
string id = 1;
}

message LogEntry {
enum Level {
Expand All @@ -73,15 +70,6 @@ message LogEntry {
}


message PredictRequest {
string modelSessionId = 1;
string datasetId = 2;
repeated Tensor tensors = 3;
}

message PredictResponse {
repeated Tensor tensors = 1;
}


service FlightControl {
Expand Down
35 changes: 12 additions & 23 deletions proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,33 @@ import "utils.proto";
service Training {
rpc ListDevices(Empty) returns (Devices) {}

rpc Init(TrainingConfig) returns (TrainingSessionId) {}
rpc Init(TrainingConfig) returns (ModelSession) {}

rpc Start(TrainingSessionId) returns (Empty) {}
rpc Start(ModelSession) returns (Empty) {}

rpc Resume(TrainingSessionId) returns (Empty) {}
rpc Resume(ModelSession) returns (Empty) {}

rpc Pause(TrainingSessionId) returns (Empty) {}
rpc Pause(ModelSession) returns (Empty) {}

rpc StreamUpdates(TrainingSessionId) returns (stream StreamUpdateResponse) {}
rpc StreamUpdates(ModelSession) returns (stream StreamUpdateResponse) {}

rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {}
rpc GetLogs(ModelSession) returns (GetLogsResponse) {}

rpc Save(TrainingSessionId) returns (Empty) {}

rpc Export(TrainingSessionId) returns (Empty) {}
rpc Export(ModelSession) returns (Empty) {}

rpc Predict(PredictRequest) returns (PredictResponse) {}

rpc GetStatus(TrainingSessionId) returns (GetStatusResponse) {}
rpc GetStatus(ModelSession) returns (GetStatusResponse) {}

rpc CloseTrainerSession(TrainingSessionId) returns (Empty) {}
rpc CloseTrainerSession(ModelSession) returns (Empty) {}
}

message TrainingSessionId {

message GetBestModelIdxResponse {
string id = 1;
}


message Logs {
enum ModelPhase {
Train = 0;
Expand All @@ -59,17 +59,6 @@ message GetLogsResponse {
}



message PredictRequest {
repeated Tensor tensors = 1;
TrainingSessionId sessionId = 2;
}


message PredictResponse {
repeated Tensor tensors = 1;
}

message ValidationResponse {
double validation_score_average = 1;
}
Expand Down
13 changes: 13 additions & 0 deletions proto/utils.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@ syntax = "proto3";

message Empty {}

message ModelSession {
string id = 1;
}

message PredictRequest {
ModelSession modelSessionId = 1;
repeated Tensor tensors = 2;
}

message PredictResponse {
repeated Tensor tensors = 1;
}

message NamedInt {
uint32 size = 1;
string name = 2;
Expand Down
43 changes: 22 additions & 21 deletions tests/test_server/test_grpc/test_inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def test_model_session_creation_using_non_existent_upload(self, grpc_stub):

def test_predict_call_fails_without_specifying_model_session_id(self, grpc_stub):
with pytest.raises(grpc.RpcError) as e:
grpc_stub.Predict(inference_pb2.PredictRequest())
grpc_stub.Predict(utils_pb2.PredictRequest())

assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
assert "model-session-id has not been provided" in e.value.details()
assert "model-session with id doesn't exist" in e.value.details()

def test_model_init_failed_close_session(self, bioimage_model_explicit_add_one_siso_v5, grpc_stub):
"""
Expand Down Expand Up @@ -169,17 +169,18 @@ def test_returns_ack_message(self, bioimage_model_explicit_add_one_siso_v5, grpc
class TestForwardPass:
def test_call_fails_with_unknown_model_session_id(self, grpc_stub):
with pytest.raises(grpc.RpcError) as e:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId="myid1"))
model_id = utils_pb2.ModelSession(id="myid")
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id))
assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
assert "model-session with id myid1 doesn't exist" in e.value.details()
assert "model-session with id myid doesn't exist" in e.value.details()

def test_call_predict_valid_explicit(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert len(res.tensors) == 1
pb_tensor = res.tensors[0]
assert pb_tensor.tensorId == "output"
Expand All @@ -188,11 +189,11 @@ def test_call_predict_valid_explicit(self, grpc_stub, bioimage_model_explicit_ad

def test_call_predict_valid_explicit_v4(self, grpc_stub, bioimage_model_add_one_v4):
model_bytes = bioimage_model_add_one_v4
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert len(res.tensors) == 1
pb_tensor = res.tensors[0]
assert pb_tensor.tensorId == "output"
Expand All @@ -201,16 +202,16 @@ def test_call_predict_valid_explicit_v4(self, grpc_stub, bioimage_model_add_one_

def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("batch", "channel", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("input", arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axis")

def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_model_add_one_miso_v5):
model_bytes = bioimage_model_add_one_miso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))

arr1 = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensor_id1 = "input1"
Expand All @@ -227,8 +228,8 @@ def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_m
converters.xarray_to_pb_tensor(tensor_id, arr) for tensor_id, arr in zip(input_tensor_ids, tensors_arr)
]

res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.CloseModelSession(model)
res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
grpc_stub.CloseModelSession(model_id)
assert len(res.tensors) == 1
pb_tensor = res.tensors[0]
assert pb_tensor.tensorId == "output"
Expand All @@ -238,33 +239,33 @@ def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_m
@pytest.mark.parametrize("shape", [(1, 2, 10, 20), (1, 2, 12, 20), (1, 2, 10, 23), (1, 2, 12, 23)])
def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimage_model_param_add_one_siso_v5):
model_bytes = bioimage_model_param_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))

@pytest.mark.parametrize(
"shape",
[(1, 1, 10, 20), (1, 2, 8, 20), (1, 2, 11, 20), (1, 2, 10, 21)],
)
def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimage_model_param_add_one_siso_v5):
model_bytes = bioimage_model_param_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axis")

def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("invalidTensorName", arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Spec 'invalidTensorName' doesn't exist")

@pytest.mark.parametrize(
Expand All @@ -278,10 +279,10 @@ def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimage_model_explici
)
def test_call_predict_invalid_axes(self, grpc_stub, axes, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=axes)
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axes names")
11 changes: 6 additions & 5 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def test_init_failed_then_devices_are_released(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
assert init_response.id is not None


def test_start_training_success(self):
"""
Test starting training after successful initialization.
Expand Down Expand Up @@ -348,7 +349,7 @@ def test_start_training_without_init(self, grpc_stub):
with pytest.raises(grpc.RpcError) as excinfo:
grpc_stub.Start(utils_pb2.Empty())
assert excinfo.value.code() == grpc.StatusCode.FAILED_PRECONDITION
assert "trainer-session with id doesn't exist" in excinfo.value.details()
assert "model-session with id doesn't exist" in excinfo.value.details()

def test_recover_training_failed(self):
class MockedExceptionTrainer:
Expand Down Expand Up @@ -484,7 +485,7 @@ def test_close_trainer_session_twice(self, grpc_stub):

def test_forward_while_running(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)

grpc_stub.Start(training_session_id)

Expand All @@ -495,7 +496,7 @@ def test_forward_while_running(self, grpc_stub):
data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x"))
pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data)
predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor])
predict_request = utils_pb2.PredictRequest(modelSessionId=training_session_id, tensors=[pb_tensor])

response = grpc_stub.Predict(predict_request)

Expand All @@ -510,7 +511,7 @@ def test_forward_while_running(self, grpc_stub):

def test_forward_while_paused(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)

grpc_stub.Start(training_session_id)

Expand All @@ -521,7 +522,7 @@ def test_forward_while_paused(self, grpc_stub):
data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x"))
pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data)
predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor])
predict_request = utils_pb2.PredictRequest(modelSessionId=training_session_id, tensors=[pb_tensor])

grpc_stub.Pause(training_session_id)

Expand Down
Loading

0 comments on commit d3d6702

Please sign in to comment.