diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 340d35b250..83efa57cb8 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -372,6 +372,7 @@ def __init__( self.endpoint_name = None self.inference_component_name = None self._is_compiled_model = False + self._is_sharded_model = False self._compilation_job_name = None self._is_edge_packaged_model = False self.inference_recommender_job_results = None @@ -1599,6 +1600,19 @@ def deploy( if self._base_name is not None: self._base_name = "-".join((self._base_name, compiled_model_suffix)) + if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED: + logging.warning( + "Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - " + "Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints." + ) + endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED + + if self._is_sharded_model and self._enable_network_isolation: + raise ValueError( + "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " + "Loading of model requires network access." + ) + # Support multiple models on same endpoint if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: if endpoint_name: diff --git a/src/sagemaker/serve/builder/jumpstart_builder.py b/src/sagemaker/serve/builder/jumpstart_builder.py index 7d6a052023..37a77179cb 100644 --- a/src/sagemaker/serve/builder/jumpstart_builder.py +++ b/src/sagemaker/serve/builder/jumpstart_builder.py @@ -684,6 +684,7 @@ def _optimize_for_jumpstart( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None, speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, @@ -705,6 +706,8 @@ def _optimize_for_jumpstart( compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. @@ -730,8 +733,13 @@ def _optimize_for_jumpstart( pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type) # optimization_config can contain configs for both quantization and compilation - optimization_config, quantization_override_env, compilation_override_env = ( - _extract_optimization_config_and_env(quantization_config, compilation_config) + ( + optimization_config, + quantization_override_env, + compilation_override_env, + sharding_override_env, + ) = _extract_optimization_config_and_env( + quantization_config, compilation_config, sharding_config ) if not optimization_config: @@ -807,11 +815,20 @@ def _optimize_for_jumpstart( { **(quantization_override_env or {}), **(compilation_override_env or {}), + **(sharding_override_env or {}), }, ) if optimization_env_vars: self.pysdk_model.env.update(optimization_env_vars) - if quantization_config or is_compilation: + + if sharding_config and self.pysdk_model._enable_network_isolation: + logger.warning( + "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " + "Loading of model requires network access. Setting it to False." + ) + self.pysdk_model._enable_network_isolation = False + + if quantization_config or sharding_config or is_compilation: return create_optimization_job_args return None diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 61af6953a2..6a3b093ac5 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -105,6 +105,7 @@ get_huggingface_model_metadata, download_huggingface_model_metadata, ) +from sagemaker.serve.validations.optimization import _validate_optimization_configuration logger = logging.getLogger(__name__) @@ -1120,6 +1121,7 @@ def optimize( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None, speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, @@ -1143,6 +1145,8 @@ def optimize( compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. @@ -1171,6 +1175,7 @@ def optimize( quantization_config=quantization_config, compilation_config=compilation_config, speculative_decoding_config=speculative_decoding_config, + sharding_config=sharding_config, env_vars=env_vars, vpc_config=vpc_config, kms_key=kms_key, @@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None, speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, @@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper( compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. @@ -1227,6 +1235,27 @@ def _model_builder_optimize_wrapper( Returns: Model: A deployable ``Model`` object. """ + if ( + hasattr(self, "enable_network_isolation") + and self.enable_network_isolation + and sharding_config + ): + raise ValueError( + "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model " + "Loading of model requires network access." + ) + + # TODO: ideally these dictionaries need to be sagemaker_core shapes + # TODO: for organization, abstract all validation behind this fn + _validate_optimization_configuration( + is_jumpstart=self._is_jumpstart_model_id(), + instance_type=instance_type, + quantization_config=quantization_config, + compilation_config=compilation_config, + sharding_config=sharding_config, + speculative_decoding_config=speculative_decoding_config, + ) + self.is_compiled = compilation_config is not None self.is_quantized = quantization_config is not None self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider( @@ -1236,6 +1265,36 @@ def _model_builder_optimize_wrapper( if self.mode != Mode.SAGEMAKER_ENDPOINT: raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.") + if sharding_config and ( + quantization_config or compilation_config or speculative_decoding_config + ): + raise ValueError( + ( + "Sharding config is mutually exclusive " + "and cannot be combined with any other optimization." + ) + ) + + if sharding_config: + has_tensor_parallel_degree_in_env_vars = ( + env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars + ) + has_tensor_parallel_degree_in_overrides = ( + sharding_config + and sharding_config.get("OverrideEnvironment") + and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config.get("OverrideEnvironment") + ) + if ( + not has_tensor_parallel_degree_in_env_vars + and not has_tensor_parallel_degree_in_overrides + ): + raise ValueError( + ( + "OPTION_TENSOR_PARALLEL_DEGREE is a required " + "environment variable with sharding config." + ) + ) + self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session() self.instance_type = instance_type or self.instance_type self.role_arn = role_arn or self.role_arn @@ -1252,6 +1311,7 @@ def _model_builder_optimize_wrapper( quantization_config=quantization_config, compilation_config=compilation_config, speculative_decoding_config=speculative_decoding_config, + sharding_config=sharding_config, env_vars=env_vars, vpc_config=vpc_config, kms_key=kms_key, @@ -1270,12 +1330,16 @@ def _model_builder_optimize_wrapper( quantization_config=quantization_config, compilation_config=compilation_config, speculative_decoding_config=speculative_decoding_config, + sharding_config=sharding_config, env_vars=env_vars, vpc_config=vpc_config, kms_key=kms_key, max_runtime_in_sec=max_runtime_in_sec, ) + if sharding_config: + self.pysdk_model._is_sharded_model = True + if input_args: optimization_instance_type = input_args["DeploymentInstanceType"] @@ -1325,6 +1389,7 @@ def _optimize_for_hf( quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None, speculative_decoding_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, env_vars: Optional[Dict] = None, vpc_config: Optional[Dict] = None, kms_key: Optional[str] = None, @@ -1340,6 +1405,8 @@ def _optimize_for_hf( compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``. speculative_decoding_config (Optional[Dict]): Speculative decoding configuration. Defaults to ``None`` + sharding_config (Optional[Dict]): Model sharding configuration. + Defaults to ``None`` env_vars (Optional[Dict]): Additional environment variables to run the optimization container. Defaults to ``None``. vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``. @@ -1363,7 +1430,7 @@ def _optimize_for_hf( self.pysdk_model, speculative_decoding_config, False ) - if quantization_config or compilation_config: + if quantization_config or compilation_config or sharding_config: create_optimization_job_args = { "OptimizationJobName": job_name, "DeploymentInstanceType": self.instance_type, @@ -1378,8 +1445,13 @@ def _optimize_for_hf( model_source = _generate_model_source(self.pysdk_model.model_data, False) create_optimization_job_args["ModelSource"] = model_source - optimization_config, quantization_override_env, compilation_override_env = ( - _extract_optimization_config_and_env(quantization_config, compilation_config) + ( + optimization_config, + quantization_override_env, + compilation_override_env, + sharding_override_env, + ) = _extract_optimization_config_and_env( + quantization_config, compilation_config, sharding_config ) create_optimization_job_args["OptimizationConfigs"] = [ {k: v} for k, v in optimization_config.items() @@ -1388,6 +1460,7 @@ def _optimize_for_hf( { **(quantization_override_env or {}), **(compilation_override_env or {}), + **(sharding_override_env or {}), } ) diff --git a/src/sagemaker/serve/utils/optimize_utils.py b/src/sagemaker/serve/utils/optimize_utils.py index 14df6b3639..68ed1e846d 100644 --- a/src/sagemaker/serve/utils/optimize_utils.py +++ b/src/sagemaker/serve/utils/optimize_utils.py @@ -361,16 +361,19 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool: def _extract_optimization_config_and_env( - quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None -) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]: + quantization_config: Optional[Dict] = None, + compilation_config: Optional[Dict] = None, + sharding_config: Optional[Dict] = None, +) -> Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]: """Extracts optimization config and environment variables. Args: quantization_config (Optional[Dict]): The quantization config. compilation_config (Optional[Dict]): The compilation config. + sharding_config (Optional[Dict]): The sharding config. Returns: - Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict]]]: + Optional[Tuple[Optional[Dict], Optional[Dict], Optional[Dict], Optional[Dict]]]: The optimization config and environment variables. """ optimization_config = {} @@ -380,6 +383,7 @@ def _extract_optimization_config_and_env( compilation_override_env = ( compilation_config.get("OverrideEnvironment") if compilation_config else None ) + sharding_override_env = sharding_config.get("OverrideEnvironment") if sharding_config else None if quantization_config is not None: optimization_config["ModelQuantizationConfig"] = quantization_config @@ -387,11 +391,19 @@ def _extract_optimization_config_and_env( if compilation_config is not None: optimization_config["ModelCompilationConfig"] = compilation_config + if sharding_config is not None: + optimization_config["ModelShardingConfig"] = sharding_config + # Return optimization config dict and environment variables if either is present if optimization_config: - return optimization_config, quantization_override_env, compilation_override_env + return ( + optimization_config, + quantization_override_env, + compilation_override_env, + sharding_override_env, + ) - return None, None, None + return None, None, None, None def _custom_speculative_decoding( diff --git a/src/sagemaker/serve/validations/optimization.py b/src/sagemaker/serve/validations/optimization.py new file mode 100644 index 0000000000..58ef167039 --- /dev/null +++ b/src/sagemaker/serve/validations/optimization.py @@ -0,0 +1,229 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Holds the validation logic used for the .optimize() function. INTERNAL only""" +from __future__ import absolute_import + +import textwrap +import logging +from typing import Any, Dict, Set, Optional +from enum import Enum +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + + +class _OptimizationContainer(Enum): + """Optimization containers""" + + TRT = "TRT" + VLLM = "vLLM" + NEURON = "Neuron" + + +class _OptimizationCombination(BaseModel): + """Optimization ruleset data structure for comparing input to ruleset""" + + optimization_container: _OptimizationContainer = None + compilation: Set[Optional[bool]] + speculative_decoding: Set[Optional[bool]] + sharding: Set[Optional[bool]] + quantization_technique: Set[Optional[str]] + + def validate_against(self, optimization_combination, rule_set: _OptimizationContainer): + """Validator for optimization containers""" + + # check the validity of each individual field + if not optimization_combination.compilation.issubset(self.compilation): + raise ValueError("Compilation") + if not optimization_combination.quantization_technique.issubset( + self.quantization_technique + ): + copy_quantization_technique = optimization_combination.quantization_technique.copy() + raise ValueError(f"Quantization:{copy_quantization_technique.pop()}") + if not optimization_combination.speculative_decoding.issubset(self.speculative_decoding): + raise ValueError("Speculative Decoding") + if not optimization_combination.sharding.issubset(self.sharding): + raise ValueError("Sharding") + + # optimization technique combinations that need to be validated + if optimization_combination.compilation and optimization_combination.speculative_decoding: + is_compiled = optimization_combination.compilation.copy().pop() + is_speculative_decoding = optimization_combination.speculative_decoding.copy().pop() + if is_compiled and is_speculative_decoding: + raise ValueError("Compilation and Speculative Decoding together") + + if rule_set == _OptimizationContainer.TRT: + is_compiled = optimization_combination.compilation.copy().pop() + is_quantized = optimization_combination.quantization_technique.copy().pop() + if is_quantized and not is_compiled: + raise ValueError(f"Quantization:{is_quantized} must be provided with Compilation") + + +TRUTHY_SET = {None, True} +FALSY_SET = {None, False} +TRT_CONFIGURATION = { + "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"}, + "optimization_combination": _OptimizationCombination( + optimization_container=_OptimizationContainer.TRT, + compilation=TRUTHY_SET, + quantization_technique={None, "awq", "fp8", "smoothquant"}, + speculative_decoding=FALSY_SET, + sharding=FALSY_SET, + ), +} +VLLM_CONFIGURATION = { + "supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"}, + "optimization_combination": _OptimizationCombination( + optimization_container=_OptimizationContainer.VLLM, + compilation=FALSY_SET, + quantization_technique={None, "awq", "fp8"}, + speculative_decoding=TRUTHY_SET, + sharding=TRUTHY_SET, + ), +} +NEURON_CONFIGURATION = { + "supported_instance_families": {"inf2", "trn1", "trn1n"}, + "optimization_combination": _OptimizationCombination( + optimization_container=_OptimizationContainer.NEURON, + compilation=TRUTHY_SET, + quantization_technique={None}, + speculative_decoding=FALSY_SET, + sharding=FALSY_SET, + ), +} + + +def _validate_optimization_configuration( + is_jumpstart: bool, + instance_type: str, + quantization_config: Dict[str, Any], + compilation_config: Dict[str, Any], + sharding_config: Dict[str, Any], + speculative_decoding_config: Dict[str, Any], +): + """Validate .optimize() input off of standard ruleset""" + + instance_family = None + if instance_type: + split_instance_type = instance_type.split(".") + if len(split_instance_type) == 3: + instance_family = split_instance_type[1] + + if ( + instance_family not in TRT_CONFIGURATION["supported_instance_families"] + and instance_family not in VLLM_CONFIGURATION["supported_instance_families"] + and instance_family not in NEURON_CONFIGURATION["supported_instance_families"] + ): + invalid_instance_type_msg = ( + f"Optimizations that uses {instance_type} instance type are " + "not currently supported both on GPU and Neuron instances" + ) + raise ValueError(invalid_instance_type_msg) + + quantization_technique = None + if ( + quantization_config + and quantization_config.get("OverrideEnvironment") + and quantization_config.get("OverrideEnvironment").get("OPTION_QUANTIZE") + ): + quantization_technique = quantization_config.get("OverrideEnvironment").get( + "OPTION_QUANTIZE" + ) + + optimization_combination = _OptimizationCombination( + compilation={None if compilation_config is None else True}, + speculative_decoding={None if speculative_decoding_config is None else True}, + sharding={None if sharding_config is None else True}, + quantization_technique={quantization_technique}, + ) + + # Check the case where no optimization combination is provided + if ( + optimization_combination.compilation == {None} + and optimization_combination.quantization_technique == {None} + and optimization_combination.speculative_decoding == {None} + and optimization_combination.sharding == {None} + ): + # JumpStart has defaults for Inf/Trn instances + if is_jumpstart and instance_family in NEURON_CONFIGURATION["supported_instance_families"]: + return + raise ValueError( + ( + "Optimizations that provide no optimization configs " + "are currently not support on both GPU and Neuron instances." + ) + ) + + # Validate based off of instance type + if instance_family in NEURON_CONFIGURATION["supported_instance_families"]: + try: + ( + NEURON_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.NEURON + ) + ) + except ValueError as neuron_compare_error: + raise ValueError( + ( + f"Optimizations that use {neuron_compare_error} " + "are not supported on Neuron instances." + ) + ) + else: + if optimization_combination.compilation.copy().pop(): # Compilation is only enabled for TRT + try: + TRT_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.TRT + ) + except ValueError as trt_compare_error: + raise ValueError( + ( + f"Optimizations that use Compilation and {trt_compare_error} " + "are not supported for GPU instances." + ) + ) + else: + try: + ( + VLLM_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.VLLM + ) + ) + except ValueError as vllm_compare_error: + try: # try both VLLM and TRT to cover both rule sets + ( + TRT_CONFIGURATION["optimization_combination"].validate_against( + optimization_combination, rule_set=_OptimizationContainer.TRT + ) + ) + except ValueError as trt_compare_error: + if ( + str(trt_compare_error) + == "Quantization:smoothquant must be provided with Compilation" + ): + raise ValueError( + f"Optimizations that use {trt_compare_error} for GPU instances." + ) + if str(trt_compare_error) == str(vllm_compare_error): + raise ValueError( + ( + f"Optimizations that use {trt_compare_error} " + "are not supported for GPU instances." + ) + ) + joint_error_msg = f""" + Optimization cannot be performed for the following reasons: + - Optimizations that use {trt_compare_error} are not supported for GPU instances. + - Optimizations that use {vllm_compare_error} are not supported for GPU instances. + """ + raise ValueError(textwrap.dedent(joint_error_msg)) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index e43ad0ed0a..316df7420d 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -959,6 +959,56 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path( sagemaker_session.create_model.reset_mock() +@patch("sagemaker.utils.repack_model") +@patch("sagemaker.fw_utils.tar_and_upload_dir") +def test_sharded_model_force_inference_component_based_endpoint_deploy_path( + repack_model, tar_and_uload_dir, sagemaker_session +): + framework_model_classes_to_kwargs = { + HuggingFaceModel: { + "pytorch_version": "1.7.1", + "py_version": "py36", + "transformers_version": "4.6.1", + }, + } + + sagemaker_session.settings = SessionSettings(include_jumpstart_tags=False) + + source_dir = "s3://blah/blah/blah" + for framework_model_class, kwargs in framework_model_classes_to_kwargs.items(): + test_sharded_model = framework_model_class( + entry_point=ENTRY_POINT_INFERENCE, + role=ROLE, + sagemaker_session=sagemaker_session, + model_data=source_dir, + **kwargs, + ) + test_sharded_model._is_sharded_model = True + test_sharded_model.deploy( + instance_type="ml.m2.xlarge", + initial_instance_count=INSTANCE_COUNT, + endpoint_type=EndpointType.MODEL_BASED, + resources=ResourceRequirements( + requests={ + "num_accelerators": 1, + "memory": 8192, + "copies": 1, + }, + limits={}, + ), + ) + + # Verified inference component based endpoint and inference component creation + # path + sagemaker_session.endpoint_in_service_or_not.assert_called_once() + sagemaker_session.create_model.assert_called_once() + sagemaker_session.create_inference_component.assert_called_once() + + sagemaker_session.create_inference_component.reset_mock() + sagemaker_session.endpoint_in_service_or_not.reset_mock() + sagemaker_session.create_model.reset_mock() + + @patch("sagemaker.utils.repack_model") def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session): diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 2da09aece3..4e34c5f864 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -11,12 +11,14 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import -from unittest.mock import MagicMock, patch, Mock, mock_open + +from unittest.mock import MagicMock, patch, Mock, mock_open, ANY import unittest from pathlib import Path from copy import deepcopy +from sagemaker.model import Model from sagemaker.serve import SchemaBuilder from sagemaker.serve.builder.model_builder import ModelBuilder from sagemaker.serve.mode.function_pointers import Mode @@ -25,6 +27,7 @@ from sagemaker.serve.utils.exceptions import TaskNotFoundException from sagemaker.serve.utils.predictors import TensorflowServingLocalPredictor from sagemaker.serve.utils.types import ModelServer +from sagemaker.serve.validations.optimization import _validate_optimization_configuration from tests.unit.sagemaker.serve.constants import MOCK_IMAGE_CONFIG, MOCK_VPC_CONFIG schema_builder = MagicMock() @@ -2383,11 +2386,11 @@ def test_optimize( builder.pysdk_model = pysdk_model job_name = "my-optimization-job" - instance_type = "ml.inf1.xlarge" + instance_type = "ml.g5.24xlarge" output_path = "s3://my-bucket/output" quantization_config = { "Image": "quantization-image-uri", - "OverrideEnvironment": {"ENV_VAR": "value"}, + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, } env_vars = {"Var1": "value", "Var2": "value"} kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" @@ -2425,7 +2428,7 @@ def test_optimize( mock_send_telemetry.assert_called_once() mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( OptimizationJobName="my-optimization-job", - DeploymentInstanceType="ml.inf1.xlarge", + DeploymentInstanceType="ml.g5.24xlarge", RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", OptimizationEnvironment={"Var1": "value", "Var2": "value"}, ModelSource={"S3": {"S3Uri": "s3://uri"}}, @@ -2433,7 +2436,7 @@ def test_optimize( { "ModelQuantizationConfig": { "Image": "quantization-image-uri", - "OverrideEnvironment": {"ENV_VAR": "value"}, + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, } } ], @@ -2646,7 +2649,8 @@ def test_optimize_local_mode(self, mock_get_serve_setting): ValueError, "Model optimization is only supported in Sagemaker Endpoint Mode.", lambda: model_builder.optimize( - quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}} + instance_type="ml.g5.24xlarge", + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, ), ) @@ -2721,6 +2725,42 @@ def test_optimize_for_hf_with_both_quantization_and_compilation( }, ) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_exclusive_sharding(self, mock_get_serve_setting): + mock_sagemaker_session = Mock() + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + sagemaker_session=mock_sagemaker_session, + ) + + self.assertRaisesRegex( + ValueError, + "Optimizations that use Compilation and Sharding are not supported for GPU instances.", + lambda: model_builder.optimize( + instance_type="ml.g5.24xlarge", + quantization_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + compilation_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + ), + ) + + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + def test_optimize_exclusive_sharding_args(self, mock_get_serve_setting): + mock_sagemaker_session = Mock() + model_builder = ModelBuilder( + model="meta-textgeneration-llama-3-70b", + sagemaker_session=mock_sagemaker_session, + ) + + self.assertRaisesRegex( + ValueError, + "OPTION_TENSOR_PARALLEL_DEGREE is a required environment variable with sharding config.", + lambda: model_builder.optimize( + instance_type="ml.g5.24xlarge", + sharding_config={"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}}, + ), + ) + @patch.object(ModelBuilder, "_prepare_for_mode") @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) def test_optimize_for_hf_with_custom_s3_path( @@ -2887,6 +2927,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation( "Compilation is not supported for Llama-3.1 with a GPU instance.", lambda: model_builder.optimize( job_name="job_name-123", + instance_type="ml.g5.24xlarge", compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}}, output_path="s3://bucket/code/", ), @@ -2935,9 +2976,10 @@ def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding( self.assertRaisesRegex( ValueError, - "Compilation is not supported with speculative decoding with a GPU instance.", + "Optimizations that use Compilation and Speculative Decoding are not supported for GPU instances.", lambda: model_builder.optimize( job_name="job_name-123", + instance_type="ml.g5.24xlarge", speculative_decoding_config={ "ModelProvider": "custom", "ModelSource": "s3://data-source", @@ -2946,3 +2988,678 @@ def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding( output_path="s3://bucket/code/", ), ) + + +class TestModelBuilderOptimizationSharding(unittest.TestCase): + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize_sharding_with_env_vars( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"key": "value"} + env_vars = {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[{"ModelShardingConfig": {"key": "value"}}], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize_sharding_with_override_and_env_var( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}} + env_vars = {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"OPTION_TENSOR_PARALLEL_DEGREE": "1"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[ + { + "ModelShardingConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + } + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_djl") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + def test_optimize_sharding_with_override( + self, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_djl, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_djl.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}} + env_vars = {"Var1": "value", "Var2": "value"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + DeploymentInstanceType="ml.g5.24xlarge", + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={"Var1": "value", "Var2": "value"}, + ModelSource={"S3": {"S3Uri": "s3://uri"}}, + OptimizationConfigs=[ + { + "ModelShardingConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + } + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + # squeeze in some validations + with self.assertRaises(ValueError): + builder.enable_network_isolation = True + builder.optimize(sharding_config={}) + + @patch.object(ModelBuilder, "_prepare_for_mode") + @patch.object(ModelBuilder, "_build_for_jumpstart") + @patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=True) + @patch.object(ModelBuilder, "_get_serve_setting", autospec=True) + @patch("sagemaker.serve.utils.telemetry_logger._send_telemetry") + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._is_gated_model", return_value=False + ) + @patch( + "sagemaker.serve.builder.jumpstart_builder.JumpStart._find_compatible_deployment_config", + return_value=Mock(), + ) + def test_optimize_sharding_with_override_for_js( + self, + mock_find_compatible_deployment_config, + mock_is_gated_model, + mock_send_telemetry, + mock_get_serve_setting, + mock_is_jumpstart_model_id, + mock_build_for_jumpstart, + mock_prepare_for_mode, + ): + mock_sagemaker_session = Mock() + + mock_settings = Mock() + mock_settings.telemetry_opt_out = False + mock_get_serve_setting.return_value = mock_settings + + pysdk_model = Mock() + pysdk_model.env = {"key": "val"} + pysdk_model._enable_network_isolation = True + pysdk_model.add_tags.side_effect = lambda *arg, **kwargs: None + + mock_build_for_jumpstart.side_effect = lambda **kwargs: pysdk_model + mock_prepare_for_mode.side_effect = lambda *args, **kwargs: ( + { + "S3DataSource": { + "S3Uri": "s3://uri", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + }, + {"key": "val"}, + ) + + builder = ModelBuilder( + schema_builder=SchemaBuilder( + sample_input={"inputs": "Hello", "parameters": {}}, + sample_output=[{"generated_text": "Hello"}], + ), + model="meta-llama/Meta-Llama-3-8B", + sagemaker_session=mock_sagemaker_session, + env_vars={"HF_TOKEN": "token"}, + model_metadata={"CUSTOM_MODEL_PATH": "/tmp/modelbuilders/code"}, + ) + builder.pysdk_model = pysdk_model + + job_name = "my-optimization-job" + instance_type = "ml.g5.24xlarge" + output_path = "s3://my-bucket/output" + sharding_config = {"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"}} + env_vars = {"Var1": "value", "Var2": "value"} + kms_key = "arn:aws:kms:us-west-2:123456789012:key/my-key-id" + max_runtime_in_sec = 3600 + tags = [ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ] + vpc_config = { + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + } + + mock_sagemaker_session.wait_for_optimization_job.side_effect = lambda *args, **kwargs: { + "OptimizationJobArn": "arn:aws:sagemaker:us-west-2:123456789012:optimization-job/my-optimization-job", + "OptimizationJobName": "my-optimization-job", + } + + # With override + model = builder.optimize( + instance_type=instance_type, + output_path=output_path, + role_arn=mock_role_arn, + job_name=job_name, + sharding_config=sharding_config, + env_vars=env_vars, + kms_key=kms_key, + max_runtime_in_sec=max_runtime_in_sec, + tags=tags, + vpc_config=vpc_config, + ) + + self.assertEqual(builder.env_vars["HUGGING_FACE_HUB_TOKEN"], "token") + + mock_send_telemetry.assert_called_once() + mock_sagemaker_session.sagemaker_client.create_optimization_job.assert_called_once_with( + OptimizationJobName="my-optimization-job", + ModelSource={"S3": {"S3Uri": ANY}}, + DeploymentInstanceType="ml.g5.24xlarge", + OptimizationConfigs=[ + { + "ModelShardingConfig": { + "OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "1"} + } + } + ], + OutputConfig={ + "S3OutputLocation": "s3://my-bucket/output", + "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/my-key-id", + }, + RoleArn="arn:aws:iam::123456789012:role/SageMakerRole", + OptimizationEnvironment={ + "key": "val", + "Var1": "value", + "Var2": "value", + "OPTION_TENSOR_PARALLEL_DEGREE": "1", + }, + StoppingCondition={"MaxRuntimeInSeconds": 3600}, + Tags=[ + {"Key": "Project", "Value": "my-project"}, + {"Key": "Environment", "Value": "production"}, + ], + VpcConfig={ + "SecurityGroupIds": ["sg-01234567890abcdef", "sg-fedcba9876543210"], + "Subnets": ["subnet-01234567", "subnet-89abcdef"], + }, + ) + + assert not model._enable_network_isolation + + def test_model_sharding_with_eni_fails(self): + test_model = Model(role="mock role") + test_model._is_sharded_model = True + test_model._enable_network_isolation = True + self.assertRaisesRegex( + ValueError, + ( + "EnableNetworkIsolation cannot be set to True since " + "SageMaker Fast Model Loading of model requires network access." + ), + lambda: test_model.deploy(initial_instance_count=1, instance_type="ml.g5.24xlarge"), + ) + + +class TestModelBuilderOptimizeValidations(unittest.TestCase): + + def test_corner_cases_throw_errors(self): + self.assertRaisesRegex( + ValueError, + "Optimizations that uses None instance type are not currently supported", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + sharding_config={"key": "value"}, + instance_type=None, + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + ), + ) + + self.assertRaisesRegex( + ValueError, + ( + "Optimizations that provide no optimization configs " + "are currently not support on both GPU and Neuron instances." + ), + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + sharding_config=None, + ), + ) + + _validate_optimization_configuration( + is_jumpstart=True, + instance_type="ml.inf2.xlarge", + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + sharding_config=None, + ) + + def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self): + # Quantization:smoothquant without compilation + self.assertRaisesRegex( + ValueError, + "Optimizations that use Quantization:smoothquant must be provided with Compilation for GPU instances.", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config=None, + ), + ) + + # Invalid quantization technique + self.assertRaisesRegex( + ValueError, + "Optimizations that use Quantization:test are not supported for GPU instances.", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "test"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config=None, + ), + ) + + def test_neuron_configurations_throw_errors_for_rule_set(self): + self.assertRaisesRegex( + ValueError, + "Optimizations that use Speculative Decoding are not supported on Neuron instances.", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.inf2.xlarge", + quantization_config=None, + speculative_decoding_config={"key": "value"}, + compilation_config=None, + sharding_config=None, + ), + ) + + self.assertRaisesRegex( + ValueError, + "Optimizations that use Sharding are not supported on Neuron instances.", + lambda: _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.inf2.xlarge", + quantization_config=None, + speculative_decoding_config=None, + compilation_config=None, + sharding_config={"key": "value"}, + ), + ) + + def test_trt_configurations_rule_set(self): + # Can be compiled with quantization + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={"key": "value"}, + ), + + # Can be just compiled + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={"key": "value"}, + ) + + # Can be just compiled with empty dict + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={}, + ) + + def test_vllm_configurations_rule_set(self): + # Can use speculative decoding + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config={"key": "value"}, + compilation_config=None, + ) + + # Can be quantized + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config={ + "OverrideEnvironment": {"OPTION_QUANTIZE": "awq"}, + }, + sharding_config=None, + speculative_decoding_config=None, + compilation_config=None, + ) + + # Can be sharded + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.g5.24xlarge", + quantization_config=None, + sharding_config={"key": "value"}, + speculative_decoding_config=None, + compilation_config=None, + ) + + def test_neuron_configurations_rule_set(self): + # Can be compiled + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.inf2.xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={"key": "value"}, + ) + + # Can be compiled with empty dict + _validate_optimization_configuration( + is_jumpstart=False, + instance_type="ml.inf2.xlarge", + quantization_config=None, + sharding_config=None, + speculative_decoding_config=None, + compilation_config={}, + ) diff --git a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py index 7cf0406f42..b392b255da 100644 --- a/tests/unit/sagemaker/serve/utils/test_optimize_utils.py +++ b/tests/unit/sagemaker/serve/utils/test_optimize_utils.py @@ -284,7 +284,10 @@ def test_is_draft_model_gated(draft_model_config, expected): @pytest.mark.parametrize( - "quantization_config, compilation_config, expected_config, expected_quant_env, expected_compilation_env", + ( + "quantization_config, compilation_config, sharding_config, expected_config, " + "expected_quant_env, expected_compilation_env, expected_sharding_env" + ), [ ( None, @@ -293,6 +296,7 @@ def test_is_draft_model_gated(draft_model_config, expected): "OPTION_TENSOR_PARALLEL_DEGREE": "2", } }, + None, { "ModelCompilationConfig": { "OverrideEnvironment": { @@ -304,6 +308,7 @@ def test_is_draft_model_gated(draft_model_config, expected): { "OPTION_TENSOR_PARALLEL_DEGREE": "2", }, + None, ), ( { @@ -312,6 +317,7 @@ def test_is_draft_model_gated(draft_model_config, expected): } }, None, + None, { "ModelQuantizationConfig": { "OverrideEnvironment": { @@ -323,21 +329,48 @@ def test_is_draft_model_gated(draft_model_config, expected): "OPTION_TENSOR_PARALLEL_DEGREE": "2", }, None, + None, + ), + ( + None, + None, + { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + { + "ModelShardingConfig": { + "OverrideEnvironment": { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + } + }, + }, + None, + None, + { + "OPTION_TENSOR_PARALLEL_DEGREE": "2", + }, ), - (None, None, None, None, None), + (None, None, None, None, None, None, None), ], ) def test_extract_optimization_config_and_env( quantization_config, compilation_config, + sharding_config, expected_config, expected_quant_env, expected_compilation_env, + expected_sharding_env, ): - assert _extract_optimization_config_and_env(quantization_config, compilation_config) == ( + assert _extract_optimization_config_and_env( + quantization_config, compilation_config, sharding_config + ) == ( expected_config, expected_quant_env, expected_compilation_env, + expected_sharding_env, )