Skip to content

Commit

Permalink
adapt for gnn properties
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 20, 2023
1 parent c8bb0da commit 421ed68
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
10 changes: 8 additions & 2 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,16 @@ def _filter_labels(self, row):
row["labels"] = [row["labels"][self.label_filter]]
return row

def dataloader(self, kind, **kwargs) -> DataLoader:
def _load_processed_data(self, kind: str) -> List:
try:
# processed_file_names_dict is only implemented for _ChEBIDataExtractor
filename = self.processed_file_names_dict[kind]
except NotImplementedError:
filename = f"{kind}.pt"
dataset = torch.load(os.path.join(self.processed_dir, filename))
return torch.load(os.path.join(self.processed_dir, filename))

def dataloader(self, kind, **kwargs) -> DataLoader:
dataset = self._load_processed_data(kind)
if "ids" in kwargs:
ids = kwargs.pop("ids")
_dataset = []
Expand Down Expand Up @@ -180,6 +183,9 @@ def setup(self, **kwargs):
)
)

if not ("keep_reader" in kwargs and kwargs["keep_reader"]):
self.reader.on_finish()

def setup_processed(self):
raise NotImplementedError

Expand Down
8 changes: 4 additions & 4 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class JCITokenData(JCIBase):


def extract_class_hierarchy(chebi_path):
with open(chebi_path, encoding="utf-8") as chebi: # encoding for windows users
with open(chebi_path, encoding="utf-8") as chebi:
chebi = "\n".join(l for l in chebi if not l.startswith("xref:"))
elements = [
term_callback(clause)
Expand Down Expand Up @@ -211,12 +211,12 @@ def setup_processed(self):
os.path.join(self.processed_dir, processed_name),
)
# create second test set with classes used in train
if self.chebi_version_train is not None:
if self.chebi_version_train is not None and not os.path.isfile(
os.path.join(self.processed_dir, self.processed_file_names_dict["test"])
):
print("transform test (select classes)")
self._setup_pruned_test_set()

self.reader.on_finish()

def get_test_split(self, df: pd.DataFrame):
print("Split dataset into train (including val) / test")

Expand Down
1 change: 1 addition & 0 deletions chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def to_data(self, row):
)

def on_finish(self):
"""Hook to run at the end of preprocessing."""
return


Expand Down

0 comments on commit 421ed68

Please sign in to comment.