From d7e4a3976a2a5c08e05125c3f9dbe7fb860eecaf Mon Sep 17 00:00:00 2001 From: xxhZs <1060434431@qq.com> Date: Thu, 14 Mar 2024 18:25:58 +0800 Subject: [PATCH 1/9] save --- Cargo.lock | 113 +++++++++++++++++++++++++--- src/connector/Cargo.toml | 3 + src/connector/src/sink/big_query.rs | 102 ++++++++++++++++++++----- 3 files changed, 190 insertions(+), 28 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c0e821669fda0..ef73103f3e020 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -273,7 +273,26 @@ dependencies = [ "arrow-row 48.0.1", "arrow-schema 48.0.1", "arrow-select 48.0.1", - "arrow-string", + "arrow-string 48.0.1", +] + +[[package]] +name = "arrow" +version = "50.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa285343fba4d829d49985bdc541e3789cf6000ed0e84be7c039438df4a4e78c" +dependencies = [ + "arrow-arith 50.0.0", + "arrow-array 50.0.0", + "arrow-buffer 50.0.0", + "arrow-cast 50.0.0", + "arrow-data 50.0.0", + "arrow-ipc 50.0.0", + "arrow-ord 50.0.0", + "arrow-row 50.0.0", + "arrow-schema 50.0.0", + "arrow-select 50.0.0", + "arrow-string 50.0.0", ] [[package]] @@ -627,6 +646,22 @@ dependencies = [ "regex-syntax 0.8.2", ] +[[package]] +name = "arrow-string" +version = "50.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00f3b37f2aeece31a2636d1b037dabb69ef590e03bdc7eb68519b51ec86932a7" +dependencies = [ + "arrow-array 50.0.0", + "arrow-buffer 50.0.0", + "arrow-data 50.0.0", + "arrow-schema 50.0.0", + "arrow-select 50.0.0", + "num", + "regex", + "regex-syntax 0.8.2", +] + [[package]] name = "arrow-udf-js" version = "0.1.2" @@ -1509,6 +1544,7 @@ dependencies = [ "num-bigint", "num-integer", "num-traits", + "serde", ] [[package]] @@ -2969,7 +3005,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "676796427e638d85e9eadf13765705212be60b34f8fc5d3934d95184c63ca1b4" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-schema 48.0.1", "async-compression", @@ -3016,7 +3052,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "31e23b3d21a6531259d291bd20ce59282ea794bda1018b0a1e278c13cd52e50c" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-buffer 48.0.1", "arrow-schema 48.0.1", @@ -3034,7 +3070,7 @@ version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4de1fd0d8db0f2b8e4f4121bfa1c7c09d3a5c08a0a65c2229cd849eb65cff855" dependencies = [ - "arrow", + "arrow 48.0.1", "chrono", "dashmap", "datafusion-common", @@ -3056,7 +3092,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "18e227fe88bf6730cab378d0cd8fc4c6b2ea42bc7e414a8ea9feba7225932735" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "datafusion-common", "sqlparser", @@ -3070,7 +3106,7 @@ version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c6648e62ea7605b9bfcd87fdc9d67e579c3b9ac563a87734ae5fe6d79ee4547" dependencies = [ - "arrow", + "arrow 48.0.1", "async-trait", "chrono", "datafusion-common", @@ -3089,7 +3125,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f32b8574add16a32411a9b3fb3844ac1fc09ab4e7be289f86fd56d620e4f2508" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-buffer 48.0.1", "arrow-ord 48.0.1", @@ -3124,7 +3160,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "796abd77d5bfecd9e5275a99daf0ec45f5b3a793ec431349ce8211a67826fd22" dependencies = [ "ahash 0.8.6", - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-buffer 48.0.1", "arrow-schema 48.0.1", @@ -3154,7 +3190,7 @@ version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26de2592417beb20f73f29b131a04d7de14e2a6336c631554d611584b4306236" dependencies = [ - "arrow", + "arrow 48.0.1", "chrono", "datafusion", "datafusion-common", @@ -3169,7 +3205,7 @@ version = "33.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced70b8a5648ba7b95c61fc512183c33287ffe2c9f22ffe22700619d7d48c84f" dependencies = [ - "arrow", + "arrow 48.0.1", "arrow-schema 48.0.1", "datafusion-common", "datafusion-expr", @@ -3225,7 +3261,7 @@ name = "deltalake-core" version = "0.17.0" source = "git+https://github.com/risingwavelabs/delta-rs?rev=5c2dccd4640490202ffe98adbd13b09cef8e007b#5c2dccd4640490202ffe98adbd13b09cef8e007b" dependencies = [ - "arrow", + "arrow 48.0.1", "arrow-array 48.0.1", "arrow-buffer 48.0.1", "arrow-cast 48.0.1", @@ -4664,6 +4700,33 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "google-cloud-bigquery" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c48abc8687f4c4cc143dd5bd3da5f1d7ef38334e4af5cef6de4c39295c6a3fd0" +dependencies = [ + "anyhow", + "arrow 50.0.0", + "async-trait", + "backon", + "base64 0.21.7", + "bigdecimal 0.4.2", + "google-cloud-auth", + "google-cloud-gax", + "google-cloud-googleapis", + "google-cloud-token", + "num-bigint", + "reqwest", + "reqwest-middleware", + "serde", + "serde_json", + "thiserror", + "time", + "tokio", + "tracing", +] + [[package]] name = "google-cloud-gax" version = "0.17.0" @@ -8534,6 +8597,7 @@ dependencies = [ "js-sys", "log", "mime", + "mime_guess", "native-tls", "once_cell", "percent-encoding", @@ -8558,6 +8622,21 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest-middleware" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88a3e86aa6053e59030e7ce2d2a3b258dd08fc2d337d52f73f6cb480f5858690" +dependencies = [ + "anyhow", + "async-trait", + "http 0.2.9", + "reqwest", + "serde", + "task-local-extensions", + "thiserror", +] + [[package]] name = "resolv-conf" version = "0.7.0" @@ -9149,6 +9228,9 @@ dependencies = [ "futures-async-stream", "gcp-bigquery-client", "glob", + "google-cloud-bigquery", + "google-cloud-gax", + "google-cloud-googleapis", "google-cloud-pubsub", "http 0.2.9", "hyper", @@ -12025,6 +12107,15 @@ version = "0.12.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14c39fd04924ca3a864207c66fc2cd7d22d7c016007f9ce846cbb9326331930a" +[[package]] +name = "task-local-extensions" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba323866e5d033818e3240feeb9f7db2c4296674e4d9e16b97b7bf8f490434e8" +dependencies = [ + "pin-utils", +] + [[package]] name = "task_stats_alloc" version = "0.1.11" diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index e18b5d2bb2c8f..dcccbe4ac4edd 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -114,6 +114,9 @@ redis = { version = "0.24.0", features = [ regex = "1.4" reqwest = { version = "0.11", features = ["json"] } risingwave_common = { workspace = true } +google-cloud-bigquery = { version = "0.7.0", features = ["auth"] } +google-cloud-gax = "0.17.0" +google-cloud-googleapis = "0.12.0" risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index 42918a6b72dfe..70f8f842664f0 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -13,6 +13,7 @@ // limitations under the License. use core::mem; +use core::time::Duration; use std::collections::HashMap; use std::sync::Arc; @@ -22,6 +23,14 @@ use gcp_bigquery_client::model::query_request::QueryRequest; use gcp_bigquery_client::model::table_data_insert_all_request::TableDataInsertAllRequest; use gcp_bigquery_client::model::table_data_insert_all_request_rows::TableDataInsertAllRequestRows; use gcp_bigquery_client::Client; +use google_cloud_gax::grpc::{Request}; +use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::Rows as AppendRowsRequestRows; +use google_cloud_gax::conn::{ConnectionOptions, Environment}; +use google_cloud_bigquery::grpc::apiv1::bigquery_client::StreamingWriteClient; +use google_cloud_bigquery::grpc::apiv1::conn_pool::{WriteConnectionManager, DOMAIN}; +use google_cloud_googleapis::cloud::bigquery::storage::v1::AppendRowsRequest; +use google_cloud_pubsub::client::google_cloud_auth; +use google_cloud_pubsub::client::google_cloud_auth::credentials::CredentialsFile; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; @@ -30,6 +39,7 @@ use serde_derive::Deserialize; use serde_json::Value; use serde_with::{serde_as, DisplayFromStr}; use url::Url; +use uuid::Uuid; use with_options::WithOptions; use yup_oauth2::ServiceAccountKey; @@ -44,6 +54,7 @@ use crate::sink::{ }; pub const BIGQUERY_SINK: &str = "bigquery"; +const DEFAULT_GRPC_CHANNEL_NUMS: usize = 4; #[serde_as] #[derive(Deserialize, Debug, Clone, WithOptions)] @@ -69,27 +80,40 @@ fn default_max_batch_rows() -> usize { impl BigQueryCommon { pub(crate) async fn build_client(&self, aws_auth_props: &AwsAuthProps) -> Result { - let service_account = if let Some(local_path) = &self.local_path { - let auth_json = std::fs::read_to_string(local_path) + let auth_json = self.get_auth_json_from_path(aws_auth_props).await?; + + let service_account = serde_json::from_str::(&auth_json) .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; - serde_json::from_str::(&auth_json) - .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))? - } else if let Some(s3_path) = &self.s3_path { - let url = - Url::parse(s3_path).map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; - let auth_json = load_file_descriptor_from_s3(&url, aws_auth_props) - .await - .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; - serde_json::from_slice::(&auth_json) - .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))? - } else { - return Err(SinkError::BigQuery(anyhow::anyhow!("`bigquery.local.path` and `bigquery.s3.path` set at least one, configure as needed."))); - }; let client: Client = Client::from_service_account_key(service_account, false) .await .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; Ok(client) } + + pub(crate) async fn build_writer_client(&self, aws_auth_props: &AwsAuthProps) -> Result { + let auth_json = self.get_auth_json_from_path(aws_auth_props).await?; + + let credentials_file= CredentialsFile::new_from_str(&auth_json).await.unwrap(); + let client = StorageWriterClient::new(credentials_file).await?; + Ok(client) + } + + async fn get_auth_json_from_path(&self, aws_auth_props: &AwsAuthProps) -> Result{ + if let Some(local_path) = &self.local_path{ + std::fs::read_to_string(local_path) + .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err))) + }else if let Some(s3_path) = &self.s3_path { + let url = + Url::parse(s3_path).map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; + let auth_vec = load_file_descriptor_from_s3(&url, aws_auth_props) + .await + .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; + Ok(String::from_utf8(auth_vec).unwrap()) + }else{ + Err(SinkError::BigQuery(anyhow::anyhow!("`bigquery.local.path` and `bigquery.s3.path` set at least one, configure as needed."))) + } + } + } #[serde_as] @@ -280,7 +304,7 @@ pub struct BigQuerySinkWriter { pub config: BigQueryConfig, schema: Schema, pk_indices: Vec, - client: Client, + client: StorageWriterClient, is_append_only: bool, insert_request: TableDataInsertAllRequest, row_encoder: JsonEncoder, @@ -308,7 +332,7 @@ impl BigQuerySinkWriter { pk_indices: Vec, is_append_only: bool, ) -> Result { - let client = config.common.build_client(&config.aws_auth_props).await?; + let client = config.common.build_writer_client(&config.aws_auth_props).await?; Ok(Self { config, schema: schema.clone(), @@ -401,6 +425,50 @@ impl SinkWriter for BigQuerySinkWriter { } } +struct StorageWriterClient{ + client: StreamingWriteClient, + environment: Environment, +} +impl StorageWriterClient{ + pub async fn new(credentials: CredentialsFile) -> Result{ + // let credentials = CredentialsFile::new_from_file("/home/xxhx/winter-dynamics-383822-9690ac19ce78.json".to_string()).await.unwrap(); + let ts_grpc = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials( + Self::bigquery_grpc_auth_config(), + Box::new(credentials), + ) + .await.unwrap(); + let conn_options = ConnectionOptions{ + connect_timeout: Some(Duration::from_secs(30)), + timeout: None, + }; + let environment = Environment::GoogleCloud(Box::new(ts_grpc)); + let conn = WriteConnectionManager::new(DEFAULT_GRPC_CHANNEL_NUMS, &environment, DOMAIN, &conn_options).await.unwrap(); + let client = conn.conn(); + Ok(StorageWriterClient{ + client, + environment, + }) + } + pub async fn append_rows( + &mut self, + rows: Vec, + write_stream: String, + ) -> Result<()> { + let trace_id = Uuid::new_v4().hyphenated().to_string(); + let append_req:Vec = rows.into_iter().map(|row| AppendRowsRequest{ write_stream: write_stream.clone(), offset:None, trace_id: trace_id.clone(), missing_value_interpretations: HashMap::default(), rows: Some(row)}).collect(); + let a = self.client.append_rows(Request::new(tokio_stream::iter(append_req))).await; + Ok(()) + } + + fn bigquery_grpc_auth_config() -> google_cloud_auth::project::Config<'static> { + google_cloud_auth::project::Config { + audience: Some(google_cloud_bigquery::grpc::apiv1::conn_pool::AUDIENCE), + scopes: Some(&google_cloud_bigquery::grpc::apiv1::conn_pool::SCOPES), + sub: None, + } + } +} + #[cfg(test)] mod test { use risingwave_common::types::{DataType, StructType}; From af66328c205cd9f6454d7bbdf3ce71d4053fd5f0 Mon Sep 17 00:00:00 2001 From: xxhZs <1060434431@qq.com> Date: Tue, 19 Mar 2024 15:39:44 +0800 Subject: [PATCH 2/9] support bigquery cdc --- Cargo.lock | 2 + src/connector/Cargo.toml | 6 +- src/connector/src/lib.rs | 1 + src/connector/src/sink/big_query.rs | 435 ++++++++++++++++++------ src/connector/src/sink/encoder/json.rs | 28 +- src/connector/src/sink/encoder/mod.rs | 6 +- src/connector/src/sink/encoder/proto.rs | 160 +++++++-- src/connector/src/sink/formatter/mod.rs | 11 +- src/connector/with_options_sink.yaml | 4 - src/workspace-hack/Cargo.toml | 7 +- 10 files changed, 496 insertions(+), 164 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ef73103f3e020..a7295685d3685 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14180,6 +14180,7 @@ dependencies = [ "futures-util", "generic-array", "getrandom", + "google-cloud-googleapis", "governor", "hashbrown 0.13.2", "hashbrown 0.14.3", @@ -14203,6 +14204,7 @@ dependencies = [ "madsim-tokio", "md-5", "memchr", + "mime_guess", "mio", "moka", "nom", diff --git a/src/connector/Cargo.toml b/src/connector/Cargo.toml index dcccbe4ac4edd..533407a63fa51 100644 --- a/src/connector/Cargo.toml +++ b/src/connector/Cargo.toml @@ -57,6 +57,9 @@ futures = { version = "0.3", default-features = false, features = ["alloc"] } futures-async-stream = { workspace = true } gcp-bigquery-client = "0.18.0" glob = "0.3" +google-cloud-bigquery = { version = "0.7.0", features = ["auth"] } +google-cloud-gax = "0.17.0" +google-cloud-googleapis = "0.12.0" google-cloud-pubsub = "0.23" http = "0.2" hyper = { version = "0.14", features = [ @@ -114,9 +117,6 @@ redis = { version = "0.24.0", features = [ regex = "1.4" reqwest = { version = "0.11", features = ["json"] } risingwave_common = { workspace = true } -google-cloud-bigquery = { version = "0.7.0", features = ["auth"] } -google-cloud-gax = "0.17.0" -google-cloud-googleapis = "0.12.0" risingwave_jni_core = { workspace = true } risingwave_pb = { workspace = true } risingwave_rpc_client = { workspace = true } diff --git a/src/connector/src/lib.rs b/src/connector/src/lib.rs index fb38f2db00c4f..a530cd681951d 100644 --- a/src/connector/src/lib.rs +++ b/src/connector/src/lib.rs @@ -33,6 +33,7 @@ #![feature(try_blocks)] #![feature(error_generic_member_access)] #![feature(register_tool)] +#![feature(assert_matches)] #![register_tool(rw)] #![recursion_limit = "256"] diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index 70f8f842664f0..96975c9967538 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use core::mem; use core::time::Duration; use std::collections::HashMap; use std::sync::Arc; @@ -20,30 +19,36 @@ use std::sync::Arc; use anyhow::anyhow; use async_trait::async_trait; use gcp_bigquery_client::model::query_request::QueryRequest; -use gcp_bigquery_client::model::table_data_insert_all_request::TableDataInsertAllRequest; -use gcp_bigquery_client::model::table_data_insert_all_request_rows::TableDataInsertAllRequestRows; use gcp_bigquery_client::Client; -use google_cloud_gax::grpc::{Request}; -use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::Rows as AppendRowsRequestRows; -use google_cloud_gax::conn::{ConnectionOptions, Environment}; use google_cloud_bigquery::grpc::apiv1::bigquery_client::StreamingWriteClient; use google_cloud_bigquery::grpc::apiv1::conn_pool::{WriteConnectionManager, DOMAIN}; -use google_cloud_googleapis::cloud::bigquery::storage::v1::AppendRowsRequest; +use google_cloud_gax::conn::{ConnectionOptions, Environment}; +use google_cloud_gax::grpc::Request; +use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request::{ + ProtoData, Rows as AppendRowsRequestRows, +}; +use google_cloud_googleapis::cloud::bigquery::storage::v1::{ + AppendRowsRequest, ProtoRows, ProtoSchema, +}; use google_cloud_pubsub::client::google_cloud_auth; use google_cloud_pubsub::client::google_cloud_auth::credentials::CredentialsFile; +use prost_reflect::MessageDescriptor; +use prost_types::{ + field_descriptor_proto, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, + FileDescriptorSet, +}; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; use risingwave_common::types::DataType; use serde_derive::Deserialize; -use serde_json::Value; -use serde_with::{serde_as, DisplayFromStr}; +use serde_with::serde_as; use url::Url; use uuid::Uuid; use with_options::WithOptions; use yup_oauth2::ServiceAccountKey; -use super::encoder::{JsonEncoder, RowEncoder}; +use super::encoder::{CustomProtoType, ProtoEncoder, ProtoHeader, RowEncoder, SerTo}; use super::writer::LogSinkerOf; use super::{SinkError, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT}; use crate::aws_utils::load_file_descriptor_from_s3; @@ -54,7 +59,10 @@ use crate::sink::{ }; pub const BIGQUERY_SINK: &str = "bigquery"; +pub const CHANGE_TYPE: &str = "_CHANGE_TYPE"; const DEFAULT_GRPC_CHANNEL_NUMS: usize = 4; +const CONNECT_TIMEOUT: Option = Some(Duration::from_secs(30)); +const CONNECTION_TIMEOUT: Option = None; #[serde_as] #[derive(Deserialize, Debug, Clone, WithOptions)] @@ -69,9 +77,6 @@ pub struct BigQueryCommon { pub dataset: String, #[serde(rename = "bigquery.table")] pub table: String, - #[serde(rename = "bigquery.max_batch_rows", default = "default_max_batch_rows")] - #[serde_as(as = "DisplayFromStr")] - pub max_batch_rows: usize, } fn default_max_batch_rows() -> usize { @@ -79,41 +84,45 @@ fn default_max_batch_rows() -> usize { } impl BigQueryCommon { - pub(crate) async fn build_client(&self, aws_auth_props: &AwsAuthProps) -> Result { + async fn build_client(&self, aws_auth_props: &AwsAuthProps) -> Result { let auth_json = self.get_auth_json_from_path(aws_auth_props).await?; - + let service_account = serde_json::from_str::(&auth_json) - .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; + .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; let client: Client = Client::from_service_account_key(service_account, false) .await .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; Ok(client) } - pub(crate) async fn build_writer_client(&self, aws_auth_props: &AwsAuthProps) -> Result { + async fn build_writer_client( + &self, + aws_auth_props: &AwsAuthProps, + ) -> Result { let auth_json = self.get_auth_json_from_path(aws_auth_props).await?; - - let credentials_file= CredentialsFile::new_from_str(&auth_json).await.unwrap(); + + let credentials_file = CredentialsFile::new_from_str(&auth_json) + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; let client = StorageWriterClient::new(credentials_file).await?; Ok(client) } - async fn get_auth_json_from_path(&self, aws_auth_props: &AwsAuthProps) -> Result{ - if let Some(local_path) = &self.local_path{ + async fn get_auth_json_from_path(&self, aws_auth_props: &AwsAuthProps) -> Result { + if let Some(local_path) = &self.local_path { std::fs::read_to_string(local_path) .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err))) - }else if let Some(s3_path) = &self.s3_path { + } else if let Some(s3_path) = &self.s3_path { let url = Url::parse(s3_path).map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; let auth_vec = load_file_descriptor_from_s3(&url, aws_auth_props) .await .map_err(|err| SinkError::BigQuery(anyhow::anyhow!(err)))?; - Ok(String::from_utf8(auth_vec).unwrap()) - }else{ + Ok(String::from_utf8(auth_vec).map_err(|e| SinkError::BigQuery(e.into()))?) + } else { Err(SinkError::BigQuery(anyhow::anyhow!("`bigquery.local.path` and `bigquery.s3.path` set at least one, configure as needed."))) } } - } #[serde_as] @@ -211,9 +220,7 @@ impl BigQuerySink { DataType::Decimal => Ok("NUMERIC".to_owned()), DataType::Date => Ok("DATE".to_owned()), DataType::Varchar => Ok("STRING".to_owned()), - DataType::Time => Err(SinkError::BigQuery(anyhow::anyhow!( - "Bigquery cannot support Time" - ))), + DataType::Time => Ok("TIME".to_owned()), DataType::Timestamp => Ok("DATETIME".to_owned()), DataType::Timestamptz => Ok("TIMESTAMP".to_owned()), DataType::Interval => Ok("INTERVAL".to_owned()), @@ -258,12 +265,6 @@ impl Sink for BigQuerySink { } async fn validate(&self) -> Result<()> { - if !self.is_append_only { - return Err(SinkError::Config(anyhow!( - "BigQuery sink don't support upsert" - ))); - } - let client = self .config .common @@ -306,8 +307,10 @@ pub struct BigQuerySinkWriter { pk_indices: Vec, client: StorageWriterClient, is_append_only: bool, - insert_request: TableDataInsertAllRequest, - row_encoder: JsonEncoder, + row_encoder: ProtoEncoder, + writer_pb_schema: ProtoSchema, + message_descriptor: MessageDescriptor, + write_stream: String, } impl TryFrom for BigQuerySink { @@ -332,66 +335,126 @@ impl BigQuerySinkWriter { pk_indices: Vec, is_append_only: bool, ) -> Result { - let client = config.common.build_writer_client(&config.aws_auth_props).await?; + let client = config + .common + .build_writer_client(&config.aws_auth_props) + .await?; + let mut descriptor_proto = build_protobuf_schema( + schema + .fields() + .iter() + .map(|f| (f.name.as_str(), &f.data_type)), + config.common.table.clone(), + 1, + ); + + if !is_append_only { + let field = FieldDescriptorProto { + name: Some(CHANGE_TYPE.to_string()), + number: Some((schema.len() + 1) as i32), + r#type: Some(field_descriptor_proto::Type::String.into()), + ..Default::default() + }; + descriptor_proto.field.push(field); + } + + let descriptor_pool = build_protobuf_descriptor_pool(&descriptor_proto); + let message_descriptor = descriptor_pool + .get_message_by_name(&config.common.table) + .ok_or_else(|| { + SinkError::BigQuery(anyhow::anyhow!( + "Can't find message proto {}", + &config.common.table + )) + })?; + let row_encoder = ProtoEncoder::new( + schema.clone(), + None, + message_descriptor.clone(), + ProtoHeader::None, + CustomProtoType::BigQuery, + )?; Ok(Self { + write_stream: format!( + "projects/{}/datasets/{}/tables/{}/streams/_default", + config.common.project, config.common.dataset, config.common.table + ), config, - schema: schema.clone(), + schema, pk_indices, client, is_append_only, - insert_request: TableDataInsertAllRequest::new(), - row_encoder: JsonEncoder::new_with_bigquery(schema, None), + row_encoder, + message_descriptor, + writer_pb_schema: ProtoSchema { + proto_descriptor: Some(descriptor_proto), + }, }) } async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> { - let mut insert_vec = Vec::with_capacity(chunk.capacity()); + let mut serialized_rows: Vec> = Vec::with_capacity(chunk.capacity()); for (op, row) in chunk.rows() { if op != Op::Insert { - return Err(SinkError::BigQuery(anyhow::anyhow!( - "BigQuery sink don't support upsert" - ))); + continue; } - insert_vec.push(TableDataInsertAllRequestRows { - insert_id: None, - json: Value::Object(self.row_encoder.encode(row)?), - }) - } - self.insert_request - .add_rows(insert_vec) - .map_err(|e| SinkError::BigQuery(e.into()))?; - if self - .insert_request - .len() - .ge(&self.config.common.max_batch_rows) - { - self.insert_data().await?; + + serialized_rows.push(self.row_encoder.encode(row)?.ser_to()?) } + let rows = AppendRowsRequestRows::ProtoRows(ProtoData { + writer_schema: Some(self.writer_pb_schema.clone()), + rows: Some(ProtoRows { serialized_rows }), + }); + self.client + .append_rows(vec![rows], self.write_stream.clone()) + .await?; Ok(()) } - async fn insert_data(&mut self) -> Result<()> { - if !self.insert_request.is_empty() { - let insert_request = - mem::replace(&mut self.insert_request, TableDataInsertAllRequest::new()); - let request = self - .client - .tabledata() - .insert_all( - &self.config.common.project, - &self.config.common.dataset, - &self.config.common.table, - insert_request, - ) - .await - .map_err(|e| SinkError::BigQuery(e.into()))?; - if let Some(error) = request.insert_errors { - return Err(SinkError::BigQuery(anyhow::anyhow!( - "Insert error: {:?}", - error - ))); - } + async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> { + let mut serialized_rows: Vec> = Vec::with_capacity(chunk.capacity()); + for (op, row) in chunk.rows() { + let mut pb_row = self.row_encoder.encode(row)?; + let proto_field = self + .message_descriptor + .get_field_by_name(CHANGE_TYPE) + .ok_or_else(|| { + SinkError::BigQuery(anyhow::anyhow!("Can't find {}", CHANGE_TYPE)) + })?; + match op { + Op::Insert => pb_row + .message + .try_set_field( + &proto_field, + prost_reflect::Value::String("INSERT".to_string()), + ) + .map_err(|e| SinkError::BigQuery(e.into()))?, + Op::Delete => pb_row + .message + .try_set_field( + &proto_field, + prost_reflect::Value::String("DELETE".to_string()), + ) + .map_err(|e| SinkError::BigQuery(e.into()))?, + Op::UpdateDelete => continue, + Op::UpdateInsert => pb_row + .message + .try_set_field( + &proto_field, + prost_reflect::Value::String("UPSERT".to_string()), + ) + .map_err(|e| SinkError::BigQuery(e.into()))?, + }; + + serialized_rows.push(pb_row.ser_to()?) } + let rows = AppendRowsRequestRows::ProtoRows(ProtoData { + writer_schema: Some(self.writer_pb_schema.clone()), + rows: Some(ProtoRows { serialized_rows }), + }); + self.client + .append_rows(vec![rows], self.write_stream.clone()) + .await?; Ok(()) } } @@ -402,9 +465,7 @@ impl SinkWriter for BigQuerySinkWriter { if self.is_append_only { self.append_only(chunk).await } else { - Err(SinkError::BigQuery(anyhow::anyhow!( - "BigQuery sink don't support upsert" - ))) + self.upsert(chunk).await } } @@ -417,7 +478,7 @@ impl SinkWriter for BigQuerySinkWriter { } async fn barrier(&mut self, _is_checkpoint: bool) -> Result<()> { - self.insert_data().await + Ok(()) } async fn update_vnode_bitmap(&mut self, _vnode_bitmap: Arc) -> Result<()> { @@ -425,38 +486,71 @@ impl SinkWriter for BigQuerySinkWriter { } } -struct StorageWriterClient{ +struct StorageWriterClient { client: StreamingWriteClient, environment: Environment, } -impl StorageWriterClient{ - pub async fn new(credentials: CredentialsFile) -> Result{ - // let credentials = CredentialsFile::new_from_file("/home/xxhx/winter-dynamics-383822-9690ac19ce78.json".to_string()).await.unwrap(); +impl StorageWriterClient { + pub async fn new(credentials: CredentialsFile) -> Result { let ts_grpc = google_cloud_auth::token::DefaultTokenSourceProvider::new_with_credentials( Self::bigquery_grpc_auth_config(), Box::new(credentials), ) - .await.unwrap(); - let conn_options = ConnectionOptions{ - connect_timeout: Some(Duration::from_secs(30)), - timeout: None, + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; + let conn_options = ConnectionOptions { + connect_timeout: CONNECT_TIMEOUT, + timeout: CONNECTION_TIMEOUT, }; let environment = Environment::GoogleCloud(Box::new(ts_grpc)); - let conn = WriteConnectionManager::new(DEFAULT_GRPC_CHANNEL_NUMS, &environment, DOMAIN, &conn_options).await.unwrap(); + let conn = WriteConnectionManager::new( + DEFAULT_GRPC_CHANNEL_NUMS, + &environment, + DOMAIN, + &conn_options, + ) + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; let client = conn.conn(); - Ok(StorageWriterClient{ + Ok(StorageWriterClient { client, environment, }) } + pub async fn append_rows( &mut self, rows: Vec, write_stream: String, ) -> Result<()> { - let trace_id = Uuid::new_v4().hyphenated().to_string(); - let append_req:Vec = rows.into_iter().map(|row| AppendRowsRequest{ write_stream: write_stream.clone(), offset:None, trace_id: trace_id.clone(), missing_value_interpretations: HashMap::default(), rows: Some(row)}).collect(); - let a = self.client.append_rows(Request::new(tokio_stream::iter(append_req))).await; + let trace_id = Uuid::new_v4().hyphenated().to_string(); + let append_req: Vec = rows + .into_iter() + .map(|row| AppendRowsRequest { + write_stream: write_stream.clone(), + offset: None, + trace_id: trace_id.clone(), + missing_value_interpretations: HashMap::default(), + rows: Some(row), + }) + .collect(); + let resp = self + .client + .append_rows(Request::new(tokio_stream::iter(append_req))) + .await + .map_err(|e| SinkError::BigQuery(e.into()))? + .into_inner() + .message() + .await + .map_err(|e| SinkError::BigQuery(e.into()))?; + if let Some(i) = resp { + if !i.row_errors.is_empty() { + return Err(SinkError::BigQuery(anyhow::anyhow!( + "Insert error {:?}", + i.row_errors + ))); + } + } Ok(()) } @@ -469,11 +563,99 @@ impl StorageWriterClient{ } } +fn build_protobuf_descriptor_pool(desc: &DescriptorProto) -> prost_reflect::DescriptorPool { + let file_descriptor = FileDescriptorProto { + message_type: vec![desc.clone()], + name: Some("bigquery".to_string()), + ..Default::default() + }; + + prost_reflect::DescriptorPool::from_file_descriptor_set(FileDescriptorSet { + file: vec![file_descriptor], + }) + .unwrap() +} + +fn build_protobuf_schema<'a>( + fields: impl Iterator, + name: String, + index: i32, +) -> DescriptorProto { + let mut proto = DescriptorProto { + name: Some(name), + ..Default::default() + }; + let mut index_mut = index; + let mut field_vec = vec![]; + let mut struct_vec = vec![]; + for (name, data_type) in fields { + let (field, des_proto) = build_protobuf_field(data_type, index_mut, name.to_string()); + field_vec.push(field); + if let Some(sv) = des_proto { + struct_vec.push(sv); + } + index_mut += 1; + } + proto.field = field_vec; + proto.nested_type = struct_vec; + proto +} + +fn build_protobuf_field( + data_type: &DataType, + index: i32, + name: String, +) -> (FieldDescriptorProto, Option) { + let mut field = FieldDescriptorProto { + name: Some(name.clone()), + number: Some(index), + ..Default::default() + }; + match data_type { + DataType::Boolean => field.r#type = Some(field_descriptor_proto::Type::Bool.into()), + DataType::Int32 => field.r#type = Some(field_descriptor_proto::Type::Int32.into()), + DataType::Int16 | DataType::Int64 => { + field.r#type = Some(field_descriptor_proto::Type::Int64.into()) + } + DataType::Float64 => field.r#type = Some(field_descriptor_proto::Type::Double.into()), + DataType::Decimal => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Date => field.r#type = Some(field_descriptor_proto::Type::Int32.into()), + DataType::Varchar => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Time => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Timestamp => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Timestamptz => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Interval => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Struct(s) => { + field.r#type = Some(field_descriptor_proto::Type::Message.into()); + let name = format!("Struct{}", name); + let sub_proto = build_protobuf_schema(s.iter(), name.clone(), 1); + field.type_name = Some(name); + return (field, Some(sub_proto)); + } + DataType::List(l) => { + let (mut field, proto) = build_protobuf_field(l.as_ref(), index, name.clone()); + field.label = Some(field_descriptor_proto::Label::Repeated.into()); + return (field, proto); + } + DataType::Bytea => field.r#type = Some(field_descriptor_proto::Type::Bytes.into()), + DataType::Jsonb => field.r#type = Some(field_descriptor_proto::Type::String.into()), + DataType::Serial => field.r#type = Some(field_descriptor_proto::Type::Int64.into()), + DataType::Float32 | DataType::Int256 => todo!(), + } + (field, None) +} + #[cfg(test)] mod test { + + use std::assert_matches::assert_matches; + + use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::{DataType, StructType}; - use crate::sink::big_query::BigQuerySink; + use crate::sink::big_query::{ + build_protobuf_descriptor_pool, build_protobuf_schema, BigQuerySink, + }; #[tokio::test] async fn test_type_check() { @@ -493,4 +675,63 @@ mod test { big_query_type_string ); } + + #[tokio::test] + async fn test_schema_check() { + let schema = Schema { + fields: vec![ + Field::with_name(DataType::Int64, "v1"), + Field::with_name(DataType::Float64, "v2"), + Field::with_name( + DataType::List(Box::new(DataType::Struct(StructType::new(vec![ + ("v1".to_owned(), DataType::List(Box::new(DataType::Int64))), + ( + "v3".to_owned(), + DataType::Struct(StructType::new(vec![ + ("v1".to_owned(), DataType::Int64), + ("v2".to_owned(), DataType::Int64), + ])), + ), + ])))), + "v3", + ), + ], + }; + let fields = schema + .fields() + .iter() + .map(|f| (f.name.as_str(), &f.data_type)); + let desc = build_protobuf_schema(fields, "t1".to_string(), 1); + let pool = build_protobuf_descriptor_pool(&desc); + let t1_message = pool.get_message_by_name("t1").unwrap(); + assert_matches!( + t1_message.get_field_by_name("v1").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + assert_matches!( + t1_message.get_field_by_name("v2").unwrap().kind(), + prost_reflect::Kind::Double + ); + assert_matches!( + t1_message.get_field_by_name("v3").unwrap().kind(), + prost_reflect::Kind::Message(_) + ); + + let v3_message = pool.get_message_by_name("t1.Structv3").unwrap(); + assert_matches!( + v3_message.get_field_by_name("v1").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + assert!(v3_message.get_field_by_name("v1").unwrap().is_list()); + + let v3_v3_message = pool.get_message_by_name("t1.Structv3.Structv3").unwrap(); + assert_matches!( + v3_v3_message.get_field_by_name("v1").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + assert_matches!( + v3_v3_message.get_field_by_name("v2").unwrap().kind(), + prost_reflect::Kind::Int64 + ); + } } diff --git a/src/connector/src/sink/encoder/json.rs b/src/connector/src/sink/encoder/json.rs index 64a06ff70770f..006500c60914d 100644 --- a/src/connector/src/sink/encoder/json.rs +++ b/src/connector/src/sink/encoder/json.rs @@ -114,19 +114,6 @@ impl JsonEncoder { } } - pub fn new_with_bigquery(schema: Schema, col_indices: Option>) -> Self { - Self { - schema, - col_indices, - time_handling_mode: TimeHandlingMode::Milli, - date_handling_mode: DateHandlingMode::String, - timestamp_handling_mode: TimestampHandlingMode::String, - timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, - custom_json_type: CustomJsonType::BigQuery, - kafka_connect: None, - } - } - pub fn with_kafka_connect(self, kafka_connect: KafkaConnectParams) -> Self { Self { kafka_connect: Some(Arc::new(kafka_connect)), @@ -204,14 +191,7 @@ fn datum_to_json_object( ) -> ArrayResult { let scalar_ref = match datum { None => { - if let CustomJsonType::BigQuery = custom_json_type - && matches!(field.data_type(), DataType::List(_)) - { - // Bigquery need to convert null of array to empty array https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types - return Ok(Value::Array(vec![])); - } else { - return Ok(Value::Null); - } + return Ok(Value::Null); } Some(datum) => datum, }; @@ -259,7 +239,7 @@ fn datum_to_json_object( } json!(v_string) } - CustomJsonType::Es | CustomJsonType::None | CustomJsonType::BigQuery => { + CustomJsonType::Es | CustomJsonType::None => { json!(v.to_text()) } }, @@ -311,7 +291,7 @@ fn datum_to_json_object( } (DataType::Jsonb, ScalarRefImpl::Jsonb(jsonb_ref)) => match custom_json_type { CustomJsonType::Es | CustomJsonType::StarRocks(_) => JsonbVal::from(jsonb_ref).take(), - CustomJsonType::Doris(_) | CustomJsonType::None | CustomJsonType::BigQuery => { + CustomJsonType::Doris(_) | CustomJsonType::None => { json!(jsonb_ref.to_string()) } }, @@ -362,7 +342,7 @@ fn datum_to_json_object( "starrocks can't support struct".to_string(), )); } - CustomJsonType::Es | CustomJsonType::None | CustomJsonType::BigQuery => { + CustomJsonType::Es | CustomJsonType::None => { let mut map = Map::with_capacity(st.len()); for (sub_datum_ref, sub_field) in struct_ref.iter_fields_ref().zip_eq_debug( st.iter() diff --git a/src/connector/src/sink/encoder/mod.rs b/src/connector/src/sink/encoder/mod.rs index 34dc4c8886448..4b4807f291bc0 100644 --- a/src/connector/src/sink/encoder/mod.rs +++ b/src/connector/src/sink/encoder/mod.rs @@ -144,7 +144,11 @@ pub enum CustomJsonType { Es, // starrocks' need jsonb is struct StarRocks(HashMap), - // bigquery need null array -> [] + None, +} + +#[derive(Clone)] +pub enum CustomProtoType { BigQuery, None, } diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs index a5f1090dbafaf..4d464e488ece7 100644 --- a/src/connector/src/sink/encoder/proto.rs +++ b/src/connector/src/sink/encoder/proto.rs @@ -22,7 +22,7 @@ use risingwave_common::row::Row; use risingwave_common::types::{DataType, DatumRef, ScalarRefImpl, StructType}; use risingwave_common::util::iter_util::ZipEqDebug; -use super::{FieldEncodeError, Result as SinkResult, RowEncoder, SerTo}; +use super::{CustomProtoType, FieldEncodeError, Result as SinkResult, RowEncoder, SerTo}; type Result = std::result::Result; @@ -31,6 +31,7 @@ pub struct ProtoEncoder { col_indices: Option>, descriptor: MessageDescriptor, header: ProtoHeader, + custom_proto_type: CustomProtoType, } #[derive(Debug, Clone, Copy)] @@ -49,6 +50,7 @@ impl ProtoEncoder { col_indices: Option>, descriptor: MessageDescriptor, header: ProtoHeader, + custom_proto_type: CustomProtoType, ) -> SinkResult { match &col_indices { Some(col_indices) => validate_fields( @@ -57,6 +59,7 @@ impl ProtoEncoder { (f.name.as_str(), &f.data_type) }), &descriptor, + custom_proto_type.clone(), )?, None => validate_fields( schema @@ -64,6 +67,7 @@ impl ProtoEncoder { .iter() .map(|f| (f.name.as_str(), &f.data_type)), &descriptor, + custom_proto_type.clone(), )?, }; @@ -72,12 +76,13 @@ impl ProtoEncoder { col_indices, descriptor, header, + custom_proto_type, }) } } pub struct ProtoEncoded { - message: DynamicMessage, + pub message: DynamicMessage, header: ProtoHeader, } @@ -103,6 +108,7 @@ impl RowEncoder for ProtoEncoder { ((f.name.as_str(), &f.data_type), row.datum_at(idx)) }), &self.descriptor, + self.custom_proto_type.clone(), ) .map_err(Into::into) .map(|m| ProtoEncoded { @@ -180,9 +186,19 @@ trait MaybeData: std::fmt::Debug { fn on_base(self, f: impl FnOnce(ScalarRefImpl<'_>) -> Result) -> Result; - fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result; - - fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result; + fn on_struct( + self, + st: &StructType, + pb: &MessageDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result; + + fn on_list( + self, + elem: &DataType, + pb: &FieldDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result; } impl MaybeData for () { @@ -192,12 +208,22 @@ impl MaybeData for () { Ok(self) } - fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result { - validate_fields(st.iter(), pb) + fn on_struct( + self, + st: &StructType, + pb: &MessageDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result { + validate_fields(st.iter(), pb, custom_proto_type) } - fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result { - encode_field(elem, (), pb, true) + fn on_list( + self, + elem: &DataType, + pb: &FieldDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result { + encode_field(elem, (), pb, true, custom_proto_type) } } @@ -213,13 +239,27 @@ impl MaybeData for ScalarRefImpl<'_> { f(self) } - fn on_struct(self, st: &StructType, pb: &MessageDescriptor) -> Result { + fn on_struct( + self, + st: &StructType, + pb: &MessageDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result { let d = self.into_struct(); - let message = encode_fields(st.iter().zip_eq_debug(d.iter_fields_ref()), pb)?; + let message = encode_fields( + st.iter().zip_eq_debug(d.iter_fields_ref()), + pb, + custom_proto_type, + )?; Ok(Value::Message(message)) } - fn on_list(self, elem: &DataType, pb: &FieldDescriptor) -> Result { + fn on_list( + self, + elem: &DataType, + pb: &FieldDescriptor, + custom_proto_type: CustomProtoType, + ) -> Result { let d = self.into_list(); let vs = d .iter() @@ -231,6 +271,7 @@ impl MaybeData for ScalarRefImpl<'_> { })?, pb, true, + custom_proto_type.clone(), ) }) .try_collect()?; @@ -241,6 +282,7 @@ impl MaybeData for ScalarRefImpl<'_> { fn validate_fields<'a>( fields: impl Iterator, descriptor: &MessageDescriptor, + custom_proto_type: CustomProtoType, ) -> Result<()> { for (name, t) in fields { let Some(proto_field) = descriptor.get_field_by_name(name) else { @@ -249,7 +291,8 @@ fn validate_fields<'a>( if proto_field.cardinality() == prost_reflect::Cardinality::Required { return Err(FieldEncodeError::new("`required` not supported").with_name(name)); } - encode_field(t, (), &proto_field, false).map_err(|e| e.with_name(name))?; + encode_field(t, (), &proto_field, false, custom_proto_type.clone()) + .map_err(|e| e.with_name(name))?; } Ok(()) } @@ -257,14 +300,15 @@ fn validate_fields<'a>( fn encode_fields<'a>( fields_with_datums: impl Iterator)>, descriptor: &MessageDescriptor, + custom_proto_type: CustomProtoType, ) -> Result { let mut message = DynamicMessage::new(descriptor.clone()); for ((name, t), d) in fields_with_datums { let proto_field = descriptor.get_field_by_name(name).unwrap(); // On `null`, simply skip setting the field. if let Some(scalar) = d { - let value = - encode_field(t, scalar, &proto_field, false).map_err(|e| e.with_name(name))?; + let value = encode_field(t, scalar, &proto_field, false, custom_proto_type.clone()) + .map_err(|e| e.with_name(name))?; message .try_set_field(&proto_field, value) .map_err(|e| FieldEncodeError::new(e).with_name(name))?; @@ -284,6 +328,7 @@ fn encode_field( maybe: D, proto_field: &FieldDescriptor, in_repeated: bool, + custom_proto_type: CustomProtoType, ) -> Result { // Regarding (proto_field.is_list, in_repeated): // (F, T) => impossible @@ -307,7 +352,7 @@ fn encode_field( proto_field.kind() ))) }; - + let is_big_query = matches!(custom_proto_type, CustomProtoType::BigQuery); let value = match &data_type { // Group A: perfect match between RisingWave types and ProtoBuf types DataType::Boolean => match (expect_list, proto_field.kind()) { @@ -345,11 +390,11 @@ fn encode_field( _ => return no_match_err(), }, DataType::Struct(st) => match (expect_list, proto_field.kind()) { - (false, Kind::Message(pb)) => maybe.on_struct(st, &pb)?, + (false, Kind::Message(pb)) => maybe.on_struct(st, &pb, custom_proto_type)?, _ => return no_match_err(), }, DataType::List(elem) => match expect_list { - true => maybe.on_list(elem, proto_field)?, + true => maybe.on_list(elem, proto_field, custom_proto_type)?, false => return no_match_err(), }, // Group B: match between RisingWave types and ProtoBuf Well-Known types @@ -364,18 +409,61 @@ fn encode_field( Ok(Value::Message(message.transcode_to_dynamic())) })? } + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))? + } + _ => return no_match_err(), + }, + DataType::Jsonb => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))? + } + _ => return no_match_err(), /* Value, NullValue, Struct (map), ListValue + * Group C: experimental */ + }, + DataType::Int16 => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))? + } _ => return no_match_err(), }, - DataType::Jsonb => return no_match_err(), // Value, NullValue, Struct (map), ListValue - // Group C: experimental - DataType::Int16 => return no_match_err(), - DataType::Date => return no_match_err(), // google.type.Date - DataType::Time => return no_match_err(), // google.type.TimeOfDay - DataType::Timestamp => return no_match_err(), // google.type.DateTime - DataType::Decimal => return no_match_err(), // google.type.Decimal - DataType::Interval => return no_match_err(), - // Group D: unsupported - DataType::Serial | DataType::Int256 => { + DataType::Date => match (expect_list, proto_field.kind()) { + (false, Kind::Int32) if is_big_query => { + maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))? + } + _ => return no_match_err(), // google.type.Date + }, + DataType::Time => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))? + } + _ => return no_match_err(), // google.type.TimeOfDay + }, + DataType::Timestamp => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))? + } + _ => return no_match_err(), // google.type.DateTime + }, + DataType::Decimal => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))? + } + _ => return no_match_err(), // google.type.Decimal + }, + DataType::Interval => match (expect_list, proto_field.kind()) { + (false, Kind::String) if is_big_query => { + maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))? + } + _ => return no_match_err(), // Group D: unsupported + }, + DataType::Serial => match (expect_list, proto_field.kind()) { + (false, Kind::Int64) if is_big_query => { + maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))? + } + _ => return no_match_err(), // Group D: unsupported + }, + DataType::Int256 => { return no_match_err(); } }; @@ -398,7 +486,7 @@ mod tests { let pool_bytes = std::fs::read(pool_path).unwrap(); let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); let descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); - + println!("a{:?}", descriptor.descriptor_proto()); let schema = Schema::new(vec![ Field::with_name(DataType::Boolean, "bool_field"), Field::with_name(DataType::Varchar, "string_field"), @@ -441,8 +529,14 @@ mod tests { Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))), ]); - let encoder = - ProtoEncoder::new(schema, None, descriptor.clone(), ProtoHeader::None).unwrap(); + let encoder = ProtoEncoder::new( + schema, + None, + descriptor.clone(), + ProtoHeader::None, + CustomProtoType::None, + ) + .unwrap(); let m = encoder.encode(row).unwrap(); let encoded: Vec = m.ser_to().unwrap(); assert_eq!( @@ -480,6 +574,7 @@ mod tests { .iter() .map(|f| (f.name.as_str(), &f.data_type)), &message_descriptor, + CustomProtoType::None, ) .unwrap_err(); assert_eq!( @@ -505,6 +600,7 @@ mod tests { .map(|f| (f.name.as_str(), &f.data_type)) .zip_eq_debug(row.iter()), &message_descriptor, + CustomProtoType::None, ) .unwrap_err(); assert_eq!( @@ -524,6 +620,7 @@ mod tests { let err = validate_fields( std::iter::once(("not_exists", &DataType::Int16)), &message_descriptor, + CustomProtoType::None, ) .unwrap_err(); assert_eq!( @@ -534,6 +631,7 @@ mod tests { let err = validate_fields( std::iter::once(("map_field", &DataType::Jsonb)), &message_descriptor, + CustomProtoType::None, ) .unwrap_err(); assert_eq!( diff --git a/src/connector/src/sink/formatter/mod.rs b/src/connector/src/sink/formatter/mod.rs index d923d337a3ffb..1ce6675d7d456 100644 --- a/src/connector/src/sink/formatter/mod.rs +++ b/src/connector/src/sink/formatter/mod.rs @@ -29,7 +29,8 @@ pub use upsert::UpsertFormatter; use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc}; use super::encoder::template::TemplateEncoder; use super::encoder::{ - DateHandlingMode, KafkaConnectParams, TimeHandlingMode, TimestamptzHandlingMode, + CustomProtoType, DateHandlingMode, KafkaConnectParams, TimeHandlingMode, + TimestamptzHandlingMode, }; use super::redis::{KEY_FORMAT, VALUE_FORMAT}; use crate::sink::encoder::{ @@ -134,7 +135,13 @@ impl SinkFormatterImpl { None => ProtoHeader::None, Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid), }; - let val_encoder = ProtoEncoder::new(schema, None, descriptor, header)?; + let val_encoder = ProtoEncoder::new( + schema, + None, + descriptor, + header, + CustomProtoType::None, + )?; let formatter = AppendOnlyFormatter::new(key_encoder, val_encoder); Ok(SinkFormatterImpl::AppendOnlyProto(formatter)) } diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index f9d459fddfd9c..9c7b87ade0ef2 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -17,10 +17,6 @@ BigQueryConfig: - name: bigquery.table field_type: String required: true - - name: bigquery.max_batch_rows - field_type: usize - required: false - default: '1024' - name: region field_type: String required: false diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index bd33f5268aedb..28eded0121a63 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -31,7 +31,7 @@ aws-smithy-runtime = { version = "1", default-features = false, features = ["cli aws-smithy-types = { version = "1", default-features = false, features = ["byte-stream-poll-next", "http-body-0-4-x", "hyper-0-14-x", "rt-tokio"] } axum = { version = "0.6" } base64 = { version = "0.21" } -bigdecimal = { version = "0.4" } +bigdecimal = { version = "0.4", features = ["serde"] } bit-vec = { version = "0.6" } bitflags = { version = "2", default-features = false, features = ["serde", "std"] } byteorder = { version = "1" } @@ -61,6 +61,7 @@ futures-task = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } getrandom = { git = "https://github.com/madsim-rs/getrandom.git", rev = "e79a7ae", default-features = false, features = ["js", "rdrand", "std"] } +google-cloud-googleapis = { version = "0.12", default-features = false, features = ["bigquery", "pubsub"] } governor = { version = "0.6", default-features = false, features = ["dashmap", "jitter", "std"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["nightly", "raw"] } hashbrown-594e8ee84c453af0 = { package = "hashbrown", version = "0.13", features = ["raw"] } @@ -84,6 +85,7 @@ madsim-rdkafka = { version = "0.3", features = ["cmake-build", "gssapi", "ssl-ve madsim-tokio = { version = "0.2", default-features = false, features = ["fs", "io-util", "macros", "net", "process", "rt", "rt-multi-thread", "signal", "sync", "time", "tracing"] } md-5 = { version = "0.10" } memchr = { version = "2" } +mime_guess = { version = "2" } mio = { version = "0.8", features = ["net", "os-ext"] } moka = { version = "0.12", features = ["future", "sync"] } nom = { version = "7" } @@ -112,7 +114,7 @@ redis = { version = "0.24", features = ["async-std-comp", "tokio-comp"] } regex = { version = "1" } regex-automata = { version = "0.4", default-features = false, features = ["dfa", "hybrid", "meta", "nfa", "perf", "unicode"] } regex-syntax = { version = "0.8" } -reqwest = { version = "0.11", features = ["blocking", "json", "rustls-tls"] } +reqwest = { version = "0.11", features = ["blocking", "json", "multipart", "rustls-tls", "stream"] } ring = { version = "0.16", features = ["std"] } rust_decimal = { version = "1", features = ["db-postgres", "maths"] } rustc-hash = { version = "1" } @@ -181,6 +183,7 @@ lazy_static = { version = "1", default-features = false, features = ["spin_no_st libc = { version = "0.2", features = ["extra_traits"] } log = { version = "0.4", default-features = false, features = ["kv_unstable", "std"] } memchr = { version = "2" } +mime_guess = { version = "2" } nom = { version = "7" } num-bigint = { version = "0.4" } num-integer = { version = "0.1", features = ["i128"] } From eaba3b63e04192eb6704467b8eb77c74167e3dac Mon Sep 17 00:00:00 2001 From: xxhZs <1060434431@qq.com> Date: Tue, 19 Mar 2024 16:25:43 +0800 Subject: [PATCH 3/9] fix remove index fix fix --- src/connector/src/sink/big_query.rs | 49 +++++++++++++------------ src/connector/src/sink/encoder/proto.rs | 3 +- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index 96975c9967538..cc2d8b53e7c6f 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -345,8 +345,7 @@ impl BigQuerySinkWriter { .iter() .map(|f| (f.name.as_str(), &f.data_type)), config.common.table.clone(), - 1, - ); + )?; if !is_append_only { let field = FieldDescriptorProto { @@ -579,33 +578,33 @@ fn build_protobuf_descriptor_pool(desc: &DescriptorProto) -> prost_reflect::Desc fn build_protobuf_schema<'a>( fields: impl Iterator, name: String, - index: i32, -) -> DescriptorProto { +) -> Result { let mut proto = DescriptorProto { name: Some(name), ..Default::default() }; - let mut index_mut = index; - let mut field_vec = vec![]; let mut struct_vec = vec![]; - for (name, data_type) in fields { - let (field, des_proto) = build_protobuf_field(data_type, index_mut, name.to_string()); - field_vec.push(field); - if let Some(sv) = des_proto { - struct_vec.push(sv); - } - index_mut += 1; - } + let field_vec = fields + .enumerate() + .map(|(index, (name, data_type))| { + let (field, des_proto) = + build_protobuf_field(data_type, (index + 1) as i32, name.to_string())?; + if let Some(sv) = des_proto { + struct_vec.push(sv); + } + Ok(field) + }) + .collect::>>()?; proto.field = field_vec; proto.nested_type = struct_vec; - proto + Ok(proto) } fn build_protobuf_field( data_type: &DataType, index: i32, name: String, -) -> (FieldDescriptorProto, Option) { +) -> Result<(FieldDescriptorProto, Option)> { let mut field = FieldDescriptorProto { name: Some(name.clone()), number: Some(index), @@ -628,21 +627,25 @@ fn build_protobuf_field( DataType::Struct(s) => { field.r#type = Some(field_descriptor_proto::Type::Message.into()); let name = format!("Struct{}", name); - let sub_proto = build_protobuf_schema(s.iter(), name.clone(), 1); + let sub_proto = build_protobuf_schema(s.iter(), name.clone())?; field.type_name = Some(name); - return (field, Some(sub_proto)); + return Ok((field, Some(sub_proto))); } DataType::List(l) => { - let (mut field, proto) = build_protobuf_field(l.as_ref(), index, name.clone()); + let (mut field, proto) = build_protobuf_field(l.as_ref(), index, name.clone())?; field.label = Some(field_descriptor_proto::Label::Repeated.into()); - return (field, proto); + return Ok((field, proto)); } DataType::Bytea => field.r#type = Some(field_descriptor_proto::Type::Bytes.into()), DataType::Jsonb => field.r#type = Some(field_descriptor_proto::Type::String.into()), DataType::Serial => field.r#type = Some(field_descriptor_proto::Type::Int64.into()), - DataType::Float32 | DataType::Int256 => todo!(), + DataType::Float32 | DataType::Int256 => { + return Err(SinkError::BigQuery(anyhow::anyhow!( + "Don't support Float32 and Int256" + ))) + } } - (field, None) + Ok((field, None)) } #[cfg(test)] @@ -701,7 +704,7 @@ mod test { .fields() .iter() .map(|f| (f.name.as_str(), &f.data_type)); - let desc = build_protobuf_schema(fields, "t1".to_string(), 1); + let desc = build_protobuf_schema(fields, "t1".to_string()).unwrap(); let pool = build_protobuf_descriptor_pool(&desc); let t1_message = pool.get_message_by_name("t1").unwrap(); assert_matches!( diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs index 4d464e488ece7..4461c2a695efc 100644 --- a/src/connector/src/sink/encoder/proto.rs +++ b/src/connector/src/sink/encoder/proto.rs @@ -422,7 +422,7 @@ fn encode_field( * Group C: experimental */ }, DataType::Int16 => match (expect_list, proto_field.kind()) { - (false, Kind::String) if is_big_query => { + (false, Kind::Int64) if is_big_query => { maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))? } _ => return no_match_err(), @@ -486,7 +486,6 @@ mod tests { let pool_bytes = std::fs::read(pool_path).unwrap(); let pool = prost_reflect::DescriptorPool::decode(pool_bytes.as_ref()).unwrap(); let descriptor = pool.get_message_by_name("recursive.AllTypes").unwrap(); - println!("a{:?}", descriptor.descriptor_proto()); let schema = Schema::new(vec![ Field::with_name(DataType::Boolean, "bool_field"), Field::with_name(DataType::Varchar, "string_field"), From 9333ae2433f82778ad9181c6fef5c35d3bc24533 Mon Sep 17 00:00:00 2001 From: xxhZs <1060434431@qq.com> Date: Tue, 2 Apr 2024 15:20:11 +0800 Subject: [PATCH 4/9] add handling mode --- src/connector/src/sink/big_query.rs | 5 +- src/connector/src/sink/encoder/json.rs | 5 + src/connector/src/sink/encoder/mod.rs | 37 +++- src/connector/src/sink/encoder/proto.rs | 220 ++++++++++++++++++------ src/connector/src/sink/formatter/mod.rs | 12 +- 5 files changed, 204 insertions(+), 75 deletions(-) diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index cc2d8b53e7c6f..a67664c93a035 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -48,7 +48,7 @@ use uuid::Uuid; use with_options::WithOptions; use yup_oauth2::ServiceAccountKey; -use super::encoder::{CustomProtoType, ProtoEncoder, ProtoHeader, RowEncoder, SerTo}; +use super::encoder::{ProtoEncoder, ProtoHeader, RowEncoder, SerTo}; use super::writer::LogSinkerOf; use super::{SinkError, SINK_TYPE_APPEND_ONLY, SINK_TYPE_OPTION, SINK_TYPE_UPSERT}; use crate::aws_utils::load_file_descriptor_from_s3; @@ -366,12 +366,11 @@ impl BigQuerySinkWriter { &config.common.table )) })?; - let row_encoder = ProtoEncoder::new( + let row_encoder = ProtoEncoder::new_with_bigquery( schema.clone(), None, message_descriptor.clone(), ProtoHeader::None, - CustomProtoType::BigQuery, )?; Ok(Self { write_stream: format!( diff --git a/src/connector/src/sink/encoder/json.rs b/src/connector/src/sink/encoder/json.rs index 006500c60914d..29ac6fa4d0765 100644 --- a/src/connector/src/sink/encoder/json.rs +++ b/src/connector/src/sink/encoder/json.rs @@ -256,6 +256,11 @@ fn datum_to_json_object( } TimestamptzHandlingMode::Micro => json!(v.timestamp_micros()), TimestamptzHandlingMode::Milli => json!(v.timestamp_millis()), + TimestamptzHandlingMode::PbMessage => { + return Err(ArrayError::internal( + "TimestamptzHandlingMode::PbMessage only support for proto format".to_string(), + )) + } }, (DataType::Time, ScalarRefImpl::Time(v)) => match time_handling_mode { TimeHandlingMode::Milli => { diff --git a/src/connector/src/sink/encoder/mod.rs b/src/connector/src/sink/encoder/mod.rs index 4b4807f291bc0..f5a933032db5c 100644 --- a/src/connector/src/sink/encoder/mod.rs +++ b/src/connector/src/sink/encoder/mod.rs @@ -110,6 +110,37 @@ pub enum TimestamptzHandlingMode { UtcWithoutSuffix, Micro, Milli, + PbMessage, +} + +#[derive(Clone, Copy)] +pub enum JsonbHandlingMode { + Jsonb, + String, +} + +#[derive(Clone, Copy)] +pub enum Int16HandlingMode { + Int16, + Int64, +} + +#[derive(Clone, Copy)] +pub enum DecimalHandlingMode { + Decimal, + String, +} + +#[derive(Clone, Copy)] +pub enum IntervalHandlingMode { + Interval, + String, +} + +#[derive(Clone, Copy)] +pub enum SerialHandlingMode { + Serial, + Int64, } impl TimestamptzHandlingMode { @@ -147,12 +178,6 @@ pub enum CustomJsonType { None, } -#[derive(Clone)] -pub enum CustomProtoType { - BigQuery, - None, -} - #[derive(Debug)] struct FieldEncodeError { message: String, diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs index 4461c2a695efc..a5b9aee065cdd 100644 --- a/src/connector/src/sink/encoder/proto.rs +++ b/src/connector/src/sink/encoder/proto.rs @@ -22,7 +22,11 @@ use risingwave_common::row::Row; use risingwave_common::types::{DataType, DatumRef, ScalarRefImpl, StructType}; use risingwave_common::util::iter_util::ZipEqDebug; -use super::{CustomProtoType, FieldEncodeError, Result as SinkResult, RowEncoder, SerTo}; +use super::{ + DateHandlingMode, DecimalHandlingMode, FieldEncodeError, Int16HandlingMode, + IntervalHandlingMode, JsonbHandlingMode, Result as SinkResult, RowEncoder, SerTo, + SerialHandlingMode, TimeHandlingMode, TimestampHandlingMode, TimestamptzHandlingMode, +}; type Result = std::result::Result; @@ -31,7 +35,7 @@ pub struct ProtoEncoder { col_indices: Option>, descriptor: MessageDescriptor, header: ProtoHeader, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, } #[derive(Debug, Clone, Copy)] @@ -44,13 +48,55 @@ pub enum ProtoHeader { ConfluentSchemaRegistry(i32), } +#[derive(Clone, Copy)] +struct ProtoHandlingModes { + time_handling_mode: TimeHandlingMode, + date_handling_mode: DateHandlingMode, + timestamp_handling_mode: TimestampHandlingMode, + timestamptz_handling_mode: TimestamptzHandlingMode, + json_handling_mode: JsonbHandlingMode, + int16_handling_mode: Int16HandlingMode, + decimal_handling_mode: DecimalHandlingMode, + interval_handling_mode: IntervalHandlingMode, + serial_handling_mode: SerialHandlingMode, +} +impl ProtoHandlingModes { + pub fn new_with_default() -> Self { + Self { + time_handling_mode: TimeHandlingMode::Milli, + date_handling_mode: DateHandlingMode::FromEpoch, + timestamp_handling_mode: TimestampHandlingMode::Milli, + timestamptz_handling_mode: TimestamptzHandlingMode::PbMessage, + json_handling_mode: JsonbHandlingMode::Jsonb, + int16_handling_mode: Int16HandlingMode::Int16, + decimal_handling_mode: DecimalHandlingMode::Decimal, + interval_handling_mode: IntervalHandlingMode::Interval, + serial_handling_mode: SerialHandlingMode::Serial, + } + } + + pub fn new_with_bigquery() -> Self { + Self { + time_handling_mode: TimeHandlingMode::String, + date_handling_mode: DateHandlingMode::FromCe, + timestamp_handling_mode: TimestampHandlingMode::String, + timestamptz_handling_mode: TimestamptzHandlingMode::UtcString, + json_handling_mode: JsonbHandlingMode::String, + int16_handling_mode: Int16HandlingMode::Int64, + decimal_handling_mode: DecimalHandlingMode::String, + interval_handling_mode: IntervalHandlingMode::String, + serial_handling_mode: SerialHandlingMode::Int64, + } + } +} + impl ProtoEncoder { - pub fn new( + fn new( schema: Schema, col_indices: Option>, descriptor: MessageDescriptor, header: ProtoHeader, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> SinkResult { match &col_indices { Some(col_indices) => validate_fields( @@ -59,7 +105,7 @@ impl ProtoEncoder { (f.name.as_str(), &f.data_type) }), &descriptor, - custom_proto_type.clone(), + proto_handling_modes, )?, None => validate_fields( schema @@ -67,7 +113,7 @@ impl ProtoEncoder { .iter() .map(|f| (f.name.as_str(), &f.data_type)), &descriptor, - custom_proto_type.clone(), + proto_handling_modes, )?, }; @@ -76,9 +122,39 @@ impl ProtoEncoder { col_indices, descriptor, header, - custom_proto_type, + proto_handling_modes, }) } + + pub fn new_with_default( + schema: Schema, + col_indices: Option>, + descriptor: MessageDescriptor, + header: ProtoHeader, + ) -> SinkResult { + Self::new( + schema, + col_indices, + descriptor, + header, + ProtoHandlingModes::new_with_default(), + ) + } + + pub fn new_with_bigquery( + schema: Schema, + col_indices: Option>, + descriptor: MessageDescriptor, + header: ProtoHeader, + ) -> SinkResult { + Self::new( + schema, + col_indices, + descriptor, + header, + ProtoHandlingModes::new_with_default(), + ) + } } pub struct ProtoEncoded { @@ -108,7 +184,7 @@ impl RowEncoder for ProtoEncoder { ((f.name.as_str(), &f.data_type), row.datum_at(idx)) }), &self.descriptor, - self.custom_proto_type.clone(), + self.proto_handling_modes, ) .map_err(Into::into) .map(|m| ProtoEncoded { @@ -190,14 +266,14 @@ trait MaybeData: std::fmt::Debug { self, st: &StructType, pb: &MessageDescriptor, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> Result; fn on_list( self, elem: &DataType, pb: &FieldDescriptor, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> Result; } @@ -212,18 +288,18 @@ impl MaybeData for () { self, st: &StructType, pb: &MessageDescriptor, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> Result { - validate_fields(st.iter(), pb, custom_proto_type) + validate_fields(st.iter(), pb, proto_handling_modes) } fn on_list( self, elem: &DataType, pb: &FieldDescriptor, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> Result { - encode_field(elem, (), pb, true, custom_proto_type) + encode_field(elem, (), pb, true, proto_handling_modes) } } @@ -243,13 +319,13 @@ impl MaybeData for ScalarRefImpl<'_> { self, st: &StructType, pb: &MessageDescriptor, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> Result { let d = self.into_struct(); let message = encode_fields( st.iter().zip_eq_debug(d.iter_fields_ref()), pb, - custom_proto_type, + proto_handling_modes, )?; Ok(Value::Message(message)) } @@ -258,7 +334,7 @@ impl MaybeData for ScalarRefImpl<'_> { self, elem: &DataType, pb: &FieldDescriptor, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> Result { let d = self.into_list(); let vs = d @@ -271,7 +347,7 @@ impl MaybeData for ScalarRefImpl<'_> { })?, pb, true, - custom_proto_type.clone(), + proto_handling_modes, ) }) .try_collect()?; @@ -282,7 +358,7 @@ impl MaybeData for ScalarRefImpl<'_> { fn validate_fields<'a>( fields: impl Iterator, descriptor: &MessageDescriptor, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> Result<()> { for (name, t) in fields { let Some(proto_field) = descriptor.get_field_by_name(name) else { @@ -291,7 +367,7 @@ fn validate_fields<'a>( if proto_field.cardinality() == prost_reflect::Cardinality::Required { return Err(FieldEncodeError::new("`required` not supported").with_name(name)); } - encode_field(t, (), &proto_field, false, custom_proto_type.clone()) + encode_field(t, (), &proto_field, false, proto_handling_modes) .map_err(|e| e.with_name(name))?; } Ok(()) @@ -300,14 +376,14 @@ fn validate_fields<'a>( fn encode_fields<'a>( fields_with_datums: impl Iterator)>, descriptor: &MessageDescriptor, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> Result { let mut message = DynamicMessage::new(descriptor.clone()); for ((name, t), d) in fields_with_datums { let proto_field = descriptor.get_field_by_name(name).unwrap(); // On `null`, simply skip setting the field. if let Some(scalar) = d { - let value = encode_field(t, scalar, &proto_field, false, custom_proto_type.clone()) + let value = encode_field(t, scalar, &proto_field, false, proto_handling_modes) .map_err(|e| e.with_name(name))?; message .try_set_field(&proto_field, value) @@ -328,7 +404,7 @@ fn encode_field( maybe: D, proto_field: &FieldDescriptor, in_repeated: bool, - custom_proto_type: CustomProtoType, + proto_handling_modes: ProtoHandlingModes, ) -> Result { // Regarding (proto_field.is_list, in_repeated): // (F, T) => impossible @@ -352,7 +428,6 @@ fn encode_field( proto_field.kind() ))) }; - let is_big_query = matches!(custom_proto_type, CustomProtoType::BigQuery); let value = match &data_type { // Group A: perfect match between RisingWave types and ProtoBuf types DataType::Boolean => match (expect_list, proto_field.kind()) { @@ -390,16 +465,20 @@ fn encode_field( _ => return no_match_err(), }, DataType::Struct(st) => match (expect_list, proto_field.kind()) { - (false, Kind::Message(pb)) => maybe.on_struct(st, &pb, custom_proto_type)?, + (false, Kind::Message(pb)) => maybe.on_struct(st, &pb, proto_handling_modes)?, _ => return no_match_err(), }, DataType::List(elem) => match expect_list { - true => maybe.on_list(elem, proto_field, custom_proto_type)?, + true => maybe.on_list(elem, proto_field, proto_handling_modes)?, false => return no_match_err(), }, // Group B: match between RisingWave types and ProtoBuf Well-Known types - DataType::Timestamptz => match (expect_list, proto_field.kind()) { - (false, Kind::Message(pb)) if pb.full_name() == WKT_TIMESTAMP => { + DataType::Timestamptz => match ( + expect_list, + proto_field.kind(), + proto_handling_modes.timestamptz_handling_mode, + ) { + (false, Kind::Message(pb), _) if pb.full_name() == WKT_TIMESTAMP => { maybe.on_base(|s| { let d = s.into_timestamptz(); let message = prost_types::Timestamp { @@ -409,56 +488,88 @@ fn encode_field( Ok(Value::Message(message.transcode_to_dynamic())) })? } - (false, Kind::String) if is_big_query => { + (false, Kind::String, TimestamptzHandlingMode::UtcString) => { maybe.on_base(|s| Ok(Value::String(s.into_timestamptz().to_string())))? } _ => return no_match_err(), }, - DataType::Jsonb => match (expect_list, proto_field.kind()) { - (false, Kind::String) if is_big_query => { + DataType::Jsonb => match ( + expect_list, + proto_field.kind(), + proto_handling_modes.json_handling_mode, + ) { + (false, Kind::String, JsonbHandlingMode::String) => { maybe.on_base(|s| Ok(Value::String(s.into_jsonb().to_string())))? } _ => return no_match_err(), /* Value, NullValue, Struct (map), ListValue * Group C: experimental */ }, - DataType::Int16 => match (expect_list, proto_field.kind()) { - (false, Kind::Int64) if is_big_query => { + DataType::Int16 => match ( + expect_list, + proto_field.kind(), + proto_handling_modes.int16_handling_mode, + ) { + (false, Kind::Int64, Int16HandlingMode::Int64) => { maybe.on_base(|s| Ok(Value::I64(s.into_int16() as i64)))? } _ => return no_match_err(), }, - DataType::Date => match (expect_list, proto_field.kind()) { - (false, Kind::Int32) if is_big_query => { + DataType::Date => match ( + expect_list, + proto_field.kind(), + proto_handling_modes.date_handling_mode, + ) { + (false, Kind::Int32, DateHandlingMode::FromCe) => { maybe.on_base(|s| Ok(Value::I32(s.into_date().get_nums_days_unix_epoch())))? } _ => return no_match_err(), // google.type.Date }, - DataType::Time => match (expect_list, proto_field.kind()) { - (false, Kind::String) if is_big_query => { + DataType::Time => match ( + expect_list, + proto_field.kind(), + proto_handling_modes.time_handling_mode, + ) { + (false, Kind::String, TimeHandlingMode::String) => { maybe.on_base(|s| Ok(Value::String(s.into_time().to_string())))? } _ => return no_match_err(), // google.type.TimeOfDay }, - DataType::Timestamp => match (expect_list, proto_field.kind()) { - (false, Kind::String) if is_big_query => { + DataType::Timestamp => match ( + expect_list, + proto_field.kind(), + proto_handling_modes.timestamp_handling_mode, + ) { + (false, Kind::String, TimestampHandlingMode::String) => { maybe.on_base(|s| Ok(Value::String(s.into_timestamp().to_string())))? } _ => return no_match_err(), // google.type.DateTime }, - DataType::Decimal => match (expect_list, proto_field.kind()) { - (false, Kind::String) if is_big_query => { + DataType::Decimal => match ( + expect_list, + proto_field.kind(), + proto_handling_modes.decimal_handling_mode, + ) { + (false, Kind::String, DecimalHandlingMode::String) => { maybe.on_base(|s| Ok(Value::String(s.into_decimal().to_string())))? } _ => return no_match_err(), // google.type.Decimal }, - DataType::Interval => match (expect_list, proto_field.kind()) { - (false, Kind::String) if is_big_query => { + DataType::Interval => match ( + expect_list, + proto_field.kind(), + proto_handling_modes.interval_handling_mode, + ) { + (false, Kind::String, IntervalHandlingMode::String) => { maybe.on_base(|s| Ok(Value::String(s.into_interval().as_iso_8601())))? } _ => return no_match_err(), // Group D: unsupported }, - DataType::Serial => match (expect_list, proto_field.kind()) { - (false, Kind::Int64) if is_big_query => { + DataType::Serial => match ( + expect_list, + proto_field.kind(), + proto_handling_modes.serial_handling_mode, + ) { + (false, Kind::Int64, SerialHandlingMode::Int64) => { maybe.on_base(|s| Ok(Value::I64(s.into_serial().as_row_id())))? } _ => return no_match_err(), // Group D: unsupported @@ -528,14 +639,9 @@ mod tests { Some(ScalarImpl::Timestamptz(Timestamptz::from_micros(3))), ]); - let encoder = ProtoEncoder::new( - schema, - None, - descriptor.clone(), - ProtoHeader::None, - CustomProtoType::None, - ) - .unwrap(); + let encoder = + ProtoEncoder::new_with_default(schema, None, descriptor.clone(), ProtoHeader::None) + .unwrap(); let m = encoder.encode(row).unwrap(); let encoded: Vec = m.ser_to().unwrap(); assert_eq!( @@ -573,7 +679,7 @@ mod tests { .iter() .map(|f| (f.name.as_str(), &f.data_type)), &message_descriptor, - CustomProtoType::None, + ProtoHandlingModes::new_with_default(), ) .unwrap_err(); assert_eq!( @@ -599,7 +705,7 @@ mod tests { .map(|f| (f.name.as_str(), &f.data_type)) .zip_eq_debug(row.iter()), &message_descriptor, - CustomProtoType::None, + ProtoHandlingModes::new_with_default(), ) .unwrap_err(); assert_eq!( @@ -619,7 +725,7 @@ mod tests { let err = validate_fields( std::iter::once(("not_exists", &DataType::Int16)), &message_descriptor, - CustomProtoType::None, + ProtoHandlingModes::new_with_default(), ) .unwrap_err(); assert_eq!( @@ -630,7 +736,7 @@ mod tests { let err = validate_fields( std::iter::once(("map_field", &DataType::Jsonb)), &message_descriptor, - CustomProtoType::None, + ProtoHandlingModes::new_with_default(), ) .unwrap_err(); assert_eq!( diff --git a/src/connector/src/sink/formatter/mod.rs b/src/connector/src/sink/formatter/mod.rs index 1ce6675d7d456..d6be02c70aea3 100644 --- a/src/connector/src/sink/formatter/mod.rs +++ b/src/connector/src/sink/formatter/mod.rs @@ -29,8 +29,7 @@ pub use upsert::UpsertFormatter; use super::catalog::{SinkEncode, SinkFormat, SinkFormatDesc}; use super::encoder::template::TemplateEncoder; use super::encoder::{ - CustomProtoType, DateHandlingMode, KafkaConnectParams, TimeHandlingMode, - TimestamptzHandlingMode, + DateHandlingMode, KafkaConnectParams, TimeHandlingMode, TimestamptzHandlingMode, }; use super::redis::{KEY_FORMAT, VALUE_FORMAT}; use crate::sink::encoder::{ @@ -135,13 +134,8 @@ impl SinkFormatterImpl { None => ProtoHeader::None, Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid), }; - let val_encoder = ProtoEncoder::new( - schema, - None, - descriptor, - header, - CustomProtoType::None, - )?; + let val_encoder = + ProtoEncoder::new_with_default(schema, None, descriptor, header)?; let formatter = AppendOnlyFormatter::new(key_encoder, val_encoder); Ok(SinkFormatterImpl::AppendOnlyProto(formatter)) } From a1d6ac131a7080550fada37dcfa9349b5b7f2738 Mon Sep 17 00:00:00 2001 From: xxhZs Date: Tue, 2 Apr 2024 08:11:35 +0000 Subject: [PATCH 5/9] Fix "cargo-hakari" --- Cargo.lock | 14 +------------- src/workspace-hack/Cargo.toml | 8 ++------ 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8a16737991477..0d7016d3b4a8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4209,7 +4209,6 @@ checksum = "1657b4441c3403d9f7b3409e47575237dac27b1b5726df654a6ecbf92f0f7577" dependencies = [ "futures-core", "futures-sink", - "nanorand", "pin-project", "spin 0.9.8", ] @@ -6555,15 +6554,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "034a0ad7deebf0c2abcf2435950a6666c3c15ea9d8fad0c0f48efa8a7f843fed" -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" -dependencies = [ - "getrandom", -] - [[package]] name = "native-tls" version = "0.2.11" @@ -14514,7 +14504,6 @@ dependencies = [ "either", "fail", "flate2", - "flume 0.10.14", "frunk_core", "futures", "futures-channel", @@ -14525,7 +14514,6 @@ dependencies = [ "futures-task", "futures-util", "generic-array", - "getrandom", "google-cloud-googleapis", "governor", "hashbrown 0.13.2", @@ -14599,6 +14587,7 @@ dependencies = [ "sqlx-postgres", "sqlx-sqlite", "strum 0.25.0", + "strum 0.26.1", "subtle", "syn 1.0.109", "syn 2.0.57", @@ -14611,7 +14600,6 @@ dependencies = [ "tokio-stream", "tokio-util", "toml_datetime", - "toml_edit 0.19.15", "tonic 0.10.2", "tower", "tracing", diff --git a/src/workspace-hack/Cargo.toml b/src/workspace-hack/Cargo.toml index adca77b2886f7..9f03850f08cfc 100644 --- a/src/workspace-hack/Cargo.toml +++ b/src/workspace-hack/Cargo.toml @@ -49,7 +49,6 @@ digest = { version = "0.10", features = ["mac", "oid", "std"] } either = { version = "1", features = ["serde"] } fail = { version = "0.5", default-features = false, features = ["failpoints"] } flate2 = { version = "1", features = ["zlib"] } -flume = { version = "0.10" } frunk_core = { version = "0.4", default-features = false, features = ["std"] } futures = { version = "0.3" } futures-channel = { version = "0.3", features = ["sink"] } @@ -60,7 +59,6 @@ futures-sink = { version = "0.3" } futures-task = { version = "0.3" } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } -getrandom = { git = "https://github.com/madsim-rs/getrandom.git", rev = "e79a7ae", default-features = false, features = ["js", "rdrand", "std"] } google-cloud-googleapis = { version = "0.12", default-features = false, features = ["bigquery", "pubsub"] } governor = { version = "0.6", default-features = false, features = ["dashmap", "jitter", "std"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["nightly", "raw"] } @@ -133,7 +131,8 @@ sqlx-core = { version = "0.7", features = ["_rt-tokio", "_tls-native-tls", "bigd sqlx-mysql = { version = "0.7", default-features = false, features = ["bigdecimal", "chrono", "json", "rust_decimal", "time", "uuid"] } sqlx-postgres = { version = "0.7", default-features = false, features = ["bigdecimal", "chrono", "json", "rust_decimal", "time", "uuid"] } sqlx-sqlite = { version = "0.7", default-features = false, features = ["chrono", "json", "time", "uuid"] } -strum = { version = "0.25", features = ["derive"] } +strum-2f80eeee3b1b6c7e = { package = "strum", version = "0.26", features = ["derive"] } +strum-2ffb4c3fe830441c = { package = "strum", version = "0.25", features = ["derive"] } subtle = { version = "2" } syn-dff4ba8e3ae991db = { package = "syn", version = "1", features = ["extra-traits", "full", "visit", "visit-mut"] } target-lexicon = { version = "0.12", features = ["std"] } @@ -144,7 +143,6 @@ tokio-postgres = { git = "https://github.com/madsim-rs/rust-postgres.git", rev = tokio-stream = { git = "https://github.com/madsim-rs/tokio.git", rev = "fe39bb8e", features = ["fs", "net"] } tokio-util = { version = "0.7", features = ["codec", "io"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } -toml_edit = { version = "0.19", features = ["serde"] } tonic = { version = "0.10", features = ["gzip", "tls-webpki-roots"] } tower = { version = "0.4", features = ["balance", "buffer", "filter", "limit", "load-shed", "timeout", "util"] } tracing = { version = "0.1", features = ["log"] } @@ -174,7 +172,6 @@ digest = { version = "0.10", features = ["mac", "oid", "std"] } either = { version = "1", features = ["serde"] } frunk_core = { version = "0.4", default-features = false, features = ["std"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } -getrandom = { git = "https://github.com/madsim-rs/getrandom.git", rev = "e79a7ae", default-features = false, features = ["js", "rdrand", "std"] } hashbrown-582f2526e08bb6a0 = { package = "hashbrown", version = "0.14", features = ["nightly", "raw"] } indexmap-f595c2ba2a3f28df = { package = "indexmap", version = "2", features = ["serde"] } itertools = { version = "0.11" } @@ -212,7 +209,6 @@ target-lexicon = { version = "0.12", features = ["std"] } time = { version = "0.3", features = ["local-offset", "macros", "serde-well-known"] } time-macros = { version = "0.2", default-features = false, features = ["formatting", "parsing", "serde"] } toml_datetime = { version = "0.6", default-features = false, features = ["serde"] } -toml_edit = { version = "0.19", features = ["serde"] } zeroize = { version = "1" } ### END HAKARI SECTION From 4605ae0a679262576bba1e1bdf9e670862ca0c9e Mon Sep 17 00:00:00 2001 From: xxhZs <1060434431@qq.com> Date: Tue, 9 Apr 2024 13:05:23 +0800 Subject: [PATCH 6/9] fix --- Cargo.lock | 2 +- src/connector/src/sink/big_query.rs | 2 +- src/connector/src/sink/encoder/proto.rs | 21 +-------------------- src/connector/src/sink/formatter/mod.rs | 3 +-- 4 files changed, 4 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8fb0405fa50a4..bcd2f0971da2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8516,7 +8516,7 @@ dependencies = [ "indoc", "libc", "memoffset", - "parking_lot 0.11.2", + "parking_lot 0.12.1", "portable-atomic", "pyo3-build-config", "pyo3-ffi", diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index 30f392b8268e8..f3933ff860294 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -366,7 +366,7 @@ impl BigQuerySinkWriter { &config.common.table )) })?; - let row_encoder = ProtoEncoder::new_with_bigquery( + let row_encoder = ProtoEncoder::new( schema.clone(), None, message_descriptor.clone(), diff --git a/src/connector/src/sink/encoder/proto.rs b/src/connector/src/sink/encoder/proto.rs index bca69ed91fd80..3f50b3d97ff26 100644 --- a/src/connector/src/sink/encoder/proto.rs +++ b/src/connector/src/sink/encoder/proto.rs @@ -74,24 +74,6 @@ impl ProtoEncoder { header, }) } - - pub fn new_with_default( - schema: Schema, - col_indices: Option>, - descriptor: MessageDescriptor, - header: ProtoHeader, - ) -> SinkResult { - Self::new(schema, col_indices, descriptor, header) - } - - pub fn new_with_bigquery( - schema: Schema, - col_indices: Option>, - descriptor: MessageDescriptor, - header: ProtoHeader, - ) -> SinkResult { - Self::new(schema, col_indices, descriptor, header) - } } pub struct ProtoEncoded { @@ -499,8 +481,7 @@ mod tests { ]); let encoder = - ProtoEncoder::new_with_default(schema, None, descriptor.clone(), ProtoHeader::None) - .unwrap(); + ProtoEncoder::new(schema, None, descriptor.clone(), ProtoHeader::None).unwrap(); let m = encoder.encode(row).unwrap(); let encoded: Vec = m.ser_to().unwrap(); assert_eq!( diff --git a/src/connector/src/sink/formatter/mod.rs b/src/connector/src/sink/formatter/mod.rs index d6be02c70aea3..d923d337a3ffb 100644 --- a/src/connector/src/sink/formatter/mod.rs +++ b/src/connector/src/sink/formatter/mod.rs @@ -134,8 +134,7 @@ impl SinkFormatterImpl { None => ProtoHeader::None, Some(sid) => ProtoHeader::ConfluentSchemaRegistry(sid), }; - let val_encoder = - ProtoEncoder::new_with_default(schema, None, descriptor, header)?; + let val_encoder = ProtoEncoder::new(schema, None, descriptor, header)?; let formatter = AppendOnlyFormatter::new(key_encoder, val_encoder); Ok(SinkFormatterImpl::AppendOnlyProto(formatter)) } From 11f2ffb25bfdbca934f7fff1b57b3b7d5d753bbd Mon Sep 17 00:00:00 2001 From: xxhZs <1060434431@qq.com> Date: Tue, 9 Apr 2024 15:59:04 +0800 Subject: [PATCH 7/9] save fix --- src/connector/src/sink/big_query.rs | 99 ++++++++++++++++++---------- src/connector/with_options_sink.yaml | 4 ++ 2 files changed, 67 insertions(+), 36 deletions(-) diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index f3933ff860294..43f9d908c27a9 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use core::mem; use core::time::Duration; use std::collections::HashMap; use std::sync::Arc; @@ -32,7 +33,7 @@ use google_cloud_googleapis::cloud::bigquery::storage::v1::{ }; use google_cloud_pubsub::client::google_cloud_auth; use google_cloud_pubsub::client::google_cloud_auth::credentials::CredentialsFile; -use prost_reflect::MessageDescriptor; +use prost_reflect::{FieldDescriptor, MessageDescriptor}; use prost_types::{ field_descriptor_proto, DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet, @@ -42,7 +43,7 @@ use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; use risingwave_common::types::DataType; use serde_derive::Deserialize; -use serde_with::serde_as; +use serde_with::{serde_as, DisplayFromStr}; use url::Url; use uuid::Uuid; use with_options::WithOptions; @@ -77,6 +78,9 @@ pub struct BigQueryCommon { pub dataset: String, #[serde(rename = "bigquery.table")] pub table: String, + #[serde(rename = "bigquery.max_batch_rows", default = "default_max_batch_rows")] + #[serde_as(as = "DisplayFromStr")] + pub max_batch_rows: usize, } fn default_max_batch_rows() -> usize { @@ -311,6 +315,9 @@ pub struct BigQuerySinkWriter { writer_pb_schema: ProtoSchema, message_descriptor: MessageDescriptor, write_stream: String, + proto_field: Option, + write_rows: Vec, + write_rows_count: usize, } impl TryFrom for BigQuerySink { @@ -366,6 +373,16 @@ impl BigQuerySinkWriter { &config.common.table )) })?; + let proto_field = if !is_append_only { + let proto_field = message_descriptor + .get_field_by_name(CHANGE_TYPE) + .ok_or_else(|| { + SinkError::BigQuery(anyhow::anyhow!("Can't find {}", CHANGE_TYPE)) + })?; + Some(proto_field) + } else { + None + }; let row_encoder = ProtoEncoder::new( schema.clone(), None, @@ -384,53 +401,45 @@ impl BigQuerySinkWriter { is_append_only, row_encoder, message_descriptor, + proto_field, writer_pb_schema: ProtoSchema { proto_descriptor: Some(descriptor_proto), }, + write_rows: vec![], + write_rows_count: 0, }) } - async fn append_only(&mut self, chunk: StreamChunk) -> Result<()> { + fn append_only(&mut self, chunk: StreamChunk) -> Result>> { let mut serialized_rows: Vec> = Vec::with_capacity(chunk.capacity()); for (op, row) in chunk.rows() { if op != Op::Insert { continue; } - serialized_rows.push(self.row_encoder.encode(row)?.ser_to()?) } - let rows = AppendRowsRequestRows::ProtoRows(ProtoData { - writer_schema: Some(self.writer_pb_schema.clone()), - rows: Some(ProtoRows { serialized_rows }), - }); - self.client - .append_rows(vec![rows], self.write_stream.clone()) - .await?; - Ok(()) + Ok(serialized_rows) } - async fn upsert(&mut self, chunk: StreamChunk) -> Result<()> { + fn upsert(&mut self, chunk: StreamChunk) -> Result>> { let mut serialized_rows: Vec> = Vec::with_capacity(chunk.capacity()); for (op, row) in chunk.rows() { + if op == Op::UpdateDelete { + continue; + } let mut pb_row = self.row_encoder.encode(row)?; - let proto_field = self - .message_descriptor - .get_field_by_name(CHANGE_TYPE) - .ok_or_else(|| { - SinkError::BigQuery(anyhow::anyhow!("Can't find {}", CHANGE_TYPE)) - })?; match op { Op::Insert => pb_row .message .try_set_field( - &proto_field, + self.proto_field.as_ref().unwrap(), prost_reflect::Value::String("INSERT".to_string()), ) .map_err(|e| SinkError::BigQuery(e.into()))?, Op::Delete => pb_row .message .try_set_field( - &proto_field, + self.proto_field.as_ref().unwrap(), prost_reflect::Value::String("DELETE".to_string()), ) .map_err(|e| SinkError::BigQuery(e.into()))?, @@ -438,7 +447,7 @@ impl BigQuerySinkWriter { Op::UpdateInsert => pb_row .message .try_set_field( - &proto_field, + self.proto_field.as_ref().unwrap(), prost_reflect::Value::String("UPSERT".to_string()), ) .map_err(|e| SinkError::BigQuery(e.into()))?, @@ -446,12 +455,17 @@ impl BigQuerySinkWriter { serialized_rows.push(pb_row.ser_to()?) } - let rows = AppendRowsRequestRows::ProtoRows(ProtoData { - writer_schema: Some(self.writer_pb_schema.clone()), - rows: Some(ProtoRows { serialized_rows }), - }); + Ok(serialized_rows) + } + + async fn write_rows(&mut self) -> Result<()> { + if self.write_rows.is_empty() { + return Ok(()); + } + let rows = mem::take(&mut self.write_rows); + self.write_rows_count = 0; self.client - .append_rows(vec![rows], self.write_stream.clone()) + .append_rows(rows, self.write_stream.clone()) .await?; Ok(()) } @@ -460,14 +474,27 @@ impl BigQuerySinkWriter { #[async_trait] impl SinkWriter for BigQuerySinkWriter { async fn write_batch(&mut self, chunk: StreamChunk) -> Result<()> { - if self.is_append_only { - self.append_only(chunk).await + let serialized_rows = if self.is_append_only { + self.append_only(chunk)? } else { - self.upsert(chunk).await + self.upsert(chunk)? + }; + self.write_rows_count += serialized_rows.len(); + let rows = AppendRowsRequestRows::ProtoRows(ProtoData { + writer_schema: Some(self.writer_pb_schema.clone()), + rows: Some(ProtoRows { serialized_rows }), + }); + self.write_rows.push(rows); + + if self.write_rows_count >= self.config.common.max_batch_rows { + self.write_rows().await?; } + + Ok(()) } async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> { + self.write_rows().await?; Ok(()) } @@ -521,27 +548,27 @@ impl StorageWriterClient { rows: Vec, write_stream: String, ) -> Result<()> { - let trace_id = Uuid::new_v4().hyphenated().to_string(); let append_req: Vec = rows .into_iter() .map(|row| AppendRowsRequest { write_stream: write_stream.clone(), offset: None, - trace_id: trace_id.clone(), + trace_id: Uuid::new_v4().hyphenated().to_string(), missing_value_interpretations: HashMap::default(), rows: Some(row), }) .collect(); - let resp = self + let mut resp = self .client .append_rows(Request::new(tokio_stream::iter(append_req))) .await .map_err(|e| SinkError::BigQuery(e.into()))? - .into_inner() + .into_inner(); + while let Some(i) = resp .message() .await - .map_err(|e| SinkError::BigQuery(e.into()))?; - if let Some(i) = resp { + .map_err(|e| SinkError::BigQuery(e.into()))? + { if !i.row_errors.is_empty() { return Err(SinkError::BigQuery(anyhow::anyhow!( "Insert error {:?}", diff --git a/src/connector/with_options_sink.yaml b/src/connector/with_options_sink.yaml index 9a4dcc25a0bcb..b287bcd6aa4b4 100644 --- a/src/connector/with_options_sink.yaml +++ b/src/connector/with_options_sink.yaml @@ -17,6 +17,10 @@ BigQueryConfig: - name: bigquery.table field_type: String required: true + - name: bigquery.max_batch_rows + field_type: usize + required: false + default: '1024' - name: region field_type: String required: false From bd5181bd8ae88c7343829c41475c87cf849b626f Mon Sep 17 00:00:00 2001 From: xxhZs <1060434431@qq.com> Date: Tue, 9 Apr 2024 18:59:15 +0800 Subject: [PATCH 8/9] fix --- src/connector/src/sink/big_query.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index 43f9d908c27a9..b5a4a4c72d517 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -29,7 +29,7 @@ use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request:: ProtoData, Rows as AppendRowsRequestRows, }; use google_cloud_googleapis::cloud::bigquery::storage::v1::{ - AppendRowsRequest, ProtoRows, ProtoSchema, + AppendRowsRequest, ProtoRows, ProtoSchema }; use google_cloud_pubsub::client::google_cloud_auth; use google_cloud_pubsub::client::google_cloud_auth::credentials::CredentialsFile; @@ -494,7 +494,6 @@ impl SinkWriter for BigQuerySinkWriter { } async fn begin_epoch(&mut self, _epoch: u64) -> Result<()> { - self.write_rows().await?; Ok(()) } @@ -502,7 +501,10 @@ impl SinkWriter for BigQuerySinkWriter { Ok(()) } - async fn barrier(&mut self, _is_checkpoint: bool) -> Result<()> { + async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> { + if is_checkpoint{ + self.write_rows().await?; + } Ok(()) } @@ -548,6 +550,7 @@ impl StorageWriterClient { rows: Vec, write_stream: String, ) -> Result<()> { + let mut resp_count = rows.len(); let append_req: Vec = rows .into_iter() .map(|row| AppendRowsRequest { @@ -564,18 +567,20 @@ impl StorageWriterClient { .await .map_err(|e| SinkError::BigQuery(e.into()))? .into_inner(); - while let Some(i) = resp + while let Some(append_rows_response) = resp .message() .await .map_err(|e| SinkError::BigQuery(e.into()))? { - if !i.row_errors.is_empty() { + resp_count -= 1; + if !append_rows_response.row_errors.is_empty() { return Err(SinkError::BigQuery(anyhow::anyhow!( "Insert error {:?}", - i.row_errors + append_rows_response.row_errors ))); } } + assert_eq!(resp_count,0,"bigquery sink insert error: the number of response inserted is not equal to the number of request"); Ok(()) } From 26e4cffd538b4db13ff592a625c7e32cdd7d7ef5 Mon Sep 17 00:00:00 2001 From: xxhZs <1060434431@qq.com> Date: Thu, 11 Apr 2024 14:10:01 +0800 Subject: [PATCH 9/9] fix upsert fmt --- src/connector/src/sink/big_query.rs | 31 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/connector/src/sink/big_query.rs b/src/connector/src/sink/big_query.rs index b5a4a4c72d517..ee385ad8c010e 100644 --- a/src/connector/src/sink/big_query.rs +++ b/src/connector/src/sink/big_query.rs @@ -29,7 +29,7 @@ use google_cloud_googleapis::cloud::bigquery::storage::v1::append_rows_request:: ProtoData, Rows as AppendRowsRequestRows, }; use google_cloud_googleapis::cloud::bigquery::storage::v1::{ - AppendRowsRequest, ProtoRows, ProtoSchema + AppendRowsRequest, ProtoRows, ProtoSchema, }; use google_cloud_pubsub::client::google_cloud_auth; use google_cloud_pubsub::client::google_cloud_auth::credentials::CredentialsFile; @@ -269,6 +269,10 @@ impl Sink for BigQuerySink { } async fn validate(&self) -> Result<()> { + if !self.is_append_only && self.pk_indices.is_empty() { + return Err(SinkError::Config(anyhow!( + "Primary key not defined for upsert bigquery sink (please define in `primary_key` field)"))); + } let client = self .config .common @@ -433,7 +437,7 @@ impl BigQuerySinkWriter { .message .try_set_field( self.proto_field.as_ref().unwrap(), - prost_reflect::Value::String("INSERT".to_string()), + prost_reflect::Value::String("UPSERT".to_string()), ) .map_err(|e| SinkError::BigQuery(e.into()))?, Op::Delete => pb_row @@ -479,17 +483,18 @@ impl SinkWriter for BigQuerySinkWriter { } else { self.upsert(chunk)? }; - self.write_rows_count += serialized_rows.len(); - let rows = AppendRowsRequestRows::ProtoRows(ProtoData { - writer_schema: Some(self.writer_pb_schema.clone()), - rows: Some(ProtoRows { serialized_rows }), - }); - self.write_rows.push(rows); - - if self.write_rows_count >= self.config.common.max_batch_rows { - self.write_rows().await?; + if !serialized_rows.is_empty() { + self.write_rows_count += serialized_rows.len(); + let rows = AppendRowsRequestRows::ProtoRows(ProtoData { + writer_schema: Some(self.writer_pb_schema.clone()), + rows: Some(ProtoRows { serialized_rows }), + }); + self.write_rows.push(rows); + + if self.write_rows_count >= self.config.common.max_batch_rows { + self.write_rows().await?; + } } - Ok(()) } @@ -502,7 +507,7 @@ impl SinkWriter for BigQuerySinkWriter { } async fn barrier(&mut self, is_checkpoint: bool) -> Result<()> { - if is_checkpoint{ + if is_checkpoint { self.write_rows().await?; } Ok(())