Skip to content

Commit

Permalink
fix(scheduler): always send model events for deleted models (#5992)
Browse files Browse the repository at this point in the history
* always send model events for deleted models

* trying to make it easier to read

* fix broken logic

* simplify loop

* replace recursive call with a loop

* add more logs and update one condition for clarity
  • Loading branch information
driev authored Nov 4, 2024
1 parent a17e12d commit ac839c3
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 52 deletions.
92 changes: 42 additions & 50 deletions scheduler/pkg/store/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ func (m *MemoryStore) updateLoadedModelsImpl(
if !ok {
return nil, fmt.Errorf("failed to find model %s", modelKey)
}

modelVersion := model.Latest()
if version != modelVersion.GetVersion() {
return nil, fmt.Errorf(
Expand All @@ -310,83 +311,74 @@ func (m *MemoryStore) updateLoadedModelsImpl(
if !ok {
return nil, fmt.Errorf("failed to find server %s", serverKey)
}

assignedReplicaIds := make(map[int]struct{})
for _, replica := range replicas {
_, ok := server.replicas[replica.GetReplicaIdx()]
if !ok {
if _, ok := server.replicas[replica.replicaIdx]; !ok {
return nil, fmt.Errorf(
"Failed to reserve replica %d as it does not exist on server %s",
replica.GetReplicaIdx(), serverKey,
"failed to reserve replica %d as it does not exist on server %s",
replica.replicaIdx, serverKey,
)
}
assignedReplicaIds[replica.replicaIdx] = struct{}{}
}

if modelVersion.HasServer() && modelVersion.Server() != serverKey {
for modelVersion.HasServer() && modelVersion.Server() != serverKey {
logger.Debugf("Adding new version as server changed to %s from %s", modelVersion.Server(), serverKey)
m.addNextModelVersion(model, model.Latest().modelDefn)
return m.updateLoadedModelsImpl(modelKey, model.Latest().GetVersion(), serverKey, replicas)
modelVersion = model.Latest()
}

// Update model that need to be placed on a replica to request loading
updatedReplicas := make(map[int]bool)
updated := false
for _, replica := range replicas {
existingState := modelVersion.replicas[replica.GetReplicaIdx()]
if !existingState.State.AlreadyLoadingOrLoaded() {
// resevere memory for existing replicas that are not already loading or loaded
replicaStateUpdated := false
for replicaIdx := range assignedReplicaIds {
if existingState, ok := modelVersion.replicas[replicaIdx]; !ok {
logger.Debugf(
"Setting model %s version %d on server %s replica %d to LoadRequested",
modelKey, modelVersion.version, serverKey, replica.GetReplicaIdx(),
"Model %s version %d state %s on server %s replica %d does not exist yet and should be loaded",
modelKey, modelVersion.version, existingState.State.String(), serverKey, replicaIdx,
)
modelVersion.SetReplicaState(replica.GetReplicaIdx(), LoadRequested, "")
m.updateReservedMemory(LoadRequested, serverKey, replica.GetReplicaIdx(), modelVersion.GetRequiredMemory())
updated = true
modelVersion.SetReplicaState(replicaIdx, LoadRequested, "")
m.updateReservedMemory(LoadRequested, serverKey, replicaIdx, modelVersion.GetRequiredMemory())
replicaStateUpdated = true
} else {
logger.Debugf(
"model %s on server %s replica %d already loaded",
modelKey, serverKey, replica.GetReplicaIdx(),
"Checking if model %s version %d state %s on server %s replica %d should be loaded",
modelKey, modelVersion.version, existingState.State.String(), serverKey, replicaIdx,
)
if !existingState.State.AlreadyLoadingOrLoaded() {
modelVersion.SetReplicaState(replicaIdx, LoadRequested, "")
m.updateReservedMemory(LoadRequested, serverKey, replicaIdx, modelVersion.GetRequiredMemory())
replicaStateUpdated = true
}
}
updatedReplicas[replica.GetReplicaIdx()] = true
}
// Unload any existing model replicas assignments no longer needed
for replicaIdx, existingState := range modelVersion.ReplicaState() {
logger.Debugf(
"Looking at replicaidx %d with state %s but ignoring processed %v",
replicaIdx, existingState.State.String(), updatedReplicas,
)

if _, ok := updatedReplicas[replicaIdx]; !ok {
if !existingState.State.UnloadingOrUnloaded() {
if existingState.State == Draining {
logger.Debugf(
"model %s version %d on server %s replica %d is Draining",
modelKey, modelVersion.version, serverKey, replicaIdx,
)
} else {
logger.Debugf(
"Setting model %s version %d on server %s replica %d to UnloadEnvoyRequested",
modelKey, modelVersion.version, serverKey, replicaIdx,
)
modelVersion.SetReplicaState(replicaIdx, UnloadEnvoyRequested, "")
updated = true
}
} else {
logger.Debugf(
"model %s on server %s replica %d already unloading or can't be unloaded",
modelKey, serverKey, replicaIdx,
)
// Unload any existing model replicas assignments that are no longer part of the replica set
for replicaIdx, existingState := range modelVersion.ReplicaState() {
if _, ok := assignedReplicaIds[replicaIdx]; !ok {
logger.Debugf(
"Checking if replicaidx %d with state %s should be unloaded",
replicaIdx, existingState.State.String(),
)
if !existingState.State.UnloadingOrUnloaded() && existingState.State != Draining {
modelVersion.SetReplicaState(replicaIdx, UnloadEnvoyRequested, "")
replicaStateUpdated = true
}
}
}

// in cases where we did have a previous ScheduleFailed, we need to reflect the change here
// this could be in the cases where we are scaling down a model and the new replica count can be all deployed
if updated || modelVersion.state.State == ScheduleFailed {
// and always send an update for deleted models, so the operator will remove them from k8s
if replicaStateUpdated || modelVersion.state.State == ScheduleFailed || model.IsDeleted() {
logger.Debugf("Updating model status for model %s server %s", modelKey, serverKey)
modelVersion.server = serverKey
m.updateModelStatus(true, model.IsDeleted(), modelVersion, model.GetLastAvailableModelVersion())
return &coordinator.ModelEventMsg{ModelName: modelVersion.GetMeta().GetName(), ModelVersion: modelVersion.GetVersion()}, nil
} else {
logger.Debugf("Model status update not required for model %s server %s as no replicas were updated", modelKey, serverKey)
return nil, nil
}
return nil, nil
}

func (m *MemoryStore) UnloadVersionModels(modelKey string, version uint32) (bool, error) {
Expand Down Expand Up @@ -415,7 +407,7 @@ func (m *MemoryStore) unloadVersionModelsImpl(modelKey string, version uint32) (
}
modelVersion := model.GetVersion(version)
if modelVersion == nil {
return nil, false, fmt.Errorf("Version not found for model %s, version %d", modelKey, version)
return nil, false, fmt.Errorf("version not found for model %s, version %d", modelKey, version)
}

updated := false
Expand Down Expand Up @@ -482,7 +474,7 @@ func (m *MemoryStore) updateModelStateImpl(
desiredState ModelReplicaState,
reason string,
) (*coordinator.ModelEventMsg, error) {
logger := m.logger.WithField("func", "UpdateModelState")
logger := m.logger.WithField("func", "updateModelStateImpl")
m.mu.Lock()
defer m.mu.Unlock()

Expand Down
2 changes: 1 addition & 1 deletion scheduler/pkg/store/memory_status.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func updateModelState(isLatest bool, modelVersion *ModelVersion, prevModelVersio
modelReason = stats.lastFailedReason
modelTimestamp = stats.lastFailedStateTime
} else if (modelVersion.GetDeploymentSpec() != nil && stats.replicasAvailable == modelVersion.GetDeploymentSpec().Replicas) || // equal to desired replicas
(stats.replicasAvailable > 0 && prevModelVersion != nil && modelVersion != prevModelVersion && prevModelVersion.state.State == ModelAvailable) { //TODO In future check if available replicas is > minReplicas
(stats.replicasAvailable > 0 && prevModelVersion != nil && modelVersion != prevModelVersion && prevModelVersion.state.State == ModelAvailable) { // TODO In future check if available replicas is > minReplicas
modelState = ModelAvailable
} else {
modelState = ModelProgressing
Expand Down
4 changes: 3 additions & 1 deletion scheduler/pkg/store/memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,10 @@ func TestUpdateLoadedModels(t *testing.T) {
test.store.models[test.modelKey].SetDeleted()
}
ms := NewMemoryStore(logger, test.store, eventHub)
err = ms.UpdateLoadedModels(test.modelKey, test.version, test.serverKey, test.replicas)
msg, err := ms.updateLoadedModelsImpl(test.modelKey, test.version, test.serverKey, test.replicas)
if !test.err {
g.Expect(err).To(BeNil())
g.Expect(msg).ToNot(BeNil())
for replicaIdx, state := range test.expectedStates {
mv := test.store.models[test.modelKey].Latest()
g.Expect(mv).ToNot(BeNil())
Expand All @@ -700,6 +701,7 @@ func TestUpdateLoadedModels(t *testing.T) {
}
} else {
g.Expect(err).ToNot(BeNil())
g.Expect(msg).To(BeNil())
}
})
}
Expand Down

0 comments on commit ac839c3

Please sign in to comment.