Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
KeXiangWang committed Jan 25, 2025
1 parent 7c4a688 commit b975e81
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 101 deletions.
19 changes: 0 additions & 19 deletions proto/batch_plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -195,25 +195,6 @@ message InsertNode {
uint32 session_id = 7;
}

// A special insert node for non-pgwire insert, not really a batch node.
message FastInsertNode {
// Id of the table to perform inserting.
uint32 table_id = 1;
// Version of the table.
uint64 table_version_id = 2;
repeated uint32 column_indices = 3;
data.DataChunk data_chunk = 4;

// An optional field and will be `None` for tables without user-defined pk.
// The `BatchInsertExecutor` should add a column with NULL value which will
// be filled in streaming.
optional uint32 row_id_index = 5;

// Session id is used to ensure that dml data from the same session should be
// sent to a fixed worker node and channel.
uint32 session_id = 6;
}

message DeleteNode {
// Id of the table to perform deleting.
uint32 table_id = 1;
Expand Down
17 changes: 15 additions & 2 deletions proto/task_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,21 @@ message GetDataResponse {
}

message FastInsertRequest {
bool wait_for_persistence = 1;
batch_plan.FastInsertNode fast_insert_node = 2;
// Id of the table to perform inserting.
uint32 table_id = 1;
// Version of the table.
uint64 table_version_id = 2;
repeated uint32 column_indices = 3;
data.DataChunk data_chunk = 4;

// An optional field and will be `None` for tables without user-defined pk.
// The `BatchInsertExecutor` should add a column with NULL value which will
// be filled in streaming.
optional uint32 row_id_index = 5;

// Use session id to assign the insert to different fixed worker node and channel.
uint32 session_id = 6;
bool wait_for_persistence = 7;
// TODO(kexiang): add support for default columns. plan_common.ExprContext expr_context is needed for it.
}

Expand Down
16 changes: 8 additions & 8 deletions src/batch/src/executor/fast_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use risingwave_common::transaction::transaction_id::TxnId;
use risingwave_common::types::DataType;
use risingwave_common::util::epoch::{Epoch, INVALID_EPOCH};
use risingwave_dml::dml_manager::DmlManagerRef;
use risingwave_pb::batch_plan::FastInsertNode;
use risingwave_pb::task_service::FastInsertRequest;

use crate::error::Result;

Expand All @@ -42,28 +42,28 @@ pub struct FastInsertExecutor {
impl FastInsertExecutor {
pub fn build(
dml_manager: DmlManagerRef,
insert_node: FastInsertNode,
insert_req: FastInsertRequest,
) -> Result<(FastInsertExecutor, DataChunk)> {
let table_id = TableId::new(insert_node.table_id);
let column_indices = insert_node
let table_id = TableId::new(insert_req.table_id);
let column_indices = insert_req
.column_indices
.iter()
.map(|&i| i as usize)
.collect();
let mut schema = Schema::new(vec![Field::unnamed(DataType::Jsonb)]);
schema.fields.push(Field::unnamed(DataType::Serial)); // row_id column
let data_chunk_pb = insert_node
let data_chunk_pb = insert_req
.data_chunk
.expect("no data_chunk found in fast insert node");

Ok((
FastInsertExecutor::new(
table_id,
insert_node.table_version_id,
insert_req.table_version_id,
dml_manager,
column_indices,
insert_node.row_id_index.as_ref().map(|index| *index as _),
insert_node.session_id,
insert_req.row_id_index.as_ref().map(|index| *index as _),
insert_req.session_id,
),
DataChunk::from_protobuf(&data_chunk_pb)?,
))
Expand Down
18 changes: 6 additions & 12 deletions src/batch/src/rpc/service/task_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::sync::Arc;

use risingwave_common::util::tracing::TracingContext;
use risingwave_pb::batch_plan::{FastInsertNode, TaskOutputId};
use risingwave_pb::batch_plan::TaskOutputId;
use risingwave_pb::task_service::task_service_server::TaskService;
use risingwave_pb::task_service::{
fast_insert_response, CancelTaskRequest, CancelTaskResponse, CreateTaskRequest, ExecuteRequest,
Expand Down Expand Up @@ -128,10 +128,7 @@ impl TaskService for BatchServiceImpl {
request: Request<FastInsertRequest>,
) -> Result<Response<FastInsertResponse>, Status> {
let req = request.into_inner();
let insert_node = req.fast_insert_node.expect("no fast insert node found");
let res = self
.do_fast_insert(insert_node, req.wait_for_persistence)
.await;
let res = self.do_fast_insert(req).await;
match res {
Ok(_) => Ok(Response::new(FastInsertResponse {
status: fast_insert_response::Status::Succeeded.into(),
Expand Down Expand Up @@ -217,14 +214,11 @@ impl BatchServiceImpl {
Ok(Response::new(ReceiverStream::new(rx)))
}

async fn do_fast_insert(
&self,
insert_node: FastInsertNode,
wait_for_persistence: bool,
) -> Result<(), BatchError> {
let table_id = insert_node.table_id;
async fn do_fast_insert(&self, insert_req: FastInsertRequest) -> Result<(), BatchError> {
let table_id = insert_req.table_id;
let wait_for_persistence = insert_req.wait_for_persistence;
let (executor, data_chunk) =
FastInsertExecutor::build(self.env.dml_manager_ref(), insert_node)?;
FastInsertExecutor::build(self.env.dml_manager_ref(), insert_req)?;
let epoch = executor
.do_execute(data_chunk, wait_for_persistence)
.await?;
Expand Down
26 changes: 13 additions & 13 deletions src/frontend/src/scheduler/fast_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

//! Local execution for batch query.
use std::sync::Arc;

use anyhow::anyhow;
use itertools::Itertools;
use pgwire::pg_server::Session;
use risingwave_batch::error::BatchError;
use risingwave_batch::worker_manager::worker_node_manager::WorkerNodeSelector;
use risingwave_common::hash::WorkerSlotMapping;
Expand All @@ -26,15 +22,15 @@ use risingwave_rpc_client::ComputeClient;

use crate::catalog::TableId;
use crate::scheduler::{SchedulerError, SchedulerResult};
use crate::session::{FrontendEnv, SessionImpl};
use crate::session::FrontendEnv;

pub async fn choose_fast_insert_client(
table_id: &TableId,
// wait_for_persistence: bool,
session: &Arc<SessionImpl>,
frontend_env: &FrontendEnv,
session_id: i32,
) -> SchedulerResult<ComputeClient> {
let worker = choose_worker(table_id, session)?;
let client = session.env().client_pool().get(&worker).await?;
let worker = choose_worker(table_id, frontend_env, session_id)?;
let client = frontend_env.client_pool().get(&worker).await?;
Ok(client)
}

Expand All @@ -61,13 +57,17 @@ fn get_table_dml_vnode_mapping(
.map_err(|e| e.into())
}

fn choose_worker(table_id: &TableId, session: &Arc<SessionImpl>) -> SchedulerResult<WorkerNode> {
fn choose_worker(
table_id: &TableId,
frontend_env: &FrontendEnv,
session_id: i32,
) -> SchedulerResult<WorkerNode> {
let worker_node_manager =
WorkerNodeSelector::new(session.env().worker_node_manager_ref(), false);
let session_id: u32 = session.id().0 as u32;
WorkerNodeSelector::new(frontend_env.worker_node_manager_ref(), false);
let session_id: u32 = session_id as u32;

// dml should use streaming vnode mapping
let vnode_mapping = get_table_dml_vnode_mapping(table_id, session.env(), &worker_node_manager)?;
let vnode_mapping = get_table_dml_vnode_mapping(table_id, frontend_env, &worker_node_manager)?;
let worker_node = {
let worker_ids = vnode_mapping.iter_unique().collect_vec();
let candidates = worker_node_manager
Expand Down
6 changes: 5 additions & 1 deletion src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,7 @@ impl SessionManagerImpl {
};

// Assign a session id and insert into sessions map (for cancel request).
let secret_key = self.number.fetch_add(1, Ordering::Relaxed);
let secret_key = self.generate_secret_key();
// Use a trivial strategy: process_id and secret_key are equal.
let id = (secret_key, secret_key);
// Read session params snapshot from frontend env.
Expand All @@ -1484,6 +1484,10 @@ impl SessionManagerImpl {
)))
}
}

pub fn generate_secret_key(&self) -> i32 {
self.number.fetch_add(1, Ordering::Relaxed)
}
}

impl Session for SessionImpl {
Expand Down
71 changes: 25 additions & 46 deletions src/frontend/src/webhook/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::net::{IpAddr, SocketAddr};
use std::net::SocketAddr;
use std::sync::Arc;

use anyhow::{anyhow, Context};
Expand All @@ -21,12 +21,9 @@ use axum::extract::{Extension, Path};
use axum::http::{HeaderMap, Method, StatusCode};
use axum::routing::post;
use axum::Router;
use pgwire::net::Address;
use pgwire::pg_server::SessionManager;
use risingwave_common::array::{Array, ArrayBuilder, DataChunk};
use risingwave_common::secret::LocalSecretManager;
use risingwave_common::types::{DataType, JsonbVal, Scalar};
use risingwave_pb::batch_plan::FastInsertNode;
use risingwave_pb::catalog::WebhookSourceInfo;
use risingwave_pb::task_service::{FastInsertRequest, FastInsertResponse};
use tokio::net::TcpListener;
Expand All @@ -47,7 +44,7 @@ const USER: &str = "root";
#[derive(Clone)]
pub struct FastInsertContext {
pub webhook_source_info: WebhookSourceInfo,
pub fast_insert_node: FastInsertNode,
pub fast_insert_request: FastInsertRequest,
pub compute_client: ComputeClient,
}

Expand All @@ -57,58 +54,43 @@ pub struct WebhookService {
}

pub(super) mod handlers {
use std::net::Ipv4Addr;

use jsonbb::Value;
use pgwire::pg_server::Session;
use risingwave_common::array::JsonbArrayBuilder;
use risingwave_pb::batch_plan::FastInsertNode;
use risingwave_common::session_config::SearchPath;
use risingwave_pb::catalog::WebhookSourceInfo;
use risingwave_pb::task_service::fast_insert_response;
use utils::{header_map_to_json, verify_signature};

use super::*;
use crate::catalog::root_catalog::SchemaPath;
use crate::scheduler::choose_fast_insert_client;
use crate::session::{SessionImpl, SESSION_MANAGER};
use crate::session::{FrontendEnv, SESSION_MANAGER};

pub async fn handle_post_request(
Extension(_srv): Extension<Service>,
headers: HeaderMap,
Path((database, schema, table)): Path<(String, String, String)>,
body: Bytes,
) -> Result<()> {
// Can be any address, we use the port of meta to indicate that it's a internal request.
let dummy_addr = Address::Tcp(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 5691));

// FIXME(kexiang): the dummy_session can lead to memory leakage
// TODO(kexiang): optimize this
// get a session object for the corresponding database
let session_mgr = SESSION_MANAGER
.get()
.expect("session manager has been initialized");
let session = session_mgr
.connect(database.as_str(), USER, Arc::new(dummy_addr))
.map_err(|e| {
err(
anyhow!(e).context(format!(
"Failed to create session for database `{}` with user `{}`",
database, USER
)),
StatusCode::UNAUTHORIZED,
)
})?;

let frontend_env = session_mgr.env().clone();
// FIXME(kexiang): the session_id is i32, overflow is possible
let session_id = session_mgr.generate_secret_key();

let FastInsertContext {
webhook_source_info,
mut fast_insert_node,
mut fast_insert_request,
compute_client,
} = acquire_table_info(&session, &database, &schema, &table).await?;
} = acquire_table_info(&frontend_env, session_id, &database, &schema, &table).await?;

let WebhookSourceInfo {
signature_expr,
secret_ref,
wait_for_persistence,
wait_for_persistence: _,
} = webhook_source_info;

let secret_string = LocalSecretManager::global()
Expand Down Expand Up @@ -147,10 +129,9 @@ pub(super) mod handlers {
let data_chunk = DataChunk::new(vec![builder.finish().into_ref()], 1);

// fill the data_chunk
fast_insert_node.data_chunk = Some(data_chunk.to_protobuf());

fast_insert_request.data_chunk = Some(data_chunk.to_protobuf());
// execute on the compute node
let res = execute(fast_insert_node, wait_for_persistence, compute_client).await?;
let res = execute(fast_insert_request, compute_client).await?;

if res.status == fast_insert_response::Status::Succeeded as i32 {
Ok(())
Expand All @@ -163,16 +144,17 @@ pub(super) mod handlers {
}

async fn acquire_table_info(
session: &Arc<SessionImpl>,
frontend_env: &FrontendEnv,
session_id: i32,
database: &String,
schema: &String,
table: &String,
) -> Result<FastInsertContext> {
let search_path = session.config().search_path();
let search_path = SearchPath::default();
let schema_path = SchemaPath::new(Some(schema.as_str()), &search_path, USER);

let (webhook_source_info, table_id, version_id) = {
let reader = session.env().catalog_reader().read_guard();
let reader = frontend_env.catalog_reader().read_guard();
let (table_catalog, _schema) = reader
.get_any_table_by_name(database.as_str(), schema_path, table)
.map_err(|e| err(e, StatusCode::NOT_FOUND))?;
Expand All @@ -194,35 +176,32 @@ pub(super) mod handlers {
)
};

let fast_insert_node = FastInsertNode {
let fast_insert_request = FastInsertRequest {
table_id: table_id.table_id,
table_version_id: version_id,
column_indices: vec![0],
// leave the data_chunk empty for now
data_chunk: None,
row_id_index: Some(1),
session_id: session.id().0 as u32,
session_id: session_id as u32,
wait_for_persistence: webhook_source_info.wait_for_persistence,
};

let compute_client = choose_fast_insert_client(&table_id, session).await.unwrap();
let compute_client = choose_fast_insert_client(&table_id, &frontend_env, session_id)
.await
.unwrap();

Ok(FastInsertContext {
webhook_source_info,
fast_insert_node,
fast_insert_request,
compute_client,
})
}

async fn execute(
fast_insert_node: FastInsertNode,
wait_for_persistence: bool,
request: FastInsertRequest,
client: ComputeClient,
) -> Result<FastInsertResponse> {
let request = FastInsertRequest {
fast_insert_node: Some(fast_insert_node),
wait_for_persistence,
};

let response = client.fast_insert(request).await.map_err(|e| {
err(
anyhow!(e).context("Failed to execute on compute node"),
Expand Down

0 comments on commit b975e81

Please sign in to comment.