Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(scheduler): account for multiple instances of a model per server when scheduling #6054

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
607 changes: 469 additions & 138 deletions apis/go/mlops/agent/agent.pb.go

Large diffs are not rendered by default.

22 changes: 22 additions & 0 deletions apis/mlops/agent/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ message ModelEventMessage {
Event event = 5;
string message = 6;
uint64 availableMemoryBytes = 7;
ModelRuntimeInfo runtimeInfo = 8;
}

message ModelEventResponse {
Expand Down Expand Up @@ -92,8 +93,29 @@ message ModelOperationMessage {
message ModelVersion {
scheduler.Model model = 1;
uint32 version = 2;
ModelRuntimeInfo runtimeInfo = 3;
}

message ModelRuntimeInfo {
oneof modelRuntimeInfo {
MLServerModelSettings mlserver = 1;
TritonModelConfig triton = 2;
}
}

message MLServerModelSettings {
uint32 parallelWorkers = 1;
}

message TritonModelConfig {
repeated TritonCPU cpu = 1;
}

message TritonCPU {
uint32 instanceCount = 1;
}


// [END Messages]

// [START Services]
Expand Down
5 changes: 3 additions & 2 deletions scheduler/pkg/agent/agent_debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func setupService(numModels int, modelPrefix string, capacity int) *agentDebug {
}

func TestAgentDebugServiceSmoke(t *testing.T) {
//TODO break this down in proper tests
// TODO break this down in proper tests
g := NewGomegaWithT(t)

service := setupService(10, "dummy", 10)
Expand All @@ -60,6 +60,7 @@ func TestAgentDebugServiceSmoke(t *testing.T) {
MemoryBytes: &mem,
},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
)
g.Expect(err).To(BeNil())
Expand Down Expand Up @@ -87,7 +88,7 @@ func TestAgentDebugServiceSmoke(t *testing.T) {
}

func TestAgentDebugEarlyStop(t *testing.T) {
//TODO break this down in proper tests
// TODO break this down in proper tests
g := NewGomegaWithT(t)

service := setupService(10, "dummy", 10)
Expand Down
17 changes: 14 additions & 3 deletions scheduler/pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,15 +615,23 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage, timestamp int64
}
logger.Infof("Chose path %s for model %s:%d", *chosenVersionPath, modelName, modelVersion)

modelConfig, err := c.ModelRepository.GetModelRuntimeInfo(modelWithVersion)
if err != nil {
logger.Errorf("there was a problem getting the config for model: %s", modelName)
}

// TODO: consider whether we need the actual protos being sent to `LoadModelVersion`?
modifiedModelVersionRequest := getModifiedModelVersion(
modelWithVersion,
pinnedModelVersion,
request.GetModelVersion(),
modelConfig,
)

loaderFn := func() error {
return c.stateManager.LoadModelVersion(modifiedModelVersionRequest)
}

if err := backoffWithMaxNumRetry(loaderFn, c.settings.maxLoadRetryCount, c.settings.maxLoadElapsedTime, logger); err != nil {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, err)
c.cleanup(modelWithVersion)
Expand All @@ -641,7 +649,8 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage, timestamp int64
}

logger.Infof("Load model %s:%d success", modelName, modelVersion)
return c.sendAgentEvent(modelName, modelVersion, agent.ModelEventMessage_LOADED)

return c.sendAgentEvent(modelName, modelVersion, modelConfig, agent.ModelEventMessage_LOADED)
}

