Skip to content

Commit

Permalink
fix: Use generation id to bootstrap version id (#6029)
Browse files Browse the repository at this point in the history
* use generation id to bootstrap versions

* add event type for sync operation

* move event to response

* introduce a second stage of sync process after the scheduler is ready

* send the right resources based on the event type for the control plane

* update control plane test

* add test coverage

* update control plane test

* allow model progressing status update

* add test for generation id

* Update scheduler/pkg/store/memory.go

Co-authored-by: Lucian Carata <[email protected]>

---------

Co-authored-by: Lucian Carata <[email protected]>
  • Loading branch information
sakoush and lc525 authored Nov 6, 2024
1 parent 1ecc543 commit 7fc3402
Show file tree
Hide file tree
Showing 9 changed files with 563 additions and 315 deletions.
613 changes: 341 additions & 272 deletions apis/go/mlops/scheduler/scheduler.pb.go

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion apis/mlops/scheduler/scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,12 @@ message ControlPlaneSubscriptionRequest {
}

message ControlPlaneResponse {

enum Event {
UNKNOWN_EVENT = 0;
SEND_SERVERS = 1; // initial sync for the servers
SEND_RESOURCES = 2; // send models / pipelines / experiments
}
Event event = 1;
}

// [END Messages]
Expand Down
47 changes: 25 additions & 22 deletions operator/scheduler/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,34 +142,37 @@ func (s *SchedulerClient) startEventHanders(namespace string, conn *grpc.ClientC
}()
}

