Skip to content

Commit

Permalink
H-1841: Trigger embedding creation workflow when entity was created/u…
Browse files Browse the repository at this point in the history
…pdated (#3862)
  • Loading branch information
TimDiekmann authored Jan 12, 2024
1 parent 197422f commit 9d8b901
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 43 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 14 additions & 12 deletions apps/hash-ai-worker-ts/src/activities/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,6 @@ export const createEmbeddings = async (params: {
embeddings: { property?: BaseUrl; embedding: number[] }[];
usage: Usage;
}> => {
if (params.propertyTypes.length === 0) {
return {
embeddings: [],
usage: {
prompt_tokens: 0,
total_tokens: 0,
},
};
}

// sort property types by their base url
params.propertyTypes.sort((a, b) =>
a.metadata.recordId.baseUrl.localeCompare(b.metadata.recordId.baseUrl),
Expand All @@ -48,7 +38,8 @@ export const createEmbeddings = async (params: {
// 2. A list of all property key:value pairs
//
// We use the last item in the array to store the combined 'all properties' list.
const propertyEmbeddings = [];
const propertyEmbeddings: string[] = [];
const usedPropertyTypes: PropertyTypeWithMetadata[] = [];
let combinedEntityEmbedding = "";
for (const propertyType of params.propertyTypes) {
const property =
Expand All @@ -61,6 +52,17 @@ export const createEmbeddings = async (params: {
const embeddingInput = createEmbeddingInput({ propertyType, property });
combinedEntityEmbedding += `${embeddingInput}\n`;
propertyEmbeddings.push(embeddingInput);
usedPropertyTypes.push(propertyType);
}

if (usedPropertyTypes.length === 0) {
return {
embeddings: [],
usage: {
prompt_tokens: 0,
total_tokens: 0,
},
};
}

const response = await openai.embeddings.create({
Expand All @@ -71,7 +73,7 @@ export const createEmbeddings = async (params: {
return {
usage: response.usage,
embeddings: response.data.map((data, idx) => ({
property: params.propertyTypes[idx]?.metadata.recordId.baseUrl,
property: usedPropertyTypes[idx]?.metadata.recordId.baseUrl,
embedding: data.embedding,
})),
};
Expand Down
5 changes: 5 additions & 0 deletions apps/hash-graph/libs/api/src/rest/entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use graph_types::{
Embedding,
};
use serde::Deserialize;
use temporal_client::TemporalClient;
use temporal_versioning::{DecisionTime, Timestamp, TransactionTime};
use type_system::url::VersionedUrl;
use utoipa::{OpenApi, ToSchema};
Expand Down Expand Up @@ -201,6 +202,7 @@ async fn create_entity<S, A>(
AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader,
store_pool: Extension<Arc<S>>,
authorization_api_pool: Extension<Arc<A>>,
temporal_client: Extension<Option<Arc<TemporalClient>>>,
body: Json<CreateEntityRequest>,
) -> Result<Json<EntityMetadata>, Response>
where
Expand All @@ -227,6 +229,7 @@ where
.create_entity(
actor_id,
&mut authorization_api,
temporal_client.as_deref(),
owned_by_id,
entity_uuid,
None,
Expand Down Expand Up @@ -469,6 +472,7 @@ async fn update_entity<S, A>(
AuthenticatedUserHeader(actor_id): AuthenticatedUserHeader,
store_pool: Extension<Arc<S>>,
authorization_api_pool: Extension<Arc<A>>,
temporal_client: Extension<Option<Arc<TemporalClient>>>,
body: Json<UpdateEntityRequest>,
) -> Result<Json<EntityMetadata>, Response>
where
Expand All @@ -494,6 +498,7 @@ where
.update_entity(
actor_id,
&mut authorization_api,
temporal_client.as_deref(),
entity_id,
None,
archived,
Expand Down
1 change: 1 addition & 0 deletions apps/hash-graph/libs/graph/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ description = "HASH Graph API"
graph-types = { workspace = true, features = ["postgres", "utoipa"] }
validation = { workspace = true }
temporal-versioning = { workspace = true, features = ["postgres", "utoipa"] }
temporal-client = { workspace = true }
type-fetcher = { workspace = true }
authorization = { workspace = true, features = ["utoipa"] }
codec = { workspace = true }
Expand Down
5 changes: 5 additions & 0 deletions apps/hash-graph/libs/graph/src/store/fetcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use graph_types::{
owned_by_id::OwnedById,
};
use tarpc::context;
use temporal_client::TemporalClient;
use temporal_versioning::{DecisionTime, Timestamp, TransactionTime};
use tokio::net::ToSocketAddrs;
use tokio_serde::formats::Json;
Expand Down Expand Up @@ -1065,6 +1066,7 @@ where
&mut self,
actor_id: AccountId,
authorization_api: &mut Au,
temporal_client: Option<&TemporalClient>,
owned_by_id: OwnedById,
entity_uuid: Option<EntityUuid>,
decision_time: Option<Timestamp<DecisionTime>>,
Expand All @@ -1090,6 +1092,7 @@ where
.create_entity(
actor_id,
authorization_api,
temporal_client,
owned_by_id,
entity_uuid,
decision_time,
Expand Down Expand Up @@ -1175,6 +1178,7 @@ where
&mut self,
actor_id: AccountId,
authorization_api: &mut Au,
temporal_client: Option<&TemporalClient>,
entity_id: EntityId,
decision_time: Option<Timestamp<DecisionTime>>,
archived: bool,
Expand All @@ -1199,6 +1203,7 @@ where
.update_entity(
actor_id,
authorization_api,
temporal_client,
entity_id,
decision_time,
archived,
Expand Down
3 changes: 3 additions & 0 deletions apps/hash-graph/libs/graph/src/store/knowledge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use graph_types::{
},
owned_by_id::OwnedById,
};
use temporal_client::TemporalClient;
use temporal_versioning::{DecisionTime, Timestamp, TransactionTime};
use type_system::{url::VersionedUrl, EntityType};
use validation::ValidationProfile;
Expand Down Expand Up @@ -62,6 +63,7 @@ pub trait EntityStore: crud::Read<Entity> {
&mut self,
actor_id: AccountId,
authorization_api: &mut A,
temporal_client: Option<&TemporalClient>,
owned_by_id: OwnedById,
entity_uuid: Option<EntityUuid>,
decision_time: Option<Timestamp<DecisionTime>>,
Expand Down Expand Up @@ -165,6 +167,7 @@ pub trait EntityStore: crud::Read<Entity> {
&mut self,
actor_id: AccountId,
authorization_api: &mut A,
temporal_client: Option<&TemporalClient>,
entity_id: EntityId,
decision_time: Option<Timestamp<DecisionTime>>,
archived: bool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use graph_types::{
};
use hash_status::StatusCode;
use postgres_types::{Json, ToSql};
use temporal_client::TemporalClient;
use temporal_versioning::{DecisionTime, RightBoundedTemporalInterval, Timestamp, TransactionTime};
use tokio_postgres::GenericClient;
use type_system::{url::VersionedUrl, EntityType};
Expand Down Expand Up @@ -298,6 +299,7 @@ impl<C: AsClient> EntityStore for PostgresStore<C> {
&mut self,
actor_id: AccountId,
authorization_api: &mut A,
temporal_client: Option<&TemporalClient>,
owned_by_id: OwnedById,
entity_uuid: Option<EntityUuid>,
decision_time: Option<Timestamp<DecisionTime>>,
Expand Down Expand Up @@ -569,7 +571,7 @@ impl<C: AsClient> EntityStore for PostgresStore<C> {
} else {
let decision_time = row.get(0);
let transaction_time = row.get(1);
Ok(EntityMetadata {
let entity_metadata = EntityMetadata {
record_id: EntityRecordId {
entity_id,
edition_id,
Expand All @@ -589,7 +591,21 @@ impl<C: AsClient> EntityStore for PostgresStore<C> {
},
archived,
draft,
})
};
if let Some(temporal_client) = temporal_client {
temporal_client
.start_update_entity_embeddings_workflow(
actor_id,
Entity {
properties,
link_data,
metadata: entity_metadata.clone(),
},
)
.await
.change_context(InsertionError)?;
}
Ok(entity_metadata)
}
}

Expand Down Expand Up @@ -968,6 +984,7 @@ impl<C: AsClient> EntityStore for PostgresStore<C> {
&mut self,
actor_id: AccountId,
authorization_api: &mut A,
temporal_client: Option<&TemporalClient>,
entity_id: EntityId,
decision_time: Option<Timestamp<DecisionTime>>,
archived: bool,
Expand Down Expand Up @@ -1126,7 +1143,7 @@ impl<C: AsClient> EntityStore for PostgresStore<C> {

transaction.commit().await.change_context(UpdateError)?;

Ok(EntityMetadata {
let entity_metadata = EntityMetadata {
record_id: EntityRecordId {
entity_id,
edition_id,
Expand All @@ -1152,13 +1169,33 @@ impl<C: AsClient> EntityStore for PostgresStore<C> {
},
archived,
draft,
})
};
if let Some(temporal_client) = temporal_client {
temporal_client
.start_update_entity_embeddings_workflow(
actor_id,
Entity {
properties,
link_data: previous_entity
.link_data
.map(|previous_link_data| LinkData {
left_entity_id: previous_link_data.left_entity_id,
right_entity_id: previous_link_data.right_entity_id,
order: link_order,
}),
metadata: entity_metadata.clone(),
},
)
.await
.change_context(UpdateError)?;
}
Ok(entity_metadata)
}

async fn update_entity_embeddings<A: AuthorizationApi + Send + Sync>(
&mut self,
actor_id: AccountId,
authorization_api: &mut A,
_actor_id: AccountId,
_authorization_api: &mut A,
embeddings: impl IntoIterator<Item = EntityEmbedding<'_>> + Send,
updated_at_transaction_time: Timestamp<TransactionTime>,
updated_at_decision_time: Timestamp<DecisionTime>,
Expand Down Expand Up @@ -1190,26 +1227,30 @@ impl<C: AsClient> EntityStore for PostgresStore<C> {
)
})
.unzip();
let permissions = authorization_api
.check_entities_permission(
actor_id,
EntityPermission::Update,
entity_ids.iter().copied(),
Consistency::FullyConsistent,
)
.await
.change_context(UpdateError)?
.0
.into_iter()
.filter_map(|(entity_id, has_permission)| (!has_permission).then_some(entity_id))
.collect::<Vec<_>>();
if !permissions.is_empty() {
let mut status = Report::new(PermissionAssertion);
for entity_id in permissions {
status = status.attach(format!("Permission denied for entity {entity_id}"));
}
return Err(status.change_context(UpdateError));
}

// TODO: Add permission to allow updating embeddings
// see https://linear.app/hash/issue/H-1870
// let permissions = authorization_api
// .check_entities_permission(
// actor_id,
// EntityPermission::UpdateEmbeddings,
// entity_ids.iter().copied(),
// Consistency::FullyConsistent,
// )
// .await
// .change_context(UpdateError)?
// .0
// .into_iter()
// .filter_map(|(entity_id, has_permission)| (!has_permission).then_some(entity_id))
// .collect::<Vec<_>>();
// if !permissions.is_empty() {
// let mut status = Report::new(PermissionAssertion);
// for entity_id in permissions {
// status = status.attach(format!("Permission denied for entity {entity_id}"));
// }
// return Err(status.change_context(UpdateError));
// }

if reset {
let (owned_by_id, entity_uuids): (Vec<_>, Vec<_>) = entity_ids
.into_iter()
Expand Down
64 changes: 64 additions & 0 deletions libs/@local/temporal-client/src/ai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::collections::HashMap;

use error_stack::{Report, ResultExt};
use graph_types::{account::AccountId, knowledge::entity::Entity};
use serde::Serialize;
use temporal_io_client::{WorkflowClientTrait, WorkflowOptions};
use temporal_io_sdk_core_protos::{
temporal::api::common::v1::Payload, ENCODING_PAYLOAD_KEY, JSON_ENCODING_VAL,
};
use uuid::Uuid;

use crate::{TemporalClient, WorkflowError};

impl TemporalClient {
/// Starts a workflow to update the embeddings for the provided entity.
///
/// Returns the run ID of the workflow.
///
/// # Errors
///
/// Returns an error if the workflow fails to start.
pub async fn start_update_entity_embeddings_workflow(
&self,
actor_id: AccountId,
entity: Entity,
) -> Result<String, Report<WorkflowError>> {
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct AuthenticationContext {
actor_id: AccountId,
}

#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct UpdateEntityEmbeddingsParams {
authentication: AuthenticationContext,
entity: Entity,
}

Ok(self
.client
.start_workflow(
vec![Payload {
metadata: HashMap::from([(
ENCODING_PAYLOAD_KEY.to_owned(),
JSON_ENCODING_VAL.as_bytes().to_vec(),
)]),
data: serde_json::to_vec(&UpdateEntityEmbeddingsParams {
authentication: AuthenticationContext { actor_id },
entity,
})
.change_context(WorkflowError("updateEntityEmbeddings"))?,
}],
"ai".to_owned(),
Uuid::new_v4().to_string(),
"updateEntityEmbeddings".to_owned(),
None,
WorkflowOptions::default(),
)
.await
.change_context(WorkflowError("updateEntityEmbeddings"))?
.run_id)
}
}
4 changes: 4 additions & 0 deletions libs/@local/temporal-client/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ pub struct ConfigError;
#[derive(Debug, Error)]
#[error("Could not connect to Temporal.io Server")]
pub struct ConnectionError;

#[derive(Debug, Error)]
#[error("Workflow execution of job {0} failed")]
pub struct WorkflowError(pub &'static str);
Loading

0 comments on commit 9d8b901

Please sign in to comment.