From 6b2b9a81551d8adb92ca8e546f81414db130603b Mon Sep 17 00:00:00 2001 From: Jeong-Yoon Lee Date: Sat, 23 Mar 2024 08:07:31 -0700 Subject: [PATCH] remove undefined variables in match.__main__() (#749) --- causalml/match.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/causalml/match.py b/causalml/match.py index 1fe5cd77..9bde5872 100644 --- a/causalml/match.py +++ b/causalml/match.py @@ -424,18 +424,22 @@ def search_best_match(self, df): if __name__ == "__main__": - from .features import TREATMENT_COL, SCORE_COL, GROUPBY_COL, PROPENSITY_FEATURES - from .features import PROPENSITY_FEATURE_TRANSFORMATIONS, MATCHING_COVARIATES from .features import load_data from .propensity import ElasticNetPropensityModel + TREATMENT_COL = "treatment" + SCORE_COL = "score" + GROUPBY_COL = "group" + parser = argparse.ArgumentParser() parser.add_argument("--input-file", required=True, dest="input_file") parser.add_argument("--output-file", required=True, dest="output_file") parser.add_argument("--treatment-col", default=TREATMENT_COL, dest="treatment_col") parser.add_argument("--groupby-col", default=GROUPBY_COL, dest="groupby_col") + parser.add_argument("--score-col", default=SCORE_COL, dest="score_col") + parser.add_argument("--feature-cols", nargs="+", required=True, dest="feature_cols") parser.add_argument( - "--feature-cols", nargs="+", default=PROPENSITY_FEATURES, dest="feature_cols" + "--matching-cols", nargs="+", required=True, dest="matching_cols" ) parser.add_argument("--caliper", type=float, default=0.2) parser.add_argument("--replace", default=False, action="store_true") @@ -455,16 +459,15 @@ def search_best_match(self, df): X = load_data( data=df, features=args.feature_cols, - transformations=PROPENSITY_FEATURE_TRANSFORMATIONS, ) logger.info("Scoring with a propensity model: {}".format(pm)) - df[SCORE_COL] = pm.fit_predict(X, w) + df[args.score_col] = pm.fit_predict(X, w) logger.info( "Balance before matching:\n{}".format( create_table_one( - data=df, treatment_col=args.treatment_col, features=MATCHING_COVARIATES + data=df, treatment_col=args.treatment_col, features=args.matching_cols ) ) ) @@ -475,7 +478,7 @@ def search_best_match(self, df): matched = psm.match_by_group( data=df, treatment_col=args.treatment_col, - score_cols=[SCORE_COL], + score_cols=[args.score_col], groupby_col=args.groupby_col, ) logger.info("shape: {}\n{}".format(matched.shape, matched.head())) @@ -485,7 +488,7 @@ def search_best_match(self, df): create_table_one( data=matched, treatment_col=args.treatment_col, - features=MATCHING_COVARIATES, + features=args.matching_cols, ) ) )