Skip to content

Commit

Permalink
fix: Feedback addressed to refactor run_linear_regression and make fu…
Browse files Browse the repository at this point in the history
…nction less complex
  • Loading branch information
agendazhang committed Feb 1, 2025
1 parent 62ad25b commit da3765d
Showing 1 changed file with 21 additions and 13 deletions.
34 changes: 21 additions & 13 deletions src/linreg_ally/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,7 @@ def run_linear_regression(dataframe, target_column, numeric_feats, categorical_f
X = dataframe.drop(columns=[target_column])
y = dataframe[target_column]

preprocessor = make_column_transformer(
(StandardScaler(), numeric_feats),
(OneHotEncoder(), categorical_feats),
('drop', drop_feats)
)
preprocessor = preprocess(numeric_feats, categorical_feats, drop_feats)

pipe = Pipeline([
('preprocessor', preprocessor),
Expand All @@ -106,9 +102,26 @@ def run_linear_regression(dataframe, target_column, numeric_feats, categorical_f

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

pipe.fit(X_train, y_train)
(best_model, scores) = fit_predict(pipe, X_train, X_test, y_train, y_test, scoring_metrics)

print("Model Summary")
print("------------------------")
for metric, score in scores.items():
print(f"Test {metric}: {score:.3f}")

best_model = pipe
return best_model, X_train, X_test, y_train, y_test, scores

def preprocess(numeric_feats, categorical_feats, drop_feats):
return make_column_transformer(
(StandardScaler(), numeric_feats),
(OneHotEncoder(), categorical_feats),
('drop', drop_feats)
)

def fit_predict(pipeline, X_train, X_test, y_train, y_test, scoring_metrics):
pipeline.fit(X_train, y_train)

best_model = pipeline

predictions = best_model.predict(X_test)

Expand All @@ -117,9 +130,4 @@ def run_linear_regression(dataframe, target_column, numeric_feats, categorical_f
scorer = get_scorer(metric)
scores[metric] = scorer._score_func(y_test, predictions)

print("Model Summary")
print("------------------------")
for metric, score in scores.items():
print(f"Test {metric}: {score:.3f}")

return best_model, X_train, X_test, y_train, y_test, scores
return (best_model, scores)

0 comments on commit da3765d

Please sign in to comment.