Skip to content

Commit

Permalink
Maintain project requirements and code QA (#268)
Browse files Browse the repository at this point in the history
- Apply ruff 0.9.3
- explicitly install R in GHA tests
- explicitly install torch along with pgmpy
- Remove reliance on ananke, since it has way too strict version pins
(poetry strikes again)
  • Loading branch information
cthoyt authored Jan 24, 2025
1 parent e422973 commit da2032c
Show file tree
Hide file tree
Showing 30 changed files with 137 additions and 126 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ jobs:
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Set up R
uses: r-lib/actions/setup-r@v2
- name: Install dependencies
run: |
pip install tox tox-uv
Expand Down
26 changes: 13 additions & 13 deletions notebooks/Make Counterfactual Graph.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,9 @@
"metadata": {},
"outputs": [],
"source": [
"assert (\n",
" figure_11a_calculated == figure_11a.graph\n",
"), \"Calculated figure 11a is not equal to the actual figure 11a\"\n",
"assert figure_11a_calculated == figure_11a.graph, (\n",
" \"Calculated figure 11a is not equal to the actual figure 11a\"\n",
")\n",
"print(\"Figure 11a calculated is correct\")"
]
},
Expand Down Expand Up @@ -465,9 +465,9 @@
},
"outputs": [],
"source": [
"assert (\n",
" figure_11b_calculated == figure_11b.graph\n",
"), \"Calculated figure 11b is not equal to the actual figure 11b\"\n",
"assert figure_11b_calculated == figure_11b.graph, (\n",
" \"Calculated figure 11b is not equal to the actual figure 11b\"\n",
")\n",
"\n",
"print(\"Figure 11b calculated is correct\")"
]
Expand Down Expand Up @@ -523,9 +523,9 @@
"metadata": {},
"outputs": [],
"source": [
"assert (\n",
" figure_11c_calculated == figure_11c.graph\n",
"), \"Calculated figure 11c is not equal to the actual figure 11c\"\n",
"assert figure_11c_calculated == figure_11c.graph, (\n",
" \"Calculated figure 11c is not equal to the actual figure 11c\"\n",
")\n",
"\n",
"print(\"Figure 11c calculated is correct\")"
]
Expand Down Expand Up @@ -586,9 +586,9 @@
"metadata": {},
"outputs": [],
"source": [
"assert (\n",
" figure_9c_calculated == figure_9c.graph\n",
"), \"Calculated figure 9c is not equal to the actual figure 9c\"\n",
"assert figure_9c_calculated == figure_9c.graph, (\n",
" \"Calculated figure 9c is not equal to the actual figure 9c\"\n",
")\n",
"print(\"Figure 9c calculated is correct\")"
]
},
Expand Down Expand Up @@ -728,7 +728,7 @@
"source": [
"for c_component in C_components:\n",
" print(\n",
" f\"C-Component:{c_component} Interventions: {get_events_of_district(cf_graph,c_component,new_event)}\"\n",
" f\"C-Component:{c_component} Interventions: {get_events_of_district(cf_graph, c_component, new_event)}\"\n",
" )"
]
},
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ r = [
tests = [
"pytest",
"coverage",
"ananke-causal",
# "ananke-causal",
"pgmpy",
"torch",
]
docs = [
# waiting on https://github.com/readthedocs/sphinx_rtd_theme/issues/1582
Expand Down
4 changes: 2 additions & 2 deletions src/y0/algorithm/conditional_independencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from ..util.combinatorics import powerset

__all__ = [
"add_ci_undirected_edges",
"are_d_separated",
"minimal",
"get_conditional_independencies",
"minimal",
"test_conditional_independencies",
"add_ci_undirected_edges",
]


Expand Down
11 changes: 5 additions & 6 deletions src/y0/algorithm/counterfactual_transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
)

__all__ = [
"unconditional_cft",
"conditional_cft",
"transport_unconditional_counterfactual_query",
"transport_conditional_counterfactual_query",
#
"Event",
"CFTDomain",
"ConditionalCFTResult",
"Event",
"UnconditionalCFTResult",
"conditional_cft",
"transport_conditional_counterfactual_query",
"transport_unconditional_counterfactual_query",
"unconditional_cft",
]
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

__all__ = [
"get_ancestors_of_counterfactual",
"minimize_counterfactual",
"get_ancestral_components",
"minimize_counterfactual",
]

logger = logging.getLogger(__file__)
Expand Down
26 changes: 12 additions & 14 deletions src/y0/algorithm/counterfactual_transport/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,24 @@
)

