From 6ec975a47062311bc496e9e770835ce214ac10d8 Mon Sep 17 00:00:00 2001 From: stonepage <40830455+st1page@users.noreply.github.com> Date: Thu, 23 Nov 2023 14:01:18 +0800 Subject: [PATCH] fix(sink): force sink shuffle with the sink pk (#13516) --- .../tests/testdata/input/sink.yaml | 22 +++++++++++++++ .../tests/testdata/output/sink.yaml | 28 +++++++++++++++++++ .../src/optimizer/plan_node/stream_sink.rs | 14 +++++++++- 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/src/frontend/planner_test/tests/testdata/input/sink.yaml b/src/frontend/planner_test/tests/testdata/input/sink.yaml index 524b74b1a592d..eb865bedbb79a 100644 --- a/src/frontend/planner_test/tests/testdata/input/sink.yaml +++ b/src/frontend/planner_test/tests/testdata/input/sink.yaml @@ -31,3 +31,25 @@ force_append_only='true'); expected_outputs: - explain_output +- id: create_upsert_kafka_sink_with_downstream_pk1 + sql: | + create table t1 (v1 int, v2 double precision, v3 varchar, v4 bigint, v5 decimal, primary key (v3,v4)); + explain create sink s1_mysql as select v1, v2, v3, v5 from t1 WITH ( + connector='kafka', + topic='abc', + type='upsert', + primary_key='v1,v2' + ); + expected_outputs: + - explain_output +- id: downstream_pk_same_with_upstream + sql: | + create table t1 (v1 int, v2 double precision, v3 varchar, v4 bigint, v5 decimal, primary key (v3,v4)); + explain create sink s1_mysql as select v2, v1, count(*) from t1 group by v1, v2 WITH ( + connector='kafka', + topic='abc', + type='upsert', + primary_key='v2,v1' + ); + expected_outputs: + - explain_output diff --git a/src/frontend/planner_test/tests/testdata/output/sink.yaml b/src/frontend/planner_test/tests/testdata/output/sink.yaml index edda2d08910a1..6956da028a008 100644 --- a/src/frontend/planner_test/tests/testdata/output/sink.yaml +++ b/src/frontend/planner_test/tests/testdata/output/sink.yaml @@ -37,3 +37,31 @@ explain_output: | StreamSink { type: append-only, columns: [v1, v2, v3, v5] } └─StreamTableScan { table: t1, columns: [v1, v2, v3, v5] } +- id: create_upsert_kafka_sink_with_downstream_pk1 + sql: | + create table t1 (v1 int, v2 double precision, v3 varchar, v4 bigint, v5 decimal, primary key (v3,v4)); + explain create sink s1_mysql as select v1, v2, v3, v5 from t1 WITH ( + connector='kafka', + topic='abc', + type='upsert', + primary_key='v1,v2' + ); + explain_output: | + StreamSink { type: upsert, columns: [v1, v2, v3, v5, t1.v4(hidden)], pk: [t1.v3, t1.v4] } + └─StreamExchange { dist: HashShard(t1.v1, t1.v2) } + └─StreamTableScan { table: t1, columns: [v1, v2, v3, v5, v4] } +- id: downstream_pk_same_with_upstream + sql: | + create table t1 (v1 int, v2 double precision, v3 varchar, v4 bigint, v5 decimal, primary key (v3,v4)); + explain create sink s1_mysql as select v2, v1, count(*) from t1 group by v1, v2 WITH ( + connector='kafka', + topic='abc', + type='upsert', + primary_key='v2,v1' + ); + explain_output: | + StreamSink { type: upsert, columns: [v2, v1, count], pk: [t1.v1, t1.v2] } + └─StreamProject { exprs: [t1.v2, t1.v1, count] } + └─StreamHashAgg { group_key: [t1.v1, t1.v2], aggs: [count] } + └─StreamExchange { dist: HashShard(t1.v1, t1.v2) } + └─StreamTableScan { table: t1, columns: [v1, v2, v3, v4] } diff --git a/src/frontend/src/optimizer/plan_node/stream_sink.rs b/src/frontend/src/optimizer/plan_node/stream_sink.rs index 324e0d21648cf..b54019fc7dbb6 100644 --- a/src/frontend/src/optimizer/plan_node/stream_sink.rs +++ b/src/frontend/src/optimizer/plan_node/stream_sink.rs @@ -170,7 +170,19 @@ impl StreamSink { } _ => { assert_matches!(user_distributed_by, RequiredDist::Any); - RequiredDist::shard_by_key(input.schema().len(), input.expect_stream_key()) + if downstream_pk.is_empty() { + RequiredDist::shard_by_key( + input.schema().len(), + input.expect_stream_key(), + ) + } else { + // force the same primary key be written into the same sink shard to make sure the sink pk mismatch compaction effective + // https://github.com/risingwavelabs/risingwave/blob/6d88344c286f250ea8a7e7ef6b9d74dea838269e/src/stream/src/executor/sink.rs#L169-L198 + RequiredDist::shard_by_key( + input.schema().len(), + downstream_pk.as_slice(), + ) + } } } }