diff --git a/econml/_tree_exporter.py b/econml/_tree_exporter.py index e7aef988d..a9c3a3def 100644 --- a/econml/_tree_exporter.py +++ b/econml/_tree_exporter.py @@ -694,7 +694,7 @@ def export_graphviz(self, out_file=None, feature_names=None, treatment_names=Non own_file = False try: if isinstance(out_file, str): - out_file = open(out_file, "w", encoding="utf-8") + out_file = open(out_file, "w", encoding="utf-8") # noqa: SIM115, we close explicitly by design own_file = True return_string = out_file is None diff --git a/econml/dml/causal_forest.py b/econml/dml/causal_forest.py index 1437b72ca..4a1a536bc 100644 --- a/econml/dml/causal_forest.py +++ b/econml/dml/causal_forest.py @@ -763,7 +763,7 @@ def tune(self, Y, T, *, X=None, W=None, else: # If custom param grid, check that only estimator parameters are being altered estimator_param_names = self.tunable_params - for key in params.keys(): + for key in params: if key not in estimator_param_names: raise ValueError(f"Parameter `{key}` is not an tunable causal forest parameter.") diff --git a/econml/orf/_ortho_forest.py b/econml/orf/_ortho_forest.py index 074d4bbc8..daa1d96b9 100644 --- a/econml/orf/_ortho_forest.py +++ b/econml/orf/_ortho_forest.py @@ -721,9 +721,9 @@ def __init__(self, model_T, model_Y, random_state=None, second_stage=True, def __call__(self, Y, T, X, W, sample_weight=None, split_indices=None): if self.global_residualization: return 0 - if self.discrete_treatment: - # Check that all discrete treatments are represented - if len(np.unique(T @ np.arange(1, T.shape[1] + 1))) < T.shape[1] + 1: + # Check that all discrete treatments are represented + if (self.discrete_treatment and + len(np.unique(T @ np.arange(1, T.shape[1] + 1))) < T.shape[1] + 1): return None # Nuissance estimates evaluated with cross-fitting this_random_state = check_random_state(self.random_state) diff --git a/econml/sklearn_extensions/model_selection.py b/econml/sklearn_extensions/model_selection.py index fc0c7a755..113b65828 100644 --- a/econml/sklearn_extensions/model_selection.py +++ b/econml/sklearn_extensions/model_selection.py @@ -525,10 +525,10 @@ def _convert_model(model, args, kwargs): best_model, score = SklearnCVSelector._convert_model(inner_model, args, kwargs) return Pipeline(steps=[*model.steps[:-1], (name, best_model)]), score - if isinstance(model, GridSearchCV) or isinstance(model, RandomizedSearchCV): + if isinstance(model, (GridSearchCV, RandomizedSearchCV)): return model.best_estimator_, model.best_score_ - for known_type in SklearnCVSelector._model_mapping().keys(): + for known_type in SklearnCVSelector._model_mapping(): if isinstance(model, known_type): converter = SklearnCVSelector._model_mapping()[known_type] return converter(model, args, kwargs) diff --git a/econml/tests/test_drlearner.py b/econml/tests/test_drlearner.py index 9bd7ee6c4..422c1048b 100644 --- a/econml/tests/test_drlearner.py +++ b/econml/tests/test_drlearner.py @@ -669,7 +669,7 @@ def _test_drlearner_with_inference_all_attributes(self, use_ray): [2, 3, len(feature_names) + (W.shape[1] if W is not None else 0)]) - if isinstance(est, LinearDRLearner) or isinstance(est, SparseLinearDRLearner): + if isinstance(est, (LinearDRLearner, SparseLinearDRLearner)): if X is not None: for t in [1, 2]: true_coef = np.zeros( diff --git a/econml/tests/test_policy_forest.py b/econml/tests/test_policy_forest.py index a329ffff3..10810b8d0 100644 --- a/econml/tests/test_policy_forest.py +++ b/econml/tests/test_policy_forest.py @@ -117,7 +117,7 @@ def _test_policy_honesty(self, trainer, dr=False): for sample_weight in [None, 'rand']: for n_outcomes in n_outcome_list: config = self._get_base_config() - config['honest'] = True if not dr else False + config['honest'] = not dr config['criterion'] = criterion config['max_depth'] = 2 config['min_samples_leaf'] = 5 diff --git a/pyproject.toml b/pyproject.toml index e3d0b31df..45916eae3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -150,14 +150,18 @@ ignore = [ "D103", # Missing docstring in public function "D104", # Missing docstring in public package "D105", # Missing docstring in magic method - "D301", # Use r""" if any backslashes in a docstring - ] -extend-select = [ + "D301", # Use r""" if any backslashes in a docstring, + "SIM108", # Use ternary instead of if-else (looks ugly for some of our long expressions) + "SIM300", # Yoda condition detected (these are often easier to understand in array expressions) +] +select = [ "D", # Docstring - "E501", # Line too long "W", # Pycodestyle warnings + "E", # All Pycodestyle erros, not just the default ones + "F", # All pyflakes rules + "SIM", # Simplifification ] extend-per-file-ignores = { "econml/tests" = ["D"] } # ignore docstring rules for tests [tool.ruff.lint.pydocstyle] -convention = "numpy" \ No newline at end of file +convention = "numpy"