diff --git a/Cargo.lock b/Cargo.lock index 0d8e7359..36ffd004 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -594,6 +594,7 @@ dependencies = [ "hyper", "hyper-rustls", "insta", + "k8s-openapi", "kube", "oauth2", "open", diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 7b81db77..93c15dac 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -50,3 +50,4 @@ similar-asserts = "1.4.2" serial_test = "0.9.0" temp-env = { version = "0.3.1", features = ["async_closure"] } rstest = "0.16.0" +k8s-openapi = { version = "0.17.0", features = ["v1_23", "schemars" ] } diff --git a/cli/tests/cli.rs b/cli/tests/cli.rs index 264f7b51..04a0db75 100644 --- a/cli/tests/cli.rs +++ b/cli/tests/cli.rs @@ -5,12 +5,20 @@ use controller::{ common::{find_ame_endpoint, private_repo_gh_pat, setup_cluster}, project_source_ctrl::ProjectSrcCtrl, }; +use controller::{ModelValidationStatus, Project}; use fs_extra::dir::CopyOptions; +use futures_util::StreamExt; + use insta::assert_snapshot; +use k8s_openapi::api::apps::v1::Deployment; +use kube::api::ListParams; +use kube::runtime::{watcher, WatchStreamExt}; +use kube::Api; use kube::Client; use rstest::*; use serial_test::serial; +use std::time::Duration; use std::{ path::{Path, PathBuf}, process::Command, @@ -303,3 +311,109 @@ async fn cannot_create_multiple_sources_for_the_same_repo() -> Result<(), Box Result<(), Box> { + test_setup().await?; + + // The template repo is required as the ame-demo requires it to train. + let template_repo = "https://github.com/TeaInSpace/ame-template-demo.git"; + let repo = "https://github.com/TeaInSpace/ame-demo.git"; + let model_name = "logreg"; // this name from the ame-demo repo. + let project_src_ctrl = ProjectSrcCtrl::new(kube_client().await?, AME_NAMESPACE); + + let kube_client = kube_client().await?; + let deployments: Api = Api::namespaced(kube_client.clone(), AME_NAMESPACE); + let projects: Api = Api::namespaced(kube_client.clone(), AME_NAMESPACE); + + let _ = project_src_ctrl.delete_project_src_for_repo(repo).await; + let mut cmd = Command::cargo_bin("cli")?; + + let _output = cmd + .arg("projectsrc") + .arg("create") + .arg(template_repo) + .assert() + .failure(); + + let mut cmd = Command::cargo_bin("cli")?; + + let _output = cmd + .arg("projectsrc") + .arg("create") + .arg(repo) + .assert() + .failure(); + + // Note that the model will start training now if now version is present. + // We will need to check for this in future tests. + + // It is important that we only have 2 projects as the watcher + // does not perform any filtering. + assert_eq!( + projects + .list_metadata(&ListParams::default()) + .await? + .items + .len(), + 2 + ); + + // No deployment for the model should be present. + assert!(deployments.get(model_name).await.is_err()); + + let mut project_watcher = watcher(projects, ListParams::default()) + .applied_objects() + .boxed(); + + while let Some(e) = project_watcher.next().await { + let Some(mut status) = e?.status else { + continue; + }; + + let Some(model_status) = status.get_model_status(model_name) else { + continue ; + }; + + match model_status.validation { + Some(ModelValidationStatus::Validated { .. }) => break, + Some(ModelValidationStatus::FailedValidation { .. }) => { + return Err("model failed validation".into()); + } + _ => (), + }; + } + + let mut deployment_watcher = watcher( + deployments, + ListParams::default().fields(&format!("metadata.name=={model_name}")), + ) + .applied_objects() + .boxed(); + + let timeout = Duration::from_secs(60); + let start = std::time::Instant::now(); + + while let Some(e) = deployment_watcher.next().await { + let Some(status) = e?.status else { + return Err("missing deployment status".into()); + }; + + if let Some(1) = status.ready_replicas { + break; + } + + if start.elapsed() > timeout { + return Err("failed to deploy model within timeout".into()); + } + } + + project_src_ctrl.delete_project_src_for_repo(repo).await?; + project_src_ctrl + .delete_project_src_for_repo(template_repo) + .await?; + + Ok(()) +} diff --git a/controller/src/lib.rs b/controller/src/lib.rs index 5eda97dc..c7de1e2d 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -68,6 +68,12 @@ pub enum Error { #[error("got error from AME's secret store: {0}")] SecretError(#[from] SecretError), + + #[error("failed to find model status for: {0}")] + MissingModelStatus(String), + + #[error("failed to find validation task status for model: {0}")] + MissingValidationTask(String), } pub type Result = std::result::Result; diff --git a/controller/src/manager.rs b/controller/src/manager.rs index bba452e7..4dde720e 100644 --- a/controller/src/manager.rs +++ b/controller/src/manager.rs @@ -389,7 +389,10 @@ impl Task { "exec {} save_artifacts {}", - self.spec.runcommand.clone().unwrap(), + self.spec + .runcommand + .clone() + .unwrap_or("missing command".to_string()), // TODO: handle missing commands self.task_artifacts_path() ) }; diff --git a/controller/src/project.rs b/controller/src/project.rs index d756d81f..f5f3f25e 100644 --- a/controller/src/project.rs +++ b/controller/src/project.rs @@ -1,10 +1,14 @@ use std::{ collections::{BTreeMap, HashMap}, + default::Default, sync::Arc, time::Duration, }; -use crate::{manager, Error, Result, TaskSpec}; +use crate::{ + manager::{self, TaskPhase}, + Error, Result, TaskSpec, +}; use ame::grpc::LogEntry; use futures::{future::BoxFuture, FutureExt, StreamExt}; @@ -26,14 +30,16 @@ use k8s_openapi::{ }, }; use kube::{ - api::{ListParams, Patch, PatchParams}, + api::{ListParams, Patch, PatchParams, PostParams}, core::ObjectMeta, runtime::{controller::Action, Controller}, Api, Client, CustomResource, Resource, ResourceExt, }; +use reqwest::Url; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use serde_json::json; +use tracing::{debug, error, log::info}; #[derive(CustomResource, Deserialize, Serialize, Clone, Debug, JsonSchema, Default)] #[kube( @@ -65,7 +71,48 @@ pub struct ProjectSpec { #[derive(Deserialize, Serialize, Clone, Debug, JsonSchema, PartialEq, Eq, Default)] pub struct ProjectStatus { #[serde(skip_serializing_if = "Option::is_none")] - models: Option>, + pub models: Option>, +} + +impl ProjectStatus { + fn set_model_status(&mut self, name: &str, status: ModelStatus) { + if let Some(ref mut statuses) = self.models { + statuses.insert(name.to_string(), status); + } else { + let mut statuses: BTreeMap = BTreeMap::new(); + statuses.insert(name.to_string(), status); + self.models = Some(statuses); + } + } + + fn set_model_validation(&mut self, name: &str, validation: ModelValidationStatus) { + let mut default = ModelStatus::default(); + let mut status = self.get_model_status(name).unwrap_or(&mut default).clone(); + + status.validation = Some(validation); + + self.set_model_status(name, status); + } + + pub fn get_model_status(&mut self, name: &str) -> Option<&mut ModelStatus> { + self.models.as_mut().and_then(|models| models.get_mut(name)) + } + + fn set_latest_valid_model_version(&mut self, name: &str, version: String) { + let mut status = self + .get_model_status(name) + .map(|s| s.to_owned()) + .unwrap_or_default(); + + status.latest_valid_model_version = Some(version); + self.set_model_status(name, status) + } + + fn get_latest_valid_model_version(&mut self, name: &str) -> Option { + self.get_model_status(name) + .as_ref() + .and_then(|model_status| model_status.latest_valid_model_version.clone()) + } } #[derive(Deserialize, Serialize, Clone, Debug, JsonSchema, Default)] @@ -78,6 +125,9 @@ pub struct Model { training: TrainingCfg, #[serde(skip_serializing_if = "Option::is_none")] deployment: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + validation_task: Option, } #[derive(Deserialize, Serialize, Clone, Debug, JsonSchema, Default)] @@ -113,16 +163,67 @@ pub enum ModelType { #[derive(Deserialize, Serialize, Clone, Debug, JsonSchema, PartialEq, Eq, Default)] #[serde(rename_all = "camelCase")] pub struct ModelStatus { - name: String, #[serde(skip_serializing_if = "Option::is_none")] - latest_model_version: Option