Skip to content

Commit

Permalink
[Lint] Make build.py more compliant to pylint/mypy (octoml#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored May 18, 2023
1 parent 5184eb8 commit 615020d
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 1,570 deletions.
59 changes: 34 additions & 25 deletions build.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=missing-docstring
import argparse
import json
import os
Expand Down Expand Up @@ -86,7 +87,7 @@ def _parse_args():
return parsed


def _setup_model_path(args):
def _setup_model_path(args): # pylint: disable=too-many-branches
if args.hf_path:
if args.model != "auto":
assert args.model == os.path.basename(args.hf_path), (
Expand Down Expand Up @@ -124,14 +125,14 @@ def _setup_model_path(args):
):
try:
validate_config(os.path.join(lookup_path, dirname))
except:
except: # pylint: disable=bare-except
pass
else:
args.model_path = os.path.join(lookup_path, dirname)
args.model = dirname
break
if args.model == "auto":
raise ValueError(f"Please specify either the model_path or the hf_path.")
raise ValueError("Please specify either the model_path or the hf_path.")

print(f'Using path "{args.model_path}" for model "{args.model}"')
return args
Expand All @@ -141,8 +142,8 @@ def validate_config(model_path: str):
assert os.path.exists(
os.path.join(model_path, "config.json")
), "Model path must contain valid config file."
with open(os.path.join(model_path, "config.json")) as f:
config = json.load(f)
with open(os.path.join(model_path, "config.json"), encoding="utf-8") as i_f:
config = json.load(i_f)
assert "model_type" in config, "Invalid config format."
assert (
config["model_type"] in utils.supported_model_types
Expand All @@ -154,15 +155,18 @@ def debug_dump_script(mod, name, args):
if not args.debug_dump:
return
dump_path = os.path.join(args.artifact_path, "debug", name)
with open(dump_path, "w") as outfile:
with open(dump_path, "w", encoding="utf-8") as outfile:
outfile.write(mod.script(show_meta=True))
print(f"Dump mod to {dump_path}")


def debug_load_script(name, args):
input_path = os.path.join(args.artifact_path, "debug", name)
lib = {"__file__": input_path}
exec(compile(open(input_path, "rb").read(), input_path, "exec"), lib, lib)
with open(input_path, "rb") as i_f:
exec( # pylint: disable=exec-used
compile(i_f.read(), input_path, "exec"), lib, lib
)
return lib["Module"]


Expand All @@ -180,7 +184,7 @@ def debug_dump_shader(ex, name, args):
suffix = suffix_map.get(target_kind, ".txt")
dump_path = os.path.join(args.artifact_path, "debug", name + suffix)
source = ex.mod.imported_modules[0].imported_modules[0].get_source()
with open(dump_path, "w") as outfile:
with open(dump_path, "w", encoding="utf-8") as outfile:
outfile.write(source)
print(f"Dump shader to {dump_path}")

Expand All @@ -200,17 +204,16 @@ def mod_transform_before_build(
]

if args.quantization.mode != "no":
mod = mlc_llm.transform.GroupQuantize(
mod = mlc_llm.transform.GroupQuantize( # pylint: disable=not-callable
group_size=40 if args.quantization.mode.endswith("3") else 32,
sym=args.quantization.sym,
mode=args.quantization.mode,
storage_nbit=args.quantization.storage_nbit,
dtype=args.quantization.model_dtype,
)(mod)
mod = mlc_llm.transform.FuseTransposeMatmul()(mod)

mod = relax.pipeline.get_pipeline()(mod)
mod = mlc_llm.transform.FuseDecodeMatmulEwise(
mod = mlc_llm.transform.FuseTransposeMatmul()(mod) # pylint: disable=not-callable
mod = relax.pipeline.get_pipeline()(mod) # pylint: disable=no-value-for-parameter
mod = mlc_llm.transform.FuseDecodeMatmulEwise( # pylint: disable=not-callable
args.quantization.model_dtype, args.target_kind
)(mod)
mod = relax.transform.DeadCodeElimination(model_names)(mod)
Expand All @@ -226,7 +229,6 @@ def mod_transform_before_build(

def dump_default_mlc_chat_config(args):
params_path = os.path.join(args.artifact_path, "params")
config = dict()
config: Dict[str, Any] = {}
config["model_lib"] = f"{args.model}-{args.quantization.name}"
config["local_id"] = f"{args.model}-{args.quantization.name}"
Expand All @@ -238,7 +240,7 @@ def dump_default_mlc_chat_config(args):
config["tokenizer_files"] = utils.get_tokenizer_files(params_path)

dump_path = os.path.join(params_path, "mlc-chat-config.json")
with open(dump_path, "w") as outfile:
with open(dump_path, "w", encoding="utf-8") as outfile:
json.dump(config, outfile, indent=4)
print(f"Finish exporting chat config to {dump_path}")

Expand All @@ -248,15 +250,21 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
debug_dump_script(mod_deploy, "mod_before_build.py", args)
if target_kind != "cpu":
if os.path.exists(args.db_path):
db = ms.database.create(work_dir=args.db_path)
db = ms.database.create( # pylint: disable=invalid-name
work_dir=args.db_path
)
else:
db = ms.database.MemoryDatabase()
db = ms.database.MemoryDatabase() # pylint: disable=invalid-name
with db, tvm.target.Target("apple/m1-gpu-restricted"):
mod_deploy = relax.transform.MetaScheduleApplyDatabase()(mod_deploy)
if args.target_kind == "android":
mod_deploy = mlc_llm.dispatch.DispatchTIROperatorAdreno()(mod_deploy)
mod_deploy = mlc_llm.dispatch.DispatchTIROperator(args.model_category)(
mod_deploy
mod_deploy = mlc_llm.dispatch.DispatchTIROperatorAdreno()( # pylint: disable=not-callable
mod_deploy
)
mod_deploy = (
mlc_llm.dispatch.DispatchTIROperator( # pylint: disable=not-callable
args.model_category
)(mod_deploy)
)
mod_deploy = tvm.tir.transform.DefaultGPUSchedule()(mod_deploy)
mod_deploy = tvm.tir.transform.ForceNarrowIndexToInt32()(mod_deploy)
Expand Down Expand Up @@ -291,10 +299,10 @@ def dump_split_tir(mod: tvm.IRModule):
static_path = os.path.join(ARGS.artifact_path, "debug", "mod_tir_static.py")
dynamic_path = os.path.join(ARGS.artifact_path, "debug", "mod_tir_dynamic.py")
print(f"Dump static shape TIR to {static_path}")
with open(static_path, "w") as o_f:
with open(static_path, "w", encoding="utf-8") as o_f:
o_f.write(template.format(content=mod_static.script()))
print(f"Dump dynamic shape TIR to {dynamic_path}")
with open(dynamic_path, "w") as o_f:
with open(dynamic_path, "w", encoding="utf-8") as o_f:
o_f.write(template.format(content=mod_dynamic.script()))


Expand All @@ -305,8 +313,8 @@ def main():
ARGS.artifact_path, f"mod_cache_before_build_{ARGS.target_kind}.pkl"
)
use_cache = ARGS.use_cache and os.path.isfile(cache_path)
with open(os.path.join(ARGS.model_path, "config.json")) as f:
config = json.load(f)
with open(os.path.join(ARGS.model_path, "config.json"), encoding="utf-8") as i_f:
config = json.load(i_f)
if not use_cache:
if ARGS.model_category == "llama":
mod, params = llama.get_model(ARGS, config)
Expand All @@ -326,7 +334,8 @@ def main():
f"Load cached module from {cache_path} and skip tracing. "
"You can use --use-cache=0 to retrace"
)
mod = pickle.load(open(cache_path, "rb"))
with open(cache_path, "rb") as pkl:
mod = pickle.load(pkl)
dump_split_tir(mod)
build(mod, ARGS)
dump_default_mlc_chat_config(ARGS)
Expand Down
2 changes: 1 addition & 1 deletion ios/MLCChat.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@
"-lsentencepiece",
"-ltokenizers_c",
);
PRODUCT_BUNDLE_IDENTIFIER = "mlc.Chat-junru";
PRODUCT_BUNDLE_IDENTIFIER = mlc.Chat;
PRODUCT_NAME = "$(TARGET_NAME)";
SWIFT_EMIT_LOC_STRINGS = YES;
SWIFT_OBJC_BRIDGING_HEADER = "MLCChat/MLCChat-Bridging-Header.h";
Expand Down
2 changes: 1 addition & 1 deletion mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def layer_norm1(sch: tir.Schedule):
sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v9)
l10, l11, l12 = sch.get_loops(block=b1)
l13 = sch.fuse(l10, l11, l12, preserve_unit_iters=True)
l14, l15, l16 = sch.split(loop=l13, factors=[None, 256, 1024], preserve_unit_iters=True)
l14, l15, l16 = sch.split(loop=l13, factors=[None, 256, 256], preserve_unit_iters=True)
sch.reorder(l15, l16, l14)
sch.bind(loop=l15, thread_axis="blockIdx.x")
sch.bind(loop=l16, thread_axis="threadIdx.x")
Expand Down
Loading

0 comments on commit 615020d

Please sign in to comment.