Skip to content

Commit

Permalink
Updated pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dvorjackz committed Dec 16, 2024
1 parent 917fb0d commit ee2eb15
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 39 deletions.
6 changes: 5 additions & 1 deletion examples/models/llama3_2_vision/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from executorch.extension.pybindings.portable_lib import _load_for_executorch
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer

# Load custom ops and quantized ops.
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
Expand All @@ -43,7 +44,10 @@ def __init__(self, args):
use_kv_cache=args.kv_cache,
vocab_size=params["vocab_size"],
)
self.model = _load_for_executorch(args.pte)
with open(args.pte, "rb") as f:
model_bytes = f.read()
self.model = _load_for_executorch_from_buffer(model_bytes)
# self.model = _load_for_executorch(args.pte)
self.use_kv_cache = args.kv_cache

def forward(
Expand Down
8 changes: 6 additions & 2 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,6 @@ def _find_fqn_for_placeholder(
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]

elif target in self.exported_program.graph_signature.inputs_to_buffers:
breakpoint()
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]

# if the buffer is mutated then record that
Expand Down Expand Up @@ -1603,6 +1602,7 @@ def placeholder(
"""
spec = self.node.meta["spec"]
constant_tag = self.node.meta.get("constant_tag", None)
initialize_buffer = self.node.meta.get("et_init_buffer", None)
is_user_input = True

if isinstance(target, str) and isinstance(spec, TensorSpec):
Expand Down Expand Up @@ -1657,7 +1657,11 @@ def placeholder(
spec.storage = real_tensor.untyped_storage()

# User inputs and mutable buffers are not constants, other buffers or parameters are.
spec.const = not is_user_input
if initialize_buffer:
assert is_mutable_buffer
spec.const = True
else:
spec.const = not (is_user_input or is_mutable_buffer)

evalue = (
self._tensor_spec_to_evalue(spec, constant_tag)
Expand Down
34 changes: 4 additions & 30 deletions exir/passes/init_mutable_buffer_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,9 @@ class InitMutableBufferPass(ExportPass):
def __init__(self) -> None:
super().__init__()

def update_placeholder_tensor_specs(
self,
exported_program: torch.export.ExportedProgram,
graph_module: torch.fx.GraphModule,
) -> None:
"""
Update the tensor specs for all placeholder nodes such that
placeholders that are parameters are marked as constant.
"""
for node in graph_module.graph.nodes:
if node.op != "placeholder":
continue
if "spec" not in node.meta:
raise RuntimeError(f"Placeholder node {node} missing meta['spec']")
# print(node)
spec = node.meta["spec"]
if (isinstance(node.target, str) and
node.target in exported_program.graph_signature.inputs_to_buffers and exported_program.graph_signature.inputs_to_buffers[node.target] in exported_program.state_dict):
# print(f"Setting {node.target}.const = True")
# breakpoint()
# print(exported_program.state_dict[exported_program.graph_signature.inputs_to_buffers[node.target]])
spec.const = True

# pyre-ignore
def placeholder(self, name: str, arg, meta):
# print(name)
meta["spec"] = make_spec(arg, const=meta.data['spec'].const)
# if name == "b_kv_cache_cache_pos":
# print("breakpoint")
# breakpoint()

if "cache_pos" in name:
meta["et_init_buffer"] = True

return super().placeholder(name, arg, meta)

4 changes: 2 additions & 2 deletions exir/passes/spec_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@


# pyre-ignore
def make_spec(x, const=False):
def make_spec(x:
if isinstance(x, torch.Tensor):
return TensorSpec.from_tensor(x, const)
return TensorSpec.from_tensor(x)
elif isinstance(x, (int, bool, float)):
return x
else:
Expand Down
2 changes: 0 additions & 2 deletions exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,8 +1354,6 @@ def to_executorch(
gm, new_signature = insert_write_back_for_buffers_pass(program)
new_gm = program.graph_module
for p in edge_to_executorch_passes(config, name):
if isinstance(p, InitMutableBufferPass):
p.update_placeholder_tensor_specs(program, new_gm)
new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def to_executorch(self) -> "LLMEdgeManager":
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
)
print(self.export_program.to_executorch_program(verbose=True))
print(self.export_program.dump_executorch_program(verbose=True))
logging.info(
"Required memory for activation in bytes: {}".format(
self.export_program._emitter_output.program.execution_plan[
Expand Down
2 changes: 1 addition & 1 deletion extension/llm/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
)
self.register_buffer(
"cache_pos", torch.arange(0, self.max_seq_len), persistent=True
"cache_pos", torch.arange(0, self.max_seq_len), persistent=False
)
self.batch_size = batch_size

Expand Down

0 comments on commit ee2eb15

Please sign in to comment.