diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index b31cce443..8c42ab14d 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -18,7 +18,7 @@ from graphnet.models.graphs import GraphDefinition from graphnet.models.gnn.gnn import GNN from graphnet.models.model import Model -from graphnet.models.task import Task +from graphnet.models.task import StandardLearnedTask class StandardModel(Model): @@ -33,7 +33,7 @@ def __init__( *, graph_definition: GraphDefinition, gnn: GNN, - tasks: Union[Task, List[Task]], + tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_kwargs: Optional[Dict] = None, scheduler_class: Optional[type] = None, @@ -45,10 +45,10 @@ def __init__( super().__init__(name=__name__, class_name=self.__class__.__name__) # Check(s) - if isinstance(tasks, Task): + if isinstance(tasks, StandardLearnedTask): tasks = [tasks] assert isinstance(tasks, (list, tuple)) - assert all(isinstance(task, Task) for task in tasks) + assert all(isinstance(task, StandardLearnedTask) for task in tasks) assert isinstance(graph_definition, GraphDefinition) assert isinstance(gnn, GNN)