From d6289c659077554ab029b9bce899c7ce8077643c Mon Sep 17 00:00:00 2001 From: Johnny Graettinger Date: Wed, 4 Sep 2024 19:30:38 -0500 Subject: [PATCH] crates/agent: authorization API uses DB snapshots Update the API to fetch on-demand snapshots from the DB of authorization state, bounding both the frequency with which snapshots may be taken and also for how long a stale snapshot may be used. A request will either succeed against a current snapshot (happy fast path), or will fail. If it fails, the caller is asked to retry after a given delay, and meanwhile the API will have typically started to refresh the current snapshot. If the request continues to fail against a snapshot which as taken _after_ the issued-at time of the request token, only then is the error considered permanently failed. This means that requests are cheap, evaluted only by the agent-api, but we *also* don't have caching artifacts that could result in false errors when tasks or collections are being created or changed. Instead, we incur a minor delay while a sufficiently-recent snapshot is taken. A final change is that we now prefix match on shard ID templates and journal name templates, rather than munging both into a more-approximate prefix match over catalog names. --- crates/agent-sql/src/data_plane.rs | 67 --- crates/agent/src/api/authorize.rs | 464 +++++++++++++++---- crates/agent/src/api/mod.rs | 5 + go/runtime/authorizer.go | 50 +- supabase/migrations/60_extract_templates.sql | 14 + 5 files changed, 425 insertions(+), 175 deletions(-) create mode 100644 supabase/migrations/60_extract_templates.sql diff --git a/crates/agent-sql/src/data_plane.rs b/crates/agent-sql/src/data_plane.rs index 1725650dd5..1379159107 100644 --- a/crates/agent-sql/src/data_plane.rs +++ b/crates/agent-sql/src/data_plane.rs @@ -68,70 +68,3 @@ pub async fn fetch_data_planes( Ok(r.into_iter().collect()) } - -pub async fn fetch_data_plane_by_task_and_fqdn( - pool: &sqlx::PgPool, - task_shard: &str, - task_data_plane_fqdn: &str, -) -> sqlx::Result> { - sqlx::query_as!( - tables::DataPlane, - r#" - select - d.id as "control_id: Id", - d.data_plane_name, - d.data_plane_fqdn, - false as "is_default!: bool", - d.hmac_keys, - d.broker_address, - d.reactor_address, - d.ops_logs_name as "ops_logs_name: models::Collection", - d.ops_stats_name as "ops_stats_name: models::Collection" - from data_planes d - join live_specs t on t.data_plane_id = d.id - where d.data_plane_fqdn = $2 and starts_with($1::text, t.catalog_name) - "#, - task_shard, - task_data_plane_fqdn, - ) - .fetch_optional(pool) - .await -} - -pub async fn verify_task_authorization( - pool: &sqlx::PgPool, - task_shard: &str, - journal_name_or_prefix: &str, - required_role: &str, -) -> sqlx::Result> { - let r = sqlx::query!( - r#" - select - t.catalog_name as "task_name: String", - c.catalog_name as "collection_name: models::Collection", - c.data_plane_id as "collection_data_plane_id: models::Id", - exists( - select 1 - from internal.task_roles($1, $3::text::grant_capability) r - where starts_with($2, r.role_prefix) - ) as "authorized!: bool" - from live_specs t, live_specs c - where starts_with($1, t.catalog_name) - and starts_with($2, c.catalog_name) - "#, - task_shard, - journal_name_or_prefix, - required_role, - ) - .fetch_optional(pool) - .await?; - - Ok(r.map(|r| { - ( - r.task_name, - r.collection_name, - r.collection_data_plane_id, - r.authorized, - ) - })) -} diff --git a/crates/agent/src/api/authorize.rs b/crates/agent/src/api/authorize.rs index 0427ed3c3e..4d3a128001 100644 --- a/crates/agent/src/api/authorize.rs +++ b/crates/agent/src/api/authorize.rs @@ -9,15 +9,19 @@ pub struct Request { token: String, } -#[derive(Debug, serde::Serialize, schemars::JsonSchema)] +#[derive(Default, Debug, serde::Serialize, schemars::JsonSchema)] #[serde(rename_all = "camelCase")] pub struct Response { // # JWT token which has been authorized for use. token: String, // # Address of Gazette brokers for the issued token. broker_address: String, + // # Number of milliseconds to wait before retrying the request. + // Non-zero if and only if token is not set. + retry_millis: u64, } +#[axum::debug_handler] pub async fn authorize_task( axum::extract::State(app): axum::extract::State>, axum::Json(request): axum::Json, @@ -39,21 +43,17 @@ async fn do_authorize_task(app: &App, Request { token }: &Request) -> anyhow::Re }?; tracing::debug!(?claims, ?header, "decoded authorization request"); - // Split off the leading 'capture', 'derivation', or 'materialization' - // prefix of the Shard ID conveyed in `claims.subject`. - // The remainder of `task_shard` is a catalog task plus a shard suffix. - let Some((task_type, task_shard)) = claims.sub.split_once('/') else { - anyhow::bail!("invalid claims subject {}", claims.sub); - }; - // Map task-type from shard prefix naming, to ops log naming. - let task_type = match task_type { - "capture" => "capture", - "derivation" => "derivation", - "materialize" => "materialization", - _ => anyhow::bail!("invalid shard task type {task_type}"), - }; + let shard_id = claims.sub.as_str(); + if shard_id.is_empty() { + anyhow::bail!("missing required shard ID (`sub` claim)"); + } - let journal_name_or_prefix = labels::expect_one(claims.sel.include(), "name")?; + let shard_data_plane_fqdn = claims.iss.as_str(); + if shard_data_plane_fqdn.is_empty() { + anyhow::bail!("missing required shard data-plane FQDN (`iss` claim)"); + } + + let journal_name_or_prefix = labels::expect_one(claims.sel.include(), "name")?.to_owned(); // Require the request was signed with the AUTHORIZE capability, // and then strip this capability before issuing a response token. @@ -64,28 +64,120 @@ async fn do_authorize_task(app: &App, Request { token }: &Request) -> anyhow::Re // Validate and match the requested capabilities to a corresponding role. let required_role = match claims.cap { - proto_gazette::capability::LIST | proto_gazette::capability::READ => "read", - proto_gazette::capability::APPLY | proto_gazette::capability::APPEND => "write", - _ => { - anyhow::bail!( - "capability {} cannot be authorized by this service", - claims.cap - ); + proto_gazette::capability::LIST | proto_gazette::capability::READ => { + models::Capability::Read + } + proto_gazette::capability::APPLY | proto_gazette::capability::APPEND => { + models::Capability::Write + } + cap => { + anyhow::bail!("capability {cap} cannot be authorized by this service"); } }; - // Resolve the identified data-plane through its task assignment (which is verified) and FQDN. - let Some(task_data_plane) = agent_sql::data_plane::fetch_data_plane_by_task_and_fqdn( - &app.pg_pool, - task_shard, - &claims.iss, - ) - .await? - else { + // Resolve the authorization snapshot against which this request is evaluated. + let snapshot = app.snapshot.read().unwrap(); + + let taken_unix = snapshot + .taken + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(); + + // If the snapshot is too old then the client MUST retry. + if claims.iat > taken_unix + MAX_SNAPSHOT_INTERVAL.as_secs() { + begin_refresh(snapshot, &app.snapshot); + + return Ok(Response { + retry_millis: jitter(), + ..Default::default() + }); + } + + match evaluate_authorization( + &snapshot, + shard_id, + shard_data_plane_fqdn, + token, + &journal_name_or_prefix, + required_role, + ) { + Ok((encoding_key, data_plane_fqdn, broker_address)) => { + claims.iss = data_plane_fqdn; + claims.exp = claims.iat + 3_600; // One hour. + + let token = jsonwebtoken::encode(&header, &claims, &encoding_key) + .context("failed to encode authorized JWT")?; + + Ok(Response { + broker_address, + token, + ..Default::default() + }) + } + Err(err) if taken_unix > claims.iat => { + // The snapshot was taken AFTER the authorization request was minted, + // which means the request cannot have prior knowledge of upcoming + // state re-configurations, and this is a terminal error. + Err(err) + } + Err(_) => { + let retry_millis = if let Some(remaining) = + MIN_SNAPSHOT_INTERVAL.checked_sub(snapshot.taken.elapsed().unwrap_or_default()) + { + // Our current snapshot isn't old enough. + remaining.as_millis() as u64 + } else { + begin_refresh(snapshot, &app.snapshot); + 0 + } + jitter(); + + Ok(Response { + retry_millis, + ..Default::default() + }) + } + } +} + +fn evaluate_authorization( + Snapshot { + taken: _, + collections, + data_planes, + role_grants, + tasks, + refresh_tx: _, + }: &Snapshot, + shard_id: &str, + shard_data_plane_fqdn: &str, + token: &str, + journal_name_or_prefix: &str, + required_role: models::Capability, +) -> anyhow::Result<(jsonwebtoken::EncodingKey, String, String)> { + // Map `claims.sub`, a Shard ID, into its task. + let task = tasks + .binary_search_by(|task| { + if shard_id.starts_with(&task.shard_template_id) { + std::cmp::Ordering::Equal + } else { + task.shard_template_id.as_str().cmp(shard_id) + } + }) + .ok() + .map(|index| &tasks[index]); + + // Map `claims.iss`, a data-plane FQDN, into its task-matched data-plane. + let task_data_plane = task.and_then(|task| { + data_planes + .get_by_key(&task.data_plane_id) + .filter(|data_plane| data_plane.data_plane_fqdn == shard_data_plane_fqdn) + }); + + let (Some(task), Some(task_data_plane)) = (task, task_data_plane) else { anyhow::bail!( - "task {task_shard} within data-plane {} is not known", - claims.iss - ); + "task shard {shard_id} within data-plane {shard_data_plane_fqdn} is not known" + ) }; // Attempt to find an HMAC key of this data-plane which validates against the request token. @@ -105,83 +197,279 @@ async fn do_authorize_task(app: &App, Request { token }: &Request) -> anyhow::Re anyhow::bail!("no data-plane keys validated against the token signature"); } - // Query for a task => collection pair and their RBAC authorization. - let (task_name, collection_name, collection_data_plane_id, mut authorized) = - agent_sql::data_plane::verify_task_authorization( - &app.pg_pool, - task_shard, - journal_name_or_prefix, - required_role, - ) - .await? - .unwrap_or(( - String::new(), - models::Collection::default(), - models::Id::zero(), - false, - )); + // Map a required `name` journal label selector into its collection. + let Some(collection) = collections + .binary_search_by(|collection| { + if journal_name_or_prefix.starts_with(&collection.journal_template_name) { + std::cmp::Ordering::Equal + } else { + collection + .journal_template_name + .as_str() + .cmp(journal_name_or_prefix) + } + }) + .ok() + .map(|index| &collections[index]) + else { + anyhow::bail!("journal name or prefix {journal_name_or_prefix} is not known"); + }; + + let Some(collection_data_plane) = data_planes.get_by_key(&collection.data_plane_id) else { + anyhow::bail!( + "collection data-plane {} not found", + collection.data_plane_id + ); + }; + + let ops_kind = match task.spec_type { + models::CatalogType::Capture => "capture", + models::CatalogType::Collection => "derivation", + models::CatalogType::Materialization => "materialization", + models::CatalogType::Test => "test", + }; // As a special case outside of the RBAC system, allow a task to write // to its designated partition within its ops collections. - if !authorized - && required_role == "write" - && (collection_name == task_data_plane.ops_logs_name - || collection_name == task_data_plane.ops_stats_name) + if required_role == models::Capability::Write + && (collection.collection_name == task_data_plane.ops_logs_name + || collection.collection_name == task_data_plane.ops_stats_name) && journal_name_or_prefix.ends_with(&format!( - "/kind={task_type}/name={}/pivot=00", - labels::percent_encoding(&task_name).to_string(), + "/kind={ops_kind}/name={}/pivot=00", + labels::percent_encoding(&task.task_name).to_string(), )) { - authorized = true; - } - - if !authorized { + // Authorized write into designated ops partition. + } else if tables::RoleGrant::is_authorized( + role_grants, + &task.task_name, + &collection.collection_name, + required_role, + ) { + // Authorized access through RBAC. + } else { let ops_suffix = format!( - "/kind={task_type}/name={}/pivot=00", - labels::percent_encoding(&task_name).to_string(), + "/kind={ops_kind}/name={}/pivot=00", + labels::percent_encoding(&task.task_name).to_string(), ); tracing::warn!( - %task_type, - %task_shard, + %task.spec_type, + %shard_id, %journal_name_or_prefix, - required_role, + ?required_role, ops_logs=%task_data_plane.ops_logs_name, ops_stats=%task_data_plane.ops_stats_name, %ops_suffix, "task authorization rejection context" ); - anyhow::bail!("task shard {task_shard} is not authorized to {journal_name_or_prefix} for {required_role}"); + anyhow::bail!( + "task shard {shard_id} is not authorized to {journal_name_or_prefix} for {required_role:?}" + ); } - // We've now completed AuthN and AuthZ checks and can proceed. - - // TODO(johnny): We can avoid a DB query in the common case, - // where a task and collection data-plane are the same. - // I'm not doing this yet to keep the code path simpler while we're testing. - - let collection_data_plane = agent_sql::data_plane::fetch_data_planes( - &app.pg_pool, - vec![collection_data_plane_id], - "", // No default name to retrieve. - uuid::Uuid::nil(), - ) - .await? - .pop() - .context("collection data-plane does not exist")?; - - claims.iss = collection_data_plane.data_plane_fqdn; - claims.iat = jsonwebtoken::get_current_timestamp(); - claims.exp = claims.iat + 3_600; // One hour. let Some(encoding_key) = collection_data_plane.hmac_keys.first() else { anyhow::bail!( - "collection data-plane {collection_data_plane_id} has no configured HMAC keys" + "collection data-plane {} has no configured HMAC keys", + collection_data_plane.data_plane_name ); }; let encoding_key = jsonwebtoken::EncodingKey::from_base64_secret(&encoding_key)?; - let token = jsonwebtoken::encode(&header, &claims, &encoding_key)?; - Ok(Response { - broker_address: collection_data_plane.broker_address, - token, + Ok(( + encoding_key, + collection_data_plane.data_plane_fqdn.clone(), + collection_data_plane.broker_address.clone(), + )) +} + +// Snapshot is a point-in-time view of control-plane state +// that influences authorization decisions. +pub struct Snapshot { + // Time immediately before the snapshot was taken. + taken: std::time::SystemTime, + // Platform collections, indexed on `journal_template_name`. + collections: Vec, + // Platform data-planes. + data_planes: tables::DataPlanes, + // Platform role grants. + role_grants: tables::RoleGrants, + // Platform tasks, indexed on `shard_template_id`. + tasks: Vec, + // `refresh` is take()-en when the current snapshot should be refreshed. + refresh_tx: Option>, +} + +// SnapshotCollection is the state of a live collection which influences authorization. +// It's indexed on `journal_template_name`. +struct SnapshotCollection { + journal_template_name: String, + collection_name: models::Collection, + data_plane_id: models::Id, +} +// SnapshotTask is the state of a live task which influences authorization. +// It's indexed on `shard_template_id`. +struct SnapshotTask { + shard_template_id: String, + task_name: models::Name, + spec_type: models::CatalogType, + data_plane_id: models::Id, +} + +pub fn seed_snapshot() -> (Snapshot, futures::channel::oneshot::Receiver<()>) { + let (next_tx, next_rx) = futures::channel::oneshot::channel(); + + ( + Snapshot { + taken: std::time::SystemTime::UNIX_EPOCH, + collections: Vec::new(), + data_planes: tables::DataPlanes::default(), + role_grants: tables::RoleGrants::default(), + tasks: Vec::new(), + refresh_tx: Some(next_tx), + }, + next_rx, + ) +} + +pub async fn snapshot_loop(app: Arc, mut refresh_rx: futures::channel::oneshot::Receiver<()>) { + while let Ok(()) = refresh_rx.await { + let (next_tx, next_rx) = futures::channel::oneshot::channel(); + refresh_rx = next_rx; + + match try_fetch_snapshot(&app.pg_pool).await { + Ok(mut snapshot) => { + snapshot.refresh_tx = Some(next_tx); + *app.snapshot.write().unwrap() = snapshot; + } + Err(err) => { + tracing::error!(?err, "failed to fetch snapshot (will retry)"); + () = tokio::time::sleep(MIN_SNAPSHOT_INTERVAL).await; + _ = next_tx.send(()); // Wake ourselves to retry. + } + }; + } +} + +async fn try_fetch_snapshot(pg_pool: &sqlx::PgPool) -> anyhow::Result { + tracing::info!("started to fetch authorization snapshot"); + let taken = std::time::SystemTime::now(); + + let mut collections = sqlx::query_as!( + SnapshotCollection, + r#" + select + journal_template_name as "journal_template_name!", + catalog_name as "collection_name: models::Collection", + data_plane_id as "data_plane_id: models::Id" + from live_specs + where journal_template_name is not null + "#, + ) + .fetch_all(pg_pool) + .await + .context("failed to fetch view of live collections")?; + + let data_planes = sqlx::query_as!( + tables::DataPlane, + r#" + select + id as "control_id: models::Id", + data_plane_name, + data_plane_fqdn, + false as "is_default!: bool", + hmac_keys, + broker_address, + reactor_address, + ops_logs_name as "ops_logs_name: models::Collection", + ops_stats_name as "ops_stats_name: models::Collection" + from data_planes + "#, + ) + .fetch_all(pg_pool) + .await + .context("failed to fetch data_planes")?; + + let role_grants = sqlx::query_as!( + tables::RoleGrant, + r#" + select + subject_role as "subject_role: models::Prefix", + object_role as "object_role: models::Prefix", + capability as "capability: models::Capability" + from role_grants + "#, + ) + .fetch_all(pg_pool) + .await + .context("failed to fetch role_grants")?; + + let mut tasks = sqlx::query_as!( + SnapshotTask, + r#" + select + shard_template_id as "shard_template_id!", + catalog_name as "task_name: models::Name", + spec_type as "spec_type!: models::CatalogType", + data_plane_id as "data_plane_id: models::Id" + from live_specs + where shard_template_id is not null + "#, + ) + .fetch_all(pg_pool) + .await + .context("failed to fetch view of live tasks")?; + + let data_planes = tables::DataPlanes::from_iter(data_planes); + let role_grants = tables::RoleGrants::from_iter(role_grants); + + // Shard ID and journal name templates are prefixes which are always + // extended with a slash-separated suffix. Avoid inadvertent matches + // over path component prefixes. + for task in &mut tasks { + task.shard_template_id.push('/'); + } + for collection in &mut collections { + collection.journal_template_name.push('/'); + } + + tasks.sort_by(|t1, t2| t1.shard_template_id.cmp(&t2.shard_template_id)); + collections.sort_by(|c1, c2| c1.journal_template_name.cmp(&c2.journal_template_name)); + + tracing::info!( + collections = collections.len(), + data_planes = data_planes.len(), + role_grants = role_grants.len(), + tasks = tasks.len(), + "fetched authorization snapshot", + ); + + Ok(Snapshot { + taken, + collections, + data_planes, + role_grants, + tasks, + refresh_tx: None, }) } + +fn begin_refresh<'m>( + guard: std::sync::RwLockReadGuard<'_, Snapshot>, + mu: &'m std::sync::RwLock, +) { + // We must release our read-lock before we can acquire a write lock. + std::mem::drop(guard); + + if let Some(tx) = mu.write().unwrap().refresh_tx.take() { + () = tx.send(()).unwrap(); // Begin a refresh. + } +} + +fn jitter() -> u64 { + use rand::Rng; + let mut rng = rand::thread_rng(); + rng.gen_range(0..=2_000) +} + +const MIN_SNAPSHOT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10); +const MAX_SNAPSHOT_INTERVAL: std::time::Duration = std::time::Duration::from_secs(300); // 5 minutes. diff --git a/crates/agent/src/api/mod.rs b/crates/agent/src/api/mod.rs index a314aeacc7..0aa0b1b67f 100644 --- a/crates/agent/src/api/mod.rs +++ b/crates/agent/src/api/mod.rs @@ -36,6 +36,7 @@ struct App { jwt_validation: jsonwebtoken::Validation, pg_pool: sqlx::PgPool, publisher: crate::publications::Publisher, + snapshot: std::sync::RwLock, } /// Build the agent's API router. @@ -50,13 +51,17 @@ pub fn build_router( let mut jwt_validation = jsonwebtoken::Validation::default(); jwt_validation.set_audience(&["authenticated"]); + let (snapshot, seed_rx) = authorize::seed_snapshot(); + let app = Arc::new(App { id_generator: Mutex::new(id_generator), jwt_secret, jwt_validation, pg_pool, publisher, + snapshot: std::sync::RwLock::new(snapshot), }); + tokio::spawn(authorize::snapshot_loop(app.clone(), seed_rx)); use axum::routing::post; diff --git a/go/runtime/authorizer.go b/go/runtime/authorizer.go index c1690773a0..0729cf9dca 100644 --- a/go/runtime/authorizer.go +++ b/go/runtime/authorizer.go @@ -170,30 +170,40 @@ func doAuthFetch(controlAPI pb.Endpoint, claims pb.Claims, key jwt.VerificationK } token = `{"token":"` + token + `"}` - // Invoke the authorization API. + var brokerAddress pb.Endpoint var url = controlAPI.URL() url.Path = path.Join(url.Path, "/authorize/task") - httpResp, err := http.Post(url.String(), "application/json", strings.NewReader(token)) - if err != nil { - return "", "", time.Time{}, fmt.Errorf("failed to POST to authorization API: %w", err) - } - respBody, err := io.ReadAll(httpResp.Body) - if err != nil { - return "", "", time.Time{}, fmt.Errorf("failed to read authorization API response: %w", err) - } - if httpResp.StatusCode != 200 { - return "", "", time.Time{}, fmt.Errorf("authorization failed (%s): %s %s", httpResp.Status, string(respBody), token) - } + // Invoke the authorization API, perhaps multiple times if asked to retry. + for { + httpResp, err := http.Post(url.String(), "application/json", strings.NewReader(token)) + if err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to POST to authorization API: %w", err) + } + respBody, err := io.ReadAll(httpResp.Body) + if err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to read authorization API response: %w", err) + } + if httpResp.StatusCode != 200 { + return "", "", time.Time{}, fmt.Errorf("authorization failed (%s): %s %s", httpResp.Status, string(respBody), token) + } - var response struct { - Token string - BrokerAddress pb.Endpoint - } - if err = json.Unmarshal(respBody, &response); err != nil { - return "", "", time.Time{}, fmt.Errorf("failed to decode authorization response: %w", err) + var response struct { + Token string + BrokerAddress pb.Endpoint + RetryMillis uint64 + } + if err = json.Unmarshal(respBody, &response); err != nil { + return "", "", time.Time{}, fmt.Errorf("failed to decode authorization response: %w", err) + } + + if response.RetryMillis != 0 { + time.Sleep(time.Millisecond * time.Duration(response.RetryMillis)) + } else { + token, brokerAddress = response.Token, response.BrokerAddress + break + } } - token = response.Token claims = pb.Claims{} if _, _, err = jwt.NewParser().ParseUnverified(token, &claims); err != nil { @@ -206,7 +216,7 @@ func doAuthFetch(controlAPI pb.Endpoint, claims pb.Claims, key jwt.VerificationK return "", "", time.Time{}, fmt.Errorf("authorization server did not include an expires-at claim") } - return token, response.BrokerAddress, claims.ExpiresAt.Time, nil + return token, brokerAddress, claims.ExpiresAt.Time, nil } var _ pb.Authorizer = &ControlPlaneAuthorizer{} diff --git a/supabase/migrations/60_extract_templates.sql b/supabase/migrations/60_extract_templates.sql new file mode 100644 index 0000000000..d45203743f --- /dev/null +++ b/supabase/migrations/60_extract_templates.sql @@ -0,0 +1,14 @@ +begin; + +alter table live_specs +add column journal_template_name text +generated always as (built_spec->'partitionTemplate'->>'name') stored; + +alter table live_specs +add column shard_template_id text +generated always as (coalesce( + built_spec->'shardTemplate'->>'id', + built_spec->'derivation'->'shardTemplate'->>'id' +)) stored; + +commit; \ No newline at end of file