Skip to content

Commit

Permalink
Implement model validation
Browse files Browse the repository at this point in the history
Issue: #96

In order to ensure stable operations, models need to be validated before
deployment.

Therefore this commit implements support for model validation using a
Task.
  • Loading branch information
Jessie Chatham Spencer committed Mar 17, 2023
1 parent ba47b05 commit 786f01b
Show file tree
Hide file tree
Showing 13 changed files with 3,971 additions and 25 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ similar-asserts = "1.4.2"
serial_test = "0.9.0"
temp-env = { version = "0.3.1", features = ["async_closure"] }
rstest = "0.16.0"
k8s-openapi = { version = "0.17.0", features = ["v1_23", "schemars" ] }
114 changes: 114 additions & 0 deletions cli/tests/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@ use controller::{
common::{find_ame_endpoint, private_repo_gh_pat, setup_cluster},
project_source_ctrl::ProjectSrcCtrl,
};
use controller::{ModelValidationStatus, Project};
use fs_extra::dir::CopyOptions;

use futures_util::StreamExt;

use insta::assert_snapshot;
use k8s_openapi::api::apps::v1::Deployment;
use kube::api::ListParams;
use kube::runtime::{watcher, WatchStreamExt};
use kube::Api;
use kube::Client;
use rstest::*;
use serial_test::serial;
use std::time::Duration;
use std::{
path::{Path, PathBuf},
process::Command,
Expand Down Expand Up @@ -303,3 +311,109 @@ async fn cannot_create_multiple_sources_for_the_same_repo() -> Result<(), Box<dy

Ok(())
}

#[tokio::test]
#[serial]
#[ignore]
async fn can_train_validate_and_deploy_model() -> Result<(), Box<dyn std::error::Error>> {
test_setup().await?;

// The template repo is required as the ame-demo requires it to train.
let template_repo = "https://github.com/TeaInSpace/ame-template-demo.git";
let repo = "https://github.com/TeaInSpace/ame-demo.git";
let model_name = "logreg"; // this name from the ame-demo repo.
let project_src_ctrl = ProjectSrcCtrl::new(kube_client().await?, AME_NAMESPACE);

let kube_client = kube_client().await?;
let deployments: Api<Deployment> = Api::namespaced(kube_client.clone(), AME_NAMESPACE);
let projects: Api<Project> = Api::namespaced(kube_client.clone(), AME_NAMESPACE);

let _ = project_src_ctrl.delete_project_src_for_repo(repo).await;
let mut cmd = Command::cargo_bin("cli")?;

let _output = cmd
.arg("projectsrc")
.arg("create")
.arg(template_repo)
.assert()
.failure();

let mut cmd = Command::cargo_bin("cli")?;

let _output = cmd
.arg("projectsrc")
.arg("create")
.arg(repo)
.assert()
.failure();

// Note that the model will start training now if now version is present.
// We will need to check for this in future tests.

// It is important that we only have 2 projects as the watcher
// does not perform any filtering.
assert_eq!(
projects
.list_metadata(&ListParams::default())
.await?
.items
.len(),
2
);

// No deployment for the model should be present.
assert!(deployments.get(model_name).await.is_err());

let mut project_watcher = watcher(projects, ListParams::default())
.applied_objects()
.boxed();

while let Some(e) = project_watcher.next().await {
let Some(mut status) = e?.status else {
continue;
};

let Some(model_status) = status.get_model_status(model_name) else {
continue ;
};

match model_status.validation {
Some(ModelValidationStatus::Validated { .. }) => break,
Some(ModelValidationStatus::FailedValidation { .. }) => {
return Err("model failed validation".into());
}
_ => (),
};
}

let mut deployment_watcher = watcher(
deployments,
ListParams::default().fields(&format!("metadata.name=={model_name}")),
)
.applied_objects()
.boxed();

let timeout = Duration::from_secs(60);
let start = std::time::Instant::now();

while let Some(e) = deployment_watcher.next().await {
let Some(status) = e?.status else {
return Err("missing deployment status".into());
};

if let Some(1) = status.ready_replicas {
break;
}

if start.elapsed() > timeout {
return Err("failed to deploy model within timeout".into());
}
}

project_src_ctrl.delete_project_src_for_repo(repo).await?;
project_src_ctrl
.delete_project_src_for_repo(template_repo)
.await?;

Ok(())
}
6 changes: 6 additions & 0 deletions controller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ pub enum Error {

#[error("got error from AME's secret store: {0}")]
SecretError(#[from] SecretError),

#[error("failed to find model status for: {0}")]
MissingModelStatus(String),

#[error("failed to find validation task status for model: {0}")]
MissingValidationTask(String),
}

pub type Result<T, E = Error> = std::result::Result<T, E>;
Expand Down
5 changes: 4 additions & 1 deletion controller/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,10 @@ impl Task {
"exec {}
save_artifacts {}",
self.spec.runcommand.clone().unwrap(),
self.spec
.runcommand
.clone()
.unwrap_or("missing command".to_string()), // TODO: handle missing commands
self.task_artifacts_path()
)
};
Expand Down
Loading

0 comments on commit 786f01b

Please sign in to comment.