Skip to content

Commit

Permalink
Merge pull request #227 from thodkatz/add-forward-to-training-servicer
Browse files Browse the repository at this point in the history
Add forward to training servicer
  • Loading branch information
thodkatz authored Jan 20, 2025
2 parents 57df547 + f787cc6 commit 90b4ce0
Show file tree
Hide file tree
Showing 19 changed files with 412 additions and 262 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ jobs:
os: [macos-latest, windows-latest, ubuntu-latest]
runs-on: ${{ matrix.os }}
env:
TIKTORCH_PACKAGE_NAME: tiktorch-${{ needs.conda-noarch-build.outputs.version }}-py_0.tar.bz2
TIKTORCH_PACKAGE_NAME: tiktorch-${{ needs.conda-noarch-build.outputs.version }}-py_0.conda
steps:
# Use GNU tar instead of BSD tar on Windows
- name: "Use GNU tar instead of BSD tar"
Expand Down
14 changes: 1 addition & 13 deletions proto/inference.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ service Inference {


message CreateDatasetDescriptionRequest {
string modelSessionId = 1;
ModelSession modelSessionId = 1;
double mean = 3;
double stddev = 4;
}
Expand Down Expand Up @@ -53,9 +53,6 @@ message NamedFloats {
}


message ModelSession {
string id = 1;
}

message LogEntry {
enum Level {
Expand All @@ -73,15 +70,6 @@ message LogEntry {
}


message PredictRequest {
string modelSessionId = 1;
string datasetId = 2;
repeated Tensor tensors = 3;
}

message PredictResponse {
repeated Tensor tensors = 1;
}


service FlightControl {
Expand Down
35 changes: 12 additions & 23 deletions proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,33 @@ import "utils.proto";
service Training {
rpc ListDevices(Empty) returns (Devices) {}

rpc Init(TrainingConfig) returns (TrainingSessionId) {}
rpc Init(TrainingConfig) returns (ModelSession) {}

rpc Start(TrainingSessionId) returns (Empty) {}
rpc Start(ModelSession) returns (Empty) {}

rpc Resume(TrainingSessionId) returns (Empty) {}
rpc Resume(ModelSession) returns (Empty) {}

rpc Pause(TrainingSessionId) returns (Empty) {}
rpc Pause(ModelSession) returns (Empty) {}

rpc StreamUpdates(TrainingSessionId) returns (stream StreamUpdateResponse) {}
rpc StreamUpdates(ModelSession) returns (stream StreamUpdateResponse) {}

rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {}
rpc GetLogs(ModelSession) returns (GetLogsResponse) {}

rpc Save(TrainingSessionId) returns (Empty) {}

rpc Export(TrainingSessionId) returns (Empty) {}
rpc Export(ModelSession) returns (Empty) {}

rpc Predict(PredictRequest) returns (PredictResponse) {}

rpc GetStatus(TrainingSessionId) returns (GetStatusResponse) {}
rpc GetStatus(ModelSession) returns (GetStatusResponse) {}

rpc CloseTrainerSession(TrainingSessionId) returns (Empty) {}
rpc CloseTrainerSession(ModelSession) returns (Empty) {}
}

message TrainingSessionId {

message GetBestModelIdxResponse {
string id = 1;
}


message Logs {
enum ModelPhase {
Train = 0;
Expand All @@ -59,17 +59,6 @@ message GetLogsResponse {
}



message PredictRequest {
repeated Tensor tensors = 1;
TrainingSessionId sessionId = 2;
}


message PredictResponse {
repeated Tensor tensors = 1;
}

message ValidationResponse {
double validation_score_average = 1;
}
Expand Down
13 changes: 13 additions & 0 deletions proto/utils.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@ syntax = "proto3";

message Empty {}

message ModelSession {
string id = 1;
}

message PredictRequest {
ModelSession modelSessionId = 1;
repeated Tensor tensors = 2;
}

message PredictResponse {
repeated Tensor tensors = 1;
}

message NamedInt {
uint32 size = 1;
string name = 2;
Expand Down
43 changes: 22 additions & 21 deletions tests/test_server/test_grpc/test_inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def test_model_session_creation_using_non_existent_upload(self, grpc_stub):

def test_predict_call_fails_without_specifying_model_session_id(self, grpc_stub):
with pytest.raises(grpc.RpcError) as e:
grpc_stub.Predict(inference_pb2.PredictRequest())
grpc_stub.Predict(utils_pb2.PredictRequest())

assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
assert "model-session-id has not been provided" in e.value.details()
assert "model-session with id doesn't exist" in e.value.details()

def test_model_init_failed_close_session(self, bioimage_model_explicit_add_one_siso_v5, grpc_stub):
"""
Expand Down Expand Up @@ -169,17 +169,18 @@ def test_returns_ack_message(self, bioimage_model_explicit_add_one_siso_v5, grpc
class TestForwardPass:
def test_call_fails_with_unknown_model_session_id(self, grpc_stub):
with pytest.raises(grpc.RpcError) as e:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId="myid1"))
model_id = utils_pb2.ModelSession(id="myid")
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id))
assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
assert "model-session with id myid1 doesn't exist" in e.value.details()
assert "model-session with id myid doesn't exist" in e.value.details()

def test_call_predict_valid_explicit(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert len(res.tensors) == 1
pb_tensor = res.tensors[0]
assert pb_tensor.tensorId == "output"
Expand All @@ -188,11 +189,11 @@ def test_call_predict_valid_explicit(self, grpc_stub, bioimage_model_explicit_ad

def test_call_predict_valid_explicit_v4(self, grpc_stub, bioimage_model_add_one_v4):
model_bytes = bioimage_model_add_one_v4
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert len(res.tensors) == 1
pb_tensor = res.tensors[0]
assert pb_tensor.tensorId == "output"
Expand All @@ -201,16 +202,16 @@ def test_call_predict_valid_explicit_v4(self, grpc_stub, bioimage_model_add_one_

def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("batch", "channel", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("input", arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axis")

def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_model_add_one_miso_v5):
model_bytes = bioimage_model_add_one_miso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))

arr1 = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensor_id1 = "input1"
Expand All @@ -227,8 +228,8 @@ def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_m
converters.xarray_to_pb_tensor(tensor_id, arr) for tensor_id, arr in zip(input_tensor_ids, tensors_arr)
]

res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.CloseModelSession(model)
res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
grpc_stub.CloseModelSession(model_id)
assert len(res.tensors) == 1
pb_tensor = res.tensors[0]
assert pb_tensor.tensorId == "output"
Expand All @@ -238,33 +239,33 @@ def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_m
@pytest.mark.parametrize("shape", [(1, 2, 10, 20), (1, 2, 12, 20), (1, 2, 10, 23), (1, 2, 12, 23)])
def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimage_model_param_add_one_siso_v5):
model_bytes = bioimage_model_param_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))

@pytest.mark.parametrize(
"shape",
[(1, 1, 10, 20), (1, 2, 8, 20), (1, 2, 11, 20), (1, 2, 10, 21)],
)
def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimage_model_param_add_one_siso_v5):
model_bytes = bioimage_model_param_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axis")

def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("invalidTensorName", arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Spec 'invalidTensorName' doesn't exist")

@pytest.mark.parametrize(
Expand All @@ -278,10 +279,10 @@ def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimage_model_explici
)
def test_call_predict_invalid_axes(self, grpc_stub, axes, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=axes)
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axes names")
84 changes: 82 additions & 2 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import h5py
import numpy as np
import pytest
import xarray as xr

from tiktorch.converters import pb_state_to_trainer, trainer_state_to_pb
from tiktorch.converters import pb_state_to_trainer, pb_tensor_to_xarray, trainer_state_to_pb, xarray_to_pb_tensor
from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2
from tiktorch.server.device_pool import TorchDevicePool
from tiktorch.server.grpc import training_servicer
Expand Down Expand Up @@ -347,7 +348,7 @@ def test_start_training_without_init(self, grpc_stub):
with pytest.raises(grpc.RpcError) as excinfo:
grpc_stub.Start(utils_pb2.Empty())
assert excinfo.value.code() == grpc.StatusCode.FAILED_PRECONDITION
assert "trainer-session with id doesn't exist" in excinfo.value.details()
assert "model-session with id doesn't exist" in excinfo.value.details()

def test_recover_training_failed(self):
class MockedExceptionTrainer:
Expand Down Expand Up @@ -481,6 +482,85 @@ def test_close_trainer_session_twice(self, grpc_stub):
grpc_stub.CloseTrainerSession(training_session_id)
assert "Unknown session" in excinfo.value.details()

@pytest.mark.parametrize(
"dims, shape",
[
(
("b", "c", "z", "y", "x"),
(5, 3, 1, 128, 128),
),
(("b", "z", "y", "x", "c"), (5, 1, 128, 128, 3)), # order of input may be different than the one expected
],
)
def test_forward_while_running(self, grpc_stub, dims, shape):
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())
)

grpc_stub.Start(training_session_id)

data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=dims)
pb_tensor = xarray_to_pb_tensor(tensor_id="input", array=xarray_data)
predict_request = utils_pb2.PredictRequest(modelSessionId=training_session_id, tensors=[pb_tensor])

response = grpc_stub.Predict(predict_request)

# assert that predict command has retained the init state (e.g. RUNNING)
self.assert_state(grpc_stub, training_session_id, TrainerState.RUNNING)

predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors]
assert len(predicted_tensors) == 1
predicted_tensor = predicted_tensors[0]
assert predicted_tensor.dims == ("b", "c", "z", "y", "x")
out_channels_unet2d = 2
assert predicted_tensor.shape == (5, out_channels_unet2d, 1, 128, 128)

def test_forward_while_paused(self, grpc_stub):
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())
)

grpc_stub.Start(training_session_id)

batch = 5
in_channels_unet2d = 3
out_channels_unet2d = 2
shape = (batch, in_channels_unet2d, 1, 128, 128)
data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x"))
pb_tensor = xarray_to_pb_tensor(tensor_id="input", array=xarray_data)
predict_request = utils_pb2.PredictRequest(modelSessionId=training_session_id, tensors=[pb_tensor])

grpc_stub.Pause(training_session_id)

response = grpc_stub.Predict(predict_request)

# assert that predict command has retained the init state (e.g. PAUSED)
self.assert_state(grpc_stub, training_session_id, TrainerState.PAUSED)

predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors]
assert len(predicted_tensors) == 1
predicted_tensor = predicted_tensors[0]
assert predicted_tensor.dims == ("b", "c", "z", "y", "x")
assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128)

def test_forward_invalid_dims(self, grpc_stub):
training_session_id = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())
)

grpc_stub.Start(training_session_id)

shape = (10, 11, 12)
data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=("dim1", "dim2", "dim3"))
pb_tensor = xarray_to_pb_tensor(tensor_id="input", array=xarray_data)
predict_request = utils_pb2.PredictRequest(modelSessionId=training_session_id, tensors=[pb_tensor])
with pytest.raises(grpc.RpcError) as excinfo:
grpc_stub.Predict(predict_request)
assert "Tensor dims should be" in excinfo.value.details()

def test_close_session(self, grpc_stub):
"""
Test closing a training session.
Expand Down
Loading

0 comments on commit 90b4ce0

Please sign in to comment.