Skip to content

Commit

Permalink
Auto-tuning scripts to maximize GPU kernel performance (octoml#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored May 7, 2023
1 parent b7be649 commit 909f267
Show file tree
Hide file tree
Showing 16 changed files with 8,735 additions and 6,665 deletions.
172 changes: 53 additions & 119 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,21 @@ def _parse_args():
args = argparse.ArgumentParser()
utils.argparse_add_common(args)
args.add_argument("--quantization-sym", action="store_true", default=False)
args.add_argument("--quantization-mode", type=str, choices=["int4", "int3", "fp4"], default="int4")
args.add_argument("--quantization-storage-nbit", type=int, choices=[32, 16], default=32)
args.add_argument(
"--quantization-mode", type=str, choices=["int4", "int3", "fp4"], default="int4"
)
args.add_argument(
"--quantization-storage-nbit", type=int, choices=[32, 16], default=32
)
args.add_argument("--no-quantize", action="store_true", default=False)
args.add_argument("--max-seq-len", type=int, default=-1)
args.add_argument("--target", type=str, default="auto")
args.add_argument("--db-path", type=str, default="log_db/")
args.add_argument(
"--db-path",
type=str,
default=None,
help="Path to log database. Default: ./log_db/{model}",
)
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument(
"--use-cache",
Expand All @@ -34,123 +43,24 @@ def _parse_args():
args.add_argument("--debug-load-script", action="store_true", default=False)

args.add_argument(
"--llvm-mingw", type=str, default="",
help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows"
"--llvm-mingw",
type=str,
default="",
help="/path/to/llvm-mingw-root, use llvm-mingw to cross compile to windows",
)
args.add_argument("--system-lib", action="store_true", default=False)

parsed = args.parse_args()
assert parsed.max_seq_len == -1 or parsed.max_seq_len > 0

parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model)
parsed.artifact_path = os.path.join(
parsed.artifact_path, parsed.model, parsed.dtype
)
parsed.export_kwargs = {}
assert parsed.max_seq_len == -1 or parsed.max_seq_len > 0
parsed.lib_format = "so"

if parsed.target == "auto":
if system() == "Darwin":
target = tvm.target.Target("apple/m1-gpu")
else:
has_gpu = tvm.cuda().exist
target = tvm.target.Target("cuda" if has_gpu else "llvm")
print(f"Automatically configuring target: {target}")
parsed.target = tvm.target.Target(target, host="llvm")
parsed.target_kind = parsed.target.kind.default_keys[0]
elif parsed.target == "webgpu":
parsed.target = tvm.target.Target(
"webgpu", host="llvm -mtriple=wasm32-unknown-unknown-wasm"
)
parsed.target_kind = "webgpu"
parsed.lib_format = "wasm"
elif parsed.target.startswith("iphone"):
from tvm.contrib import xcode, cc

# override
@tvm.register_func("tvm_callback_metal_compile")
def compile_metal(src):
return xcode.compile_metal(src, sdk="iphoneos")

dylib = parsed.target == "iphone-dylib"

parsed.target = tvm.target.Target(
tvm.target.Target(
{
"kind": "metal",
"max_threads_per_block": 256,
"max_shared_memory_per_block": 32768,
"thread_warp_size": 1,
}
),
host="llvm -mtriple=arm64-apple-darwin",
)
parsed.target_kind = "iphone"
parsed.export_kwargs = {
"fcompile": cc.create_staticlib,
}
parsed.lib_format = "a"

if dylib:
parsed.export_kwargs = {
"fcompile": xcode.create_dylib,
"sdk": "iphoneos",
"arch": "arm64",
}
parsed.lib_format = "dylib"
else:
parsed.system_lib = True

elif parsed.target == "vulkan":
parsed.target = tvm.target.Target(
tvm.target.Target(
{
"kind": "vulkan",
"max_threads_per_block": 256,
"max_shared_memory_per_block": 32768,
"thread_warp_size": 1,
"supports_float16": 1,
"supports_int16": 1,
"supports_16bit_buffer": 1,
}
),
host="llvm",
)
parsed.target_kind = parsed.target.kind.default_keys[0]
elif parsed.target == "metal_x86_64":
from tvm.contrib import xcode
parsed.target = tvm.target.Target(
tvm.target.Target({
"kind": "metal",
"max_threads_per_block": 256,
"max_shared_memory_per_block": 32768,
"thread_warp_size": 1,
}),
host="llvm -mtriple=x86_64-apple-darwin"
)
parsed.target_kind = "metal_x86_64"
parsed.export_kwargs = {
"fcompile": xcode.create_dylib,
"sdk": "macosx",
"arch": "x86_64",
}
parsed.lib_format = "dylib"
else:
parsed.target = tvm.target.Target(parsed.target, host="llvm")
parsed.target_kind = parsed.target.kind.default_keys[0]

# use mingw to cross compile windows
if parsed.llvm_mingw != "":
from tvm.contrib.cc import cross_compiler
parsed.export_kwargs = {
"fcompile": cross_compiler(
os.path.join(parsed.llvm_mingw, "bin", "x86_64-w64-mingw32-clang++"),
output_format="dll"
),
}
parsed.target = parsed.target.with_host(
"llvm -mtriple=x86_64-w64-windows-gnu")
parsed.lib_format = "dll"

parsed.db_path = parsed.db_path or os.path.join("log_db", parsed.model)
utils.parse_target(parsed)
utils.argparse_postproc_common(parsed)
return parsed

Expand Down Expand Up @@ -227,10 +137,15 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
if target_kind != "cpu":
from tvm import meta_schedule as ms

db = ms.database.create(work_dir=args.db_path)
if os.path.exists(args.db_path):
db = ms.database.create(work_dir=args.db_path)
else:
db = ms.database.MemoryDatabase()
with db, tvm.target.Target("apple/m1-gpu-restricted"):
mod_deploy = relax.transform.MetaScheduleApplyDatabase()(mod_deploy)
mod_deploy = mlc_llm.transform.DispatchTIROperator()(mod_deploy)
mod_deploy = mlc_llm.transform.DispatchTIROperator(args.model_category)(
mod_deploy
)
mod_deploy = tvm.tir.transform.DefaultGPUSchedule()(mod_deploy)
mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy)

