Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(scheduler): mms send scaling request when model shceduling fails #6235

Draft
wants to merge 5 commits into
base: v2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions scheduler/pkg/coordinator/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ package coordinator

import "fmt"

type ModelEventUpdateContext int

const (
MODEL_STATUS_UPDATE ModelEventUpdateContext = iota
MODEL_SCHEDULE_FAILED
)

type ServerEventUpdateContext int

const (
Expand All @@ -19,8 +26,9 @@ const (
)

type ModelEventMsg struct {
ModelName string
ModelVersion uint32
ModelName string
ModelVersion uint32
UpdateContext ModelEventUpdateContext
}

func (m ModelEventMsg) String() string {
Expand Down
6 changes: 4 additions & 2 deletions scheduler/pkg/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func DefaultSchedulerConfig(store store.ModelStore) SchedulerConfig {
func NewSimpleScheduler(logger log.FieldLogger,
store store.ModelStore,
schedulerConfig SchedulerConfig,
synchroniser synchroniser.Synchroniser) *SimpleScheduler {
synchroniser synchroniser.Synchroniser,
) *SimpleScheduler {
s := &SimpleScheduler{
store: store,
logger: logger.WithField("Name", "SimpleScheduler"),
Expand Down Expand Up @@ -190,6 +191,7 @@ func (s *SimpleScheduler) scheduleToServer(modelName string) error {
if okWithMinReplicas {
msg := "Failed to schedule model as no matching server had enough suitable replicas, managed to schedule with min replicas"
logger.Warn(msg)
s.store.PartiallyScheduled(latestModel, msg, !latestModel.HasLiveReplicas() && !latestModel.IsLoadingOrLoadedOnServer())
}
}

Expand All @@ -205,7 +207,7 @@ func (s *SimpleScheduler) scheduleToServer(modelName string) error {
return errors.New(msg)
}

//TODO Cleanup previous version if needed?
// TODO Cleanup previous version if needed?
return nil
}

Expand Down
10 changes: 4 additions & 6 deletions scheduler/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ const (
sendTimeout = 30 * time.Second // Timeout for sending events to subscribers via grpc `sendMsg`
)

var (
ErrAddServerEmptyServerName = status.Errorf(codes.FailedPrecondition, "Empty server name passed")
)
var ErrAddServerEmptyServerName = status.Errorf(codes.FailedPrecondition, "Empty server name passed")

type SchedulerServer struct {
pb.UnimplementedSchedulerServer
Expand Down Expand Up @@ -445,7 +443,7 @@ func (s *SchedulerServer) ServerStatus(
}

for _, s := range servers {
resp := createServerStatusResponse(s)
resp := createServerStatusUpdateResponse(s)
err := stream.Send(resp)
if err != nil {
return status.Errorf(codes.Internal, err.Error())
Expand All @@ -458,7 +456,7 @@ func (s *SchedulerServer) ServerStatus(
if err != nil {
return status.Errorf(codes.FailedPrecondition, err.Error())
}
resp := createServerStatusResponse(server)
resp := createServerStatusUpdateResponse(server)
err = stream.Send(resp)
if err != nil {
return status.Errorf(codes.Internal, err.Error())
Expand All @@ -467,7 +465,7 @@ func (s *SchedulerServer) ServerStatus(
}
}

func createServerStatusResponse(s *store.ServerSnapshot) *pb.ServerStatusResponse {
func createServerStatusUpdateResponse(s *store.ServerSnapshot) *pb.ServerStatusResponse {
// note we dont count draining replicas in available replicas

resp := &pb.ServerStatusResponse{
Expand Down
81 changes: 60 additions & 21 deletions scheduler/pkg/server/server_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@ package server
import (
"time"

"github.com/pkg/errors"

pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

"github.com/seldonio/seldon-core/scheduler/v2/pkg/coordinator"
"github.com/seldonio/seldon-core/scheduler/v2/pkg/store"
)

func (s *SchedulerServer) SubscribeModelStatus(req *pb.ModelSubscriptionRequest, stream pb.Scheduler_SubscribeModelStatusServer) error {
Expand Down Expand Up @@ -164,7 +167,7 @@ func (s *SchedulerServer) SubscribeServerStatus(req *pb.ServerSubscriptionReques
}

func (s *SchedulerServer) handleModelEventForServerStatus(event coordinator.ModelEventMsg) {
logger := s.logger.WithField("func", "handleServerEvent")
logger := s.logger.WithField("func", "handleModelEventForServerStatus")
logger.Debugf("Got server state change for %s", event.String())

err := s.updateServerModelsStatus(event)
Expand All @@ -182,7 +185,7 @@ func (s *SchedulerServer) StopSendServerEvents() {
}

func (s *SchedulerServer) updateServerModelsStatus(evt coordinator.ModelEventMsg) error {
logger := s.logger.WithField("func", "sendServerStatusEvent")
logger := s.logger.WithField("func", "updateServerModelStatus")

model, err := s.modelStore.GetModel(evt.ModelName)
if err != nil {
Expand All @@ -198,14 +201,47 @@ func (s *SchedulerServer) updateServerModelsStatus(evt coordinator.ModelEventMsg
return nil
}

s.serverEventStream.pendingLock.Lock()
// we are coalescing events so we only send one event (the latest status) per server
s.serverEventStream.pendingEvents[modelVersion.Server()] = struct{}{}
if s.serverEventStream.trigger == nil {
s.serverEventStream.trigger = time.AfterFunc(defaultBatchWait, s.sendServerStatus)
switch evt.UpdateContext {
case coordinator.MODEL_SCHEDULE_FAILED:
err = s.incrementExpectedReplicas(model, evt)
case coordinator.MODEL_STATUS_UPDATE:
s.serverEventStream.pendingLock.Lock()
// we are coalescing events so we only send one event (the latest status) per server
s.serverEventStream.pendingEvents[modelVersion.Server()] = struct{}{}
if s.serverEventStream.trigger == nil {
s.serverEventStream.trigger = time.AfterFunc(defaultBatchWait, s.sendServerStatus)
}
s.serverEventStream.pendingLock.Unlock()
default:
err = errors.Errorf("unknown update context received: %d", evt.UpdateContext)
}
s.serverEventStream.pendingLock.Unlock()

return err
}

func (s *SchedulerServer) incrementExpectedReplicas(model *store.ModelSnapshot, evt coordinator.ModelEventMsg) error {
// TODO: should there be some sort of velocity check ?
logger := s.logger.WithField("func", "incrementExpectedReplicas")
latestModel := model.GetLatest()
logger.Debugf("will attempt to scale servers to %d for %s", latestModel.DesiredReplicas(), evt.String())
if latestModel != nil && latestModel.GetVersion() == evt.ModelVersion &&
latestModel.DesiredReplicas() > int(latestModel.ModelState().AvailableReplicas) {
server, err := s.modelStore.GetServer(latestModel.Server(), true, true)
if err != nil {
return err
}
newExpectedReplicas := server.ExpectedReplicas + 1
if newExpectedReplicas < latestModel.DesiredReplicas() {
newExpectedReplicas = latestModel.DesiredReplicas()
}
ssr := createServerStatusUpdateResponse(server)
ssr.ExpectedReplicas = int32(newExpectedReplicas)
ssr.Type = pb.ServerStatusResponse_ScalingRequest
s.sendServerStatusResponse(ssr)

} else {
logger.Debugf("skipping scaling request event %s", evt.String())
}
return nil
}

Expand All @@ -228,21 +264,24 @@ func (s *SchedulerServer) sendServerStatus() {
logger.Errorf("Failed to get server %s", serverName)
continue
}
ssr := createServerStatusResponse(server)
ssr := createServerStatusUpdateResponse(server)
s.sendServerStatusResponse(ssr)
}
}

for stream, subscription := range s.serverEventStream.streams {
hasExpired, err := sendWithTimeout(func() error { return stream.Send(ssr) }, s.timeout)
if hasExpired {
// this should trigger a reconnect from the client
close(subscription.fin)
delete(s.serverEventStream.streams, stream)
}
if err != nil {
logger.WithError(err).Errorf("Failed to send server status event to %s", subscription.name)
}
func (s *SchedulerServer) sendServerStatusResponse(ssr *pb.ServerStatusResponse) {
logger := s.logger.WithField("func", "sendServerStatusResponse")
for stream, subscription := range s.serverEventStream.streams {
hasExpired, err := sendWithTimeout(func() error { return stream.Send(ssr) }, s.timeout)
if hasExpired {
// this should trigger a reconnect from the client
close(subscription.fin)
delete(s.serverEventStream.streams, stream)
}
if err != nil {
logger.WithError(err).Errorf("Failed to send server status response to %s", subscription.name)
}
}

}

// initial send of server statuses to a new controller
Expand All @@ -252,7 +291,7 @@ func (s *SchedulerServer) sendCurrentServerStatuses(stream pb.Scheduler_ServerSt
return err
}
for _, server := range servers {
ssr := createServerStatusResponse(server)
ssr := createServerStatusUpdateResponse(server)
_, err := sendWithTimeout(func() error { return stream.Send(ssr) }, s.timeout)
if err != nil {
return err
Expand Down
57 changes: 47 additions & 10 deletions scheduler/pkg/server/server_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ func TestModelsStatusEvents(t *testing.T) {
}
g.Expect(s.modelEventStream.streams[stream]).ToNot(BeNil())
hub.PublishModelEvent(modelEventHandlerName, coordinator.ModelEventMsg{
ModelName: "foo", ModelVersion: 1})
ModelName: "foo", ModelVersion: 1,
})

// to allow events to propagate
time.Sleep(500 * time.Millisecond)
Expand Down Expand Up @@ -340,10 +341,12 @@ func TestServersStatusEvents(t *testing.T) {
g := NewGomegaWithT(t)

type test struct {
name string
loadReq *pba.AgentSubscribeRequest
timeout time.Duration
err bool
name string
loadReq *pba.AgentSubscribeRequest
timeout time.Duration
desiredModelReplicas uint32
updateContext coordinator.ModelEventUpdateContext
err bool
}

tests := []test{
Expand All @@ -353,7 +356,8 @@ func TestServersStatusEvents(t *testing.T) {
ServerName: "foo",
},
timeout: 10 * time.Millisecond,
err: false,

err: false,
},
{
name: "timeout",
Expand All @@ -363,6 +367,26 @@ func TestServersStatusEvents(t *testing.T) {
timeout: 1 * time.Millisecond,
err: true,
},
{
name: "schedule failed",
loadReq: &pba.AgentSubscribeRequest{
ServerName: "foo",
},
timeout: 1 * time.Millisecond,
desiredModelReplicas: 1,
updateContext: coordinator.MODEL_SCHEDULE_FAILED,
err: false,
},
{
name: "schedule failed - desired replicas matches available replicas",
loadReq: &pba.AgentSubscribeRequest{
ServerName: "foo",
},
timeout: 1 * time.Millisecond,
desiredModelReplicas: 0, // no replicas become availble
updateContext: coordinator.MODEL_SCHEDULE_FAILED,
err: true,
},
}

for _, test := range tests {
Expand All @@ -374,7 +398,8 @@ func TestServersStatusEvents(t *testing.T) {
g.Expect(err).To(BeNil())
err = s.modelStore.UpdateModel(&pb.LoadModelRequest{
Model: &pb.Model{
Meta: &pb.MetaData{Name: "foo"},
Meta: &pb.MetaData{Name: "foo"},
DeploymentSpec: &pb.DeploymentSpec{Replicas: test.desiredModelReplicas},
},
})
g.Expect(err).To(BeNil())
Expand All @@ -394,15 +419,19 @@ func TestServersStatusEvents(t *testing.T) {
}
g.Expect(s.serverEventStream.streams[stream]).ToNot(BeNil())
hub.PublishModelEvent(serverModelEventHandlerName, coordinator.ModelEventMsg{
ModelName: "foo", ModelVersion: 1})
ModelName: "foo", ModelVersion: 1, UpdateContext: test.updateContext,
})

// to allow events to propagate
time.Sleep(500 * time.Millisecond)

if test.err {
g.Expect(s.serverEventStream.streams).To(HaveLen(0))
} else if test.err && test.updateContext == coordinator.MODEL_SCHEDULE_FAILED {
// no scaling requests are sent for models in the desired state
g.Expect(stream.msgs).To(HaveLen(0))
g.Expect(s.serverEventStream.streams).To(HaveLen(0))
} else {

var ssr *pb.ServerStatusResponse
select {
case next := <-stream.msgs:
Expand All @@ -413,7 +442,15 @@ func TestServersStatusEvents(t *testing.T) {

g.Expect(ssr).ToNot(BeNil())
g.Expect(ssr.ServerName).To(Equal("foo"))
g.Expect(s.serverEventStream.streams).To(HaveLen(1))

if test.updateContext == coordinator.MODEL_SCHEDULE_FAILED {
// server events are not coalesced for scaling request
g.Expect(s.serverEventStream.streams).To(HaveLen(0))
g.Expect(ssr.Type).To(Equal(pb.ServerStatusResponse_ScalingRequest))
} else {
g.Expect(s.serverEventStream.streams).To(HaveLen(1))
g.Expect(ssr.Type).To(Equal(pb.ServerStatusResponse_StatusUpdate))
}
}
})
}
Expand Down
16 changes: 13 additions & 3 deletions scheduler/pkg/store/memory_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,19 @@ func updateModelState(isLatest bool, modelVersion *ModelVersion, prevModelVersio
}
}

func (m *MemoryStore) PartiallyScheduled(modelVersion *ModelVersion, reason string, reset bool) {
m.failedScheduling(modelVersion, reason, reset, ModelAvailable)
}

func (m *MemoryStore) FailedScheduling(modelVersion *ModelVersion, reason string, reset bool) {
m.failedScheduling(modelVersion, reason, reset, ScheduleFailed)
}

func (m *MemoryStore) failedScheduling(modelVersion *ModelVersion, reason string, reset bool, state ModelState) {
availableReplicas := modelVersion.state.AvailableReplicas

modelVersion.state = ModelStatus{
State: ScheduleFailed,
State: state,
Reason: reason,
Timestamp: time.Now(),
AvailableReplicas: availableReplicas,
Expand All @@ -119,11 +127,13 @@ func (m *MemoryStore) FailedScheduling(modelVersion *ModelVersion, reason string
if reset {
modelVersion.server = ""
}

m.eventHub.PublishModelEvent(
modelFailureEventSource,
coordinator.ModelEventMsg{
ModelName: modelVersion.GetMeta().GetName(),
ModelVersion: modelVersion.GetVersion(),
ModelName: modelVersion.GetMeta().GetName(),
ModelVersion: modelVersion.GetVersion(),
UpdateContext: coordinator.MODEL_SCHEDULE_FAILED,
},
)
}
Expand Down
1 change: 1 addition & 0 deletions scheduler/pkg/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ type ModelStore interface {
ServerNotify(request *pb.ServerNotify) error
RemoveServerReplica(serverName string, replicaIdx int) ([]string, error) // return previously loaded models
DrainServerReplica(serverName string, replicaIdx int) ([]string, error) // return previously loaded models
PartiallyScheduled(modelVersion *ModelVersion, reason string, reset bool)
FailedScheduling(modelVersion *ModelVersion, reason string, reset bool)
GetAllModels() []string
}