Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: dead lock in transfer actor in the case of GPU #488

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/xorbits/_mars/services/storage/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ async def clean_up(self):
self._file.close()

async def close(self):
logger.debug(
"Writer closed for %s, %s on %s",
self._session_id,
self._data_key,
self._data_manager.address,
)
self._file.close()
if self._object_id is None:
# for some backends like vineyard,
Expand Down Expand Up @@ -322,6 +328,9 @@ def delete_data_info(
level: StorageLevel,
band_name: str,
):
logger.debug(
"Deleting %s, %s from %s, %s", session_id, data_key, level, band_name
)
if (session_id, data_key) in self._data_key_to_infos:
self._data_info_list[level, band_name].pop((session_id, data_key))
self._spill_strategy[level, band_name].record_delete_info(
Expand Down
170 changes: 153 additions & 17 deletions python/xorbits/_mars/services/storage/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
self._quota_refs = quota_refs
self._band_name = band_name
self._supervisor_address = None
self._lock = asyncio.Lock()

@classmethod
def gen_uid(cls, band_name: str):
Expand Down Expand Up @@ -284,6 +285,7 @@ async def delete_object(
level: StorageLevel,
):
data_key = await self._data_manager_ref.get_store_key(session_id, data_key)
logger.debug("Delete object %s, %s on %s", session_id, data_key, self.address)
await self._data_manager_ref.delete_data_info(
session_id, data_key, level, self._band_name
)
Expand All @@ -292,6 +294,7 @@ async def delete_object(

@mo.extensible
async def delete(self, session_id: str, data_key: str, error: str = "raise"):
logger.debug("Delete %s, %s on %s", session_id, data_key, self.address)
if error not in ("raise", "ignore"): # pragma: no cover
raise ValueError("error must be raise or ignore")

Expand Down Expand Up @@ -367,6 +370,9 @@ async def batch_delete(self, args_list, kwargs_list):
session_id, key, level, info.band
)
)
logger.debug(
"Batch delete %s, %s on %s", session_id, key, self.address
)
to_removes.append((level, info.object_id))
level_sizes[level] += info.store_size

Expand All @@ -382,6 +388,7 @@ async def batch_delete(self, args_list, kwargs_list):
await self._quota_refs[level].release_quota(size)

@mo.extensible
@mo.no_lock
async def open_reader(self, session_id: str, data_key: str) -> StorageFileObject:
data_info = await self._data_manager_ref.get_data_info(
session_id, data_key, self._band_name
Expand All @@ -390,6 +397,7 @@ async def open_reader(self, session_id: str, data_key: str) -> StorageFileObject
return reader

@open_reader.batch
@mo.no_lock
async def batch_open_readers(self, args_list, kwargs_list):
get_data_infos = []
for args, kwargs in zip(args_list, kwargs_list):
Expand Down Expand Up @@ -522,7 +530,21 @@ async def _fetch_remote(
await self._data_manager_ref.put_data_info.batch(*put_data_info_delays)
await asyncio.gather(*fetch_tasks)

async def _fetch_via_transfer(
async def get_receive_manager_ref(self, band_name: str):
from .transfer import ReceiverManagerActor

return await mo.actor_ref(
address=self.address,
uid=ReceiverManagerActor.gen_uid(band_name),
)

@staticmethod
async def get_send_manager_ref(address: str, band: str):
from .transfer import SenderManagerActor

return await mo.actor_ref(address=address, uid=SenderManagerActor.gen_uid(band))

async def fetch_via_transfer(
self,
session_id: str,
data_keys: List[Union[str, tuple]],
Expand All @@ -531,21 +553,136 @@ async def _fetch_via_transfer(
fetch_band_name: str,
error: str,
):
from .transfer import SenderManagerActor
from .transfer import ReceiverManagerActor, SenderManagerActor

logger.debug("Begin to fetch %s from band %s", data_keys, remote_band)
sender_ref: mo.ActorRefType[SenderManagerActor] = await mo.actor_ref(
address=remote_band[0], uid=SenderManagerActor.gen_uid(remote_band[1])

remote_data_manager_ref: mo.ActorRefType[DataManagerActor] = await mo.actor_ref(
address=remote_band[0], uid=DataManagerActor.default_uid()
)
await sender_ref.send_batch_data(
session_id,

logger.debug("Getting actual keys for %s", data_keys)
tasks = []
for key in data_keys:
tasks.append(remote_data_manager_ref.get_store_key.delay(session_id, key))
data_keys = await remote_data_manager_ref.get_store_key.batch(*tasks)
data_keys = list(set(data_keys))

logger.debug("Getting sub infos for %s", data_keys)
sub_infos = await remote_data_manager_ref.get_sub_infos.batch(
*[
remote_data_manager_ref.get_sub_infos.delay(session_id, key)
for key in data_keys
]
)

get_info_tasks = []
pin_tasks = []
for data_key in data_keys:
get_info_tasks.append(
remote_data_manager_ref.get_data_info.delay(
session_id, data_key, remote_band[1], error
)
)
pin_tasks.append(
remote_data_manager_ref.pin.delay(
session_id, data_key, remote_band[1], error
)
)
logger.debug("Getting data infos for %s", data_keys)
infos = await remote_data_manager_ref.get_data_info.batch(*get_info_tasks)
logger.debug("Pining %s", data_keys)
await remote_data_manager_ref.pin.batch(*pin_tasks)

filtered = [
(data_info, data_key)
for data_info, data_key in zip(infos, data_keys)
if data_info is not None
]
if filtered:
infos, data_keys = zip(*filtered)
else: # pragma: no cover
# no data to be transferred
return []
data_sizes = [info.store_size for info in infos]

if level is None:
level = infos[0].level

receiver_ref: mo.ActorRefType[
ReceiverManagerActor
] = await self.get_receive_manager_ref(fetch_band_name)

await self.request_quota_with_spill(level, sum(data_sizes))

open_writer_tasks = []
for data_key, data_size, sub_info in zip(data_keys, data_sizes, sub_infos):
open_writer_tasks.append(
self.open_writer.delay(
session_id,
data_key,
data_size,
level,
request_quota=False,
band_name=fetch_band_name,
)
)
writers = await self.open_writer.batch(*open_writer_tasks)
is_transferring_list = await receiver_ref.add_writers(
session_id, data_keys, data_sizes, sub_infos, writers, level
)

to_send_keys = []
to_wait_keys = []
wait_sizes = []
for data_key, is_transferring, _size in zip(
data_keys, is_transferring_list, data_sizes
):
if is_transferring:
to_wait_keys.append(data_key)
wait_sizes.append(_size)
else:
to_send_keys.append(data_key)

# Overapplied the quota for these wait keys, and now need to update the quota
if to_wait_keys:
self._quota_refs[level].update_quota(-sum(wait_sizes))

logger.debug(
"Start transferring %s from %s to %s",
data_keys,
self._data_manager_ref.address,
level,
fetch_band_name,
error=error,
remote_band,
(self.address, fetch_band_name),
)
logger.debug("Finish fetching %s from band %s", data_keys, remote_band)
sender_ref: mo.ActorRefType[
SenderManagerActor
] = await self.get_send_manager_ref(remote_band[0], remote_band[1])

try:
await sender_ref.send_batch_data(
session_id,
data_keys,
to_send_keys,
to_wait_keys,
(self.address, fetch_band_name),
)
await receiver_ref.handle_transmission_done(session_id, to_send_keys)
except asyncio.CancelledError:
keys_to_delete = await receiver_ref.handle_transmission_cancellation(
session_id, to_send_keys
)
for key in keys_to_delete:
await self.delete(session_id, key, error="ignore")
raise

unpin_tasks = []
for data_key in data_keys:
unpin_tasks.append(
remote_data_manager_ref.unpin.delay(
session_id, [data_key], remote_band[1], error="ignore"
)
)
await remote_data_manager_ref.unpin.batch(*unpin_tasks)

async def fetch_batch(
self,
Expand All @@ -559,10 +696,8 @@ async def fetch_batch(
if error not in ("raise", "ignore"): # pragma: no cover
raise ValueError("error must be raise or ignore")

meta_api = await self._get_meta_api(session_id)
remote_keys = defaultdict(set)
missing_keys = []
get_metas = []
get_info_delays = []
for data_key in data_keys:
get_info_delays.append(
Expand All @@ -586,6 +721,9 @@ async def fetch_batch(
else:
# Not exists in local, fetch from remote worker
missing_keys.append(data_key)
await self._data_manager_ref.pin.batch(*pin_delays)

meta_api = await self._get_meta_api(session_id)
if address is None or band_name is None:
# some mapper keys are absent, specify error='ignore'
# remember that meta only records those main keys
Expand All @@ -599,16 +737,14 @@ async def fetch_batch(
)
for data_key in missing_keys
]
await self._data_manager_ref.pin.batch(*pin_delays)

if get_metas:
metas = await meta_api.get_chunk_meta.batch(*get_metas)
else: # pragma: no cover
metas = [{"bands": [(address, band_name)]}] * len(missing_keys)
assert len(metas) == len(missing_keys)
for data_key, bands in zip(missing_keys, metas):
if bands is not None:
remote_keys[bands["bands"][0]].add(data_key)

transfer_tasks = []
fetch_keys = []
for band, keys in remote_keys.items():
Expand All @@ -620,7 +756,7 @@ async def fetch_batch(
else:
# fetch via transfer
transfer_tasks.append(
self._fetch_via_transfer(
self.fetch_via_transfer(
session_id, list(keys), level, band, band_name or band[1], error
)
)
Expand Down
Loading