Skip to content

Commit

Permalink
feat: Optimize() validations across TRT, VLLM, Neuron container optim…
Browse files Browse the repository at this point in the history
…izations (#4927)

* changes for blackbird - model sharding

changes for blackbird - model sharding

add more tests

fix sharded model flag

add optimization validations

fix formatting and msging

fixing validation bugs

add UTs

simplify logic

update messaging

formatting

fix UTs

add more UTs

fix validations

update ruleset

update formatting

update validation logic

update bug fixes

Disable network isolation if using sharded models.

check sharding + network iso pre optimization

add more UTs for sharding

add more UTs

* fix rebase issues

---------

Co-authored-by: Ashish Gupta <[email protected]>
  • Loading branch information
gwang111 and Ashish Gupta authored Nov 19, 2024
1 parent 663bbb6 commit efd6c80
Show file tree
Hide file tree
Showing 8 changed files with 1,166 additions and 21 deletions.
14 changes: 14 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def __init__(
self.endpoint_name = None
self.inference_component_name = None
self._is_compiled_model = False
self._is_sharded_model = False
self._compilation_job_name = None
self._is_edge_packaged_model = False
self.inference_recommender_job_results = None
Expand Down Expand Up @@ -1599,6 +1600,19 @@ def deploy(
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))

if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
logging.warning(
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
)
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED

if self._is_sharded_model and self._enable_network_isolation:
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)

