Skip to content

Commit

Permalink
fix delegate cache duplicate bug
Browse files Browse the repository at this point in the history
Differential Revision: D67067997

Pull Request resolved: #7281
  • Loading branch information
cccclai authored Jan 15, 2025
1 parent ba6c552 commit a727b55
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 10 deletions.
8 changes: 6 additions & 2 deletions exir/_serialize/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def _extract_delegate_segments(
"""
remaining_inline: List[BackendDelegateInlineData] = []
inline_indices_seen: set[int] = set()
segment_index_map: dict[bytes, int] = {}
for plan in program.execution_plan:
for delegate in plan.delegates:
if delegate.processed.location != DataLocation.INLINE:
Expand All @@ -249,8 +250,11 @@ def _extract_delegate_segments(
inline_indices_seen.add(delegate.processed.index)
if inline.data:
# Move the delegate data out of the program.
segment_index = len(segments)
segments.append(Cord(inline.data))
segment_index = segment_index_map.get(inline.data)
if segment_index is None:
segment_index = len(segments)
segments.append(Cord(inline.data))
segment_index_map[inline.data] = segment_index
delegate.processed = BackendDelegateDataReference(
location=DataLocation.SEGMENT,
index=segment_index,
Expand Down
1 change: 1 addition & 0 deletions exir/backend/test/demos/rpc/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ runtime.python_library(
],
visibility = [
"//executorch/exir/backend/test/...",
"//executorch/exir/emit/test/...",
],
deps = [
":executor_backend_preprocess",
Expand Down
23 changes: 16 additions & 7 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ class _ProgramState:
# Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference,
# and should be copied to Program.backend_delegate_data.
backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list)
# Delegate cache that is used across all entry points. Key is the hash of the delegated payload.
backend_delegate_data_cache: Dict[str, int] = field(default_factory=dict)

# Constants are optionally stored in external files.
# Aggregate unique external constants into one buffer.
Expand All @@ -144,7 +146,8 @@ class _EmitterState:
operators: List[Operator]
delegates: List[BackendDelegate]
operator_cache: Dict[Tuple[str, str], int]
delegate_cache: Dict[bytes, int]
# delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates
delegate_cache: Dict[str, int]
emit_stacktrace: bool

spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict)
Expand Down Expand Up @@ -1073,8 +1076,8 @@ def _emit_delegate(
"""Emit the delegates inputs and outputs as specified by the schema, then emit the
delegate's blob."""
processed_bytes = lowered_module.processed_bytes

delegate_index = self.emitter_state.delegate_cache.get(processed_bytes)
hashed = hashlib.sha256(processed_bytes).hexdigest()
delegate_index = self.emitter_state.delegate_cache.get(hashed)
delegate_ret = None

if isinstance(self.node.meta["spec"], list):
Expand Down Expand Up @@ -1112,10 +1115,16 @@ def _emit_delegate(
if delegate_index is None:
# Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if
# present.
data_index: int = len(self.program_state.backend_delegate_data)
self.program_state.backend_delegate_data.append(
BackendDelegateInlineData(data=processed_bytes)
hashed = hashlib.sha256(processed_bytes).hexdigest()
data_index: Optional[int] = (
self.program_state.backend_delegate_data_cache.get(hashed)
)
if data_index is None:
data_index = len(self.program_state.backend_delegate_data)
self.program_state.backend_delegate_data_cache[hashed] = data_index
self.program_state.backend_delegate_data.append(
BackendDelegateInlineData(data=processed_bytes)
)

backend_delegate = BackendDelegate(
id=lowered_module.backend_id,
Expand All @@ -1126,7 +1135,7 @@ def _emit_delegate(
)
delegate_index = len(self.emitter_state.delegate_cache)
self.emitter_state.delegates.append(backend_delegate)
self.emitter_state.delegate_cache[processed_bytes] = delegate_index
self.emitter_state.delegate_cache[hashed] = delegate_index

# TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the
# function's spec and with default arguments. This requires us to store the function's spec
Expand Down
1 change: 1 addition & 0 deletions exir/emit/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ python_unittest(
"//executorch/exir:lib",
"//executorch/exir:print_program",
"//executorch/exir:schema",
"//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner",
"//executorch/exir/backend:backend_api",
"//executorch/exir/emit:lib",
"//executorch/exir/passes:const_prop_pass",
Expand Down
56 changes: 55 additions & 1 deletion exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from executorch.exir._serialize._program import deserialize_pte_binary
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import (
ExecutorBackendPartitioner,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.emit import emit_program # noqa
from executorch.exir.error import InternalError
Expand Down Expand Up @@ -63,7 +66,7 @@
from functorch.experimental import control_flow
from torch import nn

from torch.export import Dim, export
from torch.export import Dim, export, export_for_training


class WrapperModule(torch.nn.Module):
Expand Down Expand Up @@ -1679,3 +1682,54 @@ def forward(self, x):
]
self.assertEqual(external_map["linear.weight"], 0)
self.assertEqual(external_map["linear.bias"], 1)

def test_delegate_deduplicate(self) -> None:
class SharedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, x):
return self.linear(x)

class Module1(torch.nn.Module):
def __init__(self, shared_module):
super().__init__()
self.shared_module = shared_module

def forward(self, x):
return self.shared_module(x)

class Module2(torch.nn.Module):
def __init__(self, shared_module):
super().__init__()
self.shared_module = shared_module

def forward(self, x):
return self.shared_module(x)

shared_module = SharedModule()
module_1 = Module1(shared_module)
module_2 = Module2(shared_module)
example_inputs = (torch.randn(2, 2),)
module_1(*example_inputs)
module_2(*example_inputs)

ep1 = export_for_training(module_1, example_inputs)
ep2 = export_for_training(module_2, example_inputs)

edge_program_manager = exir.to_edge(
{"forward1": ep1, "forward2": ep2},
compile_config=exir.EdgeCompileConfig(
_check_ir_validity=False, _use_edge_ops=True
),
)

edge_program_manager = edge_program_manager.to_backend(
ExecutorBackendPartitioner()
).to_executorch()

# Check that there is only one delegate because two methods are exactly the same
self.assertEqual(
len(edge_program_manager.executorch_program.backend_delegate_data), 1
)

0 comments on commit a727b55

Please sign in to comment.