Skip to content

Commit

Permalink
scope: data preparation code
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya0by0 committed Jan 22, 2025
1 parent 3b17487 commit f4d1d74
Showing 1 changed file with 196 additions and 99 deletions.
295 changes: 196 additions & 99 deletions chebai/preprocessing/datasets/scope/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,28 @@ class _SCOPeDataExtractor(_DynamicDataset, ABC):
"https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz"
)

SCOPE_HIERARCHY: Dict[str, str] = {
"cl": "class",
"cf": "fold",
"sf": "superfamily",
"fa": "family",
"dm": "protein",
"sp": "species",
"px": "domain",
}

def __init__(
self,
scope_version: float,
scope_version_train: Optional[float] = None,
scope_hierarchy_level: str = "cl",
**kwargs,
):

assert (
scope_hierarchy_level in self.SCOPE_HIERARCHY.keys()
), f"level can contain only one of the following values {self.SCOPE_HIERARCHY.keys()}"
self.scope_hierarchy_level = scope_hierarchy_level
self.scope_version: float = scope_version
self.scope_version_train: float = scope_version_train

Expand All @@ -67,7 +82,8 @@ def __init__(
# Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given
# This is to get the data from respective directory related to "scope_version_train"
_init_kwargs = kwargs
_init_kwargs["chebi_version"] = self.scope_version_train
_init_kwargs["scope_version"] = self.scope_version_train
_init_kwargs["scope_hierarchy_level"] = self.scope_hierarchy_level
self._scope_version_train_obj = self.__class__(
**_init_kwargs,
)
Expand Down Expand Up @@ -150,18 +166,40 @@ def _download_scope_raw_data(self) -> str:
open(scope_path, "wb").write(r.content)
return "dummy/path"

def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {}
for record in SeqIO.parse(
os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta"
):
pdb_id, chain = record.id.split("_")
pdb_chain_seq_mapping.setdefault(pdb_id, {})[chain] = str(record.seq)
return pdb_chain_seq_mapping

def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
print("Extracting class hierarchy...")
df_scope = self._get_scope_data()

g = nx.DiGraph()

egdes = []
for _, row in df_scope.iterrows():
g.add_node(row["sunid"], **{"sid": row["sid"], "level": row["level"]})
if row["parent_sunid"] != -1:
egdes.append((row["parent_sunid"], row["sunid"]))

for children_id in row["children_sunids"]:
egdes.append((row["sunid"], children_id))

g.add_edges_from(egdes)

print("Computing transitive closure")
return nx.transitive_closure_dag(g)

def _get_scope_data(self) -> pd.DataFrame:
df_cla = self._get_classification_data()
df_hie = self._get_hierarchy_data()
df_des = self._get_node_description_data()
df_hie_with_cla = pd.merge(df_hie, df_cla, how="left", on="sunid")
df_all = pd.merge(
df_hie_with_cla,
df_des.drop(columns=["sid"], axis=1),
how="left",
on="sunid",
)
return df_all

def _get_classification_data(self) -> pd.DataFrame:
# Load and preprocess CLA file
df_cla = pd.read_csv(
os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]),
Expand All @@ -175,125 +213,166 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
"description",
"sccs",
"sunid",
"ancestor_nodes",
"hie_levels",
]
df_cla["sunid"] = pd.to_numeric(
df_cla["sunid"], errors="coerce", downcast="integer"
)
df_cla["ancestor_nodes"] = df_cla["ancestor_nodes"].apply(

# Convert to dict - {cl:46456, cf:46457, sf:46458, fa:46459, dm:46460, sp:116748, px:113449}
df_cla["hie_levels"] = df_cla["hie_levels"].apply(
lambda x: {k: int(v) for k, v in (item.split("=") for item in x.split(","))}
)
df_cla.set_index("sunid", inplace=True)

# Split ancestor_nodes into separate columns and assign values
for key in self.SCOPE_HIERARCHY.keys():
df_cla[self.SCOPE_HIERARCHY[key]] = df_cla["hie_levels"].apply(
lambda x: x[key]
)

df_cla["sunid"] = df_cla["sunid"].astype("int64")

return df_cla

def _get_hierarchy_data(self) -> pd.DataFrame:
# Load and preprocess HIE file
df_hie = pd.read_csv(
os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]),
sep="\t",
header=None,
comment="#",
low_memory=False,
)
df_hie.columns = ["sunid", "parent_sunid", "children_sunids"]
df_hie["sunid"] = pd.to_numeric(
df_hie["sunid"], errors="coerce", downcast="integer"
)

