Skip to content

Commit

Permalink
Add check for StandardLearnedTask
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed Dec 3, 2023
1 parent 8a274b2 commit 1e46bcd
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/graphnet/models/standard_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from graphnet.models.graphs import GraphDefinition
from graphnet.models.gnn.gnn import GNN
from graphnet.models.model import Model
from graphnet.models.task import Task
from graphnet.models.task import StandardLearnedTask


class StandardModel(Model):
Expand All @@ -33,7 +33,7 @@ def __init__(
*,
graph_definition: GraphDefinition,
gnn: GNN,
tasks: Union[Task, List[Task]],
tasks: Union[StandardLearnedTask, List[StandardLearnedTask]],
optimizer_class: Type[torch.optim.Optimizer] = Adam,
optimizer_kwargs: Optional[Dict] = None,
scheduler_class: Optional[type] = None,
Expand All @@ -45,10 +45,10 @@ def __init__(
super().__init__(name=__name__, class_name=self.__class__.__name__)

# Check(s)
if isinstance(tasks, Task):
if isinstance(tasks, StandardLearnedTask):
tasks = [tasks]
assert isinstance(tasks, (list, tuple))
assert all(isinstance(task, Task) for task in tasks)
assert all(isinstance(task, StandardLearnedTask) for task in tasks)
assert isinstance(graph_definition, GraphDefinition)
assert isinstance(gnn, GNN)

Expand Down

0 comments on commit 1e46bcd

Please sign in to comment.