diff --git a/api/v1alpha1/params_validation.go b/api/v1alpha1/params_validation.go index e4b034e9a..c73ede68a 100644 --- a/api/v1alpha1/params_validation.go +++ b/api/v1alpha1/params_validation.go @@ -16,6 +16,7 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/klog/v2" "knative.dev/pkg/apis" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -107,7 +108,7 @@ func UnmarshalTrainingConfig(cm *corev1.ConfigMap) (*Config, *apis.FieldError) { return &config, nil } -func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap) *apis.FieldError { +func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap, modelName, methodLowerCase, sku string) *apis.FieldError { config, err := UnmarshalTrainingConfig(cm) if err != nil { return err @@ -136,13 +137,47 @@ func validateTrainingArgsViaConfigMap(cm *corev1.ConfigMap) *apis.FieldError { } } - // TODO: Here we perform the tuning GPU Memory Checks! - fmt.Println(trainingArgsRaw) + // Validate GPU Memory Requirements for batch size of 1 using model and tuning method + errs := validateTuningParameters(modelName, methodLowerCase, sku) + if errs != nil { + return errs + } } } return nil } +func validateTuningParameters(modelName, methodLowerCase, sku string) *apis.FieldError { + skuHandler, err := utils.GetSKUHandler() + if err != nil { + return apis.ErrInvalidValue(fmt.Sprintf("Failed to get SKU handler: %v", err), "sku") + } + + skuConfig, skuExists := skuHandler.GetGPUConfigs()[sku] + if !skuExists { + return apis.ErrInvalidValue(fmt.Sprintf("Unsupported SKU: '%s'", sku), "sku") + } + skuGPUMem := skuConfig.GPUMem + + modelTuningConfig, modelExists := modelTuningConfigs[modelName] + if !modelExists { + //klog.Infof("Model '%s' hasn't been tested yet for fine-tuning. Proceed at your own risk.", modelName) + return nil + } + + minGPURequired, methodExists := modelTuningConfig[methodLowerCase] + if !methodExists { + //klog.Infof("Tuning method '%s' for model '%s' hasn't been tested yet.", methodLowerCase, modelName) + return nil + } + + if skuGPUMem < minGPURequired { + klog.Warningf("Insufficient GPU memory: For model '%s' with tuning method '%s', the SKU '%s' with %dGi GPU memory does not support even a batch size of 1 in testing. Proceed at your own risk.", modelName, methodLowerCase, sku, skuGPUMem) + return nil + } + return nil +} + func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError { config, err := UnmarshalTrainingConfig(cm) if err != nil { @@ -249,7 +284,7 @@ func validateConfigMapSchema(cm *corev1.ConfigMap) *apis.FieldError { return nil } -func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, methodLowerCase string, configMapName string) (errs *apis.FieldError) { +func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace, methodLowerCase, sku string, configMapName string) (errs *apis.FieldError) { var cm corev1.ConfigMap if k8sclient.Client == nil { errs = errs.Also(apis.ErrGeneric("Failed to obtain client from context.Context")) @@ -269,7 +304,10 @@ func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, me if err := validateMethodViaConfigMap(&cm, methodLowerCase); err != nil { errs = errs.Also(err) } - if err := validateTrainingArgsViaConfigMap(&cm); err != nil { + + if r.Preset == nil { + errs = errs.Also(apis.ErrMissingField("Preset")) + } else if err := validateTrainingArgsViaConfigMap(&cm, string(r.Preset.Name), methodLowerCase, sku); err != nil { errs = errs.Also(err) } } diff --git a/api/v1alpha1/tuning_config.go b/api/v1alpha1/tuning_config.go new file mode 100644 index 000000000..157730e84 --- /dev/null +++ b/api/v1alpha1/tuning_config.go @@ -0,0 +1,11 @@ +package v1alpha1 + +// Map Representing Minimum Per GPU Memory required for Batch Size of 1 +// ModelName, TuningMethod, MinGPUMemory +var modelTuningConfigs = map[string]map[string]int{ + "falcon-7b": { + //string(TuningMethodLora): 24, + string(TuningMethodQLora): 16, + }, + // Add more configurations as needed +} diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 07e45c9b7..8ad64b0df 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -53,7 +53,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { if w.Tuning != nil { // TODO: Add validate resource based on Tuning Spec errs = errs.Also(w.Resource.validateCreateWithTuning(w.Tuning).ViaField("resource"), - w.Tuning.validateCreate(ctx, w.Namespace).ViaField("tuning")) + w.Tuning.validateCreate(ctx, w.Namespace, w.Resource.InstanceType).ViaField("tuning")) } } else { klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) @@ -131,7 +131,7 @@ func (r *AdapterSpec) validateCreateorUpdate() (errs *apis.FieldError) { return errs } -func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string) (errs *apis.FieldError) { +func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string, sku string) (errs *apis.FieldError) { methodLowerCase := strings.ToLower(string(r.Method)) if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method")) @@ -148,11 +148,11 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri } else if methodLowerCase == string(TuningMethodQLora) { defaultConfigMapTemplateName = DefaultQloraConfigMapTemplate } - if err := r.validateConfigMap(ctx, releaseNamespace, methodLowerCase, defaultConfigMapTemplateName); err != nil { + if err := r.validateConfigMap(ctx, releaseNamespace, methodLowerCase, sku, defaultConfigMapTemplateName); err != nil { errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config")) } } else { - if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, r.Config); err != nil { + if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, sku, r.Config); err != nil { errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config")) } } diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index bf9f8ea3e..84ec26c8f 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -1094,7 +1094,8 @@ func TestTuningSpecValidateCreate(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - errs := tt.tuningSpec.validateCreate(ctx, "WORKSPACE_NAMESPACE") + os.Setenv("CLOUD_PROVIDER", "azure") // Manually set for testing env, normally defined in helm chart + errs := tt.tuningSpec.validateCreate(ctx, "WORKSPACE_NAMESPACE", "Standard_NC6s_v3") hasErrs := errs != nil if hasErrs != tt.wantErr {