Skip to content

Commit

Permalink
Expose logic that checks whether an element fits in a packing bin.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715735425
  • Loading branch information
Grain Team authored and copybara-github committed Jan 15, 2025
1 parent 1fab87c commit f3bec91
Showing 1 changed file with 34 additions and 21 deletions.
55 changes: 34 additions & 21 deletions grain/_src/python/dataset/transformations/packing_packed_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@


@dataclasses.dataclass(frozen=True, kw_only=True)
class _SuccessfulRowOrFailingComponents:
class SuccessfulRowOrFailingComponents:
# Holds the index of the row to put a new element into if it can fit,
# or None if it can't fit into any row.
row: int | None
Expand Down Expand Up @@ -163,40 +163,42 @@ def get_packed_batch(self):
meta_features=self._meta_features,
)

def _can_add_at_row(
self, element: jt.PyTree[np.ndarray]
) -> _SuccessfulRowOrFailingComponents:
@classmethod
def can_add_at_row(
cls,
element_feature_lengths: jt.PyTree[int],
num_packing_bins: int,
length_struct: jt.PyTree[int],
first_free_cell_per_row: jt.PyTree[int],
) -> SuccessfulRowOrFailingComponents:
"""Checks whether the element can be added in any of the rows.
Args:
element: The element we are trying to fit into a row in the batch.
element_feature_lengths: The lengths of each feature in the element.
num_packing_bins: The number of packing bins.
length_struct: The max length of each feature.
first_free_cell_per_row: The first free cell per row.
Returns:
SuccessfulRowOrFailingComponents: If the element fits into a row,
return the index of that row. If it doesn't fit in any of the rows,
return the names of the components that caused it to fail to fit.
"""
tree.assert_same_structure(element, self._length_struct)

element_feature_lengths = tree.map_structure(
lambda x: 1 if np.ndim(x) == 0 else len(x), element
)

# Check no feature exceeds max length
features_exceeding_max_length = []
for (path, feature_length), (_, max_length) in zip(
tree.flatten_with_path(element_feature_lengths),
tree.flatten_with_path(self._length_struct),
tree.flatten_with_path(length_struct),
strict=True,
):
if feature_length > max_length:
features_exceeding_max_length.append((path, feature_length, max_length))

if features_exceeding_max_length:
raise ValueError(
f"Inputs to {self.__class__.__name__} must be truncated to max"
" length. Received the following features that exceed their max: "
"(feature_path, feature_length, max_length) = "
f"Inputs to {cls.__name__} must be truncated to max length. Received "
"the following features that exceed their max: (feature_path, "
"feature_length, max_length) = "
f"{features_exceeding_max_length}"
)

Expand All @@ -209,16 +211,16 @@ def _feature_will_fit(feature_length, first_free_cell, max_length):
tree.map_structure(
_feature_will_fit,
element_feature_lengths,
self._first_free_cell_per_row,
self._length_struct,
first_free_cell_per_row,
length_struct,
)
)

# Pick first row (if exists) where element can be added.
for i in range(self._num_packing_bins): # For each row.
for i in range(num_packing_bins): # For each row.
if all(free[i] for _, free in is_row_free_struct):
# All components are free at that row.
return _SuccessfulRowOrFailingComponents(row=i, failing_components=None)
return SuccessfulRowOrFailingComponents(row=i, failing_components=None)

# There is no guarantee we have a single failing component, since one
# component could be the reason an element could not fit in one row
Expand All @@ -237,7 +239,7 @@ def _feature_will_fit(feature_length, first_free_cell, max_length):
reverse=True,
)
failing_components = [e[0] for e in sorted_failing_components if e[1] > 0]
return _SuccessfulRowOrFailingComponents(
return SuccessfulRowOrFailingComponents(
row=None, failing_components=failing_components
)

Expand Down Expand Up @@ -279,7 +281,18 @@ def try_add_to_batch(self, element) -> list[str] | None:
could not be added, returns a list of strings indicating the components
that failed.
"""
successful_row_or_failing_component = self._can_add_at_row(element)
tree.assert_same_structure(element, self._length_struct)

element_feature_lengths = tree.map_structure(
lambda x: 1 if np.ndim(x) == 0 else len(x), element
)

successful_row_or_failing_component = self.can_add_at_row(
element_feature_lengths,
self._num_packing_bins,
self._length_struct,
self._first_free_cell_per_row,
)
successful_row = successful_row_or_failing_component.row
failing_components = successful_row_or_failing_component.failing_components
if successful_row is None:
Expand Down

0 comments on commit f3bec91

Please sign in to comment.