Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
cwognum authored Aug 22, 2023
2 parents f74bf20 + f96c1b6 commit f3c9711
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10"]
python-version: ["3.8", "3.9", "3.10"]
pytorch-version: ["2.0"]

runs-on: "ubuntu-latest"
Expand Down
6 changes: 2 additions & 4 deletions graphium/finetuning/finetuning_architecture.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional, Union

import torch
import torch.nn as nn
from loguru import logger
from torch import Tensor
from torch_geometric.data import Batch

Expand Down Expand Up @@ -307,7 +305,7 @@ def __init__(self, finetuning_head_kwargs: Dict[str, Any]):
self.net = net(**finetuning_head_kwargs)

def forward(self, g: Union[Dict[str, Union[torch.Tensor, Batch]], torch.Tensor, Batch]):
if isinstance(g, Union[torch.Tensor, Batch]):
if isinstance(g, (torch.Tensor, Batch)):
pass
elif isinstance(g, Dict) and len(g) == 1:
g = list(g.values())[0]
Expand Down

0 comments on commit f3c9711

Please sign in to comment.