Skip to content

Commit

Permalink
Merge pull request #639 from graphnet-team/Task-Refactor
Browse files Browse the repository at this point in the history
Task refactor
  • Loading branch information
RasmusOrsoe authored Dec 4, 2023
2 parents 776d37a + 1e46bcd commit 1017aa7
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 65 deletions.
8 changes: 4 additions & 4 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from graphnet.models.graphs import GraphDefinition
from graphnet.models.gnn.gnn import GNN
from graphnet.models.model import Model
from graphnet.models.task import Task
from graphnet.models.task import StandardLearnedTask


class StandardModel(Model):
Expand All @@ -33,7 +33,7 @@ def __init__(
*,
graph_definition: GraphDefinition,
gnn: GNN,
tasks: Union[Task, List[Task]],
tasks: Union[StandardLearnedTask, List[StandardLearnedTask]],
optimizer_class: Type[torch.optim.Optimizer] = Adam,
optimizer_kwargs: Optional[Dict] = None,
scheduler_class: Optional[type] = None,
Expand All @@ -45,10 +45,10 @@ def __init__(
super().__init__(name=__name__, class_name=self.__class__.__name__)

# Check(s)
if isinstance(tasks, Task):
if isinstance(tasks, StandardLearnedTask):
tasks = [tasks]
assert isinstance(tasks, (list, tuple))
assert all(isinstance(task, Task) for task in tasks)
assert all(isinstance(task, StandardLearnedTask) for task in tasks)
assert isinstance(graph_definition, GraphDefinition)
assert isinstance(gnn, GNN)

Expand Down
7 changes: 6 additions & 1 deletion src/graphnet/models/task/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""Physics task-specific modules to be used as model "read-outs"."""

from .task import Task, IdentityTask
from .task import (
Task,
IdentityTask,
StandardLearnedTask,
StandardFlowTask,
)
6 changes: 3 additions & 3 deletions src/graphnet/models/task/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import Tensor

from graphnet.models.task import Task, IdentityTask
from graphnet.models.task import IdentityTask, StandardLearnedTask


class MulticlassClassificationTask(IdentityTask):
Expand All @@ -17,7 +17,7 @@ class MulticlassClassificationTask(IdentityTask):
"""


class BinaryClassificationTask(Task):
class BinaryClassificationTask(StandardLearnedTask):
"""Performs binary classification."""

# Requires one feature, logit for being signal class.
Expand All @@ -30,7 +30,7 @@ def _forward(self, x: Tensor) -> Tensor:
return torch.sigmoid(x)


class BinaryClassificationTaskLogits(Task):
class BinaryClassificationTaskLogits(StandardLearnedTask):
"""Performs binary classification form logits."""

# Requires one feature, logit for being signal class.
Expand Down
22 changes: 11 additions & 11 deletions src/graphnet/models/task/reconstruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import torch
from torch import Tensor

from graphnet.models.task import Task
from graphnet.models.task import StandardLearnedTask
from graphnet.utilities.maths import eps_like


class AzimuthReconstructionWithKappa(Task):
class AzimuthReconstructionWithKappa(StandardLearnedTask):
"""Reconstructs azimuthal angle and associated kappa (1/var)."""

# Requires two features: untransformed points in (x,y)-space.
Expand Down Expand Up @@ -46,7 +46,7 @@ def _forward(self, x: Tensor) -> Tensor:
return angle


class DirectionReconstructionWithKappa(Task):
class DirectionReconstructionWithKappa(StandardLearnedTask):
"""Reconstructs direction with kappa from the 3D-vMF distribution."""

# Requires three features: untransformed points in (x,y,z)-space.
Expand All @@ -70,7 +70,7 @@ def _forward(self, x: Tensor) -> Tensor:
return torch.stack((vec_x, vec_y, vec_z, kappa), dim=1)


class ZenithReconstruction(Task):
class ZenithReconstruction(StandardLearnedTask):
"""Reconstructs zenith angle."""

# Requires two features: zenith angle itself.
Expand Down Expand Up @@ -98,7 +98,7 @@ def _forward(self, x: Tensor) -> Tensor:
return torch.stack((angle, kappa), dim=1)


class EnergyReconstruction(Task):
class EnergyReconstruction(StandardLearnedTask):
"""Reconstructs energy using stable method."""

# Requires one feature: untransformed energy
Expand All @@ -112,7 +112,7 @@ def _forward(self, x: Tensor) -> Tensor:
return torch.nn.functional.softplus(x, beta=0.05) + eps_like(x)


class EnergyReconstructionWithPower(Task):
class EnergyReconstructionWithPower(StandardLearnedTask):
"""Reconstructs energy."""

# Requires one feature: untransformed energy
Expand All @@ -125,7 +125,7 @@ def _forward(self, x: Tensor) -> Tensor:
return torch.pow(10, x[:, 0] + 1.0).unsqueeze(1)


class EnergyTCReconstruction(Task):
class EnergyTCReconstruction(StandardLearnedTask):
"""Reconstructs track and cascade energies using stable method."""

# Requires two features: untransformed energy for track and cascade
Expand Down Expand Up @@ -161,7 +161,7 @@ def _forward(self, x: Tensor) -> Tensor:
return pred


class VertexReconstruction(Task):
class VertexReconstruction(StandardLearnedTask):
"""Reconstructs vertex position and time."""

# Requires four features, x, y, z, and t.
Expand All @@ -183,7 +183,7 @@ def _forward(self, x: Tensor) -> Tensor:
return x


class PositionReconstruction(Task):
class PositionReconstruction(StandardLearnedTask):
"""Reconstructs vertex position."""

# Requires three features, x, y, and z.
Expand All @@ -204,7 +204,7 @@ def _forward(self, x: Tensor) -> Tensor:
return x


class TimeReconstruction(Task):
class TimeReconstruction(StandardLearnedTask):
"""Reconstructs time."""

# Requires one feature, time.
Expand All @@ -217,7 +217,7 @@ def _forward(self, x: Tensor) -> Tensor:
return x


class InelasticityReconstruction(Task):
class InelasticityReconstruction(StandardLearnedTask):
"""Reconstructs interaction inelasticity.
That is, 1-(track energy / hadronic energy).
Expand Down
Loading

0 comments on commit 1017aa7

Please sign in to comment.