Skip to content

Commit

Permalink
texify mark_step()'s
Browse files Browse the repository at this point in the history
  • Loading branch information
iammosespaulr committed Jan 31, 2025
1 parent 82ae6ad commit 4bed591
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions surya/texify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from PIL import Image
from tqdm import tqdm

from surya.common.util import mark_step
from surya.common.predictor import BasePredictor
from surya.settings import settings
from surya.texify.loader import TexifyModelLoader
Expand Down Expand Up @@ -75,6 +76,7 @@ def batch_texify(self, images: List[Image.Image], batch_size: int | None) -> Tup

with settings.INFERENCE_MODE():
encoder_hidden_states = self.model.encoder(pixel_values=batch_pixel_values).last_hidden_state
mark_step()

while token_count < settings.TEXIFY_MAX_TOKENS - 1:
is_prefill = token_count == 0
Expand All @@ -86,6 +88,7 @@ def batch_texify(self, images: List[Image.Image], batch_size: int | None) -> Tup
use_cache=True,
prefill=is_prefill
)
mark_step()

decoder_position_ids = decoder_position_ids[-1:] + 1
logits = return_dict["logits"][:current_batch_size] # Ignore batch padding
Expand All @@ -101,27 +104,35 @@ def batch_texify(self, images: List[Image.Image], batch_size: int | None) -> Tup
scores = scores.masked_fill(all_done, 0)
sequence_scores = torch.cat([sequence_scores, scores], dim=1)

mark_step()
if all_done.all():
break

batch_input_ids = preds.unsqueeze(1)

for j, (pred, status) in enumerate(zip(preds, all_done)):
mark_step()
if not status:
batch_predictions[j].append(int(pred))

token_count += inference_token_count

mark_step()
inference_token_count = batch_input_ids.shape[-1]

mark_step()
max_position_id = torch.max(decoder_position_ids).item()
decoder_position_ids = torch.ones_like(batch_input_ids[0, :], dtype=torch.int64,
device=self.model.device).cumsum(0) - 1 + max_position_id

if settings.TEXIFY_STATIC_CACHE:
batch_input_ids = self.pad_to_batch_size(batch_input_ids, batch_size)

mark_step()
batch_confidences = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
detected_text = self.processor.tokenizer.batch_decode(batch_predictions)

mark_step()
batch_confidences = batch_confidences.tolist()

if settings.TEXIFY_STATIC_CACHE:
Expand Down

0 comments on commit 4bed591

Please sign in to comment.