Skip to content

Commit

Permalink
feat: support batch size (#290)
Browse files Browse the repository at this point in the history
Add batch size support in the LMEvalJob which
leverages the `--batch_size` in the `lm-evaluation-harness`.
This only affects the local models. The `--bath_size` doesn't
work for remote inference APIs.

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang authored Sep 12, 2024
1 parent 0d2393d commit d2b9b2f
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 19 deletions.
3 changes: 3 additions & 0 deletions api/lmes/v1alpha1/lmevaljob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ type LMEvalJobSpec struct {
EnvSecrets []EnvSecret `json:"envSecrets,omitempty"`
// Use secrets as files
FileSecrets []FileSecret `json:"fileSecrets,omitempty"`
// Batch size for the evaluation. This is used by the models that run and are loaded
// locally and not apply for the commercial APIs.
BatchSize *int `json:"batchSize,omitempty"`
}

// LMEvalJobStatus defines the observed state of LMEvalJob
Expand Down
5 changes: 5 additions & 0 deletions api/lmes/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions config/base/params.env
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ lmes-pod-checking-interval=10s
lmes-image-pull-policy=Always
lmes-grpc-service=lmes-grpc
lmes-grpc-port=8082
lmes-max-batch-size=24
lmes-default-batch-size=8
5 changes: 5 additions & 0 deletions config/crd/bases/trustyai.opendatahub.io_lmevaljobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ spec:
spec:
description: LMEvalJobSpec defines the desired state of LMEvalJob
properties:
batchSize:
description: Batch size for the evaluation. This is used by the models
that run and are loaded locally and not apply for the commercial
APIs.
type: integer
envSecrets:
description: Assign secrets to the environment variables
items:
Expand Down
2 changes: 1 addition & 1 deletion config/overlays/lmes/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ replacements:
- .data.lmes-grpc-service

patchesStrategicMerge:
- lmes-only-patch.yaml
- lmes-only-patch.yaml
2 changes: 1 addition & 1 deletion config/overlays/lmes/lmes-only-patch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ spec:
args:
- --leader-elect
- --enable-services
- "LMES"
- "LMES"
2 changes: 1 addition & 1 deletion config/overlays/rhoai/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ resources:
- ../../base

patchesStrategicMerge:
- tas-only-patch.yaml
- tas-only-patch.yaml
2 changes: 1 addition & 1 deletion config/overlays/rhoai/tas-only-patch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ spec:
args:
- --leader-elect
- --enable-services
- "TAS"
- "TAS"
4 changes: 4 additions & 0 deletions controllers/lmes/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ const (
GrpcServiceKey = "lmes-grpc-service"
GrpcServerSecretKey = "lmes-grpc-server-secret"
GrpcClientSecretKey = "lmes-grpc-client-secret"
MaxBatchSizeKey = "lmes-max-batch-size"
DefaultBatchSizeKey = "lmes-default-batch-size"
DriverReportIntervalKey = "driver-report-interval"
GrpcServerCertEnv = "GRPC_SERVER_CERT"
GrpcServerKeyEnv = "GRPC_SERVER_KEY"
Expand All @@ -47,5 +49,7 @@ const (
DefaultGrpcService = "lm-eval-grpc"
DefaultGrpcServerSecret = "grpc-server-cert"
DefaultGrpcClientSecret = "grpc-client-cert"
DefaultMaxBatchSize = 24
DefaultBatchSize = 8
ServiceName = "LMES"
)
31 changes: 24 additions & 7 deletions controllers/lmes/lmevaljob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ var (
"GrpcServerSecret": GrpcServerSecretKey,
"GrpcClientSecret": GrpcClientSecretKey,
"DriverReportInterval": DriverReportIntervalKey,
"DefaultBatchSize": DefaultBatchSizeKey,
"MaxBatchSize": MaxBatchSizeKey,
}
)

Expand Down Expand Up @@ -97,6 +99,8 @@ type ServiceOptions struct {
GrpcService string
GrpcServerSecret string
GrpcClientSecret string
MaxBatchSize int
DefaultBatchSize int
grpcTLSMode TLSMode
}

Expand Down Expand Up @@ -298,6 +302,8 @@ func (r *LMEvalJobReconciler) constructOptionsFromConfigMap(
GrpcService: DefaultGrpcService,
GrpcServerSecret: DefaultGrpcServerSecret,
GrpcClientSecret: DefaultGrpcClientSecret,
MaxBatchSize: DefaultMaxBatchSize,
DefaultBatchSize: DefaultBatchSize,
}

log := log.FromContext(ctx)
Expand All @@ -318,10 +324,10 @@ func (r *LMEvalJobReconciler) constructOptionsFromConfigMap(
case "string":
frv.SetString(v)
case "int":
var grpcPort int
grpcPort, err = strconv.Atoi(v)
var intVal int
intVal, err = strconv.Atoi(v)
if err == nil {
frv.SetInt(int64(grpcPort))
frv.SetInt(int64(intVal))
}
case "Duration":
var d time.Duration
Expand Down Expand Up @@ -397,7 +403,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger,

// construct a new pod and create a pod for the job
currentTime := v1.Now()
pod := r.createPod(job)
pod := r.createPod(job, log)
if err := r.Create(ctx, pod, &client.CreateOptions{}); err != nil {
// Failed to create the pod. Mark the status as complete with failed
job.Status.State = lmesv1alpha1.CompleteJobState
Expand Down Expand Up @@ -562,7 +568,7 @@ func (r *LMEvalJobReconciler) handleCancel(ctx context.Context, log logr.Logger,
return ctrl.Result{}, err
}

func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob) *corev1.Pod {
func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Logger) *corev1.Pod {
var allowPrivilegeEscalation = false
var runAsNonRootUser = true
var ownerRefController = true
Expand Down Expand Up @@ -641,7 +647,7 @@ func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob) *corev1.Pod
ImagePullPolicy: r.options.ImagePullPolicy,
Env: envVars,
Command: r.generateCmd(job),
Args: generateArgs(job),
Args: r.generateArgs(job, log),
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: &allowPrivilegeEscalation,
RunAsUser: &runAsUser,
Expand All @@ -667,7 +673,7 @@ func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob) *corev1.Pod
return &pod
}

