Skip to content

Commit

Permalink
Update README
Browse files Browse the repository at this point in the history
  • Loading branch information
alexioannides committed Jun 20, 2022
1 parent eddcecc commit afc5108
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1181,7 +1181,7 @@ from typing import Any, Dict, List, NamedTuple, Tuple
from bodywork_pipeline_utils import aws, logging
from bodywork_pipeline_utils.aws import Dataset
from numpy import array, ndarray
from numpy import array
from pandas import DataFrame
from sklearn.base import BaseEstimator
from sklearn.model_selection import GridSearchCV, train_test_split
Expand All @@ -1191,10 +1191,10 @@ from sklearn.tree import DecisionTreeRegressor
PRODUCT_CODE_MAP = {"SKU001": 0, "SKU002": 1, "SKU003": 2, "SKU004": 3, "SKU005": 4}
HYPERPARAM_GRID = {
"random_state": [42],
"criterion": ["mse", "mae"],
"max_depth": [2, 3, 4, 5, 6, 7, 8, 9, 10, None],
"min_samples_split": [2, 3, 4, 5, 6, 7, 8, 9, 10],
"min_samples_leaf": [2, 3, 4, 5, 6, 7, 8, 9, 10],
"criterion": ["squared_error", "absolute_error"],
"max_depth": [2, 4, 6, 8, 10, None],
"min_samples_split": [2, 4, 6, 8, 10],
"min_samples_leaf": [2, 4, 6, 8, 10],
}
log = logging.configure_logger()
Expand All @@ -1220,7 +1220,7 @@ def main(
s3_bucket: str,
metric_error_threshold: float,
metric_warning_threshold: float,
hyperparam_grid: Dict[str, Any]
hyperparam_grid: Dict[str, Any],
) -> None:
"""Main training job."""
log.info("Starting train-model stage.")
Expand Down Expand Up @@ -1258,13 +1258,6 @@ def prepare_data(data: DataFrame) -> FeatureAndLabels:
return FeatureAndLabels(X_train, X_test, y_train, y_test)
def compute_metrics(y_true: ndarray, y_pred: ndarray) -> TaskMetrics:
"""Compute performance metrics for the task and log them."""
mae = mean_absolute_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)
return TaskMetrics(r2, mae)
def train_model(
data: FeatureAndLabels, hyperparam_grid: Dict[str, Any]
) -> Tuple[BaseEstimator, TaskMetrics]:
Expand All @@ -1279,7 +1272,10 @@ def train_model(
grid_search.fit(preprocess(data.X_train), data.y_train)
best_model = grid_search.best_estimator_
y_test_pred = best_model.predict(preprocess(data.X_test))
performance_metrics = compute_metrics(data.y_test, y_test_pred)
performance_metrics = TaskMetrics(
r2_score(data.y_test, y_test_pred),
mean_absolute_error(data.y_test, y_test_pred),
)
return (best_model, performance_metrics)
Expand Down Expand Up @@ -1354,7 +1350,7 @@ if __name__ == "__main__":
s3_bucket,
r2_metric_error_threshold,
r2_metric_warning_threshold,
HYPERPARAM_GRID
HYPERPARAM_GRID,
)
except Exception as e:
log.error(f"Error encountered when training model - {e}")
Expand Down

0 comments on commit afc5108

Please sign in to comment.