Skip to content

Commit

Permalink
Improvements for rustworkx.visit annotations (#1362)
Browse files Browse the repository at this point in the history
* Improvements for rustworkx.visit annotations

* Black

* Remove double import sys

* Add overloads

* Another fix

* Revert "Add overloads"

This reverts commit 499ea56.

* Revert visitor defaults
  • Loading branch information
IvanIsCoding authored Jan 15, 2025
1 parent b44272f commit 51e2830
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
14 changes: 8 additions & 6 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ from collections.abc import (
from abc import ABC
from rustworkx import generators # noqa

# from collections.abc import Sequence as SequenceCollection
from typing_extensions import Self

import numpy as np
import numpy.typing as npt
import sys
Expand All @@ -43,6 +40,11 @@ if sys.version_info >= (3, 13):
else:
from typing_extensions import TypeVar

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

_S = TypeVar("_S", default=Any)
_T = TypeVar("_T", default=Any)

Expand Down Expand Up @@ -975,9 +977,9 @@ def graph_transitivity(graph: PyGraph, /) -> float: ...

# Traversal

_BFSVisitor = TypeVar("_BFSVisitor", bound=BFSVisitor)
_DFSVisitor = TypeVar("_DFSVisitor", bound=DFSVisitor)
_DijkstraVisitor = TypeVar("_DijkstraVisitor", bound=DijkstraVisitor)
_BFSVisitor = TypeVar("_BFSVisitor", bound=BFSVisitor, default=BFSVisitor)
_DFSVisitor = TypeVar("_DFSVisitor", bound=DFSVisitor, default=DFSVisitor)
_DijkstraVisitor = TypeVar("_DijkstraVisitor", bound=DijkstraVisitor, default=DijkstraVisitor)

def digraph_bfs_search(
graph: PyDiGraph,
Expand Down
11 changes: 9 additions & 2 deletions rustworkx/visit.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@
# This file contains only type annotations for PyO3 functions and classes
# For implementation details, see visit.py

from typing import Any, Generic, TypeVar
from typing import Any, Generic

import sys

if sys.version_info >= (3, 13):
from typing import TypeVar
else:
from typing_extensions import TypeVar

class StopSearch(Exception): ...
class PruneSearch(Exception): ...

_T = TypeVar("_T")
_T = TypeVar("_T", default=Any)

class BFSVisitor(Generic[_T]):
def discover_vertex(self, v: int) -> Any: ...
Expand Down

0 comments on commit 51e2830

Please sign in to comment.