From a727b55b3584c9f6330eae3c4762a2ca500247dc Mon Sep 17 00:00:00 2001 From: cccclai Date: Wed, 15 Jan 2025 13:54:50 -0800 Subject: [PATCH] fix delegate cache duplicate bug Differential Revision: D67067997 Pull Request resolved: https://github.com/pytorch/executorch/pull/7281 --- exir/_serialize/_program.py | 8 +++-- exir/backend/test/demos/rpc/TARGETS | 1 + exir/emit/_emitter.py | 23 ++++++++---- exir/emit/test/TARGETS | 1 + exir/emit/test/test_emit.py | 56 ++++++++++++++++++++++++++++- 5 files changed, 79 insertions(+), 10 deletions(-) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 80b740674a..7656ea3f36 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -224,6 +224,7 @@ def _extract_delegate_segments( """ remaining_inline: List[BackendDelegateInlineData] = [] inline_indices_seen: set[int] = set() + segment_index_map: dict[bytes, int] = {} for plan in program.execution_plan: for delegate in plan.delegates: if delegate.processed.location != DataLocation.INLINE: @@ -249,8 +250,11 @@ def _extract_delegate_segments( inline_indices_seen.add(delegate.processed.index) if inline.data: # Move the delegate data out of the program. - segment_index = len(segments) - segments.append(Cord(inline.data)) + segment_index = segment_index_map.get(inline.data) + if segment_index is None: + segment_index = len(segments) + segments.append(Cord(inline.data)) + segment_index_map[inline.data] = segment_index delegate.processed = BackendDelegateDataReference( location=DataLocation.SEGMENT, index=segment_index, diff --git a/exir/backend/test/demos/rpc/TARGETS b/exir/backend/test/demos/rpc/TARGETS index a2aadb05ef..63d24ccbda 100644 --- a/exir/backend/test/demos/rpc/TARGETS +++ b/exir/backend/test/demos/rpc/TARGETS @@ -28,6 +28,7 @@ runtime.python_library( ], visibility = [ "//executorch/exir/backend/test/...", + "//executorch/exir/emit/test/...", ], deps = [ ":executor_backend_preprocess", diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index d08e68fa73..c40a00b240 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -122,6 +122,8 @@ class _ProgramState: # Delegate data stored directly in the flatbuffer. Pointed to by BackendDelegateDataReference, # and should be copied to Program.backend_delegate_data. backend_delegate_data: List[BackendDelegateInlineData] = field(default_factory=list) + # Delegate cache that is used across all entry points. Key is the hash of the delegated payload. + backend_delegate_data_cache: Dict[str, int] = field(default_factory=dict) # Constants are optionally stored in external files. # Aggregate unique external constants into one buffer. @@ -144,7 +146,8 @@ class _EmitterState: operators: List[Operator] delegates: List[BackendDelegate] operator_cache: Dict[Tuple[str, str], int] - delegate_cache: Dict[bytes, int] + # delegate_cache: the key is hash(delegated_payload) and the value is the index in delegates + delegate_cache: Dict[str, int] emit_stacktrace: bool spec2id_dict: Dict[TensorSpec, int] = field(default_factory=dict) @@ -1073,8 +1076,8 @@ def _emit_delegate( """Emit the delegates inputs and outputs as specified by the schema, then emit the delegate's blob.""" processed_bytes = lowered_module.processed_bytes - - delegate_index = self.emitter_state.delegate_cache.get(processed_bytes) + hashed = hashlib.sha256(processed_bytes).hexdigest() + delegate_index = self.emitter_state.delegate_cache.get(hashed) delegate_ret = None if isinstance(self.node.meta["spec"], list): @@ -1112,10 +1115,16 @@ def _emit_delegate( if delegate_index is None: # Allocate an entry for the data. TODO(T150113674): Reuse any duplicate entries if # present. - data_index: int = len(self.program_state.backend_delegate_data) - self.program_state.backend_delegate_data.append( - BackendDelegateInlineData(data=processed_bytes) + hashed = hashlib.sha256(processed_bytes).hexdigest() + data_index: Optional[int] = ( + self.program_state.backend_delegate_data_cache.get(hashed) ) + if data_index is None: + data_index = len(self.program_state.backend_delegate_data) + self.program_state.backend_delegate_data_cache[hashed] = data_index + self.program_state.backend_delegate_data.append( + BackendDelegateInlineData(data=processed_bytes) + ) backend_delegate = BackendDelegate( id=lowered_module.backend_id, @@ -1126,7 +1135,7 @@ def _emit_delegate( ) delegate_index = len(self.emitter_state.delegate_cache) self.emitter_state.delegates.append(backend_delegate) - self.emitter_state.delegate_cache[processed_bytes] = delegate_index + self.emitter_state.delegate_cache[hashed] = delegate_index # TODO(angelayi) Will need to emit the kwargs too, in the correct order according to the # function's spec and with default arguments. This requires us to store the function's spec diff --git a/exir/emit/test/TARGETS b/exir/emit/test/TARGETS index 9f416e78ea..153843d45e 100644 --- a/exir/emit/test/TARGETS +++ b/exir/emit/test/TARGETS @@ -16,6 +16,7 @@ python_unittest( "//executorch/exir:lib", "//executorch/exir:print_program", "//executorch/exir:schema", + "//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner", "//executorch/exir/backend:backend_api", "//executorch/exir/emit:lib", "//executorch/exir/passes:const_prop_pass", diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index 0da4085914..3fca3958fe 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -27,6 +27,9 @@ from executorch.exir._serialize._program import deserialize_pte_binary from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult +from executorch.exir.backend.test.demos.rpc.executor_backend_partitioner import ( + ExecutorBackendPartitioner, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.emit import emit_program # noqa from executorch.exir.error import InternalError @@ -63,7 +66,7 @@ from functorch.experimental import control_flow from torch import nn -from torch.export import Dim, export +from torch.export import Dim, export, export_for_training class WrapperModule(torch.nn.Module): @@ -1679,3 +1682,54 @@ def forward(self, x): ] self.assertEqual(external_map["linear.weight"], 0) self.assertEqual(external_map["linear.bias"], 1) + + def test_delegate_deduplicate(self) -> None: + class SharedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + class Module1(torch.nn.Module): + def __init__(self, shared_module): + super().__init__() + self.shared_module = shared_module + + def forward(self, x): + return self.shared_module(x) + + class Module2(torch.nn.Module): + def __init__(self, shared_module): + super().__init__() + self.shared_module = shared_module + + def forward(self, x): + return self.shared_module(x) + + shared_module = SharedModule() + module_1 = Module1(shared_module) + module_2 = Module2(shared_module) + example_inputs = (torch.randn(2, 2),) + module_1(*example_inputs) + module_2(*example_inputs) + + ep1 = export_for_training(module_1, example_inputs) + ep2 = export_for_training(module_2, example_inputs) + + edge_program_manager = exir.to_edge( + {"forward1": ep1, "forward2": ep2}, + compile_config=exir.EdgeCompileConfig( + _check_ir_validity=False, _use_edge_ops=True + ), + ) + + edge_program_manager = edge_program_manager.to_backend( + ExecutorBackendPartitioner() + ).to_executorch() + + # Check that there is only one delegate because two methods are exactly the same + self.assertEqual( + len(edge_program_manager.executorch_program.backend_delegate_data), 1 + )