Skip to content

Commit

Permalink
fix table rec device bug
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Feb 6, 2025
1 parent 3f6ff8d commit 90d6a2a
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions surya/table_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def inference_loop(
processed_logits = {}
for k, _, mode in BOX_PROPERTIES:
k_logits = return_dict["box_property_logits"][k][:, -1, :] # Get all batch logits at once

if mode == "classification":
# Process all classification logits in one operation
items = torch.argmax(k_logits, dim=-1)
Expand Down Expand Up @@ -104,13 +104,12 @@ def inference_loop(
box_property[k] = int(items[k][j].item())
box_properties.append(box_property)

all_done = all_done | done
all_done = all_done | done.cpu()

mark_step()
if all_done.all():
break


batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long).to(self.model.device)
batch_input_ids = batch_input_ids.unsqueeze(1) # Add sequence length dimension

Expand Down Expand Up @@ -222,7 +221,6 @@ def batch_table_recognition(

return output_order


def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper):
results = []
for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)):
Expand Down Expand Up @@ -309,9 +307,9 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
used_spanning_cells.add(zz)
spanning_cell.col_id = l
cells.append(spanning_cell)
skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell
skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell
else:
used_spanning_cells.add(zz) # Skip this spanning cell
used_spanning_cells.add(zz) # Skip this spanning cell

if not cell_added:
cells.append(
Expand Down Expand Up @@ -371,4 +369,4 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
image_bbox=[0, 0, orig_size[0], orig_size[1]],
)
results.append(result)
return results
return results

0 comments on commit 90d6a2a

Please sign in to comment.