-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1 parent
3265ef4
commit 1ba34e4
Showing
4 changed files
with
191 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
127 changes: 86 additions & 41 deletions
127
topobenchmarkx/transforms/data_manipulations/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,86 @@ | ||
"""Data manipulations module.""" | ||
|
||
from .calculate_simplicial_curvature import ( | ||
CalculateSimplicialCurvature, | ||
) | ||
from .equal_gaus_features import EqualGausFeatures | ||
from .identity_transform import IdentityTransform | ||
from .infere_knn_connectivity import InfereKNNConnectivity | ||
from .infere_radius_connectivity import InfereRadiusConnectivity | ||
from .keep_only_connected_component import KeepOnlyConnectedComponent | ||
from .keep_selected_data_fields import KeepSelectedDataFields | ||
from .node_degrees import NodeDegrees | ||
from .node_features_to_float import NodeFeaturesToFloat | ||
from .one_hot_degree_features import OneHotDegreeFeatures | ||
|
||
DATA_MANIPULATIONS = { | ||
"Identity": IdentityTransform, | ||
"InfereKNNConnectivity": InfereKNNConnectivity, | ||
"InfereRadiusConnectivity": InfereRadiusConnectivity, | ||
"NodeDegrees": NodeDegrees, | ||
"OneHotDegreeFeatures": OneHotDegreeFeatures, | ||
"EqualGausFeatures": EqualGausFeatures, | ||
"NodeFeaturesToFloat": NodeFeaturesToFloat, | ||
"CalculateSimplicialCurvature": CalculateSimplicialCurvature, | ||
"KeepOnlyConnectedComponent": KeepOnlyConnectedComponent, | ||
"KeepSelectedDataFields": KeepSelectedDataFields, | ||
} | ||
|
||
__all__ = [ | ||
"IdentityTransform", | ||
"InfereKNNConnectivity", | ||
"InfereRadiusConnectivity", | ||
"EqualGausFeatures", | ||
"NodeFeaturesToFloat", | ||
"NodeDegrees", | ||
"KeepOnlyConnectedComponent", | ||
"CalculateSimplicialCurvature", | ||
"OneHotDegreeFeatures", | ||
"KeepSelectedDataFields", | ||
"DATA_MANIPULATIONS", | ||
] | ||
"""Data manipulations module with automated exports.""" | ||
|
||
import inspect | ||
from importlib import util | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
|
||
class ModuleExportsManager: | ||
"""Manages automatic discovery and registration of data manipulation classes.""" | ||
|
||
@staticmethod | ||
def is_manipulation_class(obj: Any) -> bool: | ||
"""Check if an object is a valid manipulation class. | ||
Parameters | ||
---------- | ||
obj : Any | ||
The object to check if it's a valid manipulation class. | ||
Returns | ||
------- | ||
bool | ||
True if the object is a valid manipulation class (non-private class | ||
defined in __main__), False otherwise. | ||
""" | ||
return ( | ||
inspect.isclass(obj) | ||
and obj.__module__ == "__main__" | ||
and not obj.__name__.startswith("_") | ||
) | ||
|
||
@classmethod | ||
def discover_manipulations(cls, package_path: str) -> dict[str, type]: | ||
"""Dynamically discover all manipulation classes in the package. | ||
Parameters | ||
---------- | ||
package_path : str | ||
Path to the package's __init__.py file. | ||
Returns | ||
------- | ||
dict[str, type] | ||
Dictionary mapping class names to their corresponding class objects. | ||
""" | ||
manipulations = {} | ||
|
||
# Get the directory containing the manipulation modules | ||
package_dir = Path(package_path).parent | ||
|
||
# Iterate through all .py files in the directory | ||
for file_path in package_dir.glob("*.py"): | ||
if file_path.stem == "__init__": | ||
continue | ||
|
||
# Import the module | ||
module_name = f"{Path(package_path).stem}.{file_path.stem}" | ||
spec = util.spec_from_file_location(module_name, file_path) | ||
if spec and spec.loader: | ||
module = util.module_from_spec(spec) | ||
spec.loader.exec_module(module) | ||
|
||
# Find all manipulation classes in the module | ||
for name, obj in inspect.getmembers(module): | ||
if ( | ||
inspect.isclass(obj) | ||
and obj.__module__ == module.__name__ | ||
and not name.startswith("_") | ||
): | ||
manipulations[name] = obj | ||
|
||
return manipulations | ||
|
||
|
||
# Create the exports manager | ||
manager = ModuleExportsManager() | ||
|
||
# Automatically discover and populate DATA_MANIPULATIONS | ||
DATA_MANIPULATIONS = manager.discover_manipulations(__file__) | ||
|
||
# Automatically generate __all__ | ||
__all__ = [*DATA_MANIPULATIONS.keys(), "DATA_MANIPULATIONS"] | ||
|
||
# For backwards compatibility, also create individual imports | ||
locals().update(DATA_MANIPULATIONS) |
122 changes: 103 additions & 19 deletions
122
topobenchmarkx/transforms/feature_liftings/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,104 @@ | ||
"""Feature lifting transforms.""" | ||
|
||
from .concatenation import Concatenation | ||
from .identity import Identity | ||
from .projection_sum import ProjectionSum | ||
from .set import Set | ||
|
||
FEATURE_LIFTINGS = { | ||
"Concatenation": Concatenation, | ||
"ProjectionSum": ProjectionSum, | ||
"Set": Set, | ||
None: Identity, | ||
} | ||
|
||
__all__ = [ | ||
"Concatenation", | ||
"ProjectionSum", | ||
"Set", | ||
"FEATURE_LIFTINGS", | ||
"""Feature lifting transforms with automated exports.""" | ||
|
||
import inspect | ||
from importlib import util | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
from .identity import Identity # Import Identity for special case | ||
|
||
|
||
class ModuleExportsManager: | ||
"""Manages automatic discovery and registration of feature lifting classes.""" | ||
|
||
@staticmethod | ||
def is_lifting_class(obj: Any) -> bool: | ||
"""Check if an object is a valid lifting class. | ||
Parameters | ||
---------- | ||
obj : Any | ||
The object to check if it's a valid lifting class. | ||
Returns | ||
------- | ||
bool | ||
True if the object is a valid lifting class (non-private class | ||
defined in __main__), False otherwise. | ||
""" | ||
return ( | ||
inspect.isclass(obj) | ||
and obj.__module__ == "__main__" | ||
and not obj.__name__.startswith("_") | ||
) | ||
|
||
@classmethod | ||
def discover_liftings( | ||
cls, package_path: str, special_cases: dict[Any, type] | None = None | ||
) -> dict[str, type]: | ||
"""Dynamically discover all lifting classes in the package. | ||
Parameters | ||
---------- | ||
package_path : str | ||
Path to the package's __init__.py file. | ||
special_cases : Optional[dict[Any, type]] | ||
Dictionary of special case mappings (e.g., {None: Identity}), | ||
by default None. | ||
Returns | ||
------- | ||
dict[str, type] | ||
Dictionary mapping class names to their corresponding class objects, | ||
including any special cases if provided. | ||
""" | ||
liftings = {} | ||
|
||
# Get the directory containing the lifting modules | ||
package_dir = Path(package_path).parent | ||
|
||
# Iterate through all .py files in the directory | ||
for file_path in package_dir.glob("*.py"): | ||
if file_path.stem == "__init__": | ||
continue | ||
|
||
# Import the module | ||
module_name = f"{Path(package_path).stem}.{file_path.stem}" | ||
spec = util.spec_from_file_location(module_name, file_path) | ||
if spec and spec.loader: | ||
module = util.module_from_spec(spec) | ||
spec.loader.exec_module(module) | ||
|
||
# Find all lifting classes in the module | ||
for name, obj in inspect.getmembers(module): | ||
if ( | ||
inspect.isclass(obj) | ||
and obj.__module__ == module.__name__ | ||
and not name.startswith("_") | ||
): | ||
liftings[name] = obj | ||
|
||
# Add special cases if provided | ||
if special_cases: | ||
liftings.update(special_cases) | ||
|
||
return liftings | ||
|
||
|
||
# Create the exports manager | ||
manager = ModuleExportsManager() | ||
|
||
# Automatically discover and populate FEATURE_LIFTINGS with special case for None | ||
FEATURE_LIFTINGS = manager.discover_liftings( | ||
__file__, special_cases={None: Identity} | ||
) | ||
|
||
# Automatically generate __all__ (excluding None key) | ||
__all__ = [name for name in FEATURE_LIFTINGS if isinstance(name, str)] + [ | ||
"FEATURE_LIFTINGS" | ||
] | ||
|
||
# For backwards compatibility, create individual imports (excluding None key) | ||
locals().update( | ||
{k: v for k, v in FEATURE_LIFTINGS.items() if isinstance(k, str)} | ||
) |