Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: server min max replica and scaling #6196

Draft
wants to merge 1 commit into
base: v2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,343 changes: 720 additions & 623 deletions apis/go/mlops/scheduler/scheduler.pb.go

Large diffs are not rendered by default.

16 changes: 15 additions & 1 deletion apis/mlops/scheduler/scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,18 @@ message ServerStatusRequest {
/* ServerStatusResponse provides details of current server status
*/
message ServerStatusResponse {
/* Type of SterverStatus update. At the moment the scheduler doesn't combine multiple types of
* updates in the same response. However, the Type enum is forward-compatible with this
* possibility, by setting members to power-of-two values. This means enum values can be used
* as flags and combined with bitwise OR, with the exception of StatusResponseTypeUnknown.
*/
enum Type {
StatusResponseTypeUnknown = 0;
StatusUpdate = 1;
NonAuthoritativeReplicaInfo = 2;
ScalingRequest = 4;
}
Type type = 7;
string serverName = 1;
repeated ServerReplicaResources resources = 2;
int32 expectedReplicas = 3;
Expand Down Expand Up @@ -191,7 +203,9 @@ message ServerNotifyRequest {

message ServerNotify {
string name = 1;
int32 expectedReplicas = 2;
uint32 expectedReplicas = 2;
uint32 minReplicas = 5;
uint32 maxReplicas = 6;
bool shared = 3;
optional KubernetesMeta kubernetesMeta = 4;
}
Expand Down
33 changes: 6 additions & 27 deletions operator/apis/mlops/v1alpha1/model_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,34 +191,13 @@ func (m Model) AsSchedulerModel() (*scheduler.Model, error) {
md.ModelSpec.Requirements = append(md.ModelSpec.Requirements, *m.Spec.ModelType)
}
// Set Replicas
if m.Spec.Replicas != nil {
md.DeploymentSpec.Replicas = uint32(*m.Spec.Replicas)
} else {
if m.Spec.MinReplicas != nil {
// set replicas to the min replicas if not set
md.DeploymentSpec.Replicas = uint32(*m.Spec.MinReplicas)
} else {
md.DeploymentSpec.Replicas = 1
}
}

if m.Spec.MinReplicas != nil {
md.DeploymentSpec.MinReplicas = uint32(*m.Spec.MinReplicas)
if md.DeploymentSpec.Replicas < md.DeploymentSpec.MinReplicas {
return nil, fmt.Errorf("Number of replicas %d should be >= min replicas %d", md.DeploymentSpec.Replicas, md.DeploymentSpec.MinReplicas)
}
} else {
md.DeploymentSpec.MinReplicas = 0
}

if m.Spec.MaxReplicas != nil {
md.DeploymentSpec.MaxReplicas = uint32(*m.Spec.MaxReplicas)
if md.DeploymentSpec.Replicas > md.DeploymentSpec.MaxReplicas {
return nil, fmt.Errorf("Number of replicas %d should be <= max replicas %d", md.DeploymentSpec.Replicas, md.DeploymentSpec.MaxReplicas)
}
} else {
md.DeploymentSpec.MaxReplicas = 0
scalingSpec, err := GetValidatedScalingSpec(m.Spec.Replicas, m.Spec.MinReplicas, m.Spec.MaxReplicas)
if err != nil {
return nil, err
}
md.DeploymentSpec.Replicas = scalingSpec.Replicas
md.DeploymentSpec.MinReplicas = scalingSpec.MinReplicas
md.DeploymentSpec.MaxReplicas = scalingSpec.MaxReplicas

// Set memory bytes
if m.Spec.Memory != nil {
Expand Down
54 changes: 54 additions & 0 deletions operator/apis/mlops/v1alpha1/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
Copyright (c) 2024 Seldon Technologies Ltd.

Use of this software is governed by
(1) the license included in the LICENSE file or
(2) if the license included in the LICENSE file is the Business Source License 1.1,
the Change License after the Change Date as each is defined in accordance with the LICENSE file.
*/

package v1alpha1

import "fmt"

type ValidatedScalingSpec struct {
Replicas uint32
MinReplicas uint32
MaxReplicas uint32
}


func GetValidatedScalingSpec(replicas *int32, minReplicas *int32, maxReplicas *int32) (*ValidatedScalingSpec, error) {
var validatedSpec ValidatedScalingSpec

if replicas != nil && *replicas > 0 {
validatedSpec.Replicas = uint32(*replicas)
} else {
if minReplicas != nil && *minReplicas > 0 {
// set replicas to the min replicas when replicas is not set explicitly
validatedSpec.Replicas = uint32(*minReplicas)
} else {
validatedSpec.Replicas = 1
}
}

if minReplicas != nil && *minReplicas > 0 {
validatedSpec.MinReplicas = uint32(*minReplicas)
if validatedSpec.Replicas < validatedSpec.MinReplicas {
return nil, fmt.Errorf("number of replicas %d must be >= min replicas %d", validatedSpec.Replicas, validatedSpec.MinReplicas)
}
} else {
validatedSpec.MinReplicas = 0
}

if maxReplicas != nil && *maxReplicas > 0 {
validatedSpec.MaxReplicas = uint32(*maxReplicas)
if validatedSpec.Replicas > validatedSpec.MaxReplicas {
return nil, fmt.Errorf("number of replicas %d must be <= min replicas %d", validatedSpec.Replicas, validatedSpec.MaxReplicas)
}
} else {
validatedSpec.MaxReplicas = 0
}

return &validatedSpec, nil
}
72 changes: 68 additions & 4 deletions operator/scheduler/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ func (s *SchedulerClient) ServerNotify(ctx context.Context, grpcClient scheduler
if len(servers) == 0 {
return nil
}

if grpcClient == nil {
// we assume that all servers are in the same namespace
namespace := servers[0].Namespace
Expand All @@ -39,6 +38,20 @@ func (s *SchedulerClient) ServerNotify(ctx context.Context, grpcClient scheduler
grpcClient = scheduler.NewSchedulerClient(conn)
}

var scalingSpec *v1alpha1.ValidatedScalingSpec
if !server.ObjectMeta.DeletionTimestamp.IsZero() {
scalingSpec = &v1alpha1.ValidatedScalingSpec{
Replicas: 0,
MinReplicas: 0,
MaxReplicas: 0,
}
} else {
scalingSpec, err = v1alpha1.GetValidatedScalingSpec(server.Spec.Replicas, server.Spec.MinReplicas, server.Spec.MaxReplicas)
if err != nil {
return err
}
}

var requests []*scheduler.ServerNotify
for _, server := range servers {
var replicas int32
Expand All @@ -54,6 +67,8 @@ func (s *SchedulerClient) ServerNotify(ctx context.Context, grpcClient scheduler
requests = append(requests, &scheduler.ServerNotify{
Name: server.GetName(),
ExpectedReplicas: replicas,
MinReplicas: scalingSpec.MinReplicas,
MaxReplicas: scalingSpec.MaxReplicas,
KubernetesMeta: &scheduler.KubernetesMeta{
Namespace: server.GetNamespace(),
Generation: server.GetGeneration(),
Expand Down Expand Up @@ -127,9 +142,29 @@ func (s *SchedulerClient) SubscribeServerEvents(ctx context.Context, grpcClient
logger.Info("Ignoring event for old generation", "currentGeneration", server.Generation, "eventGeneration", event.GetKubernetesMeta().Generation, "server", event.ServerName)
return nil
}
// Handle status update
server.Status.LoadedModelReplicas = event.NumLoadedModelReplicas
return s.updateServerStatus(server)

// The types of updates we may get from the scheduler are:
// 1. Status updates
// 2. Requests for changing the number of server replicas
// 3. Updates containing non-authoritative replica info, because the scheduler is in a
// discovery phase (just starting up, after a restart)
//
// At the moment, the scheduler doesn't send multiple types of updates in a single event;
switch event.GetType() {
case scheduler.ServerStatusResponse_StatusUpdate:
return s.applyStatusUpdates(ctx, server, event)
case scheduler.ServerStatusResponse_ScalingRequest:
if event.ExpectedReplicas != event.AvailableReplicas {
return s.applyReplicaUpdates(ctx, server, event)
} else {
return nil
}
case scheduler.ServerStatusResponse_NonAuthoritativeReplicaInfo:
// skip updating replica info, only update status
return s.updateServerStatus(server)
default: // we ignore unknown event types
return nil
}
})
if retryErr != nil {
logger.Error(err, "Failed to update status", "model", event.ServerName)
Expand All @@ -147,3 +182,32 @@ func (s *SchedulerClient) updateServerStatus(server *v1alpha1.Server) error {
}
return nil
}

// when need to notify the scheduler about existing Server configuration
func handleRegisteredServers(
ctx context.Context, namespace string, s *SchedulerClient, grpcClient scheduler.SchedulerClient) {
serverList := &v1alpha1.ServerList{}
// Get all servers in the namespace
err := s.List(
ctx,
serverList,
client.InNamespace(namespace),
)
if err != nil {
return
}

for _, server := range serverList.Items {
// servers that are not in the process of being deleted has DeletionTimestamp as zero
if server.ObjectMeta.DeletionTimestamp.IsZero() {
s.logger.V(1).Info("Calling NotifyServer (on reconnect)", "server", server.Name)
if err := s.ServerNotify(ctx, &server); err != nil {
s.logger.Error(err, "Failed to notify scheduler about initial Server parameters", "server", server.Name)
} else {
s.logger.V(1).Info("Load model called successfully", "server", server.Name)
}
} else {
s.logger.V(1).Info("Server being deleted, not notifying", "server", server.Name)
}
}
}
3 changes: 2 additions & 1 deletion scheduler/pkg/coordinator/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ func (h *EventHub) RegisterServerEventHandler(
}
}()

handler := h.newServerEventHandler(logger, events)
handler := h.newServerEventHandler(logger, events, handle)
h.bus.RegisterHandler(name, handler)
}

func (h *EventHub) newServerEventHandler(
logger log.FieldLogger,
events chan ServerEventMsg,
_ func(event ServerEventMsg),
) busV3.Handler {
handleServerEventMessage := func(_ context.Context, e busV3.Event) {
l := logger.WithField("func", "handleServerEventMessage")
Expand Down
11 changes: 11 additions & 0 deletions scheduler/pkg/coordinator/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@ package coordinator

import "fmt"

type ModelEventUpdateContext int
type ServerEventUpdateContext int

const (
SERVER_STATUS_UPDATE ServerEventUpdateContext = iota
SERVER_REPLICA_CONNECTED
)

const (
SERVER_STATUS_UPDATE ServerEventUpdateContext = iota
SERVER_REPLICA_CONNECTED
)

const (
SERVER_STATUS_UPDATE ServerEventUpdateContext = iota
SERVER_REPLICA_CONNECTED
)

type ModelEventMsg struct {
ModelName string
ModelVersion uint32
Expand Down
39 changes: 30 additions & 9 deletions scheduler/pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,29 @@ import (
)

const (
grpcMaxConcurrentStreams = 1_000_000
pendingEventsQueueSize int = 1000
modelEventHandlerName = "scheduler.server.models"
serverEventHandlerName = "scheduler.server.servers"
experimentEventHandlerName = "scheduler.server.experiments"
pipelineEventHandlerName = "scheduler.server.pipelines"
defaultBatchWait = 250 * time.Millisecond
sendTimeout = 30 * time.Second // Timeout for sending events to subscribers via grpc `sendMsg`
grpcMaxConcurrentStreams = 1_000_000
pendingEventsQueueSize int = 1000
modelEventHandlerName = "scheduler.server.models"
serverEventHandlerName = "scheduler.server.servers"
serverModelEventHandlerName = "scheduler.server.servers.models"
experimentEventHandlerName = "scheduler.server.experiments"
pipelineEventHandlerName = "scheduler.server.pipelines"
defaultBatchWait = 250 * time.Millisecond
sendTimeout = 30 * time.Second // Timeout for sending events to subscribers via grpc `sendMsg`
)

var (
ErrAddServerEmptyServerName = status.Errorf(codes.FailedPrecondition, "Empty server name passed")
)

type SchedulerSyncPhase uint8

const (
SCHEDULER_SYNC_INIT SchedulerSyncPhase = iota
SCHEDULER_SYNC_PARTIAL
SCHEDULER_SYNC_ONLINE
)

type SchedulerServer struct {
pb.UnimplementedSchedulerServer
logger log.FieldLogger
Expand Down Expand Up @@ -77,6 +86,7 @@ type ServerEventStream struct {
trigger *time.Timer
pendingEvents map[string]struct{}
pendingLock sync.Mutex
syncPhase SchedulerSyncPhase
}

type ExperimentEventStream struct {
Expand Down Expand Up @@ -191,6 +201,7 @@ func NewSchedulerServer(
batchWait: defaultBatchWait,
trigger: nil,
pendingEvents: map[string]struct{}{},
syncPhase: SCHEDULER_SYNC_INIT,
},
pipelineEventStream: PipelineEventStream{
streams: make(map[pb.Scheduler_SubscribePipelineStatusServer]*PipelineSubscription),
Expand All @@ -208,12 +219,21 @@ func NewSchedulerServer(
s.logger,
s.handleModelEvent,
)
eventHub.RegisterModelEventHandler(

eventHub.RegisterServerEventHandler(
serverEventHandlerName,
pendingEventsQueueSize,
s.logger,
s.handleServerEvent,
)

eventHub.RegisterModelEventHandler(
serverModelEventHandlerName,
pendingEventsQueueSize,
s.logger,
s.handleServerModelEvent,
)

eventHub.RegisterExperimentEventHandler(
experimentEventHandlerName,
pendingEventsQueueSize,
Expand Down Expand Up @@ -446,6 +466,7 @@ func createServerStatusResponse(s *store.ServerSnapshot) *pb.ServerStatusRespons
// note we dont count draining replicas in available replicas

resp := &pb.ServerStatusResponse{
Type: pb.ServerStatusResponse_StatusUpdate,
ServerName: s.Name,
ExpectedReplicas: int32(s.ExpectedReplicas),
KubernetesMeta: s.KubernetesMeta,
Expand Down
Loading