Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loading LoRA adapters from clients' requests #506

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/petals/server/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ class TransformerBackend(ModuleBackend):
def __init__(
self,
*args,
block_index: int,
config: PretrainedConfig,
memory_cache: MemoryCache,
backend_dtype: torch.dtype,
max_chunk_size_bytes: int,
cache_dir: str,
max_disk_space: int,
**kwargs,
):
import petals.utils.peft as _peft_module
Expand All @@ -41,25 +44,28 @@ def __init__(

super().__init__(*args, **kwargs)
assert isinstance(self.module, TensorParallel)
self.block_index = block_index
self.config = config
self.memory_cache = memory_cache
self.max_chunk_size_bytes = max_chunk_size_bytes
self.cache_dir = cache_dir
self.max_disk_space = max_disk_space

for name, param in self.module.named_parameters():
assert not param.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"
for name, buf in self.module.named_buffers():
assert not buf.requires_grad, f"Block parameters must not accumulate gradients, but {name} does"

max_batch_size = self.forward_pool.max_batch_size
device = self.module.devices[self.module.output_device_index]
self.device = self.module.devices[self.module.output_device_index]
self.inference_pool = PrioritizedTaskPool(
self.inference_step, max_batch_size=max_batch_size, device=device, name=f"{self.name}_inference"
self.inference_step, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_inference"
) # note: inference_pools may be merged later, see merge_inference_pools_inplace
self.forward_pool = PrioritizedTaskPool(
self.forward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_forward"
self.forward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_forward"
)
self.backward_pool = PrioritizedTaskPool(
self.backward, max_batch_size=max_batch_size, device=device, name=f"{self.name}_backward"
self.backward, max_batch_size=max_batch_size, device=self.device, name=f"{self.name}_backward"
)

self.dtype = backend_dtype
Expand Down
59 changes: 54 additions & 5 deletions src/petals/server/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MSGPackSerializer,
P2PContext,
PeerID,
TensorDescriptor,
deserialize_tensor_stream,
deserialize_torch_tensor,
nested_flatten,
Expand Down Expand Up @@ -152,6 +153,7 @@ async def rpc_inference(
session_id = metadata.get("session_id")
alloc_timeout = float(metadata.get("alloc_timeout", 0.0))
args_structure = metadata.get("args_structure")
active_adapter = self._get_active_adapter(metadata)
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(
Expand All @@ -169,12 +171,14 @@ async def rpc_inference(

async with self._allocate_cache(
requested_backends, batch_size=batch_size, max_length=max_length, timeout=alloc_timeout
) as cache_handles:
) as cache_handles, self._load_peft_module(
requested_backends, active_adapter=active_adapter, timeout=alloc_timeout
):
background_tasks = set()
async for output_tensors, can_push in iterate_rpc_inference(
requested_uids=requested_uids,
requested_backends=requested_backends,
active_adapter=self._get_active_adapter(metadata),
active_adapter=active_adapter,
input_iterator=self._iterate_inference_steps(
request, requests, session_id, requested_uids, context
),
Expand Down Expand Up @@ -489,9 +493,9 @@ async def rpc_backward_stream(

def _get_active_adapter(self, metadata: dict) -> str:
active_adapter = metadata.get("active_adapter", "")
if active_adapter and (active_adapter not in self.adapters):
raise KeyError(f"adapter {active_adapter} not found")
return active_adapter
if active_adapter:
return active_adapter
return ""

def _serialize_grads(
self,
Expand Down Expand Up @@ -546,6 +550,51 @@ async def _allocate_cache(
async with backends[0].memory_cache.allocate_cache(*chain(*descriptors), timeout=timeout) as handles:
yield nested_pack(handles, descriptors)

@contextlib.asynccontextmanager
async def _load_peft_module(
self,
backends: Sequence[TransformerBackend],
*,
active_adapter: str,
timeout: float,
):
if active_adapter == "":
yield
elif active_adapter in self.adapters:
yield
else:
_peft_module = backends[0]._peft_module
token = None # TODO: Provide token from user request maybe?

estimated_peft_size = _peft_module.get_estimated_peft_module_size(
active_adapter,
token=token,
)

fake_descriptor = TensorDescriptor(
size=(estimated_peft_size,),
dtype=torch.int8,
device=backends[0].device,
)

async with backends[0].memory_cache.allocate_cache(fake_descriptor, timeout=timeout) as _:
try:
for backend in backends:
adapter_config, adapter_state_dict = _peft_module.load_peft(
active_adapter,
block_idx=backend.block_index,
token=token,
cache_dir=backend.cache_dir,
max_disk_space=backend.max_disk_space,
)

_peft_module.add_adapter_to_block(
backend.module, backend.block_index, active_adapter, adapter_config, adapter_state_dict
)
finally:
for backend in backends:
_peft_module.remove_adapter_from_block(backend.module, active_adapter)

def _log_request(
self,
method: str,
Expand Down
9 changes: 8 additions & 1 deletion src/petals/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ def __init__(
gib = 1024**3
self.attn_cache_bytes = self._cache_bytes_per_block * num_blocks
logger.info(f"Attention cache for all blocks will consume up to {self.attn_cache_bytes / gib:.2f} GiB")
self.adapters_cache_bytes = self.attn_cache_bytes
logger.info(f"Adapter cache for all blocks will consume up to {self.adapters_cache_bytes / gib:.2f} GiB")

assert isinstance(throughput, float) or throughput in ["auto", "eval", "dry_run"]
if throughput in ["auto", "eval", "dry_run"]:
Expand Down Expand Up @@ -335,6 +337,7 @@ def run(self):
converted_model_name_or_path=self.converted_model_name_or_path,
block_config=self.block_config,
attn_cache_bytes=self.attn_cache_bytes,
adapters_cache_bytes=self.adapters_cache_bytes,
server_info=self.server_info,
model_info=self.model_info,
block_indices=block_indices,
Expand Down Expand Up @@ -442,6 +445,7 @@ def create(
converted_model_name_or_path: str,
block_config: PretrainedConfig,
attn_cache_bytes: int,
adapters_cache_bytes: int,
server_info: ServerInfo,
model_info: ModelInfo,
block_indices: List[int],
Expand All @@ -464,7 +468,7 @@ def create(
**kwargs,
) -> ModuleContainer:
module_uids = [f"{dht_prefix}{UID_DELIMITER}{block_index}" for block_index in block_indices]
memory_cache = MemoryCache(attn_cache_bytes, max_alloc_timeout)
memory_cache = MemoryCache(attn_cache_bytes + adapters_cache_bytes, max_alloc_timeout)

server_info.state = ServerState.JOINING
dht_announcer = ModuleAnnouncerThread(
Expand Down Expand Up @@ -512,10 +516,13 @@ def create(
blocks[module_uid] = TransformerBackend(
module_uid,
block,
block_index=block_index,
config=block_config,
memory_cache=memory_cache,
backend_dtype=torch_dtype,
max_chunk_size_bytes=max_chunk_size_bytes,
cache_dir=cache_dir,
max_disk_space=max_disk_space,
args_schema=(
BatchTensorDescriptor(
1, 2048, block_config.hidden_size, dtype=torch_dtype, compression=compression
Expand Down
7 changes: 4 additions & 3 deletions src/petals/utils/convert_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,11 @@ def convert_block(
for shard, device in zip(block.module_shards, block.devices):
shard.to(device)

if adapters:
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft
from petals.utils.peft import add_adapter_to_block, create_lora_adapter, load_peft

create_lora_adapter(block, quant_type=quant_type)

create_lora_adapter(block, quant_type=quant_type)
if adapters:
for adapter_name in adapters:
adapter_config, adapter_state_dict = load_peft(
adapter_name,
Expand Down
5 changes: 0 additions & 5 deletions src/petals/utils/dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ async def _get_remote_module_infos(
try:
peer_id = PeerID.from_base58(peer_id)
server_info = ServerInfo.from_tuple(server_info.value)

if active_adapter and active_adapter not in server_info.adapters:
logger.debug(f"Skipped server {peer_id} since it does not have adapter {active_adapter}")
continue

servers[peer_id] = server_info
except (TypeError, ValueError) as e:
logger.warning(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
Expand Down
25 changes: 25 additions & 0 deletions src/petals/utils/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def load_peft(
time.sleep(delay)


def get_estimated_peft_module_size(
repo_id: str,
revision: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
):
weight_url = hf_hub_url(repo_id, SAFETENSORS_WEIGHTS_NAME, revision=revision)
return get_hf_file_metadata(weight_url, token=token).size


class AdapterContextMixin:
"""A mixin that makes LoRA-wrapped linear layers obey an adapter set from context"""

Expand Down Expand Up @@ -267,6 +276,22 @@ def add_adapter_to_block(block, block_index, adapter_name, peft_config, peft_sta
logger.info(f"Loaded adapter {adapter_name} for block {block_index}")


def remove_adapter_from_block(block, adapter_name):
for _, module in block.named_modules():
for child_name, child in module.named_children():
if not isinstance(child, (lora.Linear, lora.Linear8bitLt, lora.Linear4bit)):
continue

if adapter_name in child.lora_A:
del child.lora_A[adapter_name]
if adapter_name in child.lora_B:
del child.lora_B[adapter_name]

# TODO: check is this needed
if torch.cuda.is_available():
torch.cuda.empty_cache()


def estimate_adapter_memory_per_block(
block_config: transformers.PretrainedConfig,
torch_dtype: Optional[torch.dtype],
Expand Down