Skip to content

Commit

Permalink
Fix predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 7, 2025
1 parent 4c5a180 commit 9616057
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def TORCH_DEVICE_MODEL(self) -> str:
ORDER_BENCH_DATASET_NAME: str = "vikp/order_bench"

# Table Rec
TABLE_REC_MODEL_CHECKPOINT: str = "datalab-to/table_rec_2_test2"
TABLE_REC_MODEL_CHECKPOINT: str = "datalab-to/surya_tablerec"
TABLE_REC_IMAGE_SIZE: Dict = {"height": 768, "width": 768}
TABLE_REC_MAX_BOXES: int = 150
TABLE_REC_BATCH_SIZE: Optional[int] = None
Expand Down
16 changes: 10 additions & 6 deletions surya/table_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def batch_table_recognition(
with torch.inference_mode():
encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state

# Inference to get rows and columns
rowcol_predictions = self.inference_loop(
encoder_hidden_states,
batch_input_ids,
Expand Down Expand Up @@ -194,24 +195,27 @@ def batch_table_recognition(
"merges": 0,
})

# Re-inference to predict cells
row_encoder_hidden_states = torch.stack(row_encoder_hidden_states)
row_inputs = self.processor(images=None, query_items=row_query_items, columns=columns, convert_images=False)
row_input_ids = row_inputs["input_ids"].to(self.model.device)
cell_predictions = []
for j in tqdm(range(0, len(row_input_ids), batch_size), desc="Recognizing tables"):
for j in tqdm(range(0, len(row_input_ids), batch_size), desc="Recognizing table cells"):
cell_batch_hidden_states = row_encoder_hidden_states[j:j + batch_size]
cell_batch_input_ids = row_input_ids[j:j + batch_size]
cell_batch_size = len(cell_batch_input_ids)
cell_predictions.extend(
self.inference_loop(cell_batch_hidden_states, cell_batch_input_ids, cell_batch_size, batch_size)
)
result = self.decode_batch_predictions(rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper)
output_order.append(result)

result = self.decode_batch_predictions(rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper)
output_order.extend(result)

return output_order


def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_sizes, idx_map, shaper):
results = []
for j, (img_predictions, orig_size) in enumerate(zip(rowcol_predictions, orig_sizes)):
row_cell_predictions = [c for i, c in enumerate(cell_predictions) if idx_map[i] == j]
# Each row prediction matches a cell prediction
Expand All @@ -221,8 +225,7 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si

cell_id = 0
row_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-row"]]
col_predictions = [pred for pred in img_predictions if
pred["category"] == CATEGORY_TO_ID["Table-column"]]
col_predictions = [pred for pred in img_predictions if pred["category"] == CATEGORY_TO_ID["Table-column"]]

# Generate table columns
for z, col_prediction in enumerate(col_predictions):
Expand Down Expand Up @@ -335,4 +338,5 @@ def decode_batch_predictions(self, rowcol_predictions, cell_predictions, orig_si
cols=columns,
image_bbox=[0, 0, orig_size[0], orig_size[1]],
)
return result
results.append(result)
return results
8 changes: 6 additions & 2 deletions surya/table_rec/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ class SuryaTableRecConfig(PretrainedConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)

encoder_config = kwargs.pop("encoder")
decoder_config = kwargs.pop("decoder")
if "encoder" in kwargs:
encoder_config = kwargs.pop("encoder")
decoder_config = kwargs.pop("decoder")
else:
encoder_config = DonutSwinTableRecConfig()
decoder_config = SuryaTableRecDecoderConfig()

self.encoder = encoder_config
self.decoder = decoder_config
Expand Down

0 comments on commit 9616057

Please sign in to comment.