Skip to content

Commit

Permalink
fix: formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
spohngellert-o committed Oct 5, 2024
1 parent 39b8131 commit 7c0bc0b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion causalml/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def match(self, data, treatment_col, score_cols):
- match_from.loc[from_idx, score_col]
)
# Gets self.ratio lowest dists
to_np_idx_list = np.argpartition(dist, self.ratio)[:self.ratio]
to_np_idx_list = np.argpartition(dist, self.ratio)[: self.ratio]
to_idx_list = dist.index[to_np_idx_list]
for i, to_idx in enumerate(to_idx_list):
if dist[to_idx] <= sdcal:
Expand Down
16 changes: 8 additions & 8 deletions tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_nearest_neighbor_match_ratio_2(generate_unmatched_data):
matched = psm.match(data=df, treatment_col=TREATMENT_COL, score_cols=[SCORE_COL])
assert sum(matched[TREATMENT_COL] == 0) == 2 * sum(matched[TREATMENT_COL] != 0)


def test_nearest_neighbor_match_by_group(generate_unmatched_data):
df, features = generate_unmatched_data()

Expand All @@ -55,26 +56,25 @@ def test_nearest_neighbor_match_by_group(generate_unmatched_data):

assert sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0)


def test_nearest_neighbor_match_control_to_treatment(generate_unmatched_data):
'''
"""
Tests whether control to treatment matching is working. Does so
by using:
replace=True
treatment_to_control=False
ratio=2
And testing if we get 2x the number of control matches than treatment
'''
"""
df, features = generate_unmatched_data()

psm = NearestNeighborMatch(replace=True, ratio=2, treatment_to_control=False, random_state=RANDOM_SEED)
matched = psm.match(
data=df,
treatment_col=TREATMENT_COL,
score_cols=[SCORE_COL]
psm = NearestNeighborMatch(
replace=True, ratio=2, treatment_to_control=False, random_state=RANDOM_SEED
)
matched = psm.match(data=df, treatment_col=TREATMENT_COL, score_cols=[SCORE_COL])
assert 2 * sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0)


Expand Down

0 comments on commit 7c0bc0b

Please sign in to comment.