Skip to content

Commit

Permalink
Replace native Python sum with torch stack(...).sum().
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Zickel committed Sep 26, 2024
1 parent 4688dd0 commit 979bafa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
15 changes: 8 additions & 7 deletions pyro/ops/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,26 +570,27 @@ def energy_score_empirical(
pred = pred[..., pred_batch_size:, :]
# Calculate predictions distance to truth
retval = (
torch.cat(
torch.stack(
[
torch.cdist(pred_batch, truth).sum(dim=-2, keepdim=True)
torch.cdist(pred_batch, truth).sum(dim=-2)
for pred_batch in pred_batches
],
dim=-2,
).sum(dim=-2)
dim=0,
).sum(dim=0)
/ pred_len
)
# Calculate predictions self distance
for aux_pred_batch in pred_batches:
retval = (
retval
- 0.5
* sum( # type: ignore[index]
* torch.stack(
[
torch.cdist(pred_batch, aux_pred_batch).sum(dim=[-1, -2])
for pred_batch in pred_batches
]
)[..., None]
],
dim=0,
).sum(dim=0)[..., None]
/ pred_len
/ pred_len
)
Expand Down
5 changes: 5 additions & 0 deletions tests/ops/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,8 @@ def test_energy_score_empirical_batched_calculation(
actual = energy_score_empirical(pred, truth, pred_batch_size=pred_batch_size)
# Check accuracy
assert_close(actual, expected)


def test_jit_compilation():
# Test that functions can be JIT compiled
torch.jit.script(energy_score_empirical)

0 comments on commit 979bafa

Please sign in to comment.