Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
levtelyatnikov committed Oct 31, 2024
1 parent 3265ef4 commit 1ba34e4
Showing 4 changed files with 191 additions and 64 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -20,8 +20,6 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks:
- id: ruff
args: [ --no-fix ]
- id: ruff-format

- repo: https://github.com/numpy/numpydoc
4 changes: 2 additions & 2 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# numpydoc ignore=GL08
import test

import configs
import topobenchmarkx
import test


__all__ = [
"topobenchmarkx",
127 changes: 86 additions & 41 deletions topobenchmarkx/transforms/data_manipulations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,86 @@
"""Data manipulations module."""

from .calculate_simplicial_curvature import (
CalculateSimplicialCurvature,
)
from .equal_gaus_features import EqualGausFeatures
from .identity_transform import IdentityTransform
from .infere_knn_connectivity import InfereKNNConnectivity
from .infere_radius_connectivity import InfereRadiusConnectivity
from .keep_only_connected_component import KeepOnlyConnectedComponent
from .keep_selected_data_fields import KeepSelectedDataFields
from .node_degrees import NodeDegrees
from .node_features_to_float import NodeFeaturesToFloat
from .one_hot_degree_features import OneHotDegreeFeatures

DATA_MANIPULATIONS = {
"Identity": IdentityTransform,
"InfereKNNConnectivity": InfereKNNConnectivity,
"InfereRadiusConnectivity": InfereRadiusConnectivity,
"NodeDegrees": NodeDegrees,
"OneHotDegreeFeatures": OneHotDegreeFeatures,
"EqualGausFeatures": EqualGausFeatures,
"NodeFeaturesToFloat": NodeFeaturesToFloat,
"CalculateSimplicialCurvature": CalculateSimplicialCurvature,
"KeepOnlyConnectedComponent": KeepOnlyConnectedComponent,
"KeepSelectedDataFields": KeepSelectedDataFields,
}

__all__ = [
"IdentityTransform",
"InfereKNNConnectivity",
"InfereRadiusConnectivity",
"EqualGausFeatures",
"NodeFeaturesToFloat",
"NodeDegrees",
"KeepOnlyConnectedComponent",
"CalculateSimplicialCurvature",
"OneHotDegreeFeatures",
"KeepSelectedDataFields",
"DATA_MANIPULATIONS",
]
"""Data manipulations module with automated exports."""

import inspect
from importlib import util
from pathlib import Path
from typing import Any


class ModuleExportsManager:
"""Manages automatic discovery and registration of data manipulation classes."""

@staticmethod
def is_manipulation_class(obj: Any) -> bool:
"""Check if an object is a valid manipulation class.
Parameters
----------
obj : Any
The object to check if it's a valid manipulation class.
Returns
-------
bool
True if the object is a valid manipulation class (non-private class
defined in __main__), False otherwise.
"""
return (
inspect.isclass(obj)
and obj.__module__ == "__main__"
and not obj.__name__.startswith("_")
)

@classmethod
def discover_manipulations(cls, package_path: str) -> dict[str, type]:
"""Dynamically discover all manipulation classes in the package.
Parameters
----------
package_path : str
Path to the package's __init__.py file.
Returns
-------
dict[str, type]
Dictionary mapping class names to their corresponding class objects.
"""
manipulations = {}

# Get the directory containing the manipulation modules
package_dir = Path(package_path).parent

# Iterate through all .py files in the directory
for file_path in package_dir.glob("*.py"):
if file_path.stem == "__init__":
continue

# Import the module
module_name = f"{Path(package_path).stem}.{file_path.stem}"
spec = util.spec_from_file_location(module_name, file_path)
if spec and spec.loader:
module = util.module_from_spec(spec)
spec.loader.exec_module(module)

# Find all manipulation classes in the module
for name, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and obj.__module__ == module.__name__
and not name.startswith("_")
):
manipulations[name] = obj

return manipulations


# Create the exports manager
manager = ModuleExportsManager()

# Automatically discover and populate DATA_MANIPULATIONS
DATA_MANIPULATIONS = manager.discover_manipulations(__file__)

# Automatically generate __all__
__all__ = [*DATA_MANIPULATIONS.keys(), "DATA_MANIPULATIONS"]

# For backwards compatibility, also create individual imports
locals().update(DATA_MANIPULATIONS)
122 changes: 103 additions & 19 deletions topobenchmarkx/transforms/feature_liftings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,104 @@
"""Feature lifting transforms."""

from .concatenation import Concatenation
from .identity import Identity
from .projection_sum import ProjectionSum
from .set import Set

FEATURE_LIFTINGS = {
"Concatenation": Concatenation,
"ProjectionSum": ProjectionSum,
"Set": Set,
None: Identity,
}

__all__ = [
"Concatenation",
"ProjectionSum",
"Set",
"FEATURE_LIFTINGS",
"""Feature lifting transforms with automated exports."""

import inspect
from importlib import util
from pathlib import Path
from typing import Any

from .identity import Identity # Import Identity for special case


class ModuleExportsManager:
"""Manages automatic discovery and registration of feature lifting classes."""

@staticmethod
def is_lifting_class(obj: Any) -> bool:
"""Check if an object is a valid lifting class.
Parameters
----------
obj : Any
The object to check if it's a valid lifting class.
Returns
-------
bool
True if the object is a valid lifting class (non-private class
defined in __main__), False otherwise.
"""
return (
inspect.isclass(obj)
and obj.__module__ == "__main__"
and not obj.__name__.startswith("_")
)

@classmethod
def discover_liftings(
cls, package_path: str, special_cases: dict[Any, type] | None = None
) -> dict[str, type]:
"""Dynamically discover all lifting classes in the package.
Parameters
----------
package_path : str
Path to the package's __init__.py file.
special_cases : Optional[dict[Any, type]]
Dictionary of special case mappings (e.g., {None: Identity}),
by default None.
Returns
-------
dict[str, type]
Dictionary mapping class names to their corresponding class objects,
including any special cases if provided.
"""
liftings = {}

# Get the directory containing the lifting modules
package_dir = Path(package_path).parent

# Iterate through all .py files in the directory
for file_path in package_dir.glob("*.py"):
if file_path.stem == "__init__":
continue

# Import the module
module_name = f"{Path(package_path).stem}.{file_path.stem}"
spec = util.spec_from_file_location(module_name, file_path)
if spec and spec.loader:
module = util.module_from_spec(spec)
spec.loader.exec_module(module)

# Find all lifting classes in the module
for name, obj in inspect.getmembers(module):
if (
inspect.isclass(obj)
and obj.__module__ == module.__name__
and not name.startswith("_")
):
liftings[name] = obj

# Add special cases if provided
if special_cases:
liftings.update(special_cases)

return liftings


# Create the exports manager
manager = ModuleExportsManager()

# Automatically discover and populate FEATURE_LIFTINGS with special case for None
FEATURE_LIFTINGS = manager.discover_liftings(
__file__, special_cases={None: Identity}
)

# Automatically generate __all__ (excluding None key)
__all__ = [name for name in FEATURE_LIFTINGS if isinstance(name, str)] + [
"FEATURE_LIFTINGS"
]

# For backwards compatibility, create individual imports (excluding None key)
locals().update(
{k: v for k, v in FEATURE_LIFTINGS.items() if isinstance(k, str)}
)

0 comments on commit 1ba34e4

Please sign in to comment.