Skip to content

Commit

Permalink
Cleanup mark step usage
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Feb 5, 2025
1 parent 882a692 commit 9a8b7d6
Show file tree
Hide file tree
Showing 5 changed files with 3 additions and 10 deletions.
2 changes: 0 additions & 2 deletions surya/common/adetr/decoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from transformers import PretrainedConfig
from transformers.utils import ModelOutput

from transformers import PreTrainedModel
from transformers.activations import ACT2FN
Expand Down
3 changes: 0 additions & 3 deletions surya/layout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,9 @@ def batch_layout_detection(
batch_predictions[j].append(prediction)

token_count += inference_token_count

mark_step()
inference_token_count = batch_decoder_input.shape[1]
batch_decoder_input = batch_decoder_input.to(torch.long)

mark_step()
for j, (pred_dict, orig_size) in enumerate(zip(batch_predictions, orig_sizes)):
boxes = []
preds = [p for p in pred_dict if
Expand Down
1 change: 0 additions & 1 deletion surya/recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ def batch_recognition(
batch_predictions = batch_predictions.cpu()[:current_batch_size, 1:] # Remove the start token
detected_text = self.processor.tokenizer.batch_decode(batch_predictions)

mark_step()
# Convert sequence_scores to list for the current batch
batch_confidences = sequence_scores.tolist()

Expand Down
3 changes: 2 additions & 1 deletion surya/table_rec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def inference_loop(
use_cache=True,
prefill=is_prefill
)
mark_step()

decoder_position_ids = decoder_position_ids[-1:] + 1

Expand Down Expand Up @@ -95,7 +96,6 @@ def inference_loop(
k_logits = torch.clamp(k_logits, min=1)
processed_logits[k] = torch.round(k_logits)

mark_step()
items = {k: processed_logits[k].cpu() for k, _, _ in BOX_PROPERTIES}
for j in range(current_batch_size):
box_property = {}
Expand Down Expand Up @@ -181,6 +181,7 @@ def batch_table_recognition(
# We only need to process each image once
with settings.INFERENCE_MODE():
encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state
mark_step()

# Inference to get rows and columns
rowcol_predictions = self.inference_loop(
Expand Down
4 changes: 1 addition & 3 deletions surya/texify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def batch_texify(self, images: List[Image.Image], batch_size: int | None) -> Tup
mark_step()

decoder_position_ids = decoder_position_ids[-1:] + 1
logits = return_dict["logits"][:current_batch_size] # Ignore batch padding
logits = return_dict["logits"] # Ignore batch padding

preds = torch.argmax(logits[:, -1], dim=-1)
scores = torch.max(F.softmax(logits[:, -1], dim=-1), dim=-1).values.unsqueeze(1)
Expand All @@ -108,7 +108,6 @@ def batch_texify(self, images: List[Image.Image], batch_size: int | None) -> Tup
scores = scores.masked_fill(all_done, 0)
sequence_scores = torch.cat([sequence_scores, scores], dim=1)

mark_step()
if all_done_cpu[:current_batch_size].all():
break

Expand All @@ -127,7 +126,6 @@ def batch_texify(self, images: List[Image.Image], batch_size: int | None) -> Tup
decoder_position_ids = torch.ones_like(batch_input_ids[0, :], dtype=torch.int64,
device=self.model.device).cumsum(0) - 1 + max_position_id

mark_step()
batch_confidences = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
batch_confidences = batch_confidences.cpu()[:current_batch_size]
batch_predictions = batch_predictions.cpu()[:current_batch_size, 1:] # Cut off initial token
Expand Down

0 comments on commit 9a8b7d6

Please sign in to comment.