Skip to content

Commit

Permalink
fix: Add configurable startup timeout (#98)
Browse files Browse the repository at this point in the history
* Add configurable startup timeout

Co-authored-by: Aaqib <[email protected]>
  • Loading branch information
davidthomas426 and maaquib authored Feb 1, 2022
1 parent e528089 commit a98a0e1
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 23 deletions.
12 changes: 12 additions & 0 deletions src/sagemaker_inference/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

DEFAULT_MODULE_NAME = "inference.py"
DEFAULT_MODEL_SERVER_TIMEOUT = "60"
DEFAULT_STARTUP_TIMEOUT = "600" # 10 minutes
DEFAULT_HTTP_PORT = "8080"

SAGEMAKER_BASE_PATH = os.path.join("/opt", "ml") # type: str
Expand Down Expand Up @@ -50,6 +51,7 @@ class Environment(object):
module_name (str): The name of the user-provided module. Default is inference.py.
model_server_timeout (int): Timeout in seconds for the model server. Default is 60.
model_server_workers (str): Number of worker processes the model server will use.
default_accept (str): The desired default MIME type of the inference in the response
as specified in the user-supplied SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT environment
variable. Otherwise, returns 'application/json' by default.
Expand All @@ -68,6 +70,9 @@ def __init__(self):
os.environ.get(parameters.MODEL_SERVER_TIMEOUT_ENV, DEFAULT_MODEL_SERVER_TIMEOUT)
)
self._model_server_workers = os.environ.get(parameters.MODEL_SERVER_WORKERS_ENV)
self._startup_timeout = int(
os.environ.get(parameters.STARTUP_TIMEOUT_ENV, DEFAULT_STARTUP_TIMEOUT)
)
self._default_accept = os.environ.get(
parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV, content_types.JSON
)
Expand Down Expand Up @@ -107,6 +112,13 @@ def model_server_workers(self): # type: () -> str
"""str: Number of worker processes the model server is configured to use."""
return self._model_server_workers

@property
def startup_timeout(self): # type () -> int
"""int: Timeout, in seconds, used for starting up the model server and fetching
its process id, before giving up and throwing error.
"""
return self._startup_timeout

@property
def default_accept(self): # type: () -> str
"""str: The desired default MIME type of the inference in the response."""
Expand Down
25 changes: 16 additions & 9 deletions src/sagemaker_inference/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
else:
_adapt_to_mms_format(handler_service)

_create_model_server_config_file()
env = environment.Environment()
_create_model_server_config_file(env)

if os.path.exists(REQUIREMENTS_PATH):
_install_requirements()
Expand All @@ -93,7 +94,10 @@ def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):

logger.info(multi_model_server_cmd)
subprocess.Popen(multi_model_server_cmd)
mms_process = _retrieve_mms_server_process()

# retry for configured timeout
mms_process = _retry_retrieve_mms_server_process(env.startup_timeout)

_add_sigterm_handler(mms_process)
_add_sigchild_handler()

Expand Down Expand Up @@ -137,15 +141,13 @@ def _set_python_path():
os.environ[PYTHON_PATH_ENV] = code_dir_path


def _create_model_server_config_file():
configuration_properties = _generate_mms_config_properties()
def _create_model_server_config_file(env):
configuration_properties = _generate_mms_config_properties(env)

utils.write_file(MMS_CONFIG_FILE, configuration_properties)


def _generate_mms_config_properties():
env = environment.Environment()

