Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(scheduler): account for multiple instances of a model per server when scheduling #6054

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
841 changes: 618 additions & 223 deletions apis/go/mlops/agent/agent.pb.go

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions apis/mlops/agent/agent.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,35 @@ message ModelEventMessage {
Event event = 5;
string message = 6;
uint64 availableMemoryBytes = 7;
ModelConfig modelConfig = 8;
}

message ModelConfig {
driev marked this conversation as resolved.
Show resolved Hide resolved
enum Type {
UNKNOWN_TYPE = 0;
MLSERVER = 1;
TRITON = 2;
}

Type type = 1;
driev marked this conversation as resolved.
Show resolved Hide resolved
oneof config {
MLServerModelConfig mlserver = 2;
TritonModelConfig triton = 3;
}
}

message MLServerModelConfig {
uint32 instanceCount = 1;
}

message TritonModelConfig {
TritonCPU cpu = 1;
}
driev marked this conversation as resolved.
Show resolved Hide resolved

message TritonCPU {
uint32 instanceCount = 1;
}

message ModelEventResponse {

}
Expand Down Expand Up @@ -92,6 +119,7 @@ message ModelOperationMessage {
message ModelVersion {
scheduler.Model model = 1;
uint32 version = 2;
ModelConfig modelConfig = 3;
}

// [END Messages]
Expand Down
5 changes: 3 additions & 2 deletions scheduler/pkg/agent/agent_debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func setupService(numModels int, modelPrefix string, capacity int) *agentDebug {
}

func TestAgentDebugServiceSmoke(t *testing.T) {
//TODO break this down in proper tests
// TODO break this down in proper tests
g := NewGomegaWithT(t)

service := setupService(10, "dummy", 10)
Expand All @@ -60,6 +60,7 @@ func TestAgentDebugServiceSmoke(t *testing.T) {
MemoryBytes: &mem,
},
},
ModelConfig: getModelConfig(1),
},
)
g.Expect(err).To(BeNil())
Expand Down Expand Up @@ -87,7 +88,7 @@ func TestAgentDebugServiceSmoke(t *testing.T) {
}

func TestAgentDebugEarlyStop(t *testing.T) {
//TODO break this down in proper tests
// TODO break this down in proper tests
g := NewGomegaWithT(t)

service := setupService(10, "dummy", 10)
Expand Down
17 changes: 14 additions & 3 deletions scheduler/pkg/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,15 +615,23 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage, timestamp int64
}
logger.Infof("Chose path %s for model %s:%d", *chosenVersionPath, modelName, modelVersion)

modelConfig, err := c.ModelRepository.GetModelConfig(modelName)
if err != nil {
logger.Errorf("there was a problem getting the config for model: %s", modelName)
}

// TODO: consider whether we need the actual protos being sent to `LoadModelVersion`?
modifiedModelVersionRequest := getModifiedModelVersion(
modelWithVersion,
pinnedModelVersion,
request.GetModelVersion(),
modelConfig,
)

loaderFn := func() error {
return c.stateManager.LoadModelVersion(modifiedModelVersionRequest)
}

if err := backoffWithMaxNumRetry(loaderFn, c.settings.maxLoadRetryCount, c.settings.maxLoadElapsedTime, logger); err != nil {
c.sendModelEventError(modelName, modelVersion, agent.ModelEventMessage_LOAD_FAILED, err)
c.cleanup(modelWithVersion)
Expand All @@ -641,7 +649,8 @@ func (c *Client) LoadModel(request *agent.ModelOperationMessage, timestamp int64
}

logger.Infof("Load model %s:%d success", modelName, modelVersion)
return c.sendAgentEvent(modelName, modelVersion, agent.ModelEventMessage_LOADED)

return c.sendAgentEvent(modelName, modelVersion, modelConfig, agent.ModelEventMessage_LOADED)
}