# if not parent id, then insert -1
df_hie["parent_sunid"] = df_hie["parent_sunid"].replace("-", -1).astype(int)
# convert children ids to list of ids
df_hie["children_sunids"] = df_hie["children_sunids"].apply(
lambda x: list(map(int, x.split(","))) if x != "-" else []
)

# Initialize directed graph
g = nx.DiGraph()
# Ensure the 'sunid' column in both DataFrames has the same type
df_hie["sunid"] = df_hie["sunid"].astype("int64")
return df_hie

# Add nodes and edges efficiently
g.add_edges_from(
df_hie[df_hie["parent_sunid"] != -1].apply(
lambda row: (row["parent_sunid"], row["sunid"]), axis=1
)
)
g.add_edges_from(
df_hie.explode("children_sunids")
.dropna()
.apply(lambda row: (row["sunid"], row["children_sunids"]), axis=1)
def _get_node_description_data(self):
# Load and preprocess HIE file
df_des = pd.read_csv(
os.path.join(self.raw_dir, self.raw_file_names_dict["DES"]),
sep="\t",
header=None,
comment="#",
low_memory=False,
)
df_des.columns = ["sunid", "level", "scss", "sid", "description"]
df_des.loc[len(df_des)] = {"sunid": 0, "level": "root"}

pdb_chain_seq_mapping = self._parse_pdb_sequence_file()
# Ensure the 'sunid' column in both DataFrames has the same type
df_des["sunid"] = df_des["sunid"].astype("int64")
return df_des

node_to_pdb_id = df_cla["PDB_ID"].to_dict()
def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
print(f"Process graph")

for node in g.nodes():
pdb_id = node_to_pdb_id[node]
chain_mapping = pdb_chain_seq_mapping.get(pdb_id, {})
sids = nx.get_node_attributes(graph, "sid")
levels = nx.get_node_attributes(graph, "level")

# Add nodes and edges for chains in the mapping
for chain, sequence in chain_mapping.items():
chain_node = f"{pdb_id}_{chain}"
g.add_node(chain_node, sequence=sequence)
g.add_edge(node, chain_node)
sun_ids = []
sids_list = []

print("Compute transitive closure...")
return nx.transitive_closure_dag(g)
selected_sids_dict = self.select_classes(graph)

def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
"""
Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes
Swiss-Prot protein data and their associations with Gene Ontology (GO) terms.
Note:
- GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value
indicates whether a Swiss-Prot protein is associated with that GO term.
- Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins
and GO terms.
Data Format: pd.DataFrame
- Column 0 : swiss_id (Identifier for SwissProt protein)
- Column 1 : Accession of the protein
- Column 2 : GO IDs (associated GO terms)
- Column 3 : Sequence of the protein
- Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the
protein is associated with this GO term.
for sun_id, level in levels.items():
if level == self.scope_hierarchy_level and sun_id in selected_sids_dict:
sun_ids.append(sun_id)
sids_list.append(sids.get(sun_id))

Args:
g (nx.DiGraph): The class hierarchy graph.
# data_df = pd.DataFrame(OrderedDict(sun_id=sun_ids, sids=sids_list))
df_cla = self._get_classification_data()
target_col_name = self.SCOPE_HIERARCHY[self.scope_hierarchy_level]
df_cla = df_cla[df_cla[target_col_name].isin(sun_ids)]
df_cla = df_cla[["sid", target_col_name]]

Returns:
pd.DataFrame: The raw dataset created from the graph.
"""
print(f"Processing graph")

data_df = self._get_swiss_to_go_mapping()
# add ancestors to go ids
data_df["go_ids"] = data_df["go_ids"].apply(
lambda go_ids: sorted(
set(
itertools.chain.from_iterable(
[
[go_id] + list(g.predecessors(go_id))
for go_id in go_ids
if go_id in g.nodes
]
)
)
)
assert (
len(df_cla) > 1
), "dataframe should have more than one instance for `pd.get_dummies` to work as expected"
df_encoded = pd.get_dummies(
df_cla, columns=[target_col_name], drop_first=False, sparse=True
)
# Initialize the GO term labels/columns to False
selected_classes = self.select_classes(g, data_df=data_df)
new_label_columns = pd.DataFrame(
False, index=data_df.index, columns=selected_classes

pdb_chain_seq_mapping = self._parse_pdb_sequence_file()

sequence_hierarchy_df = pd.DataFrame(
columns=list(df_encoded.columns) + ["sids"]
)
data_df = pd.concat([data_df, new_label_columns], axis=1)

# Set True for the corresponding GO IDs in the DataFrame go labels/columns
for index, row in data_df.iterrows():
for go_id in row["go_ids"]:
if go_id in data_df.columns:
data_df.at[index, go_id] = True
for _, row in df_encoded.iterrows():
assert sum(row.iloc[1:].tolist()) == 1
sid = row["sid"]
# SID: 7-char identifier ("d" + 4-char PDB ID + chain ID ('_' for none, '.' for multiple)
# + domain specifier ('_' if not needed))
assert len(sid) == 7, "sid should have 7 characters"
pdb_id, chain_id = sid[1:5], sid[5]

pdb_to_chain_mapping = pdb_chain_seq_mapping.get(pdb_id, None)
if not pdb_to_chain_mapping:
continue

if chain_id != "_":
chain_sequence = pdb_to_chain_mapping.get(chain_id, None)
if chain_sequence:
self._update_or_add_sequence(
chain_sequence, row, sequence_hierarchy_df
)

else:
# Add nodes and edges for chains in the mapping
for chain, chain_sequence in pdb_to_chain_mapping.items():
self._update_or_add_sequence(
chain_sequence, row, sequence_hierarchy_df
)

sequence_hierarchy_df.drop(columns=["sid"], axis=1, inplace=True)
sequence_hierarchy_df = sequence_hierarchy_df[
["sids"] + [col for col in sequence_hierarchy_df.columns if col != "sids"]
]

# This filters the DataFrame to include only the rows where at least one value in the row from 5th column
# onwards is True/non-zero.
# Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least
# one GO term from the set of the GO terms for the model`
data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)]
return data_df
sequence_hierarchy_df = sequence_hierarchy_df[
sequence_hierarchy_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)
]
return sequence_hierarchy_df