def _generate_mms_config_properties(env):
user_defined_configuration = {
"default_response_timeout": env.model_server_timeout,
"default_workers_per_model": env.model_server_workers,
Expand Down Expand Up @@ -190,8 +192,13 @@ def _install_requirements():
raise ValueError("failed to install required packages")


# retry for 10 minutes
@retry(wait_fixed=1000, stop_max_delay=10 * 60 * 1000)
def _retry_retrieve_mms_server_process(startup_timeout):
retrieve_mms_server_process = retry(wait_fixed=1000, stop_max_delay=startup_timeout * 1000)(
_retrieve_mms_server_process
)
return retrieve_mms_server_process()


def _retrieve_mms_server_process():
mms_server_processes = list()

Expand Down
1 change: 1 addition & 0 deletions src/sagemaker_inference/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DEFAULT_INVOCATIONS_ACCEPT_ENV = "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT" # type: str
MODEL_SERVER_WORKERS_ENV = "SAGEMAKER_MODEL_SERVER_WORKERS" # type: str
MODEL_SERVER_TIMEOUT_ENV = "SAGEMAKER_MODEL_SERVER_TIMEOUT" # type: str
STARTUP_TIMEOUT_ENV = "SAGEMAKER_STARTUP_TIMEOUT" # type: str
BIND_TO_PORT_ENV = "SAGEMAKER_BIND_TO_PORT" # type: str
SAFE_PORT_RANGE_ENV = "SAGEMAKER_SAFE_PORT_RANGE" # type: str
MULTI_MODEL_ENV = "SAGEMAKER_MULTI_MODEL" # type: str
2 changes: 2 additions & 0 deletions test/unit/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
parameters.USER_PROGRAM_ENV: "main.py",
parameters.MODEL_SERVER_TIMEOUT_ENV: "20",
parameters.MODEL_SERVER_WORKERS_ENV: "8",
parameters.STARTUP_TIMEOUT_ENV: "50",
parameters.DEFAULT_INVOCATIONS_ACCEPT_ENV: "text/html",
parameters.BIND_TO_PORT_ENV: "1738",
parameters.SAFE_PORT_RANGE_ENV: "1111-2222",
Expand All @@ -38,6 +39,7 @@ def test_env():
assert environment.code_dir.endswith("opt/ml/model/code")
assert env.module_name == "main"
assert env.model_server_timeout == 20
assert env.startup_timeout == 50
assert env.model_server_workers == "8"
assert env.default_accept == "text/html"
assert env.inference_http_port == "1738"
Expand Down
39 changes: 25 additions & 14 deletions test/unit/test_model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import subprocess
import types

from mock import Mock, patch
from mock import ANY, Mock, patch
import pytest

from sagemaker_inference import environment, model_server
Expand All @@ -27,13 +27,15 @@

@patch("subprocess.call")
@patch("subprocess.Popen")
@patch("sagemaker_inference.model_server._retrieve_mms_server_process")
@patch("sagemaker_inference.model_server._retry_retrieve_mms_server_process")
@patch("sagemaker_inference.model_server._add_sigterm_handler")
@patch("sagemaker_inference.model_server._install_requirements")
@patch("os.path.exists", return_value=True)
@patch("sagemaker_inference.model_server._create_model_server_config_file")
@patch("sagemaker_inference.model_server._adapt_to_mms_format")
@patch("sagemaker_inference.environment.Environment")
def test_start_model_server_default_service_handler(
env,
adapt,
create_config,
exists,
Expand All @@ -43,10 +45,12 @@ def test_start_model_server_default_service_handler(
subprocess_popen,
subprocess_call,
):
env.return_value.startup_timeout = 10000

model_server.start_model_server()

adapt.assert_called_once_with(model_server.DEFAULT_HANDLER_SERVICE)
create_config.assert_called_once_with()
create_config.assert_called_once_with(env.return_value)
exists.assert_called_once_with(REQUIREMENTS_PATH)
install_requirements.assert_called_once_with()

Expand All @@ -67,7 +71,7 @@ def test_start_model_server_default_service_handler(

@patch("subprocess.call")
@patch("subprocess.Popen")
@patch("sagemaker_inference.model_server._retrieve_mms_server_process")
@patch("sagemaker_inference.model_server._retry_retrieve_mms_server_process")
@patch("sagemaker_inference.model_server._add_sigterm_handler")
@patch("sagemaker_inference.model_server._create_model_server_config_file")
@patch("sagemaker_inference.model_server._adapt_to_mms_format")
Expand Down Expand Up @@ -146,8 +150,10 @@ def test_new_python_path():

@patch("sagemaker_inference.model_server._generate_mms_config_properties")
@patch("sagemaker_inference.utils.write_file")
def test_create_model_server_config_file(write_file, generate_mms_config_props):
model_server._create_model_server_config_file()
@patch("sagemaker_inference.environment.Environment")
def test_create_model_server_config_file(env, write_file, generate_mms_config_props):

model_server._create_model_server_config_file(env.return_value)

write_file.assert_called_once_with(
model_server.MMS_CONFIG_FILE, generate_mms_config_props.return_value
Expand All @@ -165,7 +171,7 @@ def test_generate_mms_config_properties(env, read_file):
env.return_value.model_server_workers = model_server_workers
env.return_value.inference_http_port = http_port

mms_config_properties = model_server._generate_mms_config_properties()
mms_config_properties = model_server._generate_mms_config_properties(env.return_value)

inference_address = "inference_address=http://0.0.0.0:{}\n".format(http_port)
server_timeout = "default_response_timeout={}\n".format(model_server_timeout)
Expand All @@ -184,7 +190,7 @@ def test_generate_mms_config_properties(env, read_file):
def test_generate_mms_config_properties_default_workers(env, read_file):
env.return_value.model_server_workers = None

mms_config_properties = model_server._generate_mms_config_properties()
mms_config_properties = model_server._generate_mms_config_properties(env.return_value)

workers = "default_workers_per_model={}".format(None)

Expand Down Expand Up @@ -222,9 +228,8 @@ def test_install_requirements_installation_failed(check_call):
assert "failed to install required packages" in str(e.value)


@patch("retrying.Retrying.should_reject", return_value=False)
@patch("psutil.process_iter")
def test_retrieve_mms_server_process(process_iter, retry):
def test_retrieve_mms_server_process(process_iter):
server = Mock()
server.cmdline.return_value = MMS_NAMESPACE

Expand All @@ -238,18 +243,16 @@ def test_retrieve_mms_server_process(process_iter, retry):
assert process == server


@patch("retrying.Retrying.should_reject", return_value=False)
@patch("psutil.process_iter", return_value=list())
def test_retrieve_mms_server_process_no_server(process_iter, retry):
def test_retrieve_mms_server_process_no_server(process_iter):
with pytest.raises(Exception) as e:
model_server._retrieve_mms_server_process()

assert "mms model server was unsuccessfully started" in str(e.value)


@patch("retrying.Retrying.should_reject", return_value=False)
@patch("psutil.process_iter")
def test_retrieve_mms_server_process_too_many_servers(process_iter, retry):
def test_retrieve_mms_server_process_too_many_servers(process_iter):
server = Mock()
second_server = Mock()
server.cmdline.return_value = MMS_NAMESPACE
Expand All @@ -265,3 +268,11 @@ def test_retrieve_mms_server_process_too_many_servers(process_iter, retry):
model_server._retrieve_mms_server_process()

assert "multiple mms model servers are not supported" in str(e.value)


@patch("sagemaker_inference.model_server.retry", return_value=lambda f: f)
@patch("sagemaker_inference.model_server._retrieve_mms_server_process", return_value=17)
def test_retry_retrieve_mms_server_process(retrieve, retry):
process_id = model_server._retry_retrieve_mms_server_process(100)
assert process_id == 17
retry.assert_called_once_with(wait_fixed=ANY, stop_max_delay=100 * 1000)

0 comments on commit a98a0e1

Please sign in to comment.