From 1cdec62b950b5e7780b780fe0ce78da0000456c9 Mon Sep 17 00:00:00 2001 From: Nikolai Hartmann Date: Tue, 19 Mar 2024 18:21:02 +0100 Subject: [PATCH] use absolute value of pz to protect against minus infinity for rapidity --- weaver/nn/model/ParticleTransformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/weaver/nn/model/ParticleTransformer.py b/weaver/nn/model/ParticleTransformer.py index 5148ea34..de89b46c 100644 --- a/weaver/nn/model/ParticleTransformer.py +++ b/weaver/nn/model/ParticleTransformer.py @@ -50,7 +50,7 @@ def to_ptrapphim(x, return_mass=True, eps=1e-8, for_onnx=False): px, py, pz, energy = x.split((1, 1, 1, 1), dim=1) pt = torch.sqrt(to_pt2(x, eps=eps)) # rapidity = 0.5 * torch.log((energy + pz) / (energy - pz)) - rapidity = 0.5 * torch.log(1 + (2 * pz) / (energy - pz).clamp(min=1e-20)) + rapidity = 0.5 * pz.sign() * torch.log(1 + (2 * pz.abs()) / (energy - pz.abs()).clamp(min=1e-20)) phi = (atan2 if for_onnx else torch.atan2)(py, px) if not return_mass: return torch.cat((pt, rapidity, phi), dim=1)