func generateArgs(job *lmesv1alpha1.LMEvalJob) []string {
func (r *LMEvalJobReconciler) generateArgs(job *lmesv1alpha1.LMEvalJob, log logr.Logger) []string {
if job == nil {
return nil
}
Expand Down Expand Up @@ -699,6 +705,17 @@ func generateArgs(job *lmesv1alpha1.LMEvalJob) []string {
if job.Spec.LogSamples != nil && *job.Spec.LogSamples {
cmds = append(cmds, "--log_samples")
}
// --batch_size
var batchSize = r.options.DefaultBatchSize
if job.Spec.BatchSize != nil && *job.Spec.BatchSize > 0 {
batchSize = *job.Spec.BatchSize
}
// This could be done in the webhook if it's enabled.
if batchSize > r.options.MaxBatchSize {
batchSize = r.options.MaxBatchSize
log.Info("batchSize is greater than max-batch-size of the controller's configuration, use the max-batch-size instead")
}
cmds = append(cmds, "--batch_size", fmt.Sprintf("%d", batchSize))

return []string{"sh", "-ec", strings.Join(cmds, " ")}
}
Expand Down
79 changes: 71 additions & 8 deletions controllers/lmes/lmevaljob_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ limitations under the License.
package lmes

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
lmesv1alpha1 "github.com/trustyai-explainability/trustyai-service-operator/api/lmes/v1alpha1"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/log"
)

var (
Expand All @@ -34,6 +36,7 @@ var (
)

func Test_SimplePod(t *testing.T) {
log := log.FromContext(context.Background())
lmevalRec := LMEvalJobReconciler{
Namespace: "test",
options: &ServiceOptions{
Expand Down Expand Up @@ -115,7 +118,7 @@ func Test_SimplePod(t *testing.T) {
ImagePullPolicy: lmevalRec.options.ImagePullPolicy,
Env: []corev1.EnvVar{},
Command: lmevalRec.generateCmd(job),
Args: generateArgs(job),
Args: lmevalRec.generateArgs(job, log),
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: &allowPrivilegeEscalation,
RunAsUser: &runAsUser,
Expand Down Expand Up @@ -150,12 +153,13 @@ func Test_SimplePod(t *testing.T) {
},
}

newPod := lmevalRec.createPod(job)
newPod := lmevalRec.createPod(job, log)

assert.Equal(t, expect, newPod)
}

func Test_GrpcMTlsPod(t *testing.T) {
log := log.FromContext(context.Background())
lmevalRec := LMEvalJobReconciler{
Namespace: "test",
options: &ServiceOptions{
Expand Down Expand Up @@ -253,7 +257,7 @@ func Test_GrpcMTlsPod(t *testing.T) {
},
},
Command: lmevalRec.generateCmd(job),
Args: generateArgs(job),
Args: lmevalRec.generateArgs(job, log),
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: &allowPrivilegeEscalation,
RunAsUser: &runAsUser,
Expand Down Expand Up @@ -319,12 +323,13 @@ func Test_GrpcMTlsPod(t *testing.T) {
},
}

newPod := lmevalRec.createPod(job)
newPod := lmevalRec.createPod(job, log)

assert.Equal(t, expect, newPod)
}

func Test_EnvSecretsPod(t *testing.T) {
log := log.FromContext(context.Background())
lmevalRec := LMEvalJobReconciler{
Namespace: "test",
options: &ServiceOptions{
Expand Down Expand Up @@ -444,7 +449,7 @@ func Test_EnvSecretsPod(t *testing.T) {
},
},
Command: lmevalRec.generateCmd(job),
Args: generateArgs(job),
Args: lmevalRec.generateArgs(job, log),
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: &allowPrivilegeEscalation,
RunAsUser: &runAsUser,
Expand Down Expand Up @@ -510,12 +515,13 @@ func Test_EnvSecretsPod(t *testing.T) {
},
}

newPod := lmevalRec.createPod(job)
newPod := lmevalRec.createPod(job, log)
// maybe only verify the envs: Containers[0].Env
assert.Equal(t, expect, newPod)
}

func Test_FileSecretsPod(t *testing.T) {
log := log.FromContext(context.Background())
lmevalRec := LMEvalJobReconciler{
Namespace: "test",
options: &ServiceOptions{
Expand Down Expand Up @@ -627,7 +633,7 @@ func Test_FileSecretsPod(t *testing.T) {
},
},
Command: lmevalRec.generateCmd(job),
Args: generateArgs(job),
Args: lmevalRec.generateArgs(job, log),
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: &allowPrivilegeEscalation,
RunAsUser: &runAsUser,
Expand Down Expand Up @@ -712,7 +718,64 @@ func Test_FileSecretsPod(t *testing.T) {
},
}

newPod := lmevalRec.createPod(job)
newPod := lmevalRec.createPod(job, log)
// maybe only verify the envs: Containers[0].Env
assert.Equal(t, expect, newPod)
}

func Test_GenerateArgBatchSize(t *testing.T) {
log := log.FromContext(context.Background())
lmevalRec := LMEvalJobReconciler{
Namespace: "test",
options: &ServiceOptions{
PodImage: "podimage:latest",
DriverImage: "driver:latest",
ImagePullPolicy: corev1.PullAlways,
GrpcPort: 8088,
GrpcService: "grpc-service",
grpcTLSMode: TLSMode_None,
MaxBatchSize: 24,
DefaultBatchSize: 8,
},
}
var job = &lmesv1alpha1.LMEvalJob{
ObjectMeta: metav1.ObjectMeta{
Name: "test",
Namespace: "default",
UID: "for-testing",
},
TypeMeta: metav1.TypeMeta{
Kind: lmesv1alpha1.KindName,
APIVersion: lmesv1alpha1.Version,
},
Spec: lmesv1alpha1.LMEvalJobSpec{
Model: "test",
ModelArgs: []lmesv1alpha1.Arg{
{Name: "arg1", Value: "value1"},
},
Tasks: []string{"task1", "task2"},
},
}

// no batchSize in the job, use default batchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --device cpu --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 8",
}, lmevalRec.generateArgs(job, log))

// exceed the max-batch-size, use max-batch-size
var biggerBatchSize = 30
job.Spec.BatchSize = &biggerBatchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --device cpu --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 24",
}, lmevalRec.generateArgs(job, log))

// normal batchSize
var normalBatchSize = 16
job.Spec.BatchSize = &normalBatchSize
assert.Equal(t, []string{
"sh", "-ec",
"python -m lm_eval --output_path /opt/app-root/src/output --device cpu --model test --model_args arg1=value1 --tasks task1,task2 --batch_size 16",
}, lmevalRec.generateArgs(job, log))
}

0 comments on commit d2b9b2f

Please sign in to comment.