func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int64) error {
Expand Down Expand Up @@ -674,7 +683,7 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int
defer c.modelTimestamps.Store(modelWithVersion, timestamp)

// we do not care about model versions here
modifiedModelVersionRequest := getModifiedModelVersion(modelWithVersion, pinnedModelVersion, request.GetModelVersion())
modifiedModelVersionRequest := getModifiedModelVersion(modelWithVersion, pinnedModelVersion, request.GetModelVersion(), nil)

unloaderFn := func() error {
return c.stateManager.UnloadModelVersion(modifiedModelVersionRequest)
Expand Down Expand Up @@ -702,7 +711,7 @@ func (c *Client) UnloadModel(request *agent.ModelOperationMessage, timestamp int
}

logger.Infof("Unload model %s:%d success", modelName, modelVersion)
return c.sendAgentEvent(modelName, modelVersion, agent.ModelEventMessage_UNLOADED)
return c.sendAgentEvent(modelName, modelVersion, nil, agent.ModelEventMessage_UNLOADED)
}

func (c *Client) cleanup(modelWithVersion string) {
Expand Down Expand Up @@ -742,6 +751,7 @@ func (c *Client) sendModelEventError(
func (c *Client) sendAgentEvent(
modelName string,
modelVersion uint32,
modelConfig *agent.ModelConfig,
event agent.ModelEventMessage_Event,
) error {
// if the server is draining and the model load has succeeded, we need to "cancel"
Expand All @@ -765,6 +775,7 @@ func (c *Client) sendAgentEvent(
ModelVersion: modelVersion,
Event: event,
AvailableMemoryBytes: c.stateManager.GetAvailableMemoryBytesWithOverCommit(),
ModelConfig: modelConfig,
})
return err
}
Expand Down
19 changes: 19 additions & 0 deletions scheduler/pkg/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type mockAgentV2Server struct {
unloadFailedEvents int
otherEvents int
errors int
events []*pb.ModelEventMessage
}

type FakeModelRepository struct {
Expand All @@ -64,6 +65,11 @@ func (f *FakeModelRepository) RemoveModelVersion(modelName string) error {
return nil
}

func (f *FakeModelRepository) GetModelConfig(modelName string) (*pb.ModelConfig, error) {
modelConfig := &pb.ModelConfig_Mlserver{Mlserver: &pb.MLServerModelConfig{InstanceCount: uint32(1)}}
return &pb.ModelConfig{Type: pb.ModelConfig_MLSERVER, Config: modelConfig}, nil
}

func (f *FakeModelRepository) DownloadModelVersion(modelName string, version uint32, modelSpec *pbs.ModelSpec, config []byte) (*string, error) {
f.modelDownloads++
if f.err != nil {
Expand Down Expand Up @@ -147,6 +153,7 @@ func (m *mockAgentV2Server) AgentEvent(ctx context.Context, message *pb.ModelEve
default:
m.otherEvents++
}
m.events = append(m.events, message)
return &pb.ModelEventResponse{}, nil
}

Expand Down Expand Up @@ -247,6 +254,7 @@ func TestLoadModel(t *testing.T) {
models []string
replicaConfig *pb.ReplicaConfig
op *pb.ModelOperationMessage
modelConfig *pb.ModelConfig
expectedAvailableMemory uint64
v2Status int
modelRepoErr error
Expand All @@ -270,9 +278,11 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
ModelConfig: getModelConfig(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelConfig(1),
expectedAvailableMemory: 500,
v2Status: 200,
success: true,
Expand All @@ -289,10 +299,12 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
ModelConfig: getModelConfig(1),
},
AutoscalingEnabled: true,
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelConfig(1),
expectedAvailableMemory: 500,
v2Status: 200,
success: true,
Expand All @@ -310,9 +322,11 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &smallMemory},
},
ModelConfig: getModelConfig(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelConfig(1),
expectedAvailableMemory: 1000,
v2Status: 400,
success: false,
Expand All @@ -329,9 +343,11 @@ func TestLoadModel(t *testing.T) {
},
ModelSpec: &pbs.ModelSpec{Uri: "gs://model", MemoryBytes: &largeMemory},
},
ModelConfig: getModelConfig(1),
},
},
replicaConfig: &pb.ReplicaConfig{MemoryBytes: 1000},
modelConfig: getModelConfig(1),
expectedAvailableMemory: 1000,
v2Status: 200,
success: false,
Expand Down Expand Up @@ -399,6 +415,9 @@ func TestLoadModel(t *testing.T) {
g.Expect(err).To(BeNil())
g.Expect(mockAgentV2Server.loadedEvents).To(Equal(1))
g.Expect(mockAgentV2Server.loadFailedEvents).To(Equal(0))
g.Expect(len(mockAgentV2Server.events)).To(Equal(1))
g.Expect(mockAgentV2Server.events[0].ModelConfig).ToNot(BeNil())
g.Expect(mockAgentV2Server.events[0].ModelConfig.GetMlserver().InstanceCount).To(Equal(uint32(1)))
g.Expect(client.stateManager.GetAvailableMemoryBytes()).To(Equal(test.expectedAvailableMemory))
g.Expect(modelRepository.modelRemovals).To(Equal(0))
loadedVersions := client.stateManager.modelVersions.getVersionsForAllModels()
Expand Down
3 changes: 2 additions & 1 deletion scheduler/pkg/agent/client_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@ func isReady(service interfaces.DependencyServiceInterface, logger *log.Entry, m
return backoff.RetryNotify(readyToError, backoffWithMax, logFailure)
}

func getModifiedModelVersion(modelId string, version uint32, originalModelVersion *agent.ModelVersion) *agent.ModelVersion {
func getModifiedModelVersion(modelId string, version uint32, originalModelVersion *agent.ModelVersion, modelConfig *agent.ModelConfig) *agent.ModelVersion {
mv := proto.Clone(originalModelVersion)
mv.(*agent.ModelVersion).Model.Meta.Name = modelId
mv.(*agent.ModelVersion).Version = version
mv.(*agent.ModelVersion).ModelConfig = modelConfig
return mv.(*agent.ModelVersion)
}

Expand Down
20 changes: 16 additions & 4 deletions scheduler/pkg/agent/model_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func (modelState *ModelState) addModelVersionImpl(modelVersionDetails *agent.Mod
modelName, versionId, exsistingVersion.getVersion())
}
}

}

// Remove model version and return true if no versions left (in which case we remove from map)
Expand All @@ -70,7 +69,6 @@ func (modelState *ModelState) removeModelVersion(modelVersionDetails *agent.Mode
}

func (modelState *ModelState) removeModelVersionImpl(modelVersionDetails *agent.ModelVersion) (bool, error) {

modelName := modelVersionDetails.GetModel().GetMeta().GetName()
versionId := modelVersionDetails.GetVersion()

Expand Down Expand Up @@ -143,7 +141,8 @@ func (modelState *ModelState) getVersionsForAllModels() []*agent.ModelVersion {
mv := version.get()
versionedModelName := mv.Model.GetMeta().Name
originalModelName, originalModelVersion, _ := util.GetOrignalModelNameAndVersion(versionedModelName)
loadedModels = append(loadedModels, getModifiedModelVersion(originalModelName, originalModelVersion, mv))
modelConfig := mv.ModelConfig
loadedModels = append(loadedModels, getModifiedModelVersion(originalModelName, originalModelVersion, mv, modelConfig))
}
return loadedModels
}
Expand All @@ -153,7 +152,20 @@ type modelVersion struct {
}

func (version *modelVersion) getVersionMemory() uint64 {
return version.versionInfo.GetModel().GetModelSpec().GetMemoryBytes()
instanceCount := getInstanceCount(version)
return version.versionInfo.GetModel().GetModelSpec().GetMemoryBytes() * instanceCount
}

func getInstanceCount(version *modelVersion) uint64 {
modelConfigType := version.versionInfo.ModelConfig.Type
switch modelConfigType {
case agent.ModelConfig_MLSERVER:
return uint64(version.versionInfo.ModelConfig.GetMlserver().InstanceCount)
case agent.ModelConfig_TRITON:
return uint64(version.versionInfo.ModelConfig.GetTriton().Cpu.InstanceCount)
default:
return 1
}
}

func (version *modelVersion) getVersion() uint32 {
Expand Down
Loading
Loading