Skip to content

Commit

Permalink
Fixed the EncoderManager
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Mar 30, 2023
1 parent e9dc4ee commit 24632de
Showing 1 changed file with 58 additions and 57 deletions.
115 changes: 58 additions & 57 deletions goli/nn/architectures/encoder_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ def __init__(
super().__init__()
self.name = name
self.max_num_nodes_per_graph = max_num_nodes_per_graph
max_nodes = pe_encoders_kwargs.pop("max_num_nodes_per_graph", None)
if max_nodes is not None:
if self.max_num_nodes_per_graph is not None:
assert (
self.max_num_nodes_per_graph == max_nodes
), f"max_num_nodes_per_graph mismatch {self.max_num_nodes_per_graph}!={max_nodes}"
else:
if pe_encoders_kwargs is not None:
max_nodes = pe_encoders_kwargs.pop("max_num_nodes_per_graph", None)
if max_nodes is not None:
if self.max_num_nodes_per_graph is not None:
assert (
self.max_num_nodes_per_graph == max_nodes
), f"max_num_nodes_per_graph mismatch {self.max_num_nodes_per_graph}!={max_nodes}"
self.max_num_nodes_per_graph = max_nodes

self.pe_encoders_kwargs = deepcopy(pe_encoders_kwargs)
Expand All @@ -69,63 +69,64 @@ def _initialize_positional_encoders(self, pe_encoders_kwargs: Dict[str, Any]) ->
pe_encoders: a nn.ModuleDict containing all positional encoders specified by encoder_name in pe_encoders_kwargs["encoders"]
"""
# TODO: Currently only supports PE/SE on the nodes. Need to add edges.
pe_encoders = None

if pe_encoders_kwargs is not None:
pe_encoders = nn.ModuleDict()

# Pooling options here for pe encoders
self.pe_pool = pe_encoders_kwargs["pool"]
pe_out_dim = pe_encoders_kwargs["out_dim"]
in_dim_dict = pe_encoders_kwargs["in_dims"]

# Loop every positional encoding to assign it
for encoder_name, encoder_kwargs in pe_encoders_kwargs["encoders"].items():
encoder_kwargs = deepcopy(encoder_kwargs)
encoder_type = encoder_kwargs.pop("encoder_type")
encoder = PE_ENCODERS_DICT[encoder_type]

# Get the keys associated to in_dim. First check if there's a key that starts with `encoder_name/`
# Then check for the exact key
this_in_dims = {}
for key, dim in in_dim_dict.items():
if isinstance(key, str) and key.startswith(f"{encoder_name}/"):
key_name = "in_dim_" + key[len(encoder_name) + 1 :]
this_in_dims[key_name] = dim
if len(this_in_dims) == 0:
for key in encoder_kwargs.get("input_keys", []):
if key in in_dim_dict:
this_in_dims[key] = in_dim_dict[key]
else:
raise ValueError(
f"Key '{key}' not found in `in_dim_dict`. Encoder '{encoder_name}/' is also not found.\n Available keys: {in_dim_dict.keys()}"
)

# Parse the in_dims based on Encoder's signature
accepted_keys = inspect.signature(encoder).parameters.keys()
if all([key in accepted_keys for key in this_in_dims.keys()]):
pass
elif "in_dim" in accepted_keys:
if len(set(this_in_dims.values())) == 1:
this_in_dims = {"in_dim": list(this_in_dims.values())[0]}
if (pe_encoders_kwargs is None) or (len(pe_encoders_kwargs) == 0):
return

pe_encoders = nn.ModuleDict()

# Pooling options here for pe encoders
self.pe_pool = pe_encoders_kwargs["pool"]
pe_out_dim = pe_encoders_kwargs["out_dim"]
in_dim_dict = pe_encoders_kwargs["in_dims"]

# Loop every positional encoding to assign it
for encoder_name, encoder_kwargs in pe_encoders_kwargs["encoders"].items():
encoder_kwargs = deepcopy(encoder_kwargs)
encoder_type = encoder_kwargs.pop("encoder_type")
encoder = PE_ENCODERS_DICT[encoder_type]

# Get the keys associated to in_dim. First check if there's a key that starts with `encoder_name/`
# Then check for the exact key
this_in_dims = {}
for key, dim in in_dim_dict.items():
if isinstance(key, str) and key.startswith(f"{encoder_name}/"):
key_name = "in_dim_" + key[len(encoder_name) + 1 :]
this_in_dims[key_name] = dim
if len(this_in_dims) == 0:
for key in encoder_kwargs.get("input_keys", []):
if key in in_dim_dict:
this_in_dims[key] = in_dim_dict[key]
else:
raise ValueError(
f"All `in_dims` must be equal for encoder {encoder_name}. Provided: {this_in_dims}"
f"Key '{key}' not found in `in_dim_dict`. Encoder '{encoder_name}/' is also not found.\n Available keys: {in_dim_dict.keys()}"
)

# Parse the in_dims based on Encoder's signature
accepted_keys = inspect.signature(encoder).parameters.keys()
if all([key in accepted_keys for key in this_in_dims.keys()]):
pass
elif "in_dim" in accepted_keys:
if len(set(this_in_dims.values())) == 1:
this_in_dims = {"in_dim": list(this_in_dims.values())[0]}
else:
raise ValueError(
f"`in_dim` not understood for encoder {encoder_name}. Provided: {this_in_dims}. Accepted keys are: {accepted_keys}"
f"All `in_dims` must be equal for encoder {encoder_name}. Provided: {this_in_dims}"
)

# Add the max_num_nodes_per_graph if it's in the accepted input keys
if "max_num_nodes_per_graph" in accepted_keys:
encoder_kwargs["max_num_nodes_per_graph"] = self.max_num_nodes_per_graph

# Initialize the pe_encoder layer
pe_out_dim2 = encoder_kwargs.pop("out_dim", None)
if pe_out_dim2 is not None:
assert pe_out_dim == pe_out_dim2, f"values mismatch {pe_out_dim}!={pe_out_dim2}"
pe_encoders[encoder_name] = encoder(out_dim=pe_out_dim, **this_in_dims, **encoder_kwargs)
else:
raise ValueError(
f"`in_dim` not understood for encoder {encoder_name}. Provided: {this_in_dims}. Accepted keys are: {accepted_keys}"
)

# Add the max_num_nodes_per_graph if it's in the accepted input keys
if "max_num_nodes_per_graph" in accepted_keys:
encoder_kwargs["max_num_nodes_per_graph"] = self.max_num_nodes_per_graph

# Initialize the pe_encoder layer
pe_out_dim2 = encoder_kwargs.pop("out_dim", None)
if pe_out_dim2 is not None:
assert pe_out_dim == pe_out_dim2, f"values mismatch {pe_out_dim}!={pe_out_dim2}"
pe_encoders[encoder_name] = encoder(out_dim=pe_out_dim, **this_in_dims, **encoder_kwargs)

return pe_encoders

Expand Down

0 comments on commit 24632de

Please sign in to comment.