Skip to content

Commit

Permalink
minor code refactoring for join pipelines (#17738)
Browse files Browse the repository at this point in the history
minor code refactoring for join pipelines

Approved by: @m-schen
  • Loading branch information
badboynt1 authored Jul 26, 2024
1 parent 0cbffc3 commit 8f45ce6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 18 deletions.
40 changes: 25 additions & 15 deletions pkg/sql/compile/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -3793,7 +3793,7 @@ func (c *Compile) newShuffleJoinScopeList(left, right []*Scope, n *plan.Node) ([
return parent, children
}

func (c *Compile) newJoinProbeScope(s *Scope, ss []*Scope) *Scope {
func (c *Compile) newShuffleJoinProbeScope(s *Scope) *Scope {
rs := newScope(Merge)
mergeOp := merge.NewArgument()
mergeOp.SetIdx(vm.GetLeafOp(s.RootOp).GetOperatorBase().GetIdx())
Expand All @@ -3804,26 +3804,36 @@ func (c *Compile) newJoinProbeScope(s *Scope, ss []*Scope) *Scope {
regTransplant(s, rs, i, i)
}

if ss == nil {
s.Proc.Reg.MergeReceivers[0] = &process.WaitRegister{
Ctx: s.Proc.Ctx,
Ch: make(chan *process.RegisterMessage, shuffleChannelBufferSize),
}
rs.setRootOperator(
connector.NewArgument().
WithReg(s.Proc.Reg.MergeReceivers[0]),
)
s.Proc.Reg.MergeReceivers = append(s.Proc.Reg.MergeReceivers[:1], s.Proc.Reg.MergeReceivers[s.BuildIdx:]...)
s.BuildIdx = 1
} else {
rs.setRootOperator(constructDispatchLocal(false, false, false, extraRegisters(ss, 0)))
s.Proc.Reg.MergeReceivers[0] = &process.WaitRegister{
Ctx: s.Proc.Ctx,
Ch: make(chan *process.RegisterMessage, shuffleChannelBufferSize),
}
rs.setRootOperator(
connector.NewArgument().
WithReg(s.Proc.Reg.MergeReceivers[0]),
)
s.Proc.Reg.MergeReceivers = s.Proc.Reg.MergeReceivers[:1]
rs.IsEnd = true
return rs
}

func (c *Compile) newBroadcastJoinProbeScope(s *Scope, ss []*Scope) *Scope {
rs := newScope(Merge)
mergeOp := merge.NewArgument()
mergeOp.SetIdx(vm.GetLeafOp(s.RootOp).GetOperatorBase().GetIdx())
mergeOp.SetIsFirst(true)
rs.setRootOperator(mergeOp)
rs.Proc = process.NewFromProc(s.Proc, s.Proc.Ctx, s.BuildIdx)
for i := 0; i < s.BuildIdx; i++ {
regTransplant(s, rs, i, i)
}

rs.setRootOperator(constructDispatchLocal(false, false, false, extraRegisters(ss, 0)))
rs.IsEnd = true
return rs
}

func (c *Compile) newJoinBuildScope(s *Scope, ss []*Scope, mcpu int32) *Scope {
func (c *Compile) newJoinBuildScope(s *Scope, mcpu int32) *Scope {
rs := newScope(Merge)
buildLen := len(s.Proc.Reg.MergeReceivers) - s.BuildIdx
rs.Proc = process.NewFromProc(s.Proc, s.Proc.Ctx, buildLen)
Expand Down
6 changes: 3 additions & 3 deletions pkg/sql/compile/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,10 +460,10 @@ func buildJoinParallelRun(s *Scope, c *Compile) (*Scope, error) {
}
mcpu := s.NodeInfo.Mcpu
if mcpu <= 1 { // no need to parallel
buildScope := c.newJoinBuildScope(s, nil, 1)
buildScope := c.newJoinBuildScope(s, 1)
s.PreScopes = append(s.PreScopes, buildScope)
if s.BuildIdx > 1 {
probeScope := c.newJoinProbeScope(s, nil)
probeScope := c.newShuffleJoinProbeScope(s)
s.PreScopes = append(s.PreScopes, probeScope)
}
return s, nil
Expand All @@ -483,7 +483,7 @@ func buildJoinParallelRun(s *Scope, c *Compile) (*Scope, error) {
ss[i].Proc = process.NewFromProc(s.Proc, s.Proc.Ctx, 2)
ss[i].Proc.Reg.MergeReceivers[1].Ch = make(chan *process.RegisterMessage, 10)
}
probeScope, buildScope := c.newJoinProbeScope(s, ss), c.newJoinBuildScope(s, ss, int32(mcpu))
probeScope, buildScope := c.newBroadcastJoinProbeScope(s, ss), c.newJoinBuildScope(s, int32(mcpu))

ns, err := newParallelScope(c, s, ss)
if err != nil {
Expand Down

0 comments on commit 8f45ce6

Please sign in to comment.