func (s *SchedulerClient) handleStateOnReconnect(context context.Context, grpcClient scheduler.SchedulerClient, namespace string) error {
// on new reconnects we send a list of servers to the schedule
err := s.handleRegisteredServers(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send registered server to scheduler")
}

if err == nil {
err = s.handleExperiments(context, grpcClient, namespace)
func (s *SchedulerClient) handleStateOnReconnect(context context.Context, grpcClient scheduler.SchedulerClient, namespace string, operation scheduler.ControlPlaneResponse_Event) error {
switch operation {
case scheduler.ControlPlaneResponse_SEND_SERVERS:
// on new reconnects we send a list of servers to the schedule
err := s.handleRegisteredServers(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send experiments to scheduler")
s.logger.Error(err, "Failed to send registered server to scheduler")
}
}

if err == nil {
err = s.handlePipelines(context, grpcClient, namespace)
return err
case scheduler.ControlPlaneResponse_SEND_RESOURCES:
err := s.handleExperiments(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send pipelines to scheduler")
s.logger.Error(err, "Failed to send experiments to scheduler")
}
}

if err == nil {
err = s.handleModels(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send models to scheduler")
if err == nil {
err = s.handlePipelines(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send pipelines to scheduler")
}
}
if err == nil {
err = s.handleModels(context, grpcClient, namespace)
if err != nil {
s.logger.Error(err, "Failed to send models to scheduler")
}
}
return err
default:
s.logger.Info("Unknown operation", "operation", operation)
return fmt.Errorf("Unknown operation %v", operation)
}
return err
}

func (s *SchedulerClient) RemoveConnection(namespace string) {
Expand Down
2 changes: 1 addition & 1 deletion operator/scheduler/control_plane.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (s *SchedulerClient) SubscribeControlPlaneEvents(ctx context.Context, grpcC
logger.Info("Received event to handle state", "event", event)

fn := func() error {
return s.handleStateOnReconnect(ctx, grpcClient, namespace)
return s.handleStateOnReconnect(ctx, grpcClient, namespace, event.GetEvent())
}
_, err = execWithTimeout(fn, execTimeOut)
if err != nil {
Expand Down
12 changes: 8 additions & 4 deletions operator/scheduler/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (s *mockSchedulerServerSubscribeGrpcClient) Recv() (*scheduler.ServerStatus
// Control Plane subscribe mock grpc client

type mockControlPlaneSubscribeGrpcClient struct {
sent bool
sent int
grpc.ClientStream
}

Expand All @@ -120,9 +120,13 @@ func newMockControlPlaneSubscribeGrpcClient() *mockControlPlaneSubscribeGrpcClie
}

func (s *mockControlPlaneSubscribeGrpcClient) Recv() (*scheduler.ControlPlaneResponse, error) {
if !s.sent {
s.sent = true
return &scheduler.ControlPlaneResponse{}, nil
if s.sent == 0 {
s.sent++
return &scheduler.ControlPlaneResponse{Event: scheduler.ControlPlaneResponse_SEND_SERVERS}, nil
}
if s.sent == 1 {
s.sent++
return &scheduler.ControlPlaneResponse{Event: scheduler.ControlPlaneResponse_SEND_RESOURCES}, nil
}
return nil, io.EOF
}
Expand Down
20 changes: 18 additions & 2 deletions scheduler/pkg/server/control_plane.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,15 @@ func (s *SchedulerServer) SubscribeControlPlane(req *pb.ControlPlaneSubscription
return err
}

fin := make(chan bool)
s.synchroniser.WaitReady()

err = s.sendResourcesMarker(stream)
if err != nil {
logger.WithError(err).Errorf("Failed to send resources marker to %s", req.GetSubscriberName())
return err
}

fin := make(chan bool)
s.controlPlaneStream.mu.Lock()
s.controlPlaneStream.streams[stream] = &ControlPlaneSubsription{
name: req.GetSubscriberName(),
Expand Down Expand Up @@ -61,11 +68,20 @@ func (s *SchedulerServer) StopSendControlPlaneEvents() {
// this is to mark the initial start of a new stream (at application level)
// as otherwise the other side sometimes doesnt know if the scheduler has established a new stream explicitly
func (s *SchedulerServer) sendStartServerStreamMarker(stream pb.Scheduler_SubscribeControlPlaneServer) error {
ssr := &pb.ControlPlaneResponse{}
ssr := &pb.ControlPlaneResponse{Event: pb.ControlPlaneResponse_SEND_SERVERS}
_, err := sendWithTimeout(func() error { return stream.Send(ssr) }, s.timeout)
if err != nil {
return err
}
return nil
}

// this is to mark a stage to send resources (models, pipelines, experiments) from the controller
func (s *SchedulerServer) sendResourcesMarker(stream pb.Scheduler_SubscribeControlPlaneServer) error {
ssr := &pb.ControlPlaneResponse{Event: pb.ControlPlaneResponse_SEND_RESOURCES}
_, err := sendWithTimeout(func() error { return stream.Send(ssr) }, s.timeout)
if err != nil {
return err
}
return nil
}
91 changes: 91 additions & 0 deletions scheduler/pkg/server/control_plane_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,22 @@ the Change License after the Change Date as each is defined in accordance with t
package server

import (
"context"
"fmt"
"testing"
"time"

. "github.com/onsi/gomega"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/internal/testing_utils"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/store"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/synchroniser"
)

func TestStartServerStream(t *testing.T) {
Expand Down Expand Up @@ -67,7 +74,91 @@ func TestStartServerStream(t *testing.T) {
}

g.Expect(msr).ToNot(BeNil())
g.Expect(msr.Event).To(Equal(pb.ControlPlaneResponse_SEND_SERVERS))
}

err = test.server.sendResourcesMarker(stream)
if test.err {
g.Expect(err).ToNot(BeNil())
} else {
g.Expect(err).To(BeNil())

var msr *pb.ControlPlaneResponse
select {
case next := <-stream.msgs:
msr = next
default:
t.Fail()
}

g.Expect(msr).ToNot(BeNil())
g.Expect(msr.Event).To(Equal(pb.ControlPlaneResponse_SEND_RESOURCES))
}
})
}
}

func TestSubscribeControlPlane(t *testing.T) {
log.SetLevel(log.DebugLevel)
g := NewGomegaWithT(t)

type test struct {
name string
}
tests := []test{
{
name: "simple",
},
}

getStream := func(context context.Context, port int) (*grpc.ClientConn, pb.Scheduler_SubscribeControlPlaneClient) {
conn, _ := grpc.NewClient(fmt.Sprintf(":%d", port), grpc.WithTransportCredentials(insecure.NewCredentials()))
grpcClient := pb.NewSchedulerClient(conn)
client, _ := grpcClient.SubscribeControlPlane(
context,
&pb.ControlPlaneSubscriptionRequest{SubscriberName: "dummy"},
)
return conn, client
}

createTestScheduler := func() *SchedulerServer {
logger := log.New()
logger.SetLevel(log.WarnLevel)

eventHub, err := coordinator.NewEventHub(logger)
g.Expect(err).To(BeNil())

sync := synchroniser.NewSimpleSynchroniser(time.Duration(10 * time.Millisecond))

s := NewSchedulerServer(logger, nil, nil, nil, nil, eventHub, sync)
sync.Signals(1)

return s
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
server := createTestScheduler()
port, err := testing_utils.GetFreePortForTest()
if err != nil {
t.Fatal(err)
}
err = server.startServer(uint(port), false)
if err != nil {
t.Fatal(err)
}
time.Sleep(100 * time.Millisecond)

conn, client := getStream(context.Background(), port)

msg, _ := client.Recv()
g.Expect(msg.GetEvent()).To(Equal(pb.ControlPlaneResponse_SEND_SERVERS))

msg, _ = client.Recv()
g.Expect(msg.Event).To(Equal(pb.ControlPlaneResponse_SEND_RESOURCES))

conn.Close()
server.StopSendControlPlaneEvents()
})
}
}
11 changes: 8 additions & 3 deletions scheduler/pkg/store/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ func (m *MemoryStore) addModelVersionIfNotExists(req *agent.ModelVersion) (*Mode
}

func (m *MemoryStore) addNextModelVersion(model *Model, pbmodel *pb.Model) {
version := uint32(1)
// if we start from a clean state, lets use the generation id as the starting version
// this is to ensure that we have monotonic increasing version numbers
// and we never reset back to 1
generation := pbmodel.GetMeta().GetKubernetesMeta().GetGeneration()
version := max(uint32(1), uint32(generation))
if model.Latest() != nil {
version = model.Latest().GetVersion() + 1
}
Expand Down Expand Up @@ -329,7 +333,7 @@ func (m *MemoryStore) updateLoadedModelsImpl(
modelVersion = model.Latest()
}

// resevere memory for existing replicas that are not already loading or loaded
// reserve memory for existing replicas that are not already loading or loaded
replicaStateUpdated := false
for replicaIdx := range assignedReplicaIds {
if existingState, ok := modelVersion.replicas[replicaIdx]; !ok {
Expand Down Expand Up @@ -370,7 +374,8 @@ func (m *MemoryStore) updateLoadedModelsImpl(
// in cases where we did have a previous ScheduleFailed, we need to reflect the change here
// this could be in the cases where we are scaling down a model and the new replica count can be all deployed
// and always send an update for deleted models, so the operator will remove them from k8s
if replicaStateUpdated || modelVersion.state.State == ScheduleFailed || model.IsDeleted() {
// also send an update for progressing models so the operator can update the status in the case of a network glitch where the model generation has been updated
if replicaStateUpdated || modelVersion.state.State == ScheduleFailed || model.IsDeleted() || modelVersion.state.State == ModelProgressing {
logger.Debugf("Updating model status for model %s server %s", modelKey, serverKey)
modelVersion.server = serverKey
m.updateModelStatus(true, model.IsDeleted(), modelVersion, model.GetLastAvailableModelVersion())
Expand Down
Loading

0 comments on commit 7fc3402

Please sign in to comment.