# Support multiple models on same endpoint
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
if endpoint_name:
Expand Down
23 changes: 20 additions & 3 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ def _optimize_for_jumpstart(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -705,6 +706,8 @@ def _optimize_for_jumpstart(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -730,8 +733,13 @@ def _optimize_for_jumpstart(
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)

# optimization_config can contain configs for both quantization and compilation
optimization_config, quantization_override_env, compilation_override_env = (
_extract_optimization_config_and_env(quantization_config, compilation_config)
(
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
) = _extract_optimization_config_and_env(
quantization_config, compilation_config, sharding_config
)

if not optimization_config:
Expand Down Expand Up @@ -807,11 +815,20 @@ def _optimize_for_jumpstart(
{
**(quantization_override_env or {}),
**(compilation_override_env or {}),
**(sharding_override_env or {}),
},
)
if optimization_env_vars:
self.pysdk_model.env.update(optimization_env_vars)
if quantization_config or is_compilation:

if sharding_config and self.pysdk_model._enable_network_isolation:
logger.warning(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access. Setting it to False."
)
self.pysdk_model._enable_network_isolation = False

if quantization_config or sharding_config or is_compilation:
return create_optimization_job_args
return None

Expand Down
79 changes: 76 additions & 3 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
get_huggingface_model_metadata,
download_huggingface_model_metadata,
)
from sagemaker.serve.validations.optimization import _validate_optimization_configuration

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1120,6 +1121,7 @@ def optimize(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1143,6 +1145,8 @@ def optimize(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand Down Expand Up @@ -1171,6 +1175,7 @@ def optimize(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1227,6 +1235,27 @@ def _model_builder_optimize_wrapper(
Returns:
Model: A deployable ``Model`` object.
"""
if (
hasattr(self, "enable_network_isolation")
and self.enable_network_isolation
and sharding_config
):
raise ValueError(
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
"Loading of model requires network access."
)

# TODO: ideally these dictionaries need to be sagemaker_core shapes
# TODO: for organization, abstract all validation behind this fn
_validate_optimization_configuration(
is_jumpstart=self._is_jumpstart_model_id(),
instance_type=instance_type,
quantization_config=quantization_config,
compilation_config=compilation_config,
sharding_config=sharding_config,
speculative_decoding_config=speculative_decoding_config,
)

self.is_compiled = compilation_config is not None
self.is_quantized = quantization_config is not None
self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider(
Expand All @@ -1236,6 +1265,36 @@ def _model_builder_optimize_wrapper(
if self.mode != Mode.SAGEMAKER_ENDPOINT:
raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.")

if sharding_config and (
quantization_config or compilation_config or speculative_decoding_config
):
raise ValueError(
(
"Sharding config is mutually exclusive "
"and cannot be combined with any other optimization."
)
)

if sharding_config:
has_tensor_parallel_degree_in_env_vars = (
env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
)
has_tensor_parallel_degree_in_overrides = (
sharding_config
and sharding_config.get("OverrideEnvironment")
and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config.get("OverrideEnvironment")
)
if (
not has_tensor_parallel_degree_in_env_vars
and not has_tensor_parallel_degree_in_overrides
):
raise ValueError(
(
"OPTION_TENSOR_PARALLEL_DEGREE is a required "
"environment variable with sharding config."
)
)

self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
self.instance_type = instance_type or self.instance_type
self.role_arn = role_arn or self.role_arn
Expand All @@ -1252,6 +1311,7 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
Expand All @@ -1270,12 +1330,16 @@ def _model_builder_optimize_wrapper(
quantization_config=quantization_config,
compilation_config=compilation_config,
speculative_decoding_config=speculative_decoding_config,
sharding_config=sharding_config,
env_vars=env_vars,
vpc_config=vpc_config,
kms_key=kms_key,
max_runtime_in_sec=max_runtime_in_sec,
)

if sharding_config:
self.pysdk_model._is_sharded_model = True

if input_args:
optimization_instance_type = input_args["DeploymentInstanceType"]

Expand Down Expand Up @@ -1325,6 +1389,7 @@ def _optimize_for_hf(
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
speculative_decoding_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
env_vars: Optional[Dict] = None,
vpc_config: Optional[Dict] = None,
kms_key: Optional[str] = None,
Expand All @@ -1340,6 +1405,8 @@ def _optimize_for_hf(
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
Defaults to ``None``
sharding_config (Optional[Dict]): Model sharding configuration.
Defaults to ``None``
env_vars (Optional[Dict]): Additional environment variables to run the optimization
container. Defaults to ``None``.
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
Expand All @@ -1363,7 +1430,7 @@ def _optimize_for_hf(
self.pysdk_model, speculative_decoding_config, False
)

if quantization_config or compilation_config:
if quantization_config or compilation_config or sharding_config:
create_optimization_job_args = {
"OptimizationJobName": job_name,
"DeploymentInstanceType": self.instance_type,
Expand All @@ -1378,8 +1445,13 @@ def _optimize_for_hf(
model_source = _generate_model_source(self.pysdk_model.model_data, False)
create_optimization_job_args["ModelSource"] = model_source

optimization_config, quantization_override_env, compilation_override_env = (
_extract_optimization_config_and_env(quantization_config, compilation_config)
(
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
) = _extract_optimization_config_and_env(
quantization_config, compilation_config, sharding_config
)
create_optimization_job_args["OptimizationConfigs"] = [
{k: v} for k, v in optimization_config.items()
Expand All @@ -1388,6 +1460,7 @@ def _optimize_for_hf(
{
**(quantization_override_env or {}),
**(compilation_override_env or {}),
**(sharding_override_env or {}),
}
)

Expand Down
22 changes: 17 additions & 5 deletions src/sagemaker/serve/utils/optimize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,19 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:


def _extract_optimization_config_and_env(
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
quantization_config: Optional[Dict] = None,
compilation_config: Optional[Dict] = None,
sharding_config: Optional[Dict] = None,
) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
"""Extracts optimization config and environment variables.
Args:
quantization_config (Optional[Dict]): The quantization config.
compilation_config (Optional[Dict]): The compilation config.
sharding_config (Optional[Dict]): The sharding config.
Returns:
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]:
Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]:
The optimization config and environment variables.
"""
optimization_config = {}
Expand All @@ -380,18 +383,27 @@ def _extract_optimization_config_and_env(
compilation_override_env = (
compilation_config.get("OverrideEnvironment") if compilation_config else None
)
sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None

if quantization_config is not None:
optimization_config["ModelQuantizationConfig"] = quantization_config

if compilation_config is not None:
optimization_config["ModelCompilationConfig"] = compilation_config

if sharding_config is not None:
optimization_config["ModelShardingConfig"] = sharding_config

# Return optimization config dict and environment variables if either is present
if optimization_config:
return optimization_config, quantization_override_env, compilation_override_env
return (
optimization_config,
quantization_override_env,
compilation_override_env,
sharding_override_env,
)

return None, None, None
return None, None, None, None


def _custom_speculative_decoding(
Expand Down
Loading

0 comments on commit efd6c80

Please sign in to comment.