def _parse_pdb_sequence_file(self) -> Dict[str, Dict[str, str]]:
pdb_chain_seq_mapping: Dict[str, Dict[str, str]] = {}
for record in SeqIO.parse(
os.path.join(self.raw_dir, self.raw_file_names_dict["PDB"]), "fasta"
):
pdb_id, chain = record.id.split("_")
if str(record.seq):
pdb_chain_seq_mapping.setdefault(pdb_id.lower(), {})[chain.lower()] = (
str(record.seq)
)
return pdb_chain_seq_mapping

@staticmethod
def _update_or_add_sequence(sequence, row, sequence_hierarchy_df):
# Check if sequence already exists as an index
# Slice the series starting from column 2
sliced_data = row.iloc[1:] # Slice starting from the second column (index 1)

# Get the column name with the True value
true_column = sliced_data.idxmax() if sliced_data.any() else None

if sequence in sequence_hierarchy_df.index:
# Update encoded columns only if they are True
if row[true_column] is True:
sequence_hierarchy_df.loc[sequence, true_column] = True
sequence_hierarchy_df.loc[sequence, "sids"].append(row["sid"])
else:
# Add new row with sequence as the index and hierarchy data
new_row = row
new_row["sids"] = [row["sid"]]
sequence_hierarchy_df.loc[sequence] = new_row

# ------------------------------ Phase: Setup data -----------------------------------
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
Expand Down Expand Up @@ -367,15 +446,33 @@ def raw_file_names_dict(self) -> dict:

class SCOPE(_SCOPeDataExtractor):
READER = ProteinDataReader
THRESHOLD = 1

@property
def _name(self) -> str:
return "test"

def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
pass
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict:
# Filter nodes and create a dictionary of node and out-degree
sun_ids_dict = {
node: g.out_degree(node) # Store node and its out-degree
for node in g.nodes
if g.out_degree(node) >= self.THRESHOLD
}

# Return a sorted dictionary (by out-degree or node id)
sorted_dict = dict(
sorted(sun_ids_dict.items(), key=lambda item: item[0], reverse=False)
)

filename = "classes.txt"
with open(os.path.join(self.processed_dir_main, filename), "wt") as fout:
fout.writelines(str(sun_id) + "\n" for sun_id in sorted_dict.keys())

return sorted_dict


if __name__ == "__main__":
scope = SCOPE(scope_version=2.08)
scope._parse_pdb_sequence_file()
g = scope._extract_class_hierarchy("d")
scope._graph_to_raw_dataset(g)

0 comments on commit f4d1d74

Please sign in to comment.