diff --git a/sharktank/conftest.py b/sharktank/conftest.py index 459272d5d..1387b0611 100644 --- a/sharktank/conftest.py +++ b/sharktank/conftest.py @@ -6,7 +6,8 @@ from pathlib import Path import pytest -from typing import Optional +from pytest import FixtureRequest +from typing import Optional, Any # Tests under each top-level directory will get a mark. @@ -47,6 +48,15 @@ def pytest_addoption(parser): default=None, help="Exported model parameters. If not specified a temporary file will be used.", ) + parser.addoption( + "--prefix", + type=str, + default=None, + help=( + "Path prefix for test artifacts. " + "Other arguments may override this for specific values." + ), + ) parser.addoption( "--caching", action="store_true", @@ -55,21 +65,40 @@ def pytest_addoption(parser): ) -@pytest.fixture(scope="session") -def mlir_path(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("mlir") +def set_fixture_from_cli_option( + request: FixtureRequest, + cli_option_name: str, + class_attribute_name: Optional[str] = None, +) -> Optional[Any]: + res = request.config.getoption(cli_option_name) + if request.cls is None: + return res + else: + if class_attribute_name is None: + class_attribute_name = cli_option_name + setattr(request.cls, class_attribute_name, res) + + +@pytest.fixture(scope="class") +def mlir_path(request: FixtureRequest) -> Optional[Path]: + return set_fixture_from_cli_option(request, "mlir", "mlir_path") + + +@pytest.fixture(scope="class") +def module_path(request: FixtureRequest) -> Optional[Path]: + return set_fixture_from_cli_option(request, "module", "module_path") -@pytest.fixture(scope="session") -def module_path(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("module") +@pytest.fixture(scope="class") +def parameters_path(request: FixtureRequest) -> Optional[Path]: + return set_fixture_from_cli_option(request, "parameters", "parameters_path") -@pytest.fixture(scope="session") -def parameters_path(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("parameters") +@pytest.fixture(scope="class") +def path_prefix(request: FixtureRequest) -> Optional[str]: + return set_fixture_from_cli_option(request, "prefix", "path_prefix") -@pytest.fixture(scope="session") -def caching(pytestconfig: pytest.Config) -> Optional[Path]: - return pytestconfig.getoption("caching") +@pytest.fixture(scope="class") +def caching(request: FixtureRequest) -> Optional[bool]: + return set_fixture_from_cli_option(request, "caching") diff --git a/sharktank/sharktank/types/tensors.py b/sharktank/sharktank/types/tensors.py index 200800d44..324cc4331 100644 --- a/sharktank/sharktank/types/tensors.py +++ b/sharktank/sharktank/types/tensors.py @@ -26,6 +26,7 @@ import torch from torch import Tensor from torch.utils._pytree import register_pytree_node, SequenceKey +import torch.utils._pytree from ..utils.math import ceildiv from iree.turbine.aot import ( ExternalTensorTrait, @@ -48,6 +49,7 @@ "ReplicatedTensor", "ShardedTensor", "SplitPrimitiveTensor", + "torch_tree_flatten", "unbox_tensor", "UnreducedTensor", ] @@ -1360,3 +1362,9 @@ def flatten_with_keys_replicated_tensor(t: ReplicatedTensor): unflatten_fn=unflatten_replicated_tensor, flatten_with_keys_fn=flatten_with_keys_replicated_tensor, ) + + +def torch_tree_flatten(tree: tree_utils.Tree): + """Flatten a tree of tensors the same way they will be flattened during torch.export.export + if they are arguments or results of a function signature.""" + return torch.utils._pytree.tree_flatten(tree=tree) diff --git a/sharktank/sharktank/utils/iree.py b/sharktank/sharktank/utils/iree.py new file mode 100644 index 000000000..7c666ff62 --- /dev/null +++ b/sharktank/sharktank/utils/iree.py @@ -0,0 +1,189 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import iree.runtime +from typing import List, Tuple, Optional, Union +from pathlib import Path +import torch +import numpy as np +import collections.abc +from collections import OrderedDict +from ..types.tensors import ( + AnyTensor, + InferenceTensor, + ShardedTensor, + DefaultPrimitiveTensor, + unbox_tensor, + torch_tree_flatten, +) +from .tree import Tree + + +def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]: + hal_driver = iree.runtime.get_driver(driver) + available_devices = hal_driver.query_available_devices() + if driver in ["local-task", "local-sync"]: + # Use the same actual device for all devices. + return [ + hal_driver.create_device(available_devices[0]) for _ in range(device_count) + ] + else: + return [ + hal_driver.create_device(available_devices[i]) for i in range(device_count) + ] + + +def load_iree_module( + module_path: str, + devices: List[iree.runtime.HalDevice], + parameters_path: Optional[str] = None, +) -> Tuple[iree.runtime.VmModule, iree.runtime.VmContext, iree.runtime.VmInstance]: + """The VmContext and VmInstance need to outlive the VmModule and any device + buffers.""" + vm_instance = iree.runtime.VmInstance() + hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices) + modules = [hal_module] + if parameters_path is not None: + params_path = Path(parameters_path) + parameter_index = iree.runtime.ParameterIndex() + if len(devices) > 1: + # TODO: make IREE able to load the parameters from the top parameter file + # without having to specify the parameter file for each shard separately. + for i in range(len(devices)): + parameter_index.load( + file_path=str( + Path(params_path).with_suffix(f".rank{i}{params_path.suffix}") + ) + ) + else: + parameter_index.load(file_path=str(params_path)) + parameter_provider = parameter_index.create_provider(scope="model") + parameters_module = iree.runtime.create_io_parameters_module( + vm_instance, parameter_provider + ) + modules.append(parameters_module) + vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path)) + modules.append(vm_module) + vm_context = iree.runtime.VmContext(instance=vm_instance, modules=modules) + return vm_module, vm_context, vm_instance + + +def run_iree_module_function( + module: iree.runtime.VmModule, + vm_context: iree.runtime.VmContext, + args: List[iree.runtime.DeviceArray], + driver: str, + function_name: str = "main", + trace_path_prefix: Optional[str] = None, +) -> List[iree.runtime.DeviceArray]: + """Run IREE module function with optional tracing of arguments/results.""" + vm_function = module.lookup_function(function_name) + invoker = iree.runtime.FunctionInvoker( + vm_context=vm_context, + # TODO: rework iree.runtime.FunctionInvoker interface for multiple devices. + # This works, but does not look right. + device=iree.runtime.get_device(driver, cache=False), + vm_function=vm_function, + ) + if trace_path_prefix is not None: + for i, arg in enumerate(args): + np.save(f"{trace_path_prefix}{function_name}_arg{i}.npy", arg.to_host()) + results = invoker(*args) + if isinstance(results, iree.runtime.DeviceArray): + results = (results,) + + if trace_path_prefix is not None: + for i, arg in enumerate(args): + np.save( + f"{trace_path_prefix}{function_name}_arg_post_call{i}.npy", + arg.to_host(), + ) + for i, arg in enumerate(results): + np.save(f"{trace_path_prefix}{function_name}_result{i}.npy", arg.to_host()) + return results + + +def prepare_iree_module_function_args( + args: List[Union[AnyTensor, List[AnyTensor]]], devices: List[iree.runtime.HalDevice] +) -> List[iree.runtime.DeviceArray]: + """Flatten composite tensors into their parts and place them on devices. + Sharded tensors become a list of their shards while placing them onto their + corresponding device. + All unsharded tensors go on device 0. + """ + res = [] + for arg in args: + if isinstance(arg, ShardedTensor): + assert len(devices) == len(arg.shards) + res.extend( + [ + prepare_iree_module_function_args([shard], [device])[0] + for shard, device in zip(arg.shards, devices) + ] + ) + elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)): + res.append( + iree.runtime.asdevicearray( + devices[0], unbox_tensor(arg).to("cpu").numpy() + ) + ) + else: + assert isinstance(arg, collections.abc.Sequence) + res.extend(prepare_iree_module_function_args(arg, devices)) + return res + + +def flatten_for_iree_signature(tree: Tree) -> List[torch.Tensor]: + """Flatten a tree of arguments or results for an IREE call. + E.g. sharded tensors gets flattened into their shards.""" + + return torch_tree_flatten(tree)[0] + + +def call_torch_module_function( + module: torch.nn.Module, + function_name: str, + kwargs: OrderedDict, + trace_path_prefix: Optional[str] = None, +): + """Call a torch module function with optional tracing. + For tracing the arguments/results are flattened to match IREE's signature.""" + assert isinstance( + kwargs, OrderedDict + ), "Make sure when flattening the order is preserved" + if trace_path_prefix is not None: + flat_args = flatten_for_iree_signature(kwargs) + for i, arg in enumerate(flat_args): + np.save( + f"{trace_path_prefix}{function_name}_arg{i}.npy", + arg.to("cpu").numpy(), + ) + res = getattr(module, function_name)(**kwargs) + if trace_path_prefix is not None: + flat_args = flatten_for_iree_signature(kwargs) + for i, arg in enumerate(flat_args): + np.save( + f"{trace_path_prefix}{function_name}_arg{i}.npy", + arg.to("cpu").numpy(), + ) + results = ( + (res,) + if isinstance( + res, + ( + torch.Tensor, + InferenceTensor, + ), + ) + else res + ) + flat_results = flatten_for_iree_signature(results) + for i, result in enumerate(flat_results): + np.save( + f"{trace_path_prefix}{function_name}_result{i}.npy", + result.to("cpu").numpy(), + ) + return res diff --git a/sharktank/tests/models/llama/sharded_llama_test.py b/sharktank/tests/models/llama/sharded_llama_test.py index bdace4972..4d34dc704 100644 --- a/sharktank/tests/models/llama/sharded_llama_test.py +++ b/sharktank/tests/models/llama/sharded_llama_test.py @@ -5,115 +5,40 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import unittest -from typing import Any, List, Tuple, Union, OrderedDict -import collections.abc +import pytest +from typing import Any, List, Tuple, OrderedDict from sharktank.models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 import sharktank.ops as ops from sharktank.types import ( unbox_tensor, - ShardedTensor, - DefaultPrimitiveTensor, Dataset, - AnyTensor, ) from sharktank.models.llama.testing import make_random_llama_theta from sharktank.models.llama.sharding import shard_theta from sharktank.layers.configs import LlamaHParams from sharktank.utils.math import round_up_to_multiple_of from sharktank.utils import iterables_equal +from sharktank.utils.iree import ( + get_iree_devices, + load_iree_module, + run_iree_module_function, + prepare_iree_module_function_args, + call_torch_module_function, +) import tempfile import torch from copy import deepcopy from iree.turbine.aot import FxProgramsBuilder, export import iree.runtime -from pathlib import Path - - -def get_iree_devices(driver: str, device_count: int) -> List[iree.runtime.HalDevice]: - hal_driver = iree.runtime.get_driver(driver) - available_devices = hal_driver.query_available_devices() - # Use the same actual device for all devices. - return [hal_driver.create_device(available_devices[0]) for _ in range(device_count)] - - -def load_iree_module( - module_path: str, - parameters_path: str, - devices: List[iree.runtime.HalDevice], -) -> Tuple[iree.runtime.VmModule, iree.runtime.VmContext, iree.runtime.VmInstance]: - params_path = Path(parameters_path) - # TODO: make IREE able to load the parameters from the top parameter file - # without having to specify the parameter file for each shard separately. - parameter_index = iree.runtime.ParameterIndex() - for i in range(len(devices)): - parameter_index.load( - file_path=str( - Path(params_path).with_suffix(f".rank{i}{params_path.suffix}") - ) - ) - parameter_provider = parameter_index.create_provider(scope="model") - vm_instance = iree.runtime.VmInstance() - parameters_module = iree.runtime.create_io_parameters_module( - vm_instance, parameter_provider - ) - vm_module = iree.runtime.VmModule.mmap(vm_instance, str(module_path)) - hal_module = iree.runtime.create_hal_module(instance=vm_instance, devices=devices) - vm_context = iree.runtime.VmContext( - instance=vm_instance, modules=(hal_module, parameters_module, vm_module) - ) - return vm_module, vm_context, vm_instance - - -def run_iree_module_function( - module: iree.runtime.VmModule, - vm_context: iree.runtime.VmContext, - function_name: str, - args: List[iree.runtime.DeviceArray], - driver: str, -) -> List[iree.runtime.DeviceArray]: - vm_function = module.lookup_function(function_name) - invoker = iree.runtime.FunctionInvoker( - vm_context=vm_context, - # TODO: rework iree.runtime.FunctionInvoker interface for multiple devices. - # This works, but does not look right. - device=iree.runtime.get_device(driver, cache=False), - vm_function=vm_function, - ) - res = invoker(*args) - if isinstance(res, iree.runtime.DeviceArray): - res = (res,) - return res - - -def prepare_iree_module_function_args( - args: List[Union[AnyTensor, List[AnyTensor]]], devices: List[iree.runtime.HalDevice] -) -> List[iree.runtime.DeviceArray]: - res = [] - for arg in args: - if isinstance(arg, ShardedTensor): - assert len(devices) == len(arg.shards) - res.extend( - [ - prepare_iree_module_function_args([shard], [device])[0] - for shard, device in zip(arg.shards, devices) - ] - ) - elif isinstance(arg, (DefaultPrimitiveTensor, torch.Tensor)): - res.append( - iree.runtime.asdevicearray( - devices[0], unbox_tensor(arg).to("cpu").numpy() - ) - ) - else: - assert isinstance(arg, collections.abc.Sequence) - res.extend(prepare_iree_module_function_args(arg, devices)) - return res +import numpy as np +import os def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]: return [torch.tensor(tensor.to_host()) for tensor in tensors] +@pytest.mark.usefixtures("caching", "path_prefix") class ShardedLlamaTest(unittest.TestCase): def setUp(self): torch.random.manual_seed(123456) @@ -304,25 +229,44 @@ def testExportAndRunToySizedModelWithIree(self): """Test exporting to MLIR and compiling with IREE the sharded Llama model. Test numerical accuracy of the IREE module against PyTorch.""" - with tempfile.TemporaryDirectory() as temp_dir: - sharded_theta = shard_theta(self.theta, self.sharded_config) - sharded_theta.rename_tensors_to_paths() - sharded_dataset = Dataset({}, sharded_theta) - sharded_parameters_path = f"{temp_dir}/parameters.irpa" - sharded_dataset.save(sharded_parameters_path) - sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False) - iree_driver = "local-task" - - model = PagedLlamaModelV1(self.theta, self.config) - sharded_model = PagedLlamaModelV1( - sharded_dataset.root_theta, self.sharded_config + if self.path_prefix is not None: + self.runTestExportAndRunToySizedModelWithIree( + path_prefix=self.path_prefix, dump_enabled=True ) - sharded_fxb = FxProgramsBuilder(sharded_model) + else: + with tempfile.TemporaryDirectory() as temp_dir: + self.runTestExportAndRunToySizedModelWithIree( + path_prefix=f"{temp_dir}/", dump_enabled=False + ) + + def runTestExportAndRunToySizedModelWithIree( + self, path_prefix: str, dump_enabled: bool + ): + sharded_theta = shard_theta(self.theta, self.sharded_config) + sharded_theta.rename_tensors_to_paths() + sharded_dataset = Dataset({}, sharded_theta) + sharded_parameters_path = f"{path_prefix}parameters.irpa" + sharded_dataset.save(sharded_parameters_path) + sharded_dataset = Dataset.load(sharded_parameters_path, mmap=False) + iree_driver = "local-task" + + model = PagedLlamaModelV1(self.theta, self.config) + sharded_model = PagedLlamaModelV1( + sharded_dataset.root_theta, self.sharded_config + ) + ( + _, + sharded_prefill_args, + ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) + ( + _, + sharded_decode_args, + ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) - ( - _, - sharded_prefill_args, - ) = self.make_equal_unsharded_and_sharded_prefill_args(model, sharded_model) + iree_module_path = f"{path_prefix}program.vmfb" + if not self.caching or not os.path.exists(iree_module_path): + # Export and compile the IREE module. + sharded_fxb = FxProgramsBuilder(sharded_model) @sharded_fxb.export_program( name="prefill", args=tuple(), kwargs=sharded_prefill_args @@ -330,10 +274,6 @@ def testExportAndRunToySizedModelWithIree(self): def _(model, *args, **kwargs) -> torch.Tensor: return model.prefill(*args, **kwargs) - ( - _, - sharded_decode_args, - ) = self.make_equal_unsharded_and_sharded_decode_args(model, sharded_model) # TODO: remove strict=False when # https://github.com/pytorch/pytorch/issues/136757 # is resolved. @@ -346,91 +286,105 @@ def _(model, *args, **kwargs) -> torch.Tensor: def _(model, *args, **kwargs) -> torch.Tensor: return model.decode(*args, **kwargs) - # Compile the IREE module. output = export(sharded_fxb) - output.save_mlir(f"{temp_dir}/program.mlir") + if dump_enabled: + output.save_mlir(f"{path_prefix}program.mlir") output.session.set_flags( *[ f"--iree-hal-target-device=llvm-cpu[{i}]" for i in range(self.sharded_config.tensor_parallelism_size) ] ) - iree_module_path = f"{temp_dir}/program.vmfb" output.compile( save_to=iree_module_path, target_backends=None, ) - iree_devices = get_iree_devices( - driver=iree_driver, - device_count=self.sharded_config.tensor_parallelism_size, - ) - iree_module, vm_context, vm_instance = load_iree_module( - module_path=iree_module_path, - devices=iree_devices, - parameters_path=sharded_parameters_path, - ) + iree_devices = get_iree_devices( + driver=iree_driver, + device_count=self.sharded_config.tensor_parallelism_size, + ) + iree_module, vm_context, vm_instance = load_iree_module( + module_path=iree_module_path, + devices=iree_devices, + parameters_path=sharded_parameters_path, + ) - # Check IREE's prefill step is close to torch. - prefill_iree_args = prepare_iree_module_function_args( - args=deepcopy(sharded_prefill_args).values(), devices=iree_devices - ) - prefill_iree_result = run_iree_module_function( - args=prefill_iree_args, - function_name="prefill", - module=iree_module, - vm_context=vm_context, - driver=iree_driver, - ) - prefill_iree_result = iree_to_torch(*prefill_iree_result) - assert len(prefill_iree_result) == 1 - expected_prefill_result = sharded_model.prefill(**sharded_prefill_args) - # TODO: Although, not entirely wrong, investigate why this accuracy is that - # low for fp32 (atol=0.0011, rtol=0.013). - torch.testing.assert_close( - prefill_iree_result[0], - expected_prefill_result, - ) - prefill_iree_cache_state_shards = prefill_iree_args[ - -self.config.tensor_parallelism_size - 1 : - ] - prefill_iree_cache_state_shards = iree_to_torch( - *prefill_iree_cache_state_shards - ) - for actual_cache_state_shard, expected_cache_state_shard in zip( - prefill_iree_cache_state_shards, - sharded_prefill_args["cache_state"][0].shards, - ): - # TODO: debug inaccuracy. - torch.testing.assert_close( - actual_cache_state_shard, unbox_tensor(expected_cache_state_shard) - ) + # Run prefill step. + prefill_iree_args = prepare_iree_module_function_args( + args=deepcopy(sharded_prefill_args).values(), devices=iree_devices + ) + for i, arg in enumerate(prefill_iree_args): + np.save(f"{path_prefix}prefill_arg{i}.npy", arg.to_host()) + prefill_iree_result = run_iree_module_function( + args=prefill_iree_args, + function_name="prefill", + module=iree_module, + vm_context=vm_context, + driver=iree_driver, + trace_path_prefix=path_prefix if dump_enabled else None, + ) + prefill_iree_result = iree_to_torch(*prefill_iree_result) + assert len(prefill_iree_result) == 1 + expected_prefill_result = call_torch_module_function( + module=sharded_model, + function_name="prefill", + kwargs=sharded_prefill_args, + trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, + ) + prefill_iree_cache_state_shards = prefill_iree_args[ + -self.config.tensor_parallelism_size - 1 : + ] + prefill_iree_cache_state_shards = iree_to_torch( + *prefill_iree_cache_state_shards + ) - # Check IREE's decode step is close to torch. - decode_iree_args = prepare_iree_module_function_args( - args=deepcopy(sharded_decode_args).values(), devices=iree_devices - ) - decode_iree_result = run_iree_module_function( - args=decode_iree_args, - function_name="decode", - module=iree_module, - vm_context=vm_context, + # Run decode step. + decode_iree_args = prepare_iree_module_function_args( + args=deepcopy(sharded_decode_args).values(), devices=iree_devices + ) + decode_iree_result = run_iree_module_function( + args=decode_iree_args, + function_name="decode", + module=iree_module, + vm_context=vm_context, + driver=iree_driver, + trace_path_prefix=path_prefix if dump_enabled else None, + ) + decode_iree_result = iree_to_torch(*decode_iree_result) + expected_decode_result = call_torch_module_function( + module=sharded_model, + function_name="decode", + kwargs=sharded_decode_args, + trace_path_prefix=f"{path_prefix}expected_" if dump_enabled else None, + ) + decode_iree_cache_state_shards = decode_iree_args[ + -self.config.tensor_parallelism_size - 1 : + ] + decode_iree_cache_state_shards = iree_to_torch(*decode_iree_cache_state_shards) + + # Check IREE's numerical correctness against PyTorch. + # TODO: Although, not entirely wrong, investigate why this accuracy is that + # low for fp32 (atol=0.0011, rtol=0.013). + torch.testing.assert_close( + prefill_iree_result[0], + expected_prefill_result, + ) + for actual_cache_state_shard, expected_cache_state_shard in zip( + prefill_iree_cache_state_shards, + sharded_prefill_args["cache_state"][0].shards, + ): + # TODO: debug inaccuracy. + torch.testing.assert_close( + actual_cache_state_shard, unbox_tensor(expected_cache_state_shard) ) - decode_iree_result = iree_to_torch(*decode_iree_result) - expected_decode_result = sharded_model.decode(**sharded_decode_args) + # TODO: debug inaccuracy. + torch.testing.assert_close(decode_iree_result[0], expected_decode_result) + for actual_cache_state_shard, expected_cache_state_shard in zip( + decode_iree_cache_state_shards, + sharded_decode_args["cache_state"][0].shards, + ): # TODO: debug inaccuracy. - torch.testing.assert_close(decode_iree_result[0], expected_decode_result) - decode_iree_cache_state_shards = decode_iree_args[ - -self.config.tensor_parallelism_size - 1 : - ] - decode_iree_cache_state_shards = iree_to_torch( - *decode_iree_cache_state_shards + torch.testing.assert_close( + actual_cache_state_shard, unbox_tensor(expected_cache_state_shard) ) - for actual_cache_state_shard, expected_cache_state_shard in zip( - decode_iree_cache_state_shards, - sharded_decode_args["cache_state"][0].shards, - ): - # TODO: debug inaccuracy. - torch.testing.assert_close( - actual_cache_state_shard, unbox_tensor(expected_cache_state_shard) - )