__all__ = [
# TODO do a proper audit of which of these a user should ever have to import
"unconditional_cft",
"conditional_cft",
"transport_unconditional_counterfactual_query",
"transport_conditional_counterfactual_query",
#
"Event",
"CFTDomain",
"ConditionalCFTResult",
"Event",
"UnconditionalCFTResult",
# Utilities
"simplify",
"minimize_event",
"same_district",
"is_counterfactual_factor_form",
"get_counterfactual_factors",
"conditional_cft",
"convert_to_counterfactual_factor_form",
"do_counterfactual_factor_factorization",
"counterfactual_factors_are_transportable",
"do_counterfactual_factor_factorization",
"get_counterfactual_factors",
"is_counterfactual_factor_form",
"minimize_event",
"same_district",
"simplify",
"transport_conditional_counterfactual_query",
"transport_district_intervening_on_parents",
"transport_unconditional_counterfactual_query",
"unconditional_cft",
# TODO do a proper audit of which of these a user should ever have to import
# TODO add functions/classes/variables you want to appear in the docs and be exposed to the user in this list
# Run tox -e docs then `open docs/build/html/index.html` to see docs
]
Expand Down
2 changes: 1 addition & 1 deletion src/y0/algorithm/estimation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
)

__all__ = [
"estimate_ace",
"df_covers_graph",
"estimate_ace",
]


Expand Down
2 changes: 1 addition & 1 deletion src/y0/algorithm/estimation/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from y0.graph import NxMixedGraph, get_district_and_predecessors, is_p_fixable

__all__ = [
"get_beta_primal",
"get_primal_ipw_ace",
"get_primal_ipw_point_estimate",
"get_beta_primal",
]

#: The list of Ananke estimators implemented in
Expand Down
2 changes: 1 addition & 1 deletion src/y0/algorithm/estimation/linear_scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from y0.graph import NxMixedGraph, sympy_nested

__all__ = [
"get_single_door",
"evaluate_admg",
"evaluate_lscm",
"get_single_door",
]

EvalRv = dict[sympy.Symbol, sympy.core.numbers.Rational]
Expand Down
4 changes: 2 additions & 2 deletions src/y0/algorithm/falsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from ..struct import CITest, DSeparationJudgement, _ensure_method

__all__ = [
"get_graph_falsifications",
"get_falsifications",
"Falsifications",
"get_falsifications",
"get_graph_falsifications",
]


Expand Down
13 changes: 5 additions & 8 deletions src/y0/algorithm/identify/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,12 @@
from .utils import Identification, Query, Unidentifiable

__all__ = [
# Algorithms
"identify_outcomes",
"identify",
"Identification",
"Query",
"Unidentifiable",
"id_star",
"idc",
"idc_star",
# Data Structures
"Query",
# Exceptions
"Unidentifiable",
"Identification",
"identify",
"identify_outcomes",
]
6 changes: 3 additions & 3 deletions src/y0/algorithm/identify/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from y0.graph import NxMixedGraph

__all__ = [
"has_same_function",
"extract_interventions",
"has_same_function",
"is_not_self_intervened",
"is_pw_equivalent",
"merge_pw",
"make_counterfactual_graph",
"make_parallel_worlds_graph",
"is_not_self_intervened",
"merge_pw",
]


Expand Down
2 changes: 1 addition & 1 deletion src/y0/algorithm/identify/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from y0.mutate.canonicalize_expr import canonical_expr_equal

__all__ = [
"Query",
"Identification",
"Query",
"Unidentifiable",
"str_nodes_to_variable_nodes",
]
Expand Down
2 changes: 1 addition & 1 deletion src/y0/algorithm/separation/sigma_separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

__all__ = [
"are_sigma_separated",
"is_z_sigma_open",
"get_equivalence_classes",
"is_z_sigma_open",
]


Expand Down
8 changes: 4 additions & 4 deletions src/y0/algorithm/simplify_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from ..graph import DEFAULT_TAG, NxMixedGraph, _ensure_set

__all__ = [
"evans_simplify",
"simplify_latent_dag",
"SimplifyResults",
"remove_widow_latents",
"transform_latents_with_parents",
"evans_simplify",
"remove_redundant_latents",
"remove_unidirectional_latents",
"remove_widow_latents",
"simplify_latent_dag",
"transform_latents_with_parents",
]

logger = logging.getLogger(__name__)
Expand Down
4 changes: 2 additions & 2 deletions src/y0/algorithm/taheri_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from y0.util.combinatorics import powerset

__all__ = [
"taheri_design_admg",
"taheri_design_dag",
"Result",
"draw_results",
"taheri_design_admg",
"taheri_design_dag",
]

logger = logging.getLogger(__name__)
Expand Down
8 changes: 4 additions & 4 deletions src/y0/algorithm/tian_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from y0.graph import NxMixedGraph

__all__ = [
"identify_district_variables",
"compute_ancestral_set_q_value",
"compute_c_factor",
"compute_c_factor_conditioning_on_topological_predecessors",
"compute_q_value_of_variables_with_low_topological_ordering_indices",
"compute_c_factor_marginalizing_over_topological_successors",
"compute_c_factor",
"compute_ancestral_set_q_value",
"compute_q_value_of_variables_with_low_topological_ordering_indices",
"identify_district_variables",
]

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/y0/algorithm/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from y0.mutate.canonicalize_expr import canonicalize

__all__ = [
"TransportQuery",
"identify_target_outcomes",
"trso",
"TransportQuery",
]

logger = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit da2032c

Please sign in to comment.