diff --git a/surya/table_rec/__init__.py b/surya/table_rec/__init__.py index f14e7a0..a0d74df 100644 --- a/surya/table_rec/__init__.py +++ b/surya/table_rec/__init__.py @@ -74,7 +74,7 @@ def inference_loop( processed_logits = {} for k, _, mode in BOX_PROPERTIES: k_logits = return_dict["box_property_logits"][k][:, -1, :] # Get all batch logits at once - + if mode == "classification": # Process all classification logits in one operation items = torch.argmax(k_logits, dim=-1) @@ -104,13 +104,12 @@ def inference_loop( box_property[k] = int(items[k][j].item()) box_properties.append(box_property) - all_done = all_done | done + all_done = all_done | done.cpu() mark_step() if all_done.all(): break - batch_input_ids = torch.tensor(shaper.dict_to_labels(box_properties), dtype=torch.long).to(self.model.device) batch_input_ids = batch_input_ids.unsqueeze(1) # Add sequence length dimension @@ -222,7 +221,6 @@ def batch_table_recognition( return output_order - def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper): results = [] for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)): @@ -309,9 +307,9 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si used_spanning_cells.add(zz) spanning_cell.col_id = l cells.append(spanning_cell) - skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell + skip_columns = spanning_cell.colspan - 1 # Skip columns that are part of the spanning cell else: - used_spanning_cells.add(zz) # Skip this spanning cell + used_spanning_cells.add(zz) # Skip this spanning cell if not cell_added: cells.append( @@ -371,4 +369,4 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si image_bbox=[0, 0, orig_size[0], orig_size[1]], ) results.append(result) - return results \ No newline at end of file + return results