diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 3a9b63ef4..9987a799c 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -29,10 +29,13 @@ class TransformerBackend(ModuleBackend): def __init__( self, *args, + block_index: int, config: PretrainedConfig, memory_cache: MemoryCache, backend_dtype: torch.dtype, max_chunk_size_bytes: int, + cache_dir: str, + max_disk_space: int, **kwargs, ): import petals.utils.peft as _peft_module @@ -41,9 +44,12 @@ def __init__( super().__init__(*args, **kwargs) assert isinstance(self.module, TensorParallel) + self.block_index = block_index self.config = config self.memory_cache = memory_cache self.max_chunk_size_bytes = max_chunk_size_bytes + self.cache_dir = cache_dir + self.max_disk_space = max_disk_space for name, param in self.module.named_parameters(): assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" @@ -51,15 +57,15 @@ def __init__( assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does" max_batch_size = self.forward_pool.max_batch_size - device = self.module.devices[self.module.output_device_index] + self.device = self.module.devices[self.module.output_device_index] self.inference_pool = PrioritizedTaskPool( - self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference" + self.inference_step, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_inference" ) # note: inference_pools may be merged later, see merge_inference_pools_inplace self.forward_pool = PrioritizedTaskPool( - self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward" + self.forward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_forward" ) self.backward_pool = PrioritizedTaskPool( - self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward" + self.backward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_backward" ) self.dtype = backend_dtype diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index d8f0ec05e..ef0e9b93d 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -15,6 +15,7 @@ MSGPackSerializer, P2PContext, PeerID, + TensorDescriptor, deserialize_tensor_stream, deserialize_torch_tensor, nested_flatten, @@ -152,6 +153,7 @@ async def rpc_inference( session_id = metadata.get("session_id") alloc_timeout = float(metadata.get("alloc_timeout", 0.0)) args_structure = metadata.get("args_structure") + active_adapter = self._get_active_adapter(metadata) if not requested_uids: raise ValueError("User must specify at least one block for inference, but got none") assert isinstance( @@ -169,12 +171,14 @@ async def rpc_inference( async with self._allocate_cache( requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout - ) as cache_handles: + ) as cache_handles, self._load_peft_module( + requested_backends, active_adapter=active_adapter, timeout=alloc_timeout + ): background_tasks = set() async for output_tensors, can_push in iterate_rpc_inference( requested_uids=requested_uids, requested_backends=requested_backends, - active_adapter=self._get_active_adapter(metadata), + active_adapter=active_adapter, input_iterator=self._iterate_inference_steps( request, requests, session_id, requested_uids, context ), @@ -489,9 +493,9 @@ async def rpc_backward_stream( def _get_active_adapter(self, metadata: dict) -> str: active_adapter = metadata.get("active_adapter", "") - if active_adapter and (active_adapter not in self.adapters): - raise KeyError(f"adapter {active_adapter} not found") - return active_adapter + if active_adapter: + return active_adapter + return "" def _serialize_grads( self, @@ -546,6 +550,51 @@ async def _allocate_cache( async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles: yield nested_pack(handles, descriptors) + @contextlib.asynccontextmanager + async def _load_peft_module( + self, + backends: Sequence[TransformerBackend], + *, + active_adapter: str, + timeout: float, + ): + if active_adapter == "": + yield + elif active_adapter in self.adapters: + yield + else: + _peft_module = backends[0]._peft_module + token = None # TODO: Provide token from user request maybe? + + estimated_peft_size = _peft_module.get_estimated_peft_module_size( + active_adapter, + token=token, + ) + + fake_descriptor = TensorDescriptor( + size=(estimated_peft_size,), + dtype=torch.int8, + device=backends[0].device, + ) + + async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _: + try: + for backend in backends: + adapter_config, adapter_state_dict = _peft_module.load_peft( + active_adapter, + block_idx=backend.block_index, + token=token, + cache_dir=backend.cache_dir, + max_disk_space=backend.max_disk_space, + ) + + _peft_module.add_adapter_to_block( + backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict + ) + finally: + for backend in backends: + _peft_module.remove_adapter_from_block(backend.module, active_adapter) + def _log_request( self, method: str, diff --git a/src/petals/server/server.py b/src/petals/server/server.py index fd9f76660..5370c99f2 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -231,6 +231,8 @@ def __init__( gib = 1024**3 self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB") + self.adapters_cache_bytes = self.attn_cache_bytes + logger.info(f"Adapter cache for all blocks will consume up to {self.adapters_cache_bytes / gib:.2f} GiB") assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"] if throughput in ["auto", "eval", "dry_run"]: @@ -335,6 +337,7 @@ def run(self): converted_model_name_or_path=self.converted_model_name_or_path, block_config=self.block_config, attn_cache_bytes=self.attn_cache_bytes, + adapters_cache_bytes=self.adapters_cache_bytes, server_info=self.server_info, model_info=self.model_info, block_indices=block_indices, @@ -442,6 +445,7 @@ def create( converted_model_name_or_path: str, block_config: PretrainedConfig, attn_cache_bytes: int, + adapters_cache_bytes: int, server_info: ServerInfo, model_info: ModelInfo, block_indices: List[int], @@ -464,7 +468,7 @@ def create( **kwargs, ) -> ModuleContainer: module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices] - memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout) + memory_cache = MemoryCache(attn_cache_bytes + adapters_cache_bytes, max_alloc_timeout) server_info.state = ServerState.JOINING dht_announcer = ModuleAnnouncerThread( @@ -512,10 +516,13 @@ def create( blocks[module_uid] = TransformerBackend( module_uid, block, + block_index=block_index, config=block_config, memory_cache=memory_cache, backend_dtype=torch_dtype, max_chunk_size_bytes=max_chunk_size_bytes, + cache_dir=cache_dir, + max_disk_space=max_disk_space, args_schema=( BatchTensorDescriptor( 1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression diff --git a/src/petals/utils/convert_block.py b/src/petals/utils/convert_block.py index 94d3e29f3..30aa96973 100644 --- a/src/petals/utils/convert_block.py +++ b/src/petals/utils/convert_block.py @@ -58,10 +58,11 @@ def convert_block( for shard, device in zip(block.module_shards, block.devices): shard.to(device) - if adapters: - from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft + from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft + + create_lora_adapter(block, quant_type=quant_type) - create_lora_adapter(block, quant_type=quant_type) + if adapters: for adapter_name in adapters: adapter_config, adapter_state_dict = load_peft( adapter_name, diff --git a/src/petals/utils/dht.py b/src/petals/utils/dht.py index 0710f60e5..6a9952d11 100644 --- a/src/petals/utils/dht.py +++ b/src/petals/utils/dht.py @@ -111,11 +111,6 @@ async def _get_remote_module_infos( try: peer_id = PeerID.from_base58(peer_id) server_info = ServerInfo.from_tuple(server_info.value) - - if active_adapter and active_adapter not in server_info.adapters: - logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}") - continue - servers[peer_id] = server_info except (TypeError, ValueError) as e: logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}") diff --git a/src/petals/utils/peft.py b/src/petals/utils/peft.py index e4d29fc64..76d00ad40 100644 --- a/src/petals/utils/peft.py +++ b/src/petals/utils/peft.py @@ -128,6 +128,15 @@ def load_peft( time.sleep(delay) +def get_estimated_peft_module_size( + repo_id: str, + revision: Optional[str] = None, + token: Optional[Union[str, bool]] = None, +): + weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision) + return get_hf_file_metadata(weight_url, token=token).size + + class AdapterContextMixin: """A mixin that makes LoRA-wrapped linear layers obey an adapter set from context""" @@ -267,6 +276,22 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta logger.info(f"Loaded adapter {adapter_name} for block {block_index}") +def remove_adapter_from_block(block, adapter_name): + for _, module in block.named_modules(): + for child_name, child in module.named_children(): + if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)): + continue + + if adapter_name in child.lora_A: + del child.lora_A[adapter_name] + if adapter_name in child.lora_B: + del child.lora_B[adapter_name] + + # TODO: check is this needed + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def estimate_adapter_memory_per_block( block_config: transformers.PretrainedConfig, torch_dtype: Optional[torch.dtype],