Expand All @@ -243,14 +158,32 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:

output_filename = f"{args.model}_{target_kind}_{args.dtype}.{args.lib_format}"


debug_dump_shader(ex, f"{args.model}_{target_kind}_{args.dtype}", args)
lib_path = os.path.join(args.artifact_path, output_filename)
ex.export_library(
lib_path, **args.export_kwargs
)
ex.export_library(lib_path, **args.export_kwargs)
print(f"Finish exporting to {lib_path}")


def dump_split_tir(mod: tvm.IRModule):
template = """
from tvm.script import ir as I
from tvm.script import tir as T
# fmt: off
{content}
# fmt: on
"""
mod_static, mod_dynamic = utils.split_static_dynamic_tir(mod)
static_path = os.path.join(ARGS.artifact_path, "mod_tir_static.py")
dynamic_path = os.path.join(ARGS.artifact_path, "mod_tir_dynamic.py")
print(f"Dump static shape TIR to {static_path}")
with open(static_path, "w") as o_f:
o_f.write(template.format(content=mod_static.script()))
print(f"Dump dynamic shape TIR to {dynamic_path}")
with open(dynamic_path, "w") as o_f:
o_f.write(template.format(content=mod_dynamic.script()))


if __name__ == "__main__":
ARGS = _parse_args()
os.makedirs(ARGS.artifact_path, exist_ok=True)
Expand All @@ -260,11 +193,11 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
)
use_cache = ARGS.use_cache and os.path.isfile(cache_path)
if not use_cache:
if ARGS.model.startswith("vicuna-") or ARGS.model.startswith("llama-"):
if ARGS.model_category == "llama":
mod, params = llama.get_model(ARGS)
elif ARGS.model.startswith("dolly-") or ARGS.model.startswith("stablelm-"):
elif ARGS.model_category == "gpt_neox":
mod, params = gpt_neox.get_model(ARGS.model, ARGS.model_path, ARGS.dtype)
elif ARGS.model.startswith("moss-"):
elif ARGS.model_category == "moss":
mod, params = moss.get_model(ARGS)
else:
raise ValueError(f"Model {ARGS.model} not supported")
Expand All @@ -278,4 +211,5 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
"You can use --use-cache=0 to retrace"
)
mod = pickle.load(open(cache_path, "rb"))
dump_split_tir(mod)
build(mod, ARGS)
Loading

0 comments on commit 909f267

Please sign in to comment.