From 627f5b081384a6a5ead306cfcfe3ff7a5264d8ac Mon Sep 17 00:00:00 2001 From: Clifton Date: Sat, 16 Mar 2024 13:52:23 +0000 Subject: [PATCH] Fix DataParallel training with adapters --- src/adapters/methods/adapter_layer_base.py | 27 +++++++++++++--------- src/adapters/methods/bottleneck.py | 4 ++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/adapters/methods/adapter_layer_base.py b/src/adapters/methods/adapter_layer_base.py index 2489d445b2..8e5bf8f43a 100644 --- a/src/adapters/methods/adapter_layer_base.py +++ b/src/adapters/methods/adapter_layer_base.py @@ -163,15 +163,20 @@ def __init__(self, *args, **kwargs): self._init_mapping() def _init_mapping(self): + # Mapping between composition block types and names of composition functions self.composition_to_func_map = { - Stack: self.compose_stack, - Fuse: self.compose_fuse, - Split: self.compose_split, - BatchSplit: self.compose_batch_split, - Parallel: self.compose_parallel, - Average: self.compose_average, + Stack: "compose_stack", + Fuse: "compose_fuse", + Split: "compose_split", + BatchSplit: "compose_batch_split", + Parallel: "compose_parallel", + Average: "compose_average", } + def _get_compose_func(self, composition_type: type) -> callable: + """Retrieves the correct composition function based on the mapping in 'composition_to_func_map'.""" + return getattr(self, self.composition_to_func_map[composition_type]) + # START CUSTOMIZABLE METHODS # # The following methods should be implemented in derived classes. @@ -301,7 +306,7 @@ def compose_stack(self, adapter_setup: Stack, state: NamedTuple, lvl: int = 0) - for i, adapter_stack_layer in enumerate(adapter_setup): if isinstance(adapter_stack_layer, AdapterCompositionBlock): self.check_composition_valid(adapter_setup, adapter_stack_layer, lvl) - composition_func = self.composition_to_func_map[type(adapter_stack_layer)] + composition_func = self._get_compose_func(type(adapter_stack_layer)) state = composition_func(adapter_stack_layer, state, lvl=lvl + 1) elif adapter_stack_layer in self.adapter_modules: state = self.pre_block(adapter_stack_layer, state) @@ -353,7 +358,7 @@ def compose_batch_split(self, adapter_setup: BatchSplit, state: NamedTuple, lvl: ) if isinstance(child, AdapterCompositionBlock): self.check_composition_valid(adapter_setup, child, lvl) - composition_func = self.composition_to_func_map[type(child)] + composition_func = self._get_compose_func(type(child)) child_state = composition_func( child, self.vslice(state, slice(*batch_idx)), @@ -410,7 +415,7 @@ def compose_parallel(self, adapter_setup: Parallel, state: NamedTuple, lvl: int for i, child in enumerate(adapter_setup): if isinstance(child, AdapterCompositionBlock): self.check_composition_valid(adapter_setup, child, lvl) - composition_func = self.composition_to_func_map[type(child)] + composition_func = self._get_compose_func(type(child)) child_state = composition_func( child, self.vslice(state, slice(i * orig_batch_size, (i + 1) * orig_batch_size)), @@ -442,7 +447,7 @@ def compose_average(self, adapter_setup: Average, state: NamedTuple, lvl: int = for i, child in enumerate(adapter_setup): if isinstance(child, AdapterCompositionBlock): self.check_composition_valid(adapter_setup, child, lvl) - composition_func = self.composition_to_func_map[type(child)] + composition_func = self._get_compose_func(type(child)) child_state = composition_func(child, state, lvl=lvl + 1) children_states.append(child_state) elif child in self.adapter_modules: @@ -468,7 +473,7 @@ def compose(self, adapter_setup: Union[AdapterCompositionBlock, str], state: Nam NamedTuple: The state after forwarding through the adapter setup. """ if isinstance(adapter_setup, AdapterCompositionBlock): - composition_func = self.composition_to_func_map[type(adapter_setup)] + composition_func = self._get_compose_func(type(adapter_setup)) state = composition_func(adapter_setup, state, lvl=0) elif adapter_setup in self.adapter_modules: state = self.compose_single(adapter_setup, state, lvl=0) diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index c150191695..ab78b88e9c 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -255,7 +255,7 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0 for child in adapter_setup: if isinstance(child, AdapterCompositionBlock): self.check_composition_valid(adapter_setup, child, lvl) - composition_func = self.composition_to_func_map[type(child)] + composition_func = self._get_compose_func(type(child)) child_state = composition_func(child, state, lvl=lvl + 1) children_states.append(child_state) elif child in self.adapter_modules: @@ -311,7 +311,7 @@ def compose_split(self, adapter_setup: Split, state: BottleneckState, lvl: int = ) if isinstance(child, AdapterCompositionBlock): self.check_composition_valid(adapter_setup, child, lvl) - composition_func = self.composition_to_func_map[type(child)] + composition_func = self._get_compose_func(type(child)) child_state = composition_func(child, child_state, lvl=lvl + 1) children_states.append(child_state) elif child in self.adapter_modules: