Skip to content

Commit

Permalink
remove undefined variables in match.__main__() (#749)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeongyoonlee authored Mar 23, 2024
1 parent 750e84e commit 6b2b9a8
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions causalml/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,18 +424,22 @@ def search_best_match(self, df):


if __name__ == "__main__":
from .features import TREATMENT_COL, SCORE_COL, GROUPBY_COL, PROPENSITY_FEATURES
from .features import PROPENSITY_FEATURE_TRANSFORMATIONS, MATCHING_COVARIATES
from .features import load_data
from .propensity import ElasticNetPropensityModel

TREATMENT_COL = "treatment"
SCORE_COL = "score"
GROUPBY_COL = "group"

parser = argparse.ArgumentParser()
parser.add_argument("--input-file", required=True, dest="input_file")
parser.add_argument("--output-file", required=True, dest="output_file")
parser.add_argument("--treatment-col", default=TREATMENT_COL, dest="treatment_col")
parser.add_argument("--groupby-col", default=GROUPBY_COL, dest="groupby_col")
parser.add_argument("--score-col", default=SCORE_COL, dest="score_col")
parser.add_argument("--feature-cols", nargs="+", required=True, dest="feature_cols")
parser.add_argument(
"--feature-cols", nargs="+", default=PROPENSITY_FEATURES, dest="feature_cols"
"--matching-cols", nargs="+", required=True, dest="matching_cols"
)
parser.add_argument("--caliper", type=float, default=0.2)
parser.add_argument("--replace", default=False, action="store_true")
Expand All @@ -455,16 +459,15 @@ def search_best_match(self, df):
X = load_data(
data=df,
features=args.feature_cols,
transformations=PROPENSITY_FEATURE_TRANSFORMATIONS,
)

logger.info("Scoring with a propensity model: {}".format(pm))
df[SCORE_COL] = pm.fit_predict(X, w)
df[args.score_col] = pm.fit_predict(X, w)

logger.info(
"Balance before matching:\n{}".format(
create_table_one(
data=df, treatment_col=args.treatment_col, features=MATCHING_COVARIATES
data=df, treatment_col=args.treatment_col, features=args.matching_cols
)
)
)
Expand All @@ -475,7 +478,7 @@ def search_best_match(self, df):
matched = psm.match_by_group(
data=df,
treatment_col=args.treatment_col,
score_cols=[SCORE_COL],
score_cols=[args.score_col],
groupby_col=args.groupby_col,
)
logger.info("shape: {}\n{}".format(matched.shape, matched.head()))
Expand All @@ -485,7 +488,7 @@ def search_best_match(self, df):
create_table_one(
data=matched,
treatment_col=args.treatment_col,
features=MATCHING_COVARIATES,
features=args.matching_cols,
)
)
)
Expand Down

0 comments on commit 6b2b9a8

Please sign in to comment.