func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int64) error {
Expand Down Expand Up @@ -674,7 +683,7 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int
defer c.modelTimestamps.Store(modelWithVersion, timestamp)

// we do not care about model versions here
modifiedModelVersionRequest := getModifiedModelVersion(modelWithVersion, pinnedModelVersion, request.GetModelVersion())
modifiedModelVersionRequest := getModifiedModelVersion(modelWithVersion, pinnedModelVersion, request.GetModelVersion(), nil)

unloaderFn := func() error {
return c.stateManager.UnloadModelVersion(modifiedModelVersionRequest)
Expand Down Expand Up @@ -702,7 +711,7 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int
}

logger.Infof("Unload model %s:%d success", modelName, modelVersion)
return c.sendAgentEvent(modelName, modelVersion, agent.ModelEventMessage_UNLOADED)
return c.sendAgentEvent(modelName, modelVersion, nil, agent.ModelEventMessage_UNLOADED)
}

func (c *Client) cleanup(modelWithVersion string) {
Expand Down Expand Up @@ -742,6 +751,7 @@ func (c *Client) sendModelEventError(
func (c *Client) sendAgentEvent(
modelName string,
modelVersion uint32,
modelRuntimeInfo *agent.ModelRuntimeInfo,
event agent.ModelEventMessage_Event,
) error {
// if the server is draining and the model load has succeeded, we need to "cancel"
Expand All @@ -765,6 +775,7 @@ func (c *Client) sendAgentEvent(
ModelVersion: modelVersion,
Event: event,
AvailableMemoryBytes: c.stateManager.GetAvailableMemoryBytesWithOverCommit(),
RuntimeInfo: modelRuntimeInfo,
})
return err
}
Expand Down
19 changes: 19 additions & 0 deletions scheduler/pkg/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes/fake"

"github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pb "github.com/seldonio/seldon-core/apis/go/v2/mlops/agent"
pbs "github.com/seldonio/seldon-core/apis/go/v2/mlops/scheduler"

Expand All @@ -51,6 +52,7 @@ type mockAgentV2Server struct {
unloadFailedEvents int
otherEvents int
errors int
events []*pb.ModelEventMessage
}

type FakeModelRepository struct {
Expand All @@ -64,6 +66,10 @@ func (f *FakeModelRepository) RemoveModelVersion(modelName string) error {
return nil
}

func (f *FakeModelRepository) GetModelRuntimeInfo(modelName string) (*pb.ModelRuntimeInfo, error) {
return &pb.ModelRuntimeInfo{ModelRuntimeInfo: &pb.ModelRuntimeInfo_Mlserver{Mlserver: &agent.MLServerModelSettings{ParallelWorkers: uint32(1)}}}, nil
}

func (f *FakeModelRepository) DownloadModelVersion(modelName string, version uint32, modelSpec *pbs.ModelSpec, config []byte) (*string, error) {
f.modelDownloads++
if f.err != nil {
Expand Down Expand Up @@ -147,6 +153,7 @@ func (m *mockAgentV2Server) AgentEvent(ctx context.Context, message *pb.ModelEve
default:
m.otherEvents++
}
m.events = append(m.events, message)
return &pb.ModelEventResponse{}, nil
}

Expand Down Expand Up @@ -247,6 +254,7 @@ func TestLoadModel(t *testing.T) {
models []string
replicaConfig *pb.ReplicaConfig
op *pb.ModelOperationMessage
modelConfig *pb.ModelRuntimeInfo
expectedAvailableMemory uint64
v2Status int
modelRepoErr error
Expand All @@ -270,9 +278,11 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 500,
v2Status: 200,
success: true,
Expand All @@ -289,10 +299,12 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
AutoscalingEnabled: true,
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 500,
v2Status: 200,
success: true,
Expand All @@ -310,9 +322,11 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 1000,
v2Status: 400,
success: false,
Expand All @@ -329,9 +343,11 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &largeMemory},
},
RuntimeInfo: getModelRuntimeInfo(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelRuntimeInfo(1),
expectedAvailableMemory: 1000,
v2Status: 200,
success: false,
Expand Down Expand Up @@ -399,6 +415,9 @@ func TestLoadModel(t *testing.T) {
g.Expect(err).To(BeNil())
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(1))
g.Expect(mockAgentV2Server.loadFailedEvents).To(Equal(0))
g.Expect(len(mockAgentV2Server.events)).To(Equal(1))
g.Expect(mockAgentV2Server.events[0].RuntimeInfo).ToNot(BeNil())
g.Expect(mockAgentV2Server.events[0].RuntimeInfo.GetMlserver().ParallelWorkers).To(Equal(uint32(1)))
g.Expect(client.stateManager.GetAvailableMemoryBytes()).To(Equal(test.expectedAvailableMemory))
g.Expect(modelRepository.modelRemovals).To(Equal(0))
loadedVersions := client.stateManager.modelVersions.getVersionsForAllModels()
Expand Down
3 changes: 2 additions & 1 deletion scheduler/pkg/agent/client_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ func isReady(service interfaces.DependencyServiceInterface, logger *log.Entry, m
return backoff.RetryNotify(readyToError, backoffWithMax, logFailure)
}

func getModifiedModelVersion(modelId string, version uint32, originalModelVersion *agent.ModelVersion) *agent.ModelVersion {
func getModifiedModelVersion(modelId string, version uint32, originalModelVersion *agent.ModelVersion, modelRuntimeInfo *agent.ModelRuntimeInfo) *agent.ModelVersion {
mv := proto.Clone(originalModelVersion)
mv.(*agent.ModelVersion).Model.Meta.Name = modelId
mv.(*agent.ModelVersion).Version = version
mv.(*agent.ModelVersion).RuntimeInfo = modelRuntimeInfo
return mv.(*agent.ModelVersion)
}

Expand Down
19 changes: 15 additions & 4 deletions scheduler/pkg/agent/model_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func (modelState *ModelState) addModelVersionImpl(modelVersionDetails *agent.Mod
modelName, versionId, exsistingVersion.getVersion())
}
}

}

// Remove model version and return true if no versions left (in which case we remove from map)
Expand All @@ -70,7 +69,6 @@ func (modelState *ModelState) removeModelVersion(modelVersionDetails *agent.Mode
}

func (modelState *ModelState) removeModelVersionImpl(modelVersionDetails *agent.ModelVersion) (bool, error) {

modelName := modelVersionDetails.GetModel().GetMeta().GetName()
versionId := modelVersionDetails.GetVersion()

Expand Down Expand Up @@ -143,7 +141,8 @@ func (modelState *ModelState) getVersionsForAllModels() []*agent.ModelVersion {
mv := version.get()
versionedModelName := mv.Model.GetMeta().Name
originalModelName, originalModelVersion, _ := util.GetOrignalModelNameAndVersion(versionedModelName)
loadedModels = append(loadedModels, getModifiedModelVersion(originalModelName, originalModelVersion, mv))
modelRuntimeInfo := mv.RuntimeInfo
loadedModels = append(loadedModels, getModifiedModelVersion(originalModelName, originalModelVersion, mv, modelRuntimeInfo))
}
return loadedModels
}
Expand All @@ -153,7 +152,19 @@ type modelVersion struct {
}

func (version *modelVersion) getVersionMemory() uint64 {
return version.versionInfo.GetModel().GetModelSpec().GetMemoryBytes()
instanceCount := getInstanceCount(version)
return version.versionInfo.GetModel().GetModelSpec().GetMemoryBytes() * instanceCount
}

func getInstanceCount(version *modelVersion) uint64 {
switch version.versionInfo.RuntimeInfo.ModelRuntimeInfo.(type) {
case *agent.ModelRuntimeInfo_Mlserver:
return uint64(version.versionInfo.GetRuntimeInfo().GetMlserver().ParallelWorkers)
case *agent.ModelRuntimeInfo_Triton:
return uint64(version.versionInfo.GetRuntimeInfo().GetTriton().Cpu[0].InstanceCount)
default:
return 1
}
}

func (version *modelVersion) getVersion() uint32 {
Expand Down
Loading
Loading