Skip to content

Commit

Permalink
Fixed the priority of parameters defined in register curl cmd vs mode…
Browse files Browse the repository at this point in the history
…l-config.yaml (#2858)

* fmt

* format java

* update default value

---------

Co-authored-by: Ankith Gunapal <[email protected]>
  • Loading branch information
lxning and agunapal authored Dec 20, 2023
1 parent 814ef96 commit 77ca411
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

/** Register Model Request for Model server */
public class RegisterModelRequest {
public static final Integer DEFAULT_BATCH_SIZE = 1;
public static final Integer DEFAULT_MAX_BATCH_DELAY = 100;

@SerializedName("model_name")
private String modelName;

Expand Down Expand Up @@ -42,15 +45,18 @@ public RegisterModelRequest(QueryStringDecoder decoder) {
modelName = NettyUtils.getParameter(decoder, "model_name", null);
runtime = NettyUtils.getParameter(decoder, "runtime", null);
handler = NettyUtils.getParameter(decoder, "handler", null);
batchSize = NettyUtils.getIntParameter(decoder, "batch_size", 1);
maxBatchDelay = NettyUtils.getIntParameter(decoder, "max_batch_delay", 100);
batchSize = NettyUtils.getIntParameter(decoder, "batch_size", -1 * DEFAULT_BATCH_SIZE);
maxBatchDelay =
NettyUtils.getIntParameter(
decoder, "max_batch_delay", -1 * DEFAULT_MAX_BATCH_DELAY);
initialWorkers =
NettyUtils.getIntParameter(
decoder,
"initial_workers",
ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel());
synchronous = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "synchronous", "true"));
responseTimeout = NettyUtils.getIntParameter(decoder, "response_timeout", -1);
responseTimeout =
NettyUtils.getIntParameter(decoder, "response_timeout", -1 * DEFAULT_BATCH_SIZE);
modelUrl = NettyUtils.getParameter(decoder, "url", null);
s3SseKms = Boolean.parseBoolean(NettyUtils.getParameter(decoder, "s3_sse_kms", "false"));
}
Expand All @@ -59,8 +65,10 @@ public RegisterModelRequest(org.pytorch.serve.grpc.management.RegisterModelReque
modelName = GRPCUtils.getRegisterParam(request.getModelName(), null);
runtime = GRPCUtils.getRegisterParam(request.getRuntime(), null);
handler = GRPCUtils.getRegisterParam(request.getHandler(), null);
batchSize = GRPCUtils.getRegisterParam(request.getBatchSize(), 1);
maxBatchDelay = GRPCUtils.getRegisterParam(request.getMaxBatchDelay(), 100);
batchSize = GRPCUtils.getRegisterParam(request.getBatchSize(), -1 * DEFAULT_BATCH_SIZE);
maxBatchDelay =
GRPCUtils.getRegisterParam(
request.getMaxBatchDelay(), -1 * DEFAULT_MAX_BATCH_DELAY);
initialWorkers =
GRPCUtils.getRegisterParam(
request.getInitialWorkers(),
Expand All @@ -72,8 +80,8 @@ public RegisterModelRequest(org.pytorch.serve.grpc.management.RegisterModelReque
}

public RegisterModelRequest() {
batchSize = 1;
maxBatchDelay = 100;
batchSize = -1 * DEFAULT_BATCH_SIZE;
maxBatchDelay = -100 * DEFAULT_MAX_BATCH_DELAY;
synchronous = true;
initialWorkers = ConfigManager.getInstance().getConfiguredDefaultWorkersPerModel();
responseTimeout = -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.pytorch.serve.archive.model.ModelVersionNotFoundException;
import org.pytorch.serve.http.ConflictStatusException;
import org.pytorch.serve.http.InvalidModelVersionException;
import org.pytorch.serve.http.messages.RegisterModelRequest;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.messages.EnvironmentUtils;
Expand Down Expand Up @@ -300,43 +301,47 @@ private Model createModel(
boolean isWorkflowModel) {
Model model = new Model(archive, configManager.getJobQueueSize());

if (archive.getModelConfig() != null) {
int marBatchSize = archive.getModelConfig().getBatchSize();
batchSize =
marBatchSize > 0
? marBatchSize
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
batchSize);
} else {
batchSize =
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
batchSize);
if (batchSize == -1 * RegisterModelRequest.DEFAULT_BATCH_SIZE) {
if (archive.getModelConfig() != null) {
int marBatchSize = archive.getModelConfig().getBatchSize();
batchSize =
marBatchSize > 0
? marBatchSize
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
RegisterModelRequest.DEFAULT_BATCH_SIZE);
} else {
batchSize =
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.BATCH_SIZE,
RegisterModelRequest.DEFAULT_BATCH_SIZE);
}
}
model.setBatchSize(batchSize);

if (archive.getModelConfig() != null) {
int marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay();
maxBatchDelay =
marMaxBatchDelay > 0
? marMaxBatchDelay
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
maxBatchDelay);
} else {
maxBatchDelay =
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
maxBatchDelay);
if (maxBatchDelay == -1 * RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY) {
if (archive.getModelConfig() != null) {
int marMaxBatchDelay = archive.getModelConfig().getMaxBatchDelay();
maxBatchDelay =
marMaxBatchDelay > 0
? marMaxBatchDelay
: configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY);
} else {
maxBatchDelay =
configManager.getJsonIntValue(
archive.getModelName(),
archive.getModelVersion(),
Model.MAX_BATCH_DELAY,
RegisterModelRequest.DEFAULT_MAX_BATCH_DELAY);
}
}
model.setMaxBatchDelay(maxBatchDelay);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ private void addThreads(
List<WorkerThread> threads, Model model, int count, CompletableFuture<Integer> future) {
WorkerStateListener listener = new WorkerStateListener(future, count);
int maxGpu = model.getNumCores();
int stride = model.getParallelLevel() > 0 ? model.getParallelLevel() : 1;
for (int i = 0; i < count; ++i) {
int gpuId = -1;

Expand All @@ -215,12 +216,7 @@ private void addThreads(
gpuId =
model.getGpuCounter()
.getAndAccumulate(
maxGpu,
(prev, maxGpuId) ->
(prev + model.getParallelLevel() > 0
? model.getParallelLevel()
: 1)
% maxGpuId);
stride, (prev, myStride) -> (prev + myStride) % maxGpu);
if (model.getParallelLevel() == 0) {
gpuId = model.getDeviceIds().get(gpuId);
}
Expand Down
1 change: 1 addition & 0 deletions requirements/developer.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ intel_extension_for_pytorch==2.1.0; sys_platform != 'win32' and sys_platform !=
onnxruntime==1.15.0
googleapis-common-protos
onnx==1.14.1
orjson
142 changes: 142 additions & 0 deletions test/pytest/test_model_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import shutil
from pathlib import Path
from unittest.mock import patch

import pytest
import test_utils
from model_archiver import ModelArchiverConfig

CURR_FILE_PATH = Path(__file__).parent
REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent

MODEL_PY = """
import torch
import torch.nn as nn
class Foo(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
"""

HANDLER_PY = """
import os
import torch
from ts.torch_handler.base_handler import BaseHandler
class FooHandler(BaseHandler):
def initialize(self, ctx):
super().initialize(ctx)
def preprocess(self, data):
return torch.as_tensor(int(data[0].get('body').decode('utf-8')), device=self.device)
def postprocess(self, x):
return [x.item()]
"""

MODEL_CONFIG_YAML = f"""
#frontend settings
# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 4
maxBatchDelay: 100
batchSize: 4
"""


@pytest.fixture(scope="module")
def model_name():
yield "foo"


@pytest.fixture(scope="module")
def work_dir(tmp_path_factory, model_name):
return Path(tmp_path_factory.mktemp(model_name))


@pytest.fixture(scope="module", name="mar_file_path")
def create_mar_file(work_dir, model_archiver, model_name):
mar_file_path = work_dir.joinpath(model_name + ".mar")

model_config_yaml_file = work_dir / "model_config.yaml"
model_config_yaml_file.write_text(MODEL_CONFIG_YAML)

model_py_file = work_dir / "model.py"
model_py_file.write_text(MODEL_PY)

handler_py_file = work_dir / "handler.py"
handler_py_file.write_text(HANDLER_PY)

config = ModelArchiverConfig(
model_name=model_name,
version="1.0",
serialized_file=None,
model_file=model_py_file.as_posix(),
handler=handler_py_file.as_posix(),
extra_files=None,
export_path=work_dir,
requirements_file=None,
runtime="python",
force=False,
archive_format="default",
config_file=model_config_yaml_file.as_posix(),
)

with patch("archiver.ArgParser.export_model_args_parser", return_value=config):
model_archiver.generate_model_archive()

assert mar_file_path.exists()

yield mar_file_path.as_posix()

# Clean up files
mar_file_path.unlink(missing_ok=True)


def register_model(mar_file_path, model_store, params, torchserve):
shutil.copy(mar_file_path, model_store)

file_name = Path(mar_file_path).name

model_name = Path(file_name).stem

params = params + (
("model_name", model_name),
("url", file_name),
)

test_utils.reg_resp = test_utils.register_model_with_params(params)
return model_name


def test_register_model_with_batch_size(mar_file_path, model_store, torchserve):
params = (
("initial_workers", "2"),
("synchronous", "true"),
("batch_size", "2"),
)

model_name = register_model(mar_file_path, model_store, params, torchserve)

describe_resp = test_utils.describe_model(model_name, "1.0")

assert describe_resp[0]["batchSize"] == 2

test_utils.unregister_model(model_name)


def test_register_model_without_batch_size(mar_file_path, model_store, torchserve):
params = (
("initial_workers", "2"),
("synchronous", "true"),
)
model_name = register_model(mar_file_path, model_store, params, torchserve)

describe_resp = test_utils.describe_model(model_name, "1.0")

assert describe_resp[0]["batchSize"] == 4

test_utils.unregister_model(model_name)
8 changes: 8 additions & 0 deletions test/pytest/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from queue import Queue
from subprocess import PIPE, STDOUT, Popen

import orjson
import requests

# To help discover margen modules
Expand Down Expand Up @@ -125,6 +126,13 @@ def unregister_model(model_name):
return response


def describe_model(model_name, version):
response = requests.get(
"http://localhost:8081/models/{}/{}".format(model_name, version)
)
return orjson.loads(response.content)


def delete_mar_file_from_model_store(model_store=None, model_mar=None):
model_store = (
model_store
Expand Down

0 comments on commit 77ca411

Please sign in to comment.