Skip to content

Commit

Permalink
✨Enhanced async_fetch
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Oct 17, 2024
1 parent bffd55a commit f6f11b5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion core/learn/data/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def async_submit(self, cursor: int, index: Any) -> bool:
self._map[cursor] = self.x[index], None if self.y is None else self.y[index]
return True

def async_fetch(self, cursor: int) -> tensor_dict_type:
def async_fetch(self, cursor: int, index: Any) -> tensor_dict_type:
x, y = self._map.pop(cursor)
batch = {INPUT_KEY: x, LABEL_KEY: y}
batch = np_batch_to_tensor(batch)
Expand Down
14 changes: 7 additions & 7 deletions core/learn/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,23 @@ def async_submit(self, cursor: int, index: Any) -> bool:
"""return whether the submission is successful"""

@abstractmethod
def async_fetch(self, cursor: int) -> Optional[Any]:
def async_fetch(self, cursor: int, index: Any) -> Optional[Any]:
"""fetch the data after submission, return None if not ready"""

@abstractmethod
def async_finalize(self) -> None:
"""finalize the dataset at the end of each epoch"""

def poll(self, cursor: int) -> Any:
def poll(self, cursor: int, index: Any) -> Any:
while True:
fetched = self.async_fetch(cursor)
fetched = self.async_fetch(cursor, index)
if fetched is not None:
return fetched
time.sleep(0.01) # pragma: no cover


class AsyncDataLoaderIter(_SingleProcessDataLoaderIter):
_queue: Optional[List[Any]]
_queue: Optional[List[Tuple[int, Any]]]
_drained: bool
_queue_cursor: int
_dataset: IAsyncDataset
Expand Down Expand Up @@ -291,7 +291,7 @@ def _sumbit_next(self) -> None:
msg = f"failed to submit async task with cursor={cursor} and index={index}"
console.error(msg)
raise RuntimeError("failed to sumbit async task")
self._queue.append(cursor) # type: ignore
self._queue.append((cursor, index)) # type: ignore
self._queue_cursor = cursor + 1

def _next_data(self) -> Any:
Expand All @@ -314,8 +314,8 @@ def _next_data(self) -> Any:
self._sumbit_next()
except StopIteration:
self._drained = True
cursor = self._queue.pop(0)
data = self._dataset.poll(cursor)
cursor, index = self._queue.pop(0)
data = self._dataset.poll(cursor, index)
if self._pin_memory: # pragma: no cover
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data
Expand Down

0 comments on commit f6f11b5

Please sign in to comment.