Skip to content

Commit

Permalink
added logging and improved docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
shvbsle committed Jan 14, 2025
1 parent 8e9c7f7 commit 5995b8c
Showing 1 changed file with 58 additions and 52 deletions.
110 changes: 58 additions & 52 deletions test/images/neuron-inference/infer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import sys
import time
Expand All @@ -8,27 +9,52 @@
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertForPreTraining, BertTokenizer

logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger("BERTNeuronInference")


def print_info(msg: str):
"""Helper function to prefix all info messages uniformly."""
print(f"[INFO] {msg}")
logger.info(f"[INFO] {msg}")


def print_warning(msg: str):
"""Helper function for warnings."""
print(f"[WARNING] {msg}")
logger.warning(f"[WARNING] {msg}")


def print_error(msg: str):
"""Helper function for errors."""
print(f"[ERROR] {msg}")
logger.error(f"[ERROR] {msg}")


def create_dummy_data(tokenizer, batch_size, num_samples=100, max_length=128, seed=42):
"""
Creates a realistic NSP-style dataset (50% next-sentence, 50% random).
Ensures num_samples is a multiple of batch_size.
Creates a realistic Next Sentence Prediction (NSP) dataset for BERT model testing.
Args:
tokenizer (BertTokenizer): instance used to tokenize the input sentences
batch_size (int): specifying the size of each batch
num_samples (int): specifying total number of samples to generate (default: 100)
max_length (int): specifying maximum sequence length for tokenization (default: 128)
seed (int): for random seed to ensure reproducibility (default: 42)
Returns:
TensorDataset containing:
- input_ids (torcTensor): of tokenized input sequences
- attention_mask: of attention masks
- nsp_labels: Tensor of NSP labels (0 for random next sentence, 1 for actual next sentence)
Notes:
- Automatically adjusts num_samples to be a multiple of batch_size
- Creates balanced dataset with 50% true next sentences and 50% random sentences
- Uses a predefined set of sample sentences for generating pairs
"""

random.seed(seed)

if num_samples % batch_size != 0:
Expand Down Expand Up @@ -88,11 +114,26 @@ def create_dummy_data(tokenizer, batch_size, num_samples=100, max_length=128, se

def run_inference(model, tokenizer, batch_size, mode):
"""
1) Creates dummy NSP data
2) Moves model and data to the XLA device ("xla") for Inf2 usage
3) Defines a wrapper for torch_neuronx.trace(...) that expects 2 positional arguments
4) Traces and then runs inference in a loop
Runs BERT model inference using Neuron runtime with dummy NSP (Next Sentence Prediction) data.
Args:
model (BertForPreTraining): model instance to be used for inference
tokenizer (BertTokenizer): instance for processing input text
batch_size (int): specifying batch size (8 for throughput mode, 1 for latency mode)
mode (str): indicating inference mode ('throughput' or 'latency')
Returns:
None, but prints performance metrics including:
- Average time per batch
- Throughput (samples per second)
Notes:
- Performance metrics are logged with prefix [BERT_INFERENCE_NEURON_METRICS]
- Uses torch_neuronx for model compilation
- Handles both throughput and latency testing modes
- Runs inference with no gradient computation (torch.no_grad)
"""

print_info("About to create dummy data...")
try:
dataset = create_dummy_data(tokenizer, batch_size=batch_size)
Expand All @@ -108,54 +149,18 @@ def run_inference(model, tokenizer, batch_size, mode):
# Since we run inference in batches, we must first
# split the dataset into the size of input expected in a
# single batch. This input signature would then be used
# to call the .trace() method and compile the Bert model to Neuron
# to call the .trace() method and compile the Bert model to Neuron.
_input_ids, _attention_masks, _output_ids = dataset.tensors
_split_input_ids = torch.split(_input_ids, batch_size)[0]
_split_attention_masks = torch.split(_attention_masks, batch_size)[0]

batch_input = (_split_input_ids, _split_attention_masks)
model_neuron = torch_neuronx.trace(model, batch_input)

print_info(f"DataLoader created with {len(dataloader)} batches.")


"""
# The XLA device for Inf2 usage.
device = torch.device("xla")
print_info(f"Using device: {device}")
print_info("Moving model to XLA device...")
model.to(device)
model.eval()
print_info("Model moved to device and set to eval mode.")
"""

"""
def bert_inference_func(input_ids, attention_mask):
# BERT forward pass with two inputs
return model(input_ids=input_ids, attention_mask=attention_mask)
# Grab a sample batch to compile the model
try:
sample_inputs, sample_masks, _ = next(iter(dataloader))
except StopIteration:
print_error("DataLoader returned no batches; cannot trace model.")
raise RuntimeError("No data to perform tracing.")
print_info("Casting sample inputs to long and moving to device...")
sample_inputs = sample_inputs.long().to(device)
sample_masks = sample_masks.long().to(device)
print_info("About to trace model with torch_neuronx.trace()...")
try:
model_neuron = torch_neuronx.trace(
bert_inference_func,
(sample_inputs, sample_masks)
)
model_neuron = torch_neuronx.trace(model, batch_input)
except Exception as e:
print_error(f"Model tracing failed: {e}")
raise
"""
logger.exception(f"[ERROR] Failed to trace BERT model. Failed with error: {e}")
raise e

print_info(f"DataLoader created with {len(dataloader)} batches.")
print_info("Model tracing completed successfully.")

total_time = 0.0
Expand All @@ -164,7 +169,8 @@ def bert_inference_func(input_ids, attention_mask):
print_info(f"Starting Neuron inference loop with {total_batches} total batches...")
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
input_tuple = tuple(batch[:2])
batch_input_tensor, batch_attention_tensor, _ = batch
input_tuple = tuple(batch_input_tensor, batch_attention_tensor)
print_info(f"Processing batch {batch_idx}/{total_batches - 1}.")
start_time = time.time()
try:
Expand Down

0 comments on commit 5995b8c

Please sign in to comment.