Skip to content

Commit

Permalink
feat: implement retrain and eval workflow (fix #9)
Browse files Browse the repository at this point in the history
  • Loading branch information
ijdoc authored Oct 10, 2024
1 parent e30f0d2 commit e9e6df2
Show file tree
Hide file tree
Showing 9 changed files with 426 additions and 194 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ jobs:
permissions:
contents: read
issues: write # Grant write access to issues
outputs: # Define outputs for downstream jobs
drift_detected: ${{ steps.drift_check.outputs.drift_detected }}
# outputs: # Define outputs for downstream jobs
# drift_detected: ${{ steps.drift_check.outputs.drift_detected }}
steps:
- name: ⏬ Checkout repository
uses: actions/checkout@v4
Expand Down
38 changes: 38 additions & 0 deletions .github/workflows/train_and_eval.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Checks model performance, and retrains
# if model degradation is detected
name: Performance Evaluation

on:
repository_dispatch: # Allow triggering from a POST request
types: ["Train On Promoted Dataset"]
push:
branches: ijdoc/issue9

jobs:
train:
runs-on: self-hosted
steps:
- name: ⏬ Checkout repository
uses: actions/checkout@v4

- name: ⚙️ Train!
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
run: |
python train.py
eval_check:
needs: train
runs-on: self-hosted
steps:
- name: ⏬ Checkout repository
uses: actions/checkout@v4

- name: ⚙️ Run Evaluation
env:
WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}
id: eval_check
run: |
output=$(python evaluate.py)
echo "$output" >> $GITHUB_STEP_SUMMARY
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
**/artifacts
**/checkpoints
**/report.txt
**/*.pth

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
19 changes: 10 additions & 9 deletions drift/check_drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
) as run:

# Grab the latest training and production dataframes
train_artifact = run.use_artifact("jdoc-org/wandb-registry-dataset/training:latest")
run.config["train_data"] = train_artifact.source_name
registered_training_dataset = "jdoc-org/wandb-registry-dataset/training:latest"
train_artifact = run.use_artifact(registered_training_dataset)
run.config["train_data"] = train_artifact.name
train_data = train_artifact.get("training_data").get_dataframe()

prod_artifact = run.use_artifact("production_data:latest")
run.config["prod_data"] = prod_artifact.source_name
run.config["prod_data"] = prod_artifact.name
prod_data = prod_artifact.get("production_data").get_dataframe()

feature_list = ["active_power", "temp", "humidity", "pressure"]
Expand Down Expand Up @@ -72,19 +73,19 @@
artifact.description = prod_artifact.description
artifact = run.log_artifact(artifact).wait()
# Open a github issue asking for manual review
issue_title = f"Data drift detected on {train_artifact.source_name}"
issue_title = f"Data drift detected on {train_artifact.name}"
issue_body = (
f"Data drift has been detected when comparing the registered training dataset with recent production data.\n\n"
f"Please review the [candidate artifact](https://wandb.ai/{run.entity}/{run.project}/artifacts/{artifact.type}/{artifact.source_name}) "
f"Please review the [candidate artifact](https://wandb.ai/{run.entity}/{run.project}/artifacts/{artifact.type}/{artifact.name}) "
f"and the [drift report]({report_url}) to determine if the registered training data should be updated.\n\n"
f"To approve the new candidate after review, link it to [the training Dataset Registry](https://wandb.ai/registry/dataset?selectionPath=jdoc-org%2Fwandb-registry-dataset%2Ftraining&view=versions) at "
f"(`jdoc-org/wandb-registry-dataset/training`), otherwise close this issue."
f"(`{registered_training_dataset}`), otherwise close this issue."
)
issue_url = open_github_issue(issue_title, issue_body, labels=["drift", "data"])
print(
f"Production batch `{prod_artifact.source_name}` has been logged "
f"as candidate to replace training data `{artifact.source_name}`. "
f"An [issue]({issue_url}) was created for manual review:\n"
f"as candidate `{artifact.name}` to replace training data. "
f"An [issue]({issue_url}) was also created for manual review:\n"
)
print(f"- [Data Drift Issue]({issue_url})")
else:
Expand All @@ -93,7 +94,7 @@
print(f"- [W&B Run]({run.url})")
print(f"- [Full data drift report]({report_url})")

# Optionally the drift detection result in a parseable format.
# Optionally print the drift detection result in a parseable format.
# Helpful if you want to use this result in a CI/CD pipeline
# to automatically update the data and/or retrain your model.
# print(f"DRIFT_DETECTED={drift_detected}")
144 changes: 144 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import torch
import torch.nn.functional as F
from sklearn.preprocessing import StandardScaler
import pandas as pd
from simple_model import load_model
import wandb
import os
from utils import plot_predictions_vs_actuals, prep_time_series_data
import numpy as np

# Initialize W&B for evaluation job
wandb.init(
project="wandb-webinar-cicd-2024",
job_type="evaluate",
)

# Load the production model
prod_model_name = "jdoc-org/wandb-registry-model/production:latest"
prod_artifact = wandb.use_artifact(prod_model_name)
wandb.config["prod_model"] = prod_artifact.name
prod_model_path = os.path.join(prod_artifact.download(), "best_model.pth")

# Load the rival model
rival_artifact = wandb.use_artifact("trained_model:latest")
wandb.config["rival_model"] = rival_artifact.name
rival_model_path = os.path.join(rival_artifact.download(), "best_model.pth")

# Load the checkpoint
model_checkpoint = torch.load(prod_model_path, map_location=torch.device("cpu"))
rival_checkpoint = torch.load(rival_model_path, map_location=torch.device("cpu"))

# Load the metrics for comparisson
prod_metrics = model_checkpoint["metrics"]

# Load rival scalers and metrics
scaler_X = rival_checkpoint["scaler_X"]
scaler_y = rival_checkpoint["scaler_y"]
config = rival_checkpoint["config"]
metrics = rival_checkpoint["metrics"]

# Instantiate the model and load its state dictionary
model = load_model(
config["input_size"] * config["n_time_steps"],
config["hidden_size"],
config["output_size"],
)
model.load_state_dict(rival_checkpoint["model_state_dict"])
model.eval() # Set the model to evaluation mode

# Load the latest production data artifact from W&B
artifact = wandb.use_artifact("production_data:latest")
df_test = artifact.get("production_data").get_dataframe()

# Prepare data (assumes the first column is the target value)
X_test = df_test.iloc[:, :].values # Last 3 columns as input
y_test = df_test.iloc[:, 0].values.reshape(-1, 1) # First column as target

# Normalize the data using StandardScaler
scaler_X = StandardScaler()
scaler_y = StandardScaler()

# Normalize the data using StandardScaler
scaler_X = StandardScaler()
scaler_y = StandardScaler()

X_test_scaled = scaler_X.fit_transform(X_test)
y_test_scaled = scaler_y.fit_transform(y_test)

# Create time series data using n_time_steps
n_time_steps = config["n_time_steps"]
X_time_series, y_time_series = prep_time_series_data(
X_test_scaled, y_test_scaled, config["n_time_steps"]
)

# Convert time series data to tensors
X_test_tensor = torch.tensor(X_time_series, dtype=torch.float32)
y_test_tensor = torch.tensor(y_time_series, dtype=torch.float32)

# Create a DataLoader for the test data
test_dataset = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config["batch_size"])

# Evaluation loop
mse_loss = 0.0
mae_loss = 0.0
ss_res = 0.0
ss_tot = 0.0
all_predictions = []
all_actuals = []

with torch.no_grad():
for batch_X, batch_y in test_loader:
outputs = model(batch_X)

# Store predictions and actuals for plotting
all_predictions.append(outputs.numpy())
all_actuals.append(batch_y.numpy())

# Calculate MSE
mse_loss += F.mse_loss(outputs, batch_y).item()

# Calculate MAE
mae_loss += F.l1_loss(outputs, batch_y).item()

# Calculate R²
ss_res += torch.sum((batch_y - outputs) ** 2).item()
ss_tot += torch.sum((batch_y - torch.mean(batch_y)) ** 2).item()

# Average the losses over the test dataset
mse_loss /= len(test_loader)
mae_loss /= len(test_loader)
r2_score = 1 - (ss_res / ss_tot)

# Log evaluation metrics to W&B
eval_table = wandb.Table(columns=["Metric", "Production", "Candidate"])
eval_table.add_data("MSE", prod_metrics["val_loss"], mse_loss)
eval_table.add_data("MAE", prod_metrics["val_mae"], mae_loss)
eval_table.add_data("R²", prod_metrics["val_r2"], r2_score)
wandb.log({"performance_metrics": eval_table})

# Convert predictions and actuals to numpy arrays for plotting
all_predictions = scaler_y.inverse_transform(np.vstack(all_predictions))
all_actuals = scaler_y.inverse_transform(np.vstack(all_actuals))

# Generate and log predictions vs actuals plot
plt = plot_predictions_vs_actuals(all_actuals, all_predictions)
wandb.log({"predictions_vs_actuals": wandb.Image(plt)})

if prod_metrics["val_r2"] > r2_score:
print("> Candidate model did not perform as well as the production model\n\n")
else:
print("> [!INFO]")
print("> The candidate model performed better than the production model\n\n")

# Link the rival model to the proction model registry
rival_artifact.link("jdoc-org/wandb-registry-model/production")
print(
"The candidate model has been promoted to the [production model registry](https://wandb.ai/registry/model?selectionPath=jdoc-org%2Fwandb-registry-model%2Fproduction&view=versions)!"
)

print(f"- [W&B Run]({wandb.run.url})")

# Finish W&B run
wandb.finish()
50 changes: 50 additions & 0 deletions simple_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torch.nn as nn


class SimpleDNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size, dropout_prob=0.2):
super(SimpleDNN, self).__init__()
self.half_hidden_size = round(hidden_size / 2)

self.fc1 = nn.Linear(input_size, hidden_size)
self.bn1 = nn.BatchNorm1d(hidden_size)

self.fc2 = nn.Linear(hidden_size, self.half_hidden_size)
self.bn2 = nn.BatchNorm1d(self.half_hidden_size)
self.fc3 = nn.Linear(self.half_hidden_size, hidden_size)
self.bn3 = nn.BatchNorm1d(hidden_size)

self.fc4 = nn.Linear(hidden_size, output_size) # Output layer

# Activation function
self.relu = nn.ReLU()

# Dropout layer
self.dropout = nn.Dropout(dropout_prob)

def forward(self, x):
# First hidden layer
out = self.fc1(x)
out = self.bn1(out) # Apply batch normalization
out = self.relu(out)
out = self.dropout(out) # Apply dropout

# Second hidden layer
out = self.fc2(out)
out = self.bn2(out) # Apply batch normalization
out = self.relu(out)

# Third hidden layer
out = self.fc3(out)
out = self.bn3(out) # Apply batch normalization
out = self.relu(out)

# Output layer (no activation here, typically applied outside for regression/classification)
out = self.fc4(out)
return out


def load_model(input_size, hidden_size, output_size, dropout_prob=0.2):
model = SimpleDNN(input_size, hidden_size, output_size, dropout_prob)
return model
Loading

0 comments on commit e9e6df2

Please sign in to comment.