Skip to content

Commit

Permalink
style: apply ruff formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Jan 10, 2025
1 parent 689779d commit 2202f84
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions bench/generation/metrics/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def elapsed_time(self, other):

memory = get_device_memory(device)
if memory is not None:
print(f"Device memory: {memory / (2 ** 30):.4f} GB")
print(f"Device memory: {memory / (2**30):.4f} GB")

latencies = []
input_ids = torch.randint(1, model.config.vocab_size - 1, size=(batch_size, prompt_length)).to(device)
Expand All @@ -89,7 +89,7 @@ def elapsed_time(self, other):

if device.type == "cuda":
peak_memory = torch.cuda.max_memory_allocated()
print(f"Peak memory during benchmark: {peak_memory / (2 ** 30):.4f} GB")
print(f"Peak memory during benchmark: {peak_memory / (2**30):.4f} GB")

mean_latency = np.mean(latencies) / generation_config.min_new_tokens
print(f"Average latency per token: {mean_latency} ms")
Expand Down
2 changes: 1 addition & 1 deletion bench/generation/setup/quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def setup(
calibrate(model, tokenizer, batch_size, batches=4)
print("Freezing")
freeze(model)
print(f"Finished: {time.time()-start:.2f}")
print(f"Finished: {time.time() - start:.2f}")
return model, tokenizer


Expand Down
6 changes: 3 additions & 3 deletions test/tensor/weights/weight_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ def check_weight_qtensor_linear(qweight, batch_size, tokens, use_bias, rel_max_e
rel_max_err = max_err / mean_val
# These values were evaluated empirically without any optimized kernels.
rtol = {"cpu": 1e-2, "cuda": 2e-2, "mps": 1e-2, "xpu": 2e-2}[device.type]
assert (
rel_max_err < rtol
), f"Maximum error {max_err:.2f} is too high for input of mean value {mean_val:.2f} ({rel_max_err*100:.2f} %)"
assert rel_max_err < rtol, (
f"Maximum error {max_err:.2f} is too high for input of mean value {mean_val:.2f} ({rel_max_err * 100:.2f} %)"
)

0 comments on commit 2202f84

Please sign in to comment.