Skip to content

Commit

Permalink
Enable simplification rules
Browse files Browse the repository at this point in the history
Signed-off-by: Keith Battocchi <[email protected]>
  • Loading branch information
kbattocchi committed Aug 11, 2024
1 parent 1fbeb76 commit 8b29a8d
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 14 deletions.
2 changes: 1 addition & 1 deletion econml/_tree_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def export_graphviz(self, out_file=None, feature_names=None, treatment_names=Non
own_file = False
try:
if isinstance(out_file, str):
out_file = open(out_file, "w", encoding="utf-8")
out_file = open(out_file, "w", encoding="utf-8") # noqa: SIM115, we close explicitly by design
own_file = True

return_string = out_file is None
Expand Down
2 changes: 1 addition & 1 deletion econml/dml/causal_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def tune(self, Y, T, *, X=None, W=None,
else:
# If custom param grid, check that only estimator parameters are being altered
estimator_param_names = self.tunable_params
for key in params.keys():
for key in params:
if key not in estimator_param_names:
raise ValueError(f"Parameter `{key}` is not an tunable causal forest parameter.")

Expand Down
6 changes: 3 additions & 3 deletions econml/orf/_ortho_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,9 +721,9 @@ def __init__(self, model_T, model_Y, random_state=None, second_stage=True,
def __call__(self, Y, T, X, W, sample_weight=None, split_indices=None):
if self.global_residualization:
return 0
if self.discrete_treatment:
# Check that all discrete treatments are represented
if len(np.unique(T @ np.arange(1, T.shape[1] + 1))) < T.shape[1] + 1:
# Check that all discrete treatments are represented
if (self.discrete_treatment and
len(np.unique(T @ np.arange(1, T.shape[1] + 1))) < T.shape[1] + 1):
return None
# Nuissance estimates evaluated with cross-fitting
this_random_state = check_random_state(self.random_state)
Expand Down
4 changes: 2 additions & 2 deletions econml/sklearn_extensions/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,10 +525,10 @@ def _convert_model(model, args, kwargs):
best_model, score = SklearnCVSelector._convert_model(inner_model, args, kwargs)
return Pipeline(steps=[*model.steps[:-1], (name, best_model)]), score

if isinstance(model, GridSearchCV) or isinstance(model, RandomizedSearchCV):
if isinstance(model, (GridSearchCV, RandomizedSearchCV)):
return model.best_estimator_, model.best_score_

for known_type in SklearnCVSelector._model_mapping().keys():
for known_type in SklearnCVSelector._model_mapping():
if isinstance(model, known_type):
converter = SklearnCVSelector._model_mapping()[known_type]
return converter(model, args, kwargs)
Expand Down
2 changes: 1 addition & 1 deletion econml/tests/test_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def _test_drlearner_with_inference_all_attributes(self, use_ray):
[2, 3, len(feature_names) +
(W.shape[1] if W is not None else 0)])

if isinstance(est, LinearDRLearner) or isinstance(est, SparseLinearDRLearner):
if isinstance(est, (LinearDRLearner, SparseLinearDRLearner)):
if X is not None:
for t in [1, 2]:
true_coef = np.zeros(
Expand Down
2 changes: 1 addition & 1 deletion econml/tests/test_policy_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _test_policy_honesty(self, trainer, dr=False):
for sample_weight in [None, 'rand']:
for n_outcomes in n_outcome_list:
config = self._get_base_config()
config['honest'] = True if not dr else False
config['honest'] = not dr
config['criterion'] = criterion
config['max_depth'] = 2
config['min_samples_leaf'] = 5
Expand Down
14 changes: 9 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,18 @@ ignore = [
"D103", # Missing docstring in public function
"D104", # Missing docstring in public package
"D105", # Missing docstring in magic method
"D301", # Use r""" if any backslashes in a docstring
]
extend-select = [
"D301", # Use r""" if any backslashes in a docstring,
"SIM108", # Use ternary instead of if-else (looks ugly for some of our long expressions)
"SIM300", # Yoda condition detected (these are often easier to understand in array expressions)
]
select = [
"D", # Docstring
"E501", # Line too long
"W", # Pycodestyle warnings
"E", # All Pycodestyle erros, not just the default ones
"F", # All pyflakes rules
"SIM", # Simplifification
]
extend-per-file-ignores = { "econml/tests" = ["D"] } # ignore docstring rules for tests

[tool.ruff.lint.pydocstyle]
convention = "numpy"
convention = "numpy"

0 comments on commit 8b29a8d

Please sign in to comment.