From 9d8b9018eba0338ec2abd95837132db980c799ed Mon Sep 17 00:00:00 2001 From: Tim Diekmann <21277928+TimDiekmann@users.noreply.github.com> Date: Fri, 12 Jan 2024 14:52:07 +0100 Subject: [PATCH] H-1841: Trigger embedding creation workflow when entity was created/updated (#3862) --- Cargo.lock | 1 + .../src/activities/embeddings.ts | 26 +++--- apps/hash-graph/libs/api/src/rest/entity.rs | 5 + apps/hash-graph/libs/graph/Cargo.toml | 1 + .../libs/graph/src/store/fetcher.rs | 5 + .../libs/graph/src/store/knowledge.rs | 3 + .../store/postgres/knowledge/entity/mod.rs | 93 +++++++++++++------ libs/@local/temporal-client/src/ai.rs | 64 +++++++++++++ libs/@local/temporal-client/src/error.rs | 4 + libs/@local/temporal-client/src/lib.rs | 10 +- tests/hash-graph-integration/postgres/lib.rs | 4 + 11 files changed, 173 insertions(+), 43 deletions(-) create mode 100644 libs/@local/temporal-client/src/ai.rs diff --git a/Cargo.lock b/Cargo.lock index 65a6f1822f5..b772e2c11c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1078,6 +1078,7 @@ dependencies = [ "serde", "serde_json", "tarpc", + "temporal-client 0.0.0", "temporal-versioning", "time", "tokio", diff --git a/apps/hash-ai-worker-ts/src/activities/embeddings.ts b/apps/hash-ai-worker-ts/src/activities/embeddings.ts index 7f813dd66a4..3380eeaad29 100644 --- a/apps/hash-ai-worker-ts/src/activities/embeddings.ts +++ b/apps/hash-ai-worker-ts/src/activities/embeddings.ts @@ -28,16 +28,6 @@ export const createEmbeddings = async (params: { embeddings: { property?: BaseUrl; embedding: number[] }[]; usage: Usage; }> => { - if (params.propertyTypes.length === 0) { - return { - embeddings: [], - usage: { - prompt_tokens: 0, - total_tokens: 0, - }, - }; - } - // sort property types by their base url params.propertyTypes.sort((a, b) => a.metadata.recordId.baseUrl.localeCompare(b.metadata.recordId.baseUrl), @@ -48,7 +38,8 @@ export const createEmbeddings = async (params: { // 2. A list of all property key:value pairs // // We use the last item in the array to store the combined 'all properties' list. - const propertyEmbeddings = []; + const propertyEmbeddings: string[] = []; + const usedPropertyTypes: PropertyTypeWithMetadata[] = []; let combinedEntityEmbedding = ""; for (const propertyType of params.propertyTypes) { const property = @@ -61,6 +52,17 @@ export const createEmbeddings = async (params: { const embeddingInput = createEmbeddingInput({ propertyType, property }); combinedEntityEmbedding += `${embeddingInput}\n`; propertyEmbeddings.push(embeddingInput); + usedPropertyTypes.push(propertyType); + } + + if (usedPropertyTypes.length === 0) { + return { + embeddings: [], + usage: { + prompt_tokens: 0, + total_tokens: 0, + }, + }; } const response = await openai.embeddings.create({ @@ -71,7 +73,7 @@ export const createEmbeddings = async (params: { return { usage: response.usage, embeddings: response.data.map((data, idx) => ({ - property: params.propertyTypes[idx]?.metadata.recordId.baseUrl, + property: usedPropertyTypes[idx]?.metadata.recordId.baseUrl, embedding: data.embedding, })), }; diff --git a/apps/hash-graph/libs/api/src/rest/entity.rs b/apps/hash-graph/libs/api/src/rest/entity.rs index 2d02c76a000..03a10f57d8e 100644 --- a/apps/hash-graph/libs/api/src/rest/entity.rs +++ b/apps/hash-graph/libs/api/src/rest/entity.rs @@ -46,6 +46,7 @@ use graph_types::{ Embedding, }; use serde::Deserialize; +use temporal_client::TemporalClient; use temporal_versioning::{DecisionTime, Timestamp, TransactionTime}; use type_system::url::VersionedUrl; use utoipa::{OpenApi, ToSchema}; @@ -201,6 +202,7 @@ async fn create_entity( AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, store_pool: Extension>, authorization_api_pool: Extension>, + temporal_client: Extension>>, body: Json, ) -> Result, Response> where @@ -227,6 +229,7 @@ where .create_entity( actor_id, &mut authorization_api, + temporal_client.as_deref(), owned_by_id, entity_uuid, None, @@ -469,6 +472,7 @@ async fn update_entity( AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader, store_pool: Extension>, authorization_api_pool: Extension>, + temporal_client: Extension>>, body: Json, ) -> Result, Response> where @@ -494,6 +498,7 @@ where .update_entity( actor_id, &mut authorization_api, + temporal_client.as_deref(), entity_id, None, archived, diff --git a/apps/hash-graph/libs/graph/Cargo.toml b/apps/hash-graph/libs/graph/Cargo.toml index 0b1bacf486f..a316e31ba65 100644 --- a/apps/hash-graph/libs/graph/Cargo.toml +++ b/apps/hash-graph/libs/graph/Cargo.toml @@ -9,6 +9,7 @@ description = "HASH Graph API" graph-types = { workspace = true, features = ["postgres", "utoipa"] } validation = { workspace = true } temporal-versioning = { workspace = true, features = ["postgres", "utoipa"] } +temporal-client = { workspace = true } type-fetcher = { workspace = true } authorization = { workspace = true, features = ["utoipa"] } codec = { workspace = true } diff --git a/apps/hash-graph/libs/graph/src/store/fetcher.rs b/apps/hash-graph/libs/graph/src/store/fetcher.rs index 88a19235d56..10a026fdf81 100644 --- a/apps/hash-graph/libs/graph/src/store/fetcher.rs +++ b/apps/hash-graph/libs/graph/src/store/fetcher.rs @@ -27,6 +27,7 @@ use graph_types::{ owned_by_id::OwnedById, }; use tarpc::context; +use temporal_client::TemporalClient; use temporal_versioning::{DecisionTime, Timestamp, TransactionTime}; use tokio::net::ToSocketAddrs; use tokio_serde::formats::Json; @@ -1065,6 +1066,7 @@ where &mut self, actor_id: AccountId, authorization_api: &mut Au, + temporal_client: Option<&TemporalClient>, owned_by_id: OwnedById, entity_uuid: Option, decision_time: Option>, @@ -1090,6 +1092,7 @@ where .create_entity( actor_id, authorization_api, + temporal_client, owned_by_id, entity_uuid, decision_time, @@ -1175,6 +1178,7 @@ where &mut self, actor_id: AccountId, authorization_api: &mut Au, + temporal_client: Option<&TemporalClient>, entity_id: EntityId, decision_time: Option>, archived: bool, @@ -1199,6 +1203,7 @@ where .update_entity( actor_id, authorization_api, + temporal_client, entity_id, decision_time, archived, diff --git a/apps/hash-graph/libs/graph/src/store/knowledge.rs b/apps/hash-graph/libs/graph/src/store/knowledge.rs index 751b8e400cf..9f8684629ff 100644 --- a/apps/hash-graph/libs/graph/src/store/knowledge.rs +++ b/apps/hash-graph/libs/graph/src/store/knowledge.rs @@ -11,6 +11,7 @@ use graph_types::{ }, owned_by_id::OwnedById, }; +use temporal_client::TemporalClient; use temporal_versioning::{DecisionTime, Timestamp, TransactionTime}; use type_system::{url::VersionedUrl, EntityType}; use validation::ValidationProfile; @@ -62,6 +63,7 @@ pub trait EntityStore: crud::Read { &mut self, actor_id: AccountId, authorization_api: &mut A, + temporal_client: Option<&TemporalClient>, owned_by_id: OwnedById, entity_uuid: Option, decision_time: Option>, @@ -165,6 +167,7 @@ pub trait EntityStore: crud::Read { &mut self, actor_id: AccountId, authorization_api: &mut A, + temporal_client: Option<&TemporalClient>, entity_id: EntityId, decision_time: Option>, archived: bool, diff --git a/apps/hash-graph/libs/graph/src/store/postgres/knowledge/entity/mod.rs b/apps/hash-graph/libs/graph/src/store/postgres/knowledge/entity/mod.rs index e2721fe10cd..0425b38acdb 100644 --- a/apps/hash-graph/libs/graph/src/store/postgres/knowledge/entity/mod.rs +++ b/apps/hash-graph/libs/graph/src/store/postgres/knowledge/entity/mod.rs @@ -31,6 +31,7 @@ use graph_types::{ }; use hash_status::StatusCode; use postgres_types::{Json, ToSql}; +use temporal_client::TemporalClient; use temporal_versioning::{DecisionTime, RightBoundedTemporalInterval, Timestamp, TransactionTime}; use tokio_postgres::GenericClient; use type_system::{url::VersionedUrl, EntityType}; @@ -298,6 +299,7 @@ impl EntityStore for PostgresStore { &mut self, actor_id: AccountId, authorization_api: &mut A, + temporal_client: Option<&TemporalClient>, owned_by_id: OwnedById, entity_uuid: Option, decision_time: Option>, @@ -569,7 +571,7 @@ impl EntityStore for PostgresStore { } else { let decision_time = row.get(0); let transaction_time = row.get(1); - Ok(EntityMetadata { + let entity_metadata = EntityMetadata { record_id: EntityRecordId { entity_id, edition_id, @@ -589,7 +591,21 @@ impl EntityStore for PostgresStore { }, archived, draft, - }) + }; + if let Some(temporal_client) = temporal_client { + temporal_client + .start_update_entity_embeddings_workflow( + actor_id, + Entity { + properties, + link_data, + metadata: entity_metadata.clone(), + }, + ) + .await + .change_context(InsertionError)?; + } + Ok(entity_metadata) } } @@ -968,6 +984,7 @@ impl EntityStore for PostgresStore { &mut self, actor_id: AccountId, authorization_api: &mut A, + temporal_client: Option<&TemporalClient>, entity_id: EntityId, decision_time: Option>, archived: bool, @@ -1126,7 +1143,7 @@ impl EntityStore for PostgresStore { transaction.commit().await.change_context(UpdateError)?; - Ok(EntityMetadata { + let entity_metadata = EntityMetadata { record_id: EntityRecordId { entity_id, edition_id, @@ -1152,13 +1169,33 @@ impl EntityStore for PostgresStore { }, archived, draft, - }) + }; + if let Some(temporal_client) = temporal_client { + temporal_client + .start_update_entity_embeddings_workflow( + actor_id, + Entity { + properties, + link_data: previous_entity + .link_data + .map(|previous_link_data| LinkData { + left_entity_id: previous_link_data.left_entity_id, + right_entity_id: previous_link_data.right_entity_id, + order: link_order, + }), + metadata: entity_metadata.clone(), + }, + ) + .await + .change_context(UpdateError)?; + } + Ok(entity_metadata) } async fn update_entity_embeddings( &mut self, - actor_id: AccountId, - authorization_api: &mut A, + _actor_id: AccountId, + _authorization_api: &mut A, embeddings: impl IntoIterator> + Send, updated_at_transaction_time: Timestamp, updated_at_decision_time: Timestamp, @@ -1190,26 +1227,30 @@ impl EntityStore for PostgresStore { ) }) .unzip(); - let permissions = authorization_api - .check_entities_permission( - actor_id, - EntityPermission::Update, - entity_ids.iter().copied(), - Consistency::FullyConsistent, - ) - .await - .change_context(UpdateError)? - .0 - .into_iter() - .filter_map(|(entity_id, has_permission)| (!has_permission).then_some(entity_id)) - .collect::>(); - if !permissions.is_empty() { - let mut status = Report::new(PermissionAssertion); - for entity_id in permissions { - status = status.attach(format!("Permission denied for entity {entity_id}")); - } - return Err(status.change_context(UpdateError)); - } + + // TODO: Add permission to allow updating embeddings + // see https://linear.app/hash/issue/H-1870 + // let permissions = authorization_api + // .check_entities_permission( + // actor_id, + // EntityPermission::UpdateEmbeddings, + // entity_ids.iter().copied(), + // Consistency::FullyConsistent, + // ) + // .await + // .change_context(UpdateError)? + // .0 + // .into_iter() + // .filter_map(|(entity_id, has_permission)| (!has_permission).then_some(entity_id)) + // .collect::>(); + // if !permissions.is_empty() { + // let mut status = Report::new(PermissionAssertion); + // for entity_id in permissions { + // status = status.attach(format!("Permission denied for entity {entity_id}")); + // } + // return Err(status.change_context(UpdateError)); + // } + if reset { let (owned_by_id, entity_uuids): (Vec<_>, Vec<_>) = entity_ids .into_iter() diff --git a/libs/@local/temporal-client/src/ai.rs b/libs/@local/temporal-client/src/ai.rs new file mode 100644 index 00000000000..e22c44b706f --- /dev/null +++ b/libs/@local/temporal-client/src/ai.rs @@ -0,0 +1,64 @@ +use std::collections::HashMap; + +use error_stack::{Report, ResultExt}; +use graph_types::{account::AccountId, knowledge::entity::Entity}; +use serde::Serialize; +use temporal_io_client::{WorkflowClientTrait, WorkflowOptions}; +use temporal_io_sdk_core_protos::{ + temporal::api::common::v1::Payload, ENCODING_PAYLOAD_KEY, JSON_ENCODING_VAL, +}; +use uuid::Uuid; + +use crate::{TemporalClient, WorkflowError}; + +impl TemporalClient { + /// Starts a workflow to update the embeddings for the provided entity. + /// + /// Returns the run ID of the workflow. + /// + /// # Errors + /// + /// Returns an error if the workflow fails to start. + pub async fn start_update_entity_embeddings_workflow( + &self, + actor_id: AccountId, + entity: Entity, + ) -> Result> { + #[derive(Serialize)] + #[serde(rename_all = "camelCase")] + struct AuthenticationContext { + actor_id: AccountId, + } + + #[derive(Serialize)] + #[serde(rename_all = "camelCase")] + struct UpdateEntityEmbeddingsParams { + authentication: AuthenticationContext, + entity: Entity, + } + + Ok(self + .client + .start_workflow( + vec![Payload { + metadata: HashMap::from([( + ENCODING_PAYLOAD_KEY.to_owned(), + JSON_ENCODING_VAL.as_bytes().to_vec(), + )]), + data: serde_json::to_vec(&UpdateEntityEmbeddingsParams { + authentication: AuthenticationContext { actor_id }, + entity, + }) + .change_context(WorkflowError("updateEntityEmbeddings"))?, + }], + "ai".to_owned(), + Uuid::new_v4().to_string(), + "updateEntityEmbeddings".to_owned(), + None, + WorkflowOptions::default(), + ) + .await + .change_context(WorkflowError("updateEntityEmbeddings"))? + .run_id) + } +} diff --git a/libs/@local/temporal-client/src/error.rs b/libs/@local/temporal-client/src/error.rs index 3eae4564e2d..90987facb8f 100644 --- a/libs/@local/temporal-client/src/error.rs +++ b/libs/@local/temporal-client/src/error.rs @@ -7,3 +7,7 @@ pub struct ConfigError; #[derive(Debug, Error)] #[error("Could not connect to Temporal.io Server")] pub struct ConnectionError; + +#[derive(Debug, Error)] +#[error("Workflow execution of job {0} failed")] +pub struct WorkflowError(pub &'static str); diff --git a/libs/@local/temporal-client/src/lib.rs b/libs/@local/temporal-client/src/lib.rs index ff85534c481..7c6b71312d2 100644 --- a/libs/@local/temporal-client/src/lib.rs +++ b/libs/@local/temporal-client/src/lib.rs @@ -1,6 +1,9 @@ -#![feature(lint_reasons, impl_trait_in_assoc_type)] +#![feature(impl_trait_in_assoc_type)] -pub mod error; +pub use self::error::{ConfigError, ConnectionError, WorkflowError}; + +mod ai; +mod error; use std::future::{Future, IntoFuture}; @@ -8,11 +11,8 @@ use error_stack::{Report, ResultExt}; use temporal_io_client::{Client, ClientOptions, ClientOptionsBuilder, RetryClient}; use url::Url; -use self::error::{ConfigError, ConnectionError}; - #[derive(Debug)] pub struct TemporalClient { - #[expect(dead_code)] client: RetryClient, } diff --git a/tests/hash-graph-integration/postgres/lib.rs b/tests/hash-graph-integration/postgres/lib.rs index 0bff9823747..d19b62f575f 100644 --- a/tests/hash-graph-integration/postgres/lib.rs +++ b/tests/hash-graph-integration/postgres/lib.rs @@ -514,6 +514,7 @@ impl DatabaseApi<'_> { .create_entity( self.account_id, &mut NoAuthorization, + None, OwnedById::new(self.account_id.into_uuid()), entity_uuid, Some(generate_decision_time()), @@ -627,6 +628,7 @@ impl DatabaseApi<'_> { .update_entity( self.account_id, &mut NoAuthorization, + None, entity_id, Some(generate_decision_time()), false, @@ -650,6 +652,7 @@ impl DatabaseApi<'_> { .create_entity( self.account_id, &mut NoAuthorization, + None, OwnedById::new(self.account_id.into_uuid()), entity_uuid, None, @@ -830,6 +833,7 @@ impl DatabaseApi<'_> { .update_entity( self.account_id, &mut NoAuthorization, + None, entity_id, None, true,