diff --git a/cmd/lmes_driver/main.go b/cmd/lmes_driver/main.go index e7ba343..537310f 100644 --- a/cmd/lmes_driver/main.go +++ b/cmd/lmes_driver/main.go @@ -18,12 +18,12 @@ package main import ( "context" + "encoding/json" "flag" "fmt" "io" "os" "strings" - "time" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/log" @@ -50,17 +50,14 @@ func (t *strArrayArg) String() string { } var ( - taskRecipes strArrayArg - customCards strArrayArg - copy = flag.String("copy", "", "copy this binary to specified destination path") - jobNameSpace = flag.String("job-namespace", "", "Job's namespace ") - jobName = flag.String("job-name", "", "Job's name") - grpcService = flag.String("grpc-service", "", "grpc service name") - grpcPort = flag.Int("grpc-port", 8082, "grpc port") - outputPath = flag.String("output-path", OutputPath, "output path") - detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU") - reportInterval = flag.Duration("report-interval", time.Second*10, "specify the druation interval to report the progress") - driverLog = ctrl.Log.WithName("driver") + taskRecipes strArrayArg + customCards strArrayArg + copy = flag.String("copy", "", "copy this binary to specified destination path") + getStatus = flag.Bool("get-status", false, "Get current status") + shutdown = flag.Bool("shutdown", false, "Shutdown the driver") + outputPath = flag.String("output-path", OutputPath, "output path") + detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU") + driverLog = ctrl.Log.WithName("driver") ) func init() { @@ -83,7 +80,7 @@ func main() { if *copy != "" { // copy exec to destination - if err := CopyExec(*copy); err != nil { + if err := copyExec(*copy); err != nil { driverLog.Error(err, "failed to copy binary") os.Exit(1) return @@ -92,29 +89,34 @@ func main() { return } + if *getStatus { + getStatusOrDie(ctx) + return + } + + if *shutdown { + shutdownOrDie(ctx) + return + } + if len(args) == 0 { driverLog.Error(fmt.Errorf("no user program"), "empty args") os.Exit(1) } driverOpt := driver.DriverOption{ - Context: ctx, - JobNamespace: *jobNameSpace, - JobName: *jobName, - OutputPath: *outputPath, - GrpcService: *grpcService, - GrpcPort: *grpcPort, - DetectDevice: *detectDevice, - Logger: driverLog, - TaskRecipes: taskRecipes, - CustomCards: customCards, - Args: args, - ReportInterval: *reportInterval, + Context: ctx, + OutputPath: *outputPath, + DetectDevice: *detectDevice, + Logger: driverLog, + TaskRecipes: taskRecipes, + CustomCards: customCards, + Args: args, } driver, err := driver.NewDriver(&driverOpt) if err != nil { - driverLog.Error(err, "Driver.Run failed") + driverLog.Error(err, "Driver.NewDriver failed") os.Exit(1) } @@ -123,11 +125,10 @@ func main() { driverLog.Error(err, "Driver.Run failed") exitCode = 1 } - driver.Cleanup() os.Exit(exitCode) } -func CopyExec(destination string) (err error) { +func copyExec(destination string) (err error) { defer func() { if err != nil { err = fmt.Errorf("copy this binary to %s: %w", destination, err) @@ -161,3 +162,53 @@ func findThisBinary() (string, error) { } return bin, nil } + +func getStatusOrDie(ctx context.Context) { + driver, err := driver.NewDriver(&driver.DriverOption{ + Context: ctx, + OutputPath: *outputPath, + DetectDevice: *detectDevice, + Logger: driverLog, + }) + + if err != nil { + driverLog.Error(err, "failed to initialize the driver") + os.Exit(1) + } + + status, err := driver.GetStatus() + if err != nil { + driverLog.Error(err, "failed to get status", "error", err.Error()) + os.Exit(1) + } + + b, err := json.Marshal(status) + if err != nil { + driverLog.Error(err, "json serialization failed", "error", err.Error()) + os.Exit(1) + } + + fmt.Print(string(b)) + os.Exit(0) +} + +func shutdownOrDie(ctx context.Context) { + driver, err := driver.NewDriver(&driver.DriverOption{ + Context: ctx, + OutputPath: *outputPath, + DetectDevice: *detectDevice, + Logger: driverLog, + }) + + if err != nil { + driverLog.Error(err, "failed to initialize the driver") + os.Exit(1) + } + + err = driver.Shutdown() + if err != nil { + driverLog.Error(err, "failed to shutdown", "error", err.Error()) + os.Exit(1) + } + os.Exit(0) +} diff --git a/cmd/lmes_driver/main_test.go b/cmd/lmes_driver/main_test.go index 723f3ca..070ec88 100644 --- a/cmd/lmes_driver/main_test.go +++ b/cmd/lmes_driver/main_test.go @@ -21,7 +21,6 @@ import ( "flag" "os" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/driver" @@ -31,13 +30,8 @@ import ( func Test_ArgParsing(t *testing.T) { os.Args = []string{ "/opt/app-root/src/bin/driver", - "--job-namespace", "default", - "--job-name", "test", - "--grpc-service", "grpc-service.test.svc", - "--grpc-port", "8088", "--output-path", "/opt/app-root/src/output", "--detect-device", - "--report-interval", "10s", "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", @@ -53,30 +47,19 @@ func Test_ArgParsing(t *testing.T) { args := flag.Args() - assert.Equal(t, "default", *jobNameSpace) - assert.Equal(t, "test", *jobName) - assert.Equal(t, "grpc-service.test.svc", *grpcService) - assert.Equal(t, 8088, *grpcPort) - assert.Equal(t, "/opt/app-root/src/output", *outputPath) assert.Equal(t, true, *detectDevice) - assert.Equal(t, time.Second*10, *reportInterval) assert.Equal(t, strArrayArg{ "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", }, taskRecipes) dOption := driver.DriverOption{ - Context: context.Background(), - JobNamespace: *jobNameSpace, - JobName: *jobName, - OutputPath: *outputPath, - GrpcService: *grpcService, - GrpcPort: *grpcPort, - DetectDevice: *detectDevice, - Logger: driverLog, - TaskRecipes: taskRecipes, - Args: args, - ReportInterval: *reportInterval, + Context: context.Background(), + OutputPath: *outputPath, + DetectDevice: *detectDevice, + Logger: driverLog, + TaskRecipes: taskRecipes, + Args: args, } assert.Equal(t, []string{ diff --git a/config/base/params.env b/config/base/params.env index befcd01..f78a120 100644 --- a/config/base/params.env +++ b/config/base/params.env @@ -6,7 +6,6 @@ lmes-driver-image=quay.io/trustyai/ta-lmes-driver:latest lmes-pod-image=quay.io/trustyai/ta-lmes-job:latest lmes-pod-checking-interval=10s lmes-image-pull-policy=Always -lmes-grpc-service=lmes-grpc -lmes-grpc-port=8082 lmes-max-batch-size=24 lmes-default-batch-size=8 +lmes-detect-device=true diff --git a/config/manager/kustomization.yaml b/config/manager/kustomization.yaml index 91db94a..be0410a 100644 --- a/config/manager/kustomization.yaml +++ b/config/manager/kustomization.yaml @@ -1,5 +1,4 @@ resources: - manager.yaml - - service.yaml apiVersion: kustomize.config.k8s.io/v1beta1 kind: Kustomization diff --git a/config/manager/service.yaml b/config/manager/service.yaml deleted file mode 100644 index 3e440bd..0000000 --- a/config/manager/service.yaml +++ /dev/null @@ -1,14 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - labels: - app.kubernetes.io/name: deployment - app.kubernetes.io/managed-by: kustomize - name: lmes-grpc -spec: - ports: - - port: 8082 - protocol: TCP - targetPort: 8082 - selector: - control-plane: controller-manager diff --git a/config/overlays/lmes/kustomization.yaml b/config/overlays/lmes/kustomization.yaml index af5277d..216c929 100644 --- a/config/overlays/lmes/kustomization.yaml +++ b/config/overlays/lmes/kustomization.yaml @@ -4,19 +4,5 @@ kind: Kustomization resources: - ../../base -replacements: - - source: - kind: Service - version: v1 - name: lmes-grpc - fieldPath: .metadata.name - targets: - - select: - kind: ConfigMap - version: v1 - name: config - fieldPaths: - - .data.lmes-grpc-service - patchesStrategicMerge: - lmes-only-patch.yaml diff --git a/config/overlays/odh/kustomization.yaml b/config/overlays/odh/kustomization.yaml index 1e13e5b..e6e984b 100644 --- a/config/overlays/odh/kustomization.yaml +++ b/config/overlays/odh/kustomization.yaml @@ -3,17 +3,3 @@ apiVersion: kustomize.config.k8s.io/v1beta1 kind: Kustomization resources: - ../../base - -replacements: - - source: - kind: Service - version: v1 - name: lmes-grpc - fieldPath: .metadata.name - targets: - - select: - kind: ConfigMap - version: v1 - name: config - fieldPaths: - - .data.lmes-grpc-service diff --git a/config/rbac/role.yaml b/config/rbac/role.yaml index 0e0e3a6..56cb051 100644 --- a/config/rbac/role.yaml +++ b/config/rbac/role.yaml @@ -37,6 +37,16 @@ rules: - patch - update - watch +- apiGroups: + - "" + resources: + - pods/exec + verbs: + - create + - delete + - get + - list + - watch - apiGroups: - "" resources: diff --git a/controllers/lmes/api/v1beta1/update_status.pb.go b/controllers/lmes/api/v1beta1/update_status.pb.go deleted file mode 100644 index da64b7f..0000000 --- a/controllers/lmes/api/v1beta1/update_status.pb.go +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright 2024. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.34.2 -// protoc v3.12.4 -// source: controllers/lmes/api/v1beta1/update_status.proto - -package v1beta1 - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type ResponseCode int32 - -const ( - ResponseCode_OK ResponseCode = 0 - ResponseCode_ERROR ResponseCode = 1 -) - -// Enum value maps for ResponseCode. -var ( - ResponseCode_name = map[int32]string{ - 0: "OK", - 1: "ERROR", - } - ResponseCode_value = map[string]int32{ - "OK": 0, - "ERROR": 1, - } -) - -func (x ResponseCode) Enum() *ResponseCode { - p := new(ResponseCode) - *p = x - return p -} - -func (x ResponseCode) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) -} - -func (ResponseCode) Descriptor() protoreflect.EnumDescriptor { - return file_backend_api_v1beta1_update_status_proto_enumTypes[0].Descriptor() -} - -func (ResponseCode) Type() protoreflect.EnumType { - return &file_backend_api_v1beta1_update_status_proto_enumTypes[0] -} - -func (x ResponseCode) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) -} - -// Deprecated: Use ResponseCode.Descriptor instead. -func (ResponseCode) EnumDescriptor() ([]byte, []int) { - return file_backend_api_v1beta1_update_status_proto_rawDescGZIP(), []int{0} -} - -// the JobState, Reason, message, and optional Results -type JobStatus struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - JobName string `protobuf:"bytes,1,opt,name=job_name,json=jobName,proto3" json:"job_name,omitempty"` - JobNamespace string `protobuf:"bytes,2,opt,name=job_namespace,json=jobNamespace,proto3" json:"job_namespace,omitempty"` - State string `protobuf:"bytes,3,opt,name=state,proto3" json:"state,omitempty"` - Reason string `protobuf:"bytes,4,opt,name=reason,proto3" json:"reason,omitempty"` - StatusMessage string `protobuf:"bytes,5,opt,name=status_message,json=statusMessage,proto3" json:"status_message,omitempty"` - Results *string `protobuf:"bytes,6,opt,name=results,proto3,oneof" json:"results,omitempty"` -} - -func (x *JobStatus) Reset() { - *x = JobStatus{} - if protoimpl.UnsafeEnabled { - mi := &file_backend_api_v1beta1_update_status_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *JobStatus) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*JobStatus) ProtoMessage() {} - -func (x *JobStatus) ProtoReflect() protoreflect.Message { - mi := &file_backend_api_v1beta1_update_status_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use JobStatus.ProtoReflect.Descriptor instead. -func (*JobStatus) Descriptor() ([]byte, []int) { - return file_backend_api_v1beta1_update_status_proto_rawDescGZIP(), []int{0} -} - -func (x *JobStatus) GetJobName() string { - if x != nil { - return x.JobName - } - return "" -} - -func (x *JobStatus) GetJobNamespace() string { - if x != nil { - return x.JobNamespace - } - return "" -} - -func (x *JobStatus) GetState() string { - if x != nil { - return x.State - } - return "" -} - -func (x *JobStatus) GetReason() string { - if x != nil { - return x.Reason - } - return "" -} - -func (x *JobStatus) GetStatusMessage() string { - if x != nil { - return x.StatusMessage - } - return "" -} - -func (x *JobStatus) GetResults() string { - if x != nil && x.Results != nil { - return *x.Results - } - return "" -} - -type Response struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Code ResponseCode `protobuf:"varint,1,opt,name=code,proto3,enum=ResponseCode" json:"code,omitempty"` - Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` -} - -func (x *Response) Reset() { - *x = Response{} - if protoimpl.UnsafeEnabled { - mi := &file_backend_api_v1beta1_update_status_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *Response) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Response) ProtoMessage() {} - -func (x *Response) ProtoReflect() protoreflect.Message { - mi := &file_backend_api_v1beta1_update_status_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Response.ProtoReflect.Descriptor instead. -func (*Response) Descriptor() ([]byte, []int) { - return file_backend_api_v1beta1_update_status_proto_rawDescGZIP(), []int{1} -} - -func (x *Response) GetCode() ResponseCode { - if x != nil { - return x.Code - } - return ResponseCode_OK -} - -func (x *Response) GetMessage() string { - if x != nil { - return x.Message - } - return "" -} - -var File_backend_api_v1beta1_update_status_proto protoreflect.FileDescriptor - -var file_backend_api_v1beta1_update_status_proto_rawDesc = []byte{ - 0x0a, 0x27, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, - 0x62, 0x65, 0x74, 0x61, 0x31, 0x2f, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x5f, 0x73, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xcb, 0x01, 0x0a, 0x09, 0x4a, 0x6f, - 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x19, 0x0a, 0x08, 0x6a, 0x6f, 0x62, 0x5f, 0x6e, - 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6a, 0x6f, 0x62, 0x4e, 0x61, - 0x6d, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x6a, 0x6f, 0x62, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, - 0x61, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x6a, 0x6f, 0x62, 0x4e, 0x61, - 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x73, 0x74, 0x61, 0x74, 0x65, 0x12, 0x16, 0x0a, - 0x06, 0x72, 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x72, - 0x65, 0x61, 0x73, 0x6f, 0x6e, 0x12, 0x25, 0x0a, 0x0e, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x5f, - 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x73, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x1d, 0x0a, 0x07, - 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, - 0x07, 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x88, 0x01, 0x01, 0x42, 0x0a, 0x0a, 0x08, 0x5f, - 0x72, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x73, 0x22, 0x47, 0x0a, 0x08, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x21, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x0d, 0x2e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, - 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x2a, 0x21, 0x0a, 0x0c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x43, 0x6f, 0x64, 0x65, - 0x12, 0x06, 0x0a, 0x02, 0x4f, 0x4b, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, - 0x52, 0x10, 0x01, 0x32, 0x3f, 0x0a, 0x16, 0x4c, 0x4d, 0x45, 0x76, 0x61, 0x6c, 0x4a, 0x6f, 0x62, - 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x25, 0x0a, - 0x0c, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x0a, 0x2e, - 0x4a, 0x6f, 0x62, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x1a, 0x09, 0x2e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x4b, 0x5a, 0x49, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, - 0x6f, 0x6d, 0x2f, 0x66, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2d, 0x6d, 0x6f, - 0x64, 0x65, 0x6c, 0x2d, 0x73, 0x74, 0x61, 0x63, 0x6b, 0x2f, 0x66, 0x6d, 0x73, 0x2d, 0x6c, 0x6d, - 0x2d, 0x65, 0x76, 0x61, 0x6c, 0x2d, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2f, 0x62, 0x61, - 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x62, 0x65, 0x74, 0x61, - 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} - -var ( - file_backend_api_v1beta1_update_status_proto_rawDescOnce sync.Once - file_backend_api_v1beta1_update_status_proto_rawDescData = file_backend_api_v1beta1_update_status_proto_rawDesc -) - -func file_backend_api_v1beta1_update_status_proto_rawDescGZIP() []byte { - file_backend_api_v1beta1_update_status_proto_rawDescOnce.Do(func() { - file_backend_api_v1beta1_update_status_proto_rawDescData = protoimpl.X.CompressGZIP(file_backend_api_v1beta1_update_status_proto_rawDescData) - }) - return file_backend_api_v1beta1_update_status_proto_rawDescData -} - -var file_backend_api_v1beta1_update_status_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_backend_api_v1beta1_update_status_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_backend_api_v1beta1_update_status_proto_goTypes = []any{ - (ResponseCode)(0), // 0: ResponseCode - (*JobStatus)(nil), // 1: JobStatus - (*Response)(nil), // 2: Response -} -var file_backend_api_v1beta1_update_status_proto_depIdxs = []int32{ - 0, // 0: Response.code:type_name -> ResponseCode - 1, // 1: LMEvalJobUpdateService.UpdateStatus:input_type -> JobStatus - 2, // 2: LMEvalJobUpdateService.UpdateStatus:output_type -> Response - 2, // [2:3] is the sub-list for method output_type - 1, // [1:2] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name -} - -func init() { file_backend_api_v1beta1_update_status_proto_init() } -func file_backend_api_v1beta1_update_status_proto_init() { - if File_backend_api_v1beta1_update_status_proto != nil { - return - } - if !protoimpl.UnsafeEnabled { - file_backend_api_v1beta1_update_status_proto_msgTypes[0].Exporter = func(v any, i int) any { - switch v := v.(*JobStatus); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_backend_api_v1beta1_update_status_proto_msgTypes[1].Exporter = func(v any, i int) any { - switch v := v.(*Response); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - file_backend_api_v1beta1_update_status_proto_msgTypes[0].OneofWrappers = []any{} - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_backend_api_v1beta1_update_status_proto_rawDesc, - NumEnums: 1, - NumMessages: 2, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_backend_api_v1beta1_update_status_proto_goTypes, - DependencyIndexes: file_backend_api_v1beta1_update_status_proto_depIdxs, - EnumInfos: file_backend_api_v1beta1_update_status_proto_enumTypes, - MessageInfos: file_backend_api_v1beta1_update_status_proto_msgTypes, - }.Build() - File_backend_api_v1beta1_update_status_proto = out.File - file_backend_api_v1beta1_update_status_proto_rawDesc = nil - file_backend_api_v1beta1_update_status_proto_goTypes = nil - file_backend_api_v1beta1_update_status_proto_depIdxs = nil -} diff --git a/controllers/lmes/api/v1beta1/update_status.proto b/controllers/lmes/api/v1beta1/update_status.proto deleted file mode 100644 index f202a12..0000000 --- a/controllers/lmes/api/v1beta1/update_status.proto +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2024. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -syntax = "proto3"; - -option go_package = "github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/api/v1beta1"; - -enum ResponseCode { - OK = 0; - ERROR = 1; -} - -// the JobState, Reason, message, and optional Results -message JobStatus { - string job_name = 1; - string job_namespace = 2; - string state = 3; - string reason = 4; - string status_message = 5; - optional string results = 6; -} - -message Response { - ResponseCode code = 1; - string message = 2; -} - -service LMEvalJobUpdateService { - rpc UpdateStatus(JobStatus) returns (Response); -} diff --git a/controllers/lmes/api/v1beta1/update_status_grpc.pb.go b/controllers/lmes/api/v1beta1/update_status_grpc.pb.go deleted file mode 100644 index 293ae3d..0000000 --- a/controllers/lmes/api/v1beta1/update_status_grpc.pb.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2024. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Code generated by protoc-gen-go-grpc. DO NOT EDIT. -// versions: -// - protoc-gen-go-grpc v1.2.0 -// - protoc v3.12.4 -// source: controllers/lmes/api/v1beta1/update_status.proto - -package v1beta1 - -import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" -) - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 - -// LMEvalJobUpdateServiceClient is the client API for LMEvalJobUpdateService service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type LMEvalJobUpdateServiceClient interface { - UpdateStatus(ctx context.Context, in *JobStatus, opts ...grpc.CallOption) (*Response, error) -} - -type lMEvalJobUpdateServiceClient struct { - cc grpc.ClientConnInterface -} - -func NewLMEvalJobUpdateServiceClient(cc grpc.ClientConnInterface) LMEvalJobUpdateServiceClient { - return &lMEvalJobUpdateServiceClient{cc} -} - -func (c *lMEvalJobUpdateServiceClient) UpdateStatus(ctx context.Context, in *JobStatus, opts ...grpc.CallOption) (*Response, error) { - out := new(Response) - err := c.cc.Invoke(ctx, "/LMEvalJobUpdateService/UpdateStatus", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// LMEvalJobUpdateServiceServer is the server API for LMEvalJobUpdateService service. -// All implementations must embed UnimplementedLMEvalJobUpdateServiceServer -// for forward compatibility -type LMEvalJobUpdateServiceServer interface { - UpdateStatus(context.Context, *JobStatus) (*Response, error) - mustEmbedUnimplementedLMEvalJobUpdateServiceServer() -} - -// UnimplementedLMEvalJobUpdateServiceServer must be embedded to have forward compatible implementations. -type UnimplementedLMEvalJobUpdateServiceServer struct { -} - -func (UnimplementedLMEvalJobUpdateServiceServer) UpdateStatus(context.Context, *JobStatus) (*Response, error) { - return nil, status.Errorf(codes.Unimplemented, "method UpdateStatus not implemented") -} -func (UnimplementedLMEvalJobUpdateServiceServer) mustEmbedUnimplementedLMEvalJobUpdateServiceServer() { -} - -// UnsafeLMEvalJobUpdateServiceServer may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to LMEvalJobUpdateServiceServer will -// result in compilation errors. -type UnsafeLMEvalJobUpdateServiceServer interface { - mustEmbedUnimplementedLMEvalJobUpdateServiceServer() -} - -func RegisterLMEvalJobUpdateServiceServer(s grpc.ServiceRegistrar, srv LMEvalJobUpdateServiceServer) { - s.RegisterService(&LMEvalJobUpdateService_ServiceDesc, srv) -} - -func _LMEvalJobUpdateService_UpdateStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(JobStatus) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(LMEvalJobUpdateServiceServer).UpdateStatus(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/LMEvalJobUpdateService/UpdateStatus", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(LMEvalJobUpdateServiceServer).UpdateStatus(ctx, req.(*JobStatus)) - } - return interceptor(ctx, in, info, handler) -} - -// LMEvalJobUpdateService_ServiceDesc is the grpc.ServiceDesc for LMEvalJobUpdateService service. -// It's only intended for direct use with grpc.RegisterService, -// and not to be introspected or modified (even as a copy) -var LMEvalJobUpdateService_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "LMEvalJobUpdateService", - HandlerType: (*LMEvalJobUpdateServiceServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "UpdateStatus", - Handler: _LMEvalJobUpdateService_UpdateStatus_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "backend/api/v1beta1/update_status.proto", -} diff --git a/controllers/lmes/constants.go b/controllers/lmes/constants.go index 0ddf207..ec2c594 100644 --- a/controllers/lmes/constants.go +++ b/controllers/lmes/constants.go @@ -23,34 +23,21 @@ import ( ) const ( - DriverPath = "/bin/driver" - DestDriverPath = "/opt/app-root/src/bin/driver" - PodImageKey = "lmes-pod-image" - DriverImageKey = "lmes-driver-image" - PodCheckingIntervalKey = "lmes-pod-checking-interval" - ImagePullPolicyKey = "lmes-image-pull-policy" - GrpcPortKey = "lmes-grpc-port" - GrpcServiceKey = "lmes-grpc-service" - GrpcServerSecretKey = "lmes-grpc-server-secret" - GrpcClientSecretKey = "lmes-grpc-client-secret" - MaxBatchSizeKey = "lmes-max-batch-size" - DefaultBatchSizeKey = "lmes-default-batch-size" - DetectDeviceKey = "lmes-detect-device" - DriverReportIntervalKey = "driver-report-interval" - GrpcServerCertEnv = "GRPC_SERVER_CERT" - GrpcServerKeyEnv = "GRPC_SERVER_KEY" - GrpcClientCaEnv = "GRPC_CLIENT_CA" - DefaultPodImage = "quay.io/trustyai/ta-lmes-job:latest" - DefaultDriverImage = "quay.io/trustyai/ta-lmes-driver:latest" - DefaultPodCheckingInterval = time.Second * 10 - DefaultDriverReportInterval = time.Second * 10 - DefaultImagePullPolicy = corev1.PullAlways - DefaultGrpcPort = 8082 - DefaultGrpcService = "lm-eval-grpc" - DefaultGrpcServerSecret = "grpc-server-cert" - DefaultGrpcClientSecret = "grpc-client-cert" - DefaultMaxBatchSize = 24 - DefaultBatchSize = 8 - DefaultDetectDevice = true - ServiceName = "LMES" + DriverPath = "/bin/driver" + DestDriverPath = "/opt/app-root/src/bin/driver" + PodImageKey = "lmes-pod-image" + DriverImageKey = "lmes-driver-image" + PodCheckingIntervalKey = "lmes-pod-checking-interval" + ImagePullPolicyKey = "lmes-image-pull-policy" + MaxBatchSizeKey = "lmes-max-batch-size" + DefaultBatchSizeKey = "lmes-default-batch-size" + DetectDeviceKey = "lmes-detect-device" + DefaultPodImage = "quay.io/trustyai/ta-lmes-job:latest" + DefaultDriverImage = "quay.io/trustyai/ta-lmes-driver:latest" + DefaultPodCheckingInterval = time.Second * 10 + DefaultImagePullPolicy = corev1.PullAlways + DefaultMaxBatchSize = 24 + DefaultBatchSize = 8 + DefaultDetectDevice = true + ServiceName = "LMES" ) diff --git a/controllers/lmes/driver/driver.go b/controllers/lmes/driver/driver.go index fb65c74..2c10add 100644 --- a/controllers/lmes/driver/driver.go +++ b/controllers/lmes/driver/driver.go @@ -19,59 +19,41 @@ package driver import ( "bufio" "context" - "crypto/tls" - "crypto/x509" + "encoding/json" "fmt" "io" "io/fs" + "net" + "net/http" "os" "os/exec" "path/filepath" "regexp" "strings" "sync" - "time" "unicode" "github.com/go-logr/logr" - "github.com/spf13/viper" lmesv1alpha1 "github.com/trustyai-explainability/trustyai-service-operator/api/lmes/v1alpha1" - "github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/api/v1beta1" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "k8s.io/apimachinery/pkg/runtime" - utilruntime "k8s.io/apimachinery/pkg/util/runtime" - clientgoscheme "k8s.io/client-go/kubernetes/scheme" ) var ( - scheme = runtime.NewScheme() progressMegPattern = regexp.MustCompile(`^(.*?:\s*?\d*?%)\|`) ) const ( - GrpcClientKeyEnv = "GRPC_CLIENT_KEY" - GrpcClientCertEnv = "GRPC_CLIENT_CERT" - GrpcServerCaEnv = "GRPC_SERVER_CA" - DefaultDriverReportInterval = time.Second * 10 - DefaultTaskRecipesPath = "/opt/app-root/src/my_tasks" - DefaultCatalogPath = "/opt/app-root/src/my_catalogs" - TaskRecipePrefix = "tr" - CustomCardPrefix = "custom" + // put the domain socket under /tmp. may move to emptydir to share across containers + socketPath = "/tmp/ta-lmes-driver.sock" + DefaultTaskRecipesPath = "/opt/app-root/src/my_tasks" + DefaultCatalogPath = "/opt/app-root/src/my_catalogs" + TaskRecipePrefix = "tr" + CustomCardPrefix = "custom" + ShutdownURI = "/Shutdown" + GetStatusURI = "/GetStatus" ) -func init() { - utilruntime.Must(clientgoscheme.AddToScheme(scheme)) - utilruntime.Must(lmesv1alpha1.AddToScheme(scheme)) -} - type DriverOption struct { Context context.Context - JobNamespace string - JobName string - GrpcService string - GrpcPort int OutputPath string DetectDevice bool TaskRecipesPath string @@ -80,20 +62,29 @@ type DriverOption struct { CustomCards []string Logger logr.Logger Args []string - ReportInterval time.Duration + SocketPath string } type Driver interface { Run() error - Cleanup() + GetStatus() (*lmesv1alpha1.LMEvalJobStatus, error) + Shutdown() error +} + +// the communication server that is used by the driverImpl to +// send and recive messages using a domain socket +type driverComm struct { + connection chan int + server *http.Server + path string } type driverImpl struct { - client v1beta1.LMEvalJobUpdateServiceClient - grpcConn *grpc.ClientConn Option *DriverOption - lastReportTime time.Time lastProgressMsg string + status lmesv1alpha1.LMEvalJobStatus + err error + comm *driverComm } func NewDriver(opt *DriverOption) (Driver, error) { @@ -101,17 +92,10 @@ func NewDriver(opt *DriverOption) (Driver, error) { return nil, nil } - if opt.ReportInterval == 0 { - opt.ReportInterval = DefaultDriverReportInterval - } if opt.Context == nil { return nil, fmt.Errorf("context is nil") } - if opt.JobNamespace == "" || opt.JobName == "" { - return nil, fmt.Errorf("JobNamespace or JobName is empty") - } - if opt.TaskRecipesPath == "" { opt.TaskRecipesPath = DefaultTaskRecipesPath } @@ -120,24 +104,22 @@ func NewDriver(opt *DriverOption) (Driver, error) { opt.CatalogPath = DefaultCatalogPath } - conn, err := getGRPCClientConn(opt) - if err != nil { - return nil, err + if opt.SocketPath == "" { + opt.SocketPath = socketPath } return &driverImpl{ - client: v1beta1.NewLMEvalJobUpdateServiceClient(conn), - grpcConn: conn, - Option: opt, + Option: opt, }, nil } // Run implements Driver. func (d *driverImpl) Run() error { - if err := d.updateStatus(lmesv1alpha1.RunningJobState, - "update status from the driver: running"); err != nil { + d.updateStatus(lmesv1alpha1.RunningJobState, lmesv1alpha1.NoReason, "initializing the evaluation job") - return err + if err := d.setupComm(); err != nil { + d.err = err + return d.err } execErr := d.exec() @@ -151,71 +133,41 @@ func (d *driverImpl) Run() error { toConsole(filepath.Join(d.Option.OutputPath, "stdout.log")) toConsole(filepath.Join(d.Option.OutputPath, "stderr.log")) - return d.updateCompleteStatus(execErr) -} + d.updateCompleteStatus(execErr) -func (d *driverImpl) Cleanup() { - if d != nil && d.grpcConn != nil { - d.grpcConn.Close() - } + // wait for shutdown signal then properly clean up the resources + d.comm.wait4Sutdownload() + d.comm.close() + return d.err } -func getGRPCClientConn(option *DriverOption) (clientConn *grpc.ClientConn, err error) { - // Set up a connection to the server. - if option.GrpcPort == 0 || option.GrpcService == "" { - return nil, fmt.Errorf("GrpcService or GrpcPort is not valid") +func (d *driverImpl) GetStatus() (*lmesv1alpha1.LMEvalJobStatus, error) { + client := createClient(d.Option.SocketPath) + resp, err := client.Get(fmt.Sprintf("http://unix%s", GetStatusURI)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + content, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err } - serverAddr := fmt.Sprintf("%s:%d", option.GrpcService, option.GrpcPort) - - if viper.IsSet(GrpcServerCaEnv) { - serverCAPath := viper.GetString(GrpcServerCaEnv) - - if viper.IsSet(GrpcClientCertEnv) && viper.IsSet(GrpcClientKeyEnv) { - // mTLS - certPath, keyPath := viper.GetString(GrpcClientCertEnv), viper.GetString(GrpcClientKeyEnv) - var cert tls.Certificate - cert, err = tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - return nil, err - } - - ca := x509.NewCertPool() - var caBytes []byte - caBytes, err = os.ReadFile(serverCAPath) - if err != nil { - return nil, fmt.Errorf("failed to read server CA %q: %v", serverCAPath, err) - } - if ok := ca.AppendCertsFromPEM(caBytes); !ok { - return nil, fmt.Errorf("failed to parse server CA %q", serverCAPath) - } - - tlsConfig := &tls.Config{ - ServerName: serverAddr, - Certificates: []tls.Certificate{cert}, - RootCAs: ca, - } - - clientConn, err = grpc.NewClient(serverAddr, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) - } else { - // TLS - creds, err := credentials.NewClientTLSFromFile(serverCAPath, serverAddr) - if err != nil { - return nil, fmt.Errorf("failed to load server CA: %v", err) - } + var status lmesv1alpha1.LMEvalJobStatus + err = json.Unmarshal(content, &status) - clientConn, err = grpc.NewClient(serverAddr, grpc.WithTransportCredentials(creds)) - if err != nil { - return nil, fmt.Errorf("failed to connect to GRPC server: %v", err) - } - } - } else { - clientConn, err = grpc.NewClient( - serverAddr, - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) + return &status, err +} +func (d *driverImpl) Shutdown() error { + client := createClient(d.Option.SocketPath) + resp, err := client.Post(fmt.Sprintf("http://unix%s", ShutdownURI), "application/json", nil) + if err != nil { + return err } - return + + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + return err } func (d *driverImpl) detectDevice() error { @@ -262,6 +214,75 @@ func patchDevice(args []string, hasCuda bool) { } } +// Create a domain socket and use HTTP protocal to handle communication +func (d *driverImpl) setupComm() error { + + serve := http.NewServeMux() + d.comm = &driverComm{ + server: &http.Server{Handler: serve}, + connection: make(chan int), + path: d.Option.SocketPath, + } + + // handle the `GetStatus` API: return the complete lmesv1alpha1.LMEvalJobStatus + // or error if the JSON marshaling fails. + serve.HandleFunc(GetStatusURI, func(w http.ResponseWriter, _ *http.Request) { + status, err := json.Marshal(d.status) + if err == nil { + w.Write(status) + } else { + w.Write([]byte(fmt.Sprintf(`{"err": "%s"}`, err.Error()))) + } + }) + + // handle the `Shutdown` API: tear down the communication server. + serve.HandleFunc(ShutdownURI, func(w http.ResponseWriter, _ *http.Request) { + w.Write([]byte(`{"msg": "ok"}`)) + d.comm.notifyShutdownWait() + }) + + go func() { + d.comm.serve() + }() + + return nil +} + +func createClient(path string) *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", path) + }, + }, + } +} + +func (dc *driverComm) wait4Sutdownload() { + <-dc.connection +} + +func (dc *driverComm) serve() error { + socket, err := net.Listen("unix", dc.path) + if err != nil { + return err + } + + return dc.server.Serve(socket) +} + +func (dc *driverComm) close() { + if dc.server != nil && dc.connection != nil { + dc.server.Shutdown(context.Background()) + close(dc.connection) + os.Remove(dc.path) + } +} + +func (dc *driverComm) notifyShutdownWait() { + dc.connection <- 1 +} + func (d *driverImpl) exec() error { // create Unitxt task recipes if err := d.createTaskRecipes(); err != nil { @@ -287,7 +308,6 @@ func (d *driverImpl) exec() error { if err != nil { return err } - bout := bufio.NewWriter(stdout) stderr, err := os.Create(filepath.Join(d.Option.OutputPath, "stderr.log")) if err != nil { @@ -298,7 +318,6 @@ func (d *driverImpl) exec() error { // lm-eval's outputs are in the stderr pr, pw := io.Pipe() mwriter := io.MultiWriter(stderr, pw) - berr := bufio.NewWriter(mwriter) scanner := bufio.NewScanner(pr) executor := exec.Command(d.Option.Args[0], args...) @@ -306,18 +325,16 @@ func (d *driverImpl) exec() error { if err != nil { return err } - executor.Stdout = bout - executor.Stderr = berr + executor.Stdout = stdout + executor.Stderr = mwriter executor.Env = append(os.Environ(), "UNITXT_ALLOW_UNVERIFIED_CODE=True", ) var freeRes = func() { stdin.Close() - bout.Flush() stdout.Sync() stdout.Close() - berr.Flush() stderr.Sync() stderr.Close() pr.Close() @@ -335,9 +352,7 @@ func (d *driverImpl) exec() error { go func() { for scanner.Scan() { msg := scanner.Text() - if err := d.reportProgress(msg); err != nil { - d.Option.Logger.Error(err, "report progress failed") - } + d.updateProgress(msg) } wg.Done() }() @@ -348,59 +363,30 @@ func (d *driverImpl) exec() error { return finalError } -func (d *driverImpl) updateStatus(state lmesv1alpha1.JobState, msg string) error { - ctx, cancel := context.WithTimeout(d.Option.Context, time.Second*10) - defer cancel() - - r, err := d.client.UpdateStatus(ctx, &v1beta1.JobStatus{ - JobName: d.Option.JobName, - JobNamespace: d.Option.JobNamespace, - State: string(state), - Reason: string(lmesv1alpha1.NoReason), - StatusMessage: msg, - }) - - if r != nil && err == nil { - d.Option.Logger.Info(fmt.Sprintf("UpdateStatus done: %s", r.Message)) - d.lastReportTime = time.Now() - } - - return err -} - -func (d *driverImpl) updateCompleteStatus(err error) error { - ctx, cancel := context.WithTimeout(d.Option.Context, time.Second*10) - defer cancel() - newStatus := v1beta1.JobStatus{ - JobName: d.Option.JobName, - JobNamespace: d.Option.JobNamespace, - State: string(lmesv1alpha1.CompleteJobState), - Reason: string(lmesv1alpha1.SucceedReason), - StatusMessage: "update status from the driver: completed", - } +func (d *driverImpl) updateCompleteStatus(err error) { + d.status.State = lmesv1alpha1.CompleteJobState + d.status.Reason = lmesv1alpha1.SucceedReason + d.status.Message = "job completed" - var setErr = func(err error) { - newStatus.Reason = string(lmesv1alpha1.FailedReason) - newStatus.StatusMessage = err.Error() + if err == nil { + var results string + results, err = d.getResults() + d.status.Results = results } if err != nil { - setErr(err) - } else { - results, err := d.getResults() - if err != nil { - setErr(err) - } else { - newStatus.Results = &results - } + d.status.Reason = lmesv1alpha1.FailedReason + d.status.Message = err.Error() + d.err = err } - r, err := d.client.UpdateStatus(ctx, &newStatus) - if r != nil && err == nil { - d.Option.Logger.Info(fmt.Sprintf("UpdateStatus with the results: %s", r.Message)) - } + d.Option.Logger.Info("update status: job completed", "state", d.status) +} - return err +func (d *driverImpl) updateStatus(state lmesv1alpha1.JobState, reason lmesv1alpha1.Reason, msg string) { + d.status.State = state + d.status.Reason = reason + d.status.Message = msg } func (d *driverImpl) getResults() (string, error) { @@ -427,7 +413,7 @@ func (d *driverImpl) getResults() (string, error) { return results, nil } -func (d *driverImpl) reportProgress(msg string) error { +func (d *driverImpl) updateProgress(msg string) { msg = strings.Map(func(r rune) rune { if unicode.IsPrint(r) { return r @@ -445,14 +431,9 @@ func (d *driverImpl) reportProgress(msg string) error { if matches := progressMegPattern.FindStringSubmatch(msglist[len(msglist)-1]); len(matches) == 2 { if matches[1] != d.lastProgressMsg { d.lastProgressMsg = strings.Trim(matches[1], " \r") + d.updateStatus(lmesv1alpha1.RunningJobState, lmesv1alpha1.NoReason, d.lastProgressMsg) } } - if time.Since(d.lastReportTime) >= d.Option.ReportInterval { - if err := d.updateStatus(lmesv1alpha1.RunningJobState, d.lastProgressMsg); err != nil { - return err - } - } - return nil } func (d *driverImpl) createTaskRecipes() error { diff --git a/controllers/lmes/driver/driver_test.go b/controllers/lmes/driver/driver_test.go index 25d567e..b74d039 100644 --- a/controllers/lmes/driver/driver_test.go +++ b/controllers/lmes/driver/driver_test.go @@ -18,16 +18,15 @@ package driver import ( "context" + "crypto/rand" "flag" "fmt" - "net" "os" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/api/v1beta1" - "google.golang.org/grpc" + "github.com/trustyai-explainability/trustyai-service-operator/api/lmes/v1alpha1" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" @@ -48,120 +47,121 @@ func TestMain(m *testing.M) { m.Run() } -type DummyUpdateServer struct { - v1beta1.UnimplementedLMEvalJobUpdateServiceServer +func genRandomSocketPath() string { + b := make([]byte, 10) + rand.Read(b) + p := fmt.Sprintf("/tmp/ta-lmes-%x.sock", b) + return p } -func (*DummyUpdateServer) UpdateStatus(context.Context, *v1beta1.JobStatus) (*v1beta1.Response, error) { - return &v1beta1.Response{ - Code: v1beta1.ResponseCode_OK, - Message: "updated the job status successfully", - }, nil +func runDirverAndWait4Complete(t *testing.T, driver Driver, returnError bool) (progressMsgs []string, results string) { + go func() { + if returnError { + assert.NotNil(t, driver.Run()) + } else { + assert.Nil(t, driver.Run()) + } + }() + + for { + time.Sleep(time.Second) + status, err := driver.GetStatus() + assert.Nil(t, err) + if len(progressMsgs) == 0 || progressMsgs[len(progressMsgs)-1] != status.Message { + progressMsgs = append(progressMsgs, status.Message) + } + if status.State == v1alpha1.CompleteJobState { + results = status.Results + break + } + } + return progressMsgs, results } func Test_Driver(t *testing.T) { - server := grpc.NewServer() - v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &DummyUpdateServer{}) - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", 8082)) - assert.Nil(t, err) - go server.Serve(lis) - driver, err := NewDriver(&DriverOption{ - Context: context.Background(), - JobNamespace: "fms-lm-eval-service-system", - JobName: "evaljob-sample", - GrpcService: "localhost", - GrpcPort: 8082, - OutputPath: ".", - Logger: driverLog, - Args: []string{"sh", "-ec", "echo tttttttttttttttttttt"}, + Context: context.Background(), + OutputPath: ".", + Logger: driverLog, + Args: []string{"sh", "-ec", "echo tttttttttttttttttttt"}, + SocketPath: genRandomSocketPath(), }) assert.Nil(t, err) - assert.Nil(t, driver.Run()) + runDirverAndWait4Complete(t, driver, false) - server.Stop() + assert.Nil(t, driver.Shutdown()) assert.Nil(t, os.Remove("./stderr.log")) assert.Nil(t, os.Remove("./stdout.log")) } -type ProgressUpdateServer struct { - v1beta1.UnimplementedLMEvalJobUpdateServiceServer - progressMsgs []string -} +func Test_Wait4Shutdown(t *testing.T) { + driver, err := NewDriver(&DriverOption{ + Context: context.Background(), + OutputPath: ".", + Logger: driverLog, + Args: []string{"sh", "-ec", "echo test"}, + SocketPath: genRandomSocketPath(), + }) + assert.Nil(t, err) -func (s *ProgressUpdateServer) UpdateStatus(_ context.Context, status *v1beta1.JobStatus) (*v1beta1.Response, error) { - if status.StatusMessage != "" { - s.progressMsgs = append(s.progressMsgs, status.StatusMessage) - } - return &v1beta1.Response{ - Code: v1beta1.ResponseCode_OK, - Message: "updated the job status successfully", - }, nil -} + runDirverAndWait4Complete(t, driver, false) -func Test_ProgressUpdate(t *testing.T) { - server := grpc.NewServer() - progresssServer := ProgressUpdateServer{} - v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &progresssServer) - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", 8082)) + // can still get the status even the user program finishes + time.Sleep(time.Second * 3) + status, err := driver.GetStatus() assert.Nil(t, err) - go server.Serve(lis) + assert.Equal(t, v1alpha1.CompleteJobState, status.State) + + assert.Nil(t, driver.Shutdown()) + + _, err = driver.GetStatus() + assert.ErrorContains(t, err, "no such file or directory") + assert.Nil(t, os.Remove("./stderr.log")) + assert.Nil(t, os.Remove("./stdout.log")) +} + +func Test_ProgressUpdate(t *testing.T) { driver, err := NewDriver(&DriverOption{ - Context: context.Background(), - JobNamespace: "fms-lm-eval-service-system", - JobName: "evaljob-sample", - GrpcService: "localhost", - GrpcPort: 8082, - OutputPath: ".", - Logger: driverLog, - Args: []string{"sh", "-ec", "echo 'testing progress: 100%|' >&2; sleep 6"}, - ReportInterval: time.Second * 5, + Context: context.Background(), + OutputPath: ".", + Logger: driverLog, + Args: []string{"sh", "-ec", "sleep 2; echo 'testing progress: 100%|' >&2; sleep 4"}, + SocketPath: genRandomSocketPath(), }) assert.Nil(t, err) - assert.Nil(t, driver.Run()) + msgs, _ := runDirverAndWait4Complete(t, driver, false) + assert.Equal(t, []string{ - "update status from the driver: running", + "initializing the evaluation job", "testing progress: 100%", - "update status from the driver: completed", - }, progresssServer.progressMsgs) + "job completed", + }, msgs) - server.Stop() + assert.Nil(t, driver.Shutdown()) assert.Nil(t, os.Remove("./stderr.log")) assert.Nil(t, os.Remove("./stdout.log")) } func Test_DetectDeviceError(t *testing.T) { - server := grpc.NewServer() - progresssServer := ProgressUpdateServer{} - v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &progresssServer) - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", 8082)) - assert.Nil(t, err) - go server.Serve(lis) - driver, err := NewDriver(&DriverOption{ - Context: context.Background(), - JobNamespace: "fms-lm-eval-service-system", - JobName: "evaljob-sample", - GrpcService: "localhost", - GrpcPort: 8082, - OutputPath: ".", - DetectDevice: true, - Logger: driverLog, - Args: []string{"sh", "-ec", "python -m lm_eval --output_path ./output --model test --model_args arg1=value1 --tasks task1,task2"}, - ReportInterval: time.Second * 5, + Context: context.Background(), + OutputPath: ".", + DetectDevice: true, + Logger: driverLog, + Args: []string{"sh", "-ec", "python -m lm_eval --output_path ./output --model test --model_args arg1=value1 --tasks task1,task2"}, + SocketPath: genRandomSocketPath(), }) assert.Nil(t, err) - assert.Nil(t, driver.Run()) + msgs, _ := runDirverAndWait4Complete(t, driver, true) assert.Equal(t, []string{ - "update status from the driver: running", "failed to detect available device(s): exit status 1", - }, progresssServer.progressMsgs) + }, msgs) - server.Stop() + assert.Nil(t, driver.Shutdown()) // the following files don't exist for this case assert.NotNil(t, os.Remove("./stderr.log")) @@ -170,16 +170,11 @@ func Test_DetectDeviceError(t *testing.T) { func Test_PatchDevice(t *testing.T) { driverOpt := DriverOption{ - Context: context.Background(), - JobNamespace: "fms-lm-eval-service-system", - JobName: "evaljob-sample", - GrpcService: "localhost", - GrpcPort: 8082, - OutputPath: ".", - DetectDevice: true, - Logger: driverLog, - Args: []string{"sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"}, - ReportInterval: time.Second * 5, + Context: context.Background(), + OutputPath: ".", + DetectDevice: true, + Logger: driverLog, + Args: []string{"sh", "-ec", "python -m lm_eval --output_path /opt/app-root/src/output --model test --model_args arg1=value1 --tasks task1,task2"}, } // append `--device cuda` @@ -207,19 +202,8 @@ func Test_PatchDevice(t *testing.T) { } func Test_TaskRecipes(t *testing.T) { - server := grpc.NewServer() - progresssServer := ProgressUpdateServer{} - v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &progresssServer) - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", 8082)) - assert.Nil(t, err) - go server.Serve(lis) - driver, err := NewDriver(&DriverOption{ Context: context.Background(), - JobNamespace: "fms-lm-eval-service-system", - JobName: "evaljob-sample", - GrpcService: "localhost", - GrpcPort: 8082, OutputPath: ".", Logger: driverLog, TaskRecipesPath: "./", @@ -227,19 +211,20 @@ func Test_TaskRecipes(t *testing.T) { "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", }, - Args: []string{"sh", "-ec", "echo 'testing progress: 100%|' >&2; sleep 6"}, - ReportInterval: time.Second * 5, + Args: []string{"sh", "-ec", "sleep 2; echo 'testing progress: 100%|' >&2; sleep 4"}, + SocketPath: genRandomSocketPath(), }) assert.Nil(t, err) - assert.Nil(t, driver.Run()) + msgs, _ := runDirverAndWait4Complete(t, driver, false) + assert.Equal(t, []string{ - "update status from the driver: running", + "initializing the evaluation job", "testing progress: 100%", - "update status from the driver: completed", - }, progresssServer.progressMsgs) + "job completed", + }, msgs) - server.Stop() + assert.Nil(t, driver.Shutdown()) tr0, err := os.ReadFile("./tr_0.yaml") assert.Nil(t, err) @@ -260,19 +245,8 @@ func Test_TaskRecipes(t *testing.T) { } func Test_CustomCards(t *testing.T) { - server := grpc.NewServer() - progresssServer := ProgressUpdateServer{} - v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &progresssServer) - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", 8082)) - assert.Nil(t, err) - go server.Serve(lis) - driver, err := NewDriver(&DriverOption{ Context: context.Background(), - JobNamespace: "fms-lm-eval-service-system", - JobName: "evaljob-sample", - GrpcService: "localhost", - GrpcPort: 8082, OutputPath: ".", Logger: driverLog, TaskRecipesPath: "./", @@ -283,21 +257,22 @@ func Test_CustomCards(t *testing.T) { CustomCards: []string{ `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "deutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, }, - Args: []string{"sh", "-ec", "echo 'testing progress: 100%|' >&2; sleep 6"}, - ReportInterval: time.Second * 5, + Args: []string{"sh", "-ec", "sleep 1; echo 'testing progress: 100%|' >&2; sleep 3"}, + SocketPath: genRandomSocketPath(), }) assert.Nil(t, err) os.Mkdir("cards", 0750) - assert.Nil(t, driver.Run()) + msgs, _ := runDirverAndWait4Complete(t, driver, false) + assert.Equal(t, []string{ - "update status from the driver: running", + "initializing the evaluation job", "testing progress: 100%", - "update status from the driver: completed", - }, progresssServer.progressMsgs) + "job completed", + }, msgs) - server.Stop() + assert.Nil(t, driver.Shutdown()) tr0, err := os.ReadFile("./tr_0.yaml") assert.Nil(t, err) @@ -319,33 +294,24 @@ func Test_CustomCards(t *testing.T) { } func Test_ProgramError(t *testing.T) { - server := grpc.NewServer() - progresssServer := ProgressUpdateServer{} - v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &progresssServer) - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", 8082)) - assert.Nil(t, err) - go server.Serve(lis) - driver, err := NewDriver(&DriverOption{ - Context: context.Background(), - JobNamespace: "fms-lm-eval-service-system", - JobName: "evaljob-sample", - GrpcService: "localhost", - GrpcPort: 8082, - OutputPath: ".", - Logger: driverLog, - Args: []string{"sh", "-ec", "exit 1"}, - ReportInterval: time.Second * 5, + Context: context.Background(), + OutputPath: ".", + Logger: driverLog, + Args: []string{"sh", "-ec", "sleep 1; exit 1"}, + SocketPath: genRandomSocketPath(), }) assert.Nil(t, err) - assert.Nil(t, driver.Run()) + msgs, _ := runDirverAndWait4Complete(t, driver, true) + assert.Equal(t, []string{ - "update status from the driver: running", + "initializing the evaluation job", "exit status 1", - }, progresssServer.progressMsgs) + }, msgs) + + assert.Nil(t, driver.Shutdown()) - server.Stop() assert.Nil(t, os.Remove("./stderr.log")) assert.Nil(t, os.Remove("./stdout.log")) } diff --git a/controllers/lmes/grpc_server.go b/controllers/lmes/grpc_server.go deleted file mode 100644 index 2aef9b4..0000000 --- a/controllers/lmes/grpc_server.go +++ /dev/null @@ -1,119 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package lmes - -import ( - "context" - "crypto/tls" - "crypto/x509" - "fmt" - "net" - "os" - - "github.com/spf13/viper" - "github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/api/v1beta1" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "sigs.k8s.io/controller-runtime/pkg/log" -) - -var server *grpc.Server - -func StartGrpcServer(ctx context.Context, ctor *LMEvalJobReconciler) error { - log := log.FromContext(ctx) - // checking the cer/key envs - if viper.IsSet(GrpcServerCertEnv) && viper.IsSet(GrpcServerKeyEnv) { - serverKey, serverCert := viper.GetString(GrpcServerKeyEnv), viper.GetString(GrpcServerCertEnv) - - if viper.IsSet(GrpcClientCaEnv) { - // mTLS case - var err error - if server, err = getMTLSServer(serverCert, serverKey); err != nil { - return err - } - ctor.options.grpcTLSMode = TLSMode_mTLS - log.Info("GRPC server uses the mTLS") - } else { - // TLS case - creds, err := credentials.NewServerTLSFromFile(serverCert, serverKey) - if err != nil { - return err - } - server = grpc.NewServer(grpc.Creds(creds)) - ctor.options.grpcTLSMode = TLSMode_TLS - log.Info("GRPC server uses the TLS") - } - } else { - ctor.options.grpcTLSMode = TLSMode_None - log.Info("GRPC server uses insecure protocol") - server = grpc.NewServer() - } - - v1beta1.RegisterLMEvalJobUpdateServiceServer(server, &updateStatusServer{ - controller: ctor, - }) - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", ctor.options.GrpcPort)) - if err != nil { - return err - } - log.Info("GRPC server started") - err = server.Serve(lis) - return err -} - -func getMTLSServer(serverCert string, serverKey string) (*grpc.Server, error) { - clientCaPath := viper.GetString(GrpcClientCaEnv) - cert, err := tls.LoadX509KeyPair(serverCert, serverKey) - if err != nil { - return nil, err - } - ca := x509.NewCertPool() - caBytes, err := os.ReadFile(clientCaPath) - if err != nil { - return nil, err - } - if ok := ca.AppendCertsFromPEM(caBytes); !ok { - return nil, fmt.Errorf("failed to parse %q", clientCaPath) - } - tlsConfig := &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - Certificates: []tls.Certificate{cert}, - ClientCAs: ca, - } - - return grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))), nil -} - -type updateStatusServer struct { - v1beta1.UnimplementedLMEvalJobUpdateServiceServer - controller *LMEvalJobReconciler -} - -func (s *updateStatusServer) UpdateStatus(ctx context.Context, newStatus *v1beta1.JobStatus) (resp *v1beta1.Response, err error) { - resp = &v1beta1.Response{ - Code: v1beta1.ResponseCode_OK, - Message: "updated the job status successfully", - } - - err = s.controller.updateStatus(ctx, newStatus) - if err != nil { - resp.Code = v1beta1.ResponseCode_ERROR - resp.Message = err.Error() - } - - return resp, err -} diff --git a/controllers/lmes/grpc_server_test.go b/controllers/lmes/grpc_server_test.go deleted file mode 100644 index 6c32a0a..0000000 --- a/controllers/lmes/grpc_server_test.go +++ /dev/null @@ -1,138 +0,0 @@ -/* -Copyright 2024. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package lmes - -import ( - "bytes" - "context" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "math/big" - "net" - "os" - "time" - - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - "github.com/spf13/viper" -) - -var _ = Describe("GRPC Server", func() { - Context("Start the GRPC Server", func() { - - ctx := context.Background() - BeforeEach(func() { - By("create ca, key, cert for the test case") - ca := &x509.Certificate{ - SerialNumber: big.NewInt(2019), - Subject: pkix.Name{ - Organization: []string{"LM-Eval"}, - Country: []string{"US"}, - Province: []string{""}, - Locality: []string{"San Jose"}, - StreetAddress: []string{"large lanaguage"}, - PostalCode: []string{"95141"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, - } - caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) - Expect(err).NotTo(HaveOccurred()) - caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) - Expect(err).NotTo(HaveOccurred()) - caPEM := new(bytes.Buffer) - pem.Encode(caPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: caBytes, - }) - - caPrivKeyPEM := new(bytes.Buffer) - pem.Encode(caPrivKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(caPrivKey), - }) - - cert := &x509.Certificate{ - SerialNumber: big.NewInt(1658), - Subject: pkix.Name{ - Organization: []string{"LM-Eval"}, - Country: []string{"US"}, - Province: []string{""}, - Locality: []string{"San Jose"}, - StreetAddress: []string{"large lanaguage"}, - PostalCode: []string{"95141"}, - }, - IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), - SubjectKeyId: []byte{1, 2, 3, 4, 6}, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - } - - certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) - Expect(err).NotTo(HaveOccurred()) - - certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) - Expect(err).NotTo(HaveOccurred()) - - certPEM := new(bytes.Buffer) - pem.Encode(certPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }) - - certPrivKeyPEM := new(bytes.Buffer) - pem.Encode(certPrivKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), - }) - - var writeToFile = func(prefix string, content *bytes.Buffer, env string) { - tmpFile, err := os.CreateTemp("", prefix) - Expect(err).NotTo(HaveOccurred()) - _, err = tmpFile.Write(content.Bytes()) - Expect(err).NotTo(HaveOccurred()) - viper.Set(env, tmpFile.Name()) - } - - writeToFile("ca", caPEM, GrpcClientCaEnv) - writeToFile("cert", certPEM, GrpcServerCertEnv) - writeToFile("key", certPrivKeyPEM, GrpcServerKeyEnv) - }) - It("should successfully reconcile the resource", func() { - By("Reconciling the created resource") - controllerReconciler := &LMEvalJobReconciler{ - options: &ServiceOptions{ - GrpcPort: 8082, - GrpcService: "localhost", - }, - } - - go StartGrpcServer(ctx, controllerReconciler) - time.Sleep(time.Second * 10) - server.Stop() - }) - }) -}) diff --git a/controllers/lmes/lmevaljob_controller.go b/controllers/lmes/lmevaljob_controller.go index 2b77909..3240036 100644 --- a/controllers/lmes/lmevaljob_controller.go +++ b/controllers/lmes/lmevaljob_controller.go @@ -17,6 +17,7 @@ limitations under the License. package lmes import ( + "bytes" "context" "fmt" "maps" @@ -24,6 +25,7 @@ import ( "slices" "strconv" "strings" + "sync" "time" corev1 "k8s.io/api/core/v1" @@ -31,7 +33,11 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/json" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/scheme" + "k8s.io/client-go/rest" "k8s.io/client-go/tools/record" + "k8s.io/client-go/tools/remotecommand" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/client" @@ -45,9 +51,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/source" "github.com/go-logr/logr" - "github.com/spf13/viper" lmesv1alpha1 "github.com/trustyai-explainability/trustyai-service-operator/api/lmes/v1alpha1" - backendv1beta1 "github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/api/v1beta1" "github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/driver" ) @@ -59,72 +63,102 @@ var ( } optionKeys = map[string]string{ - "PodImage": PodImageKey, - "DriverImage": DriverImageKey, - "PodCheckingInterval": PodCheckingIntervalKey, - "ImagePullPolicy": ImagePullPolicyKey, - "GrpcPort": GrpcPortKey, - "GrpcService": GrpcServiceKey, - "GrpcServerSecret": GrpcServerSecretKey, - "GrpcClientSecret": GrpcClientSecretKey, - "DriverReportInterval": DriverReportIntervalKey, - "DefaultBatchSize": DefaultBatchSizeKey, - "MaxBatchSize": MaxBatchSizeKey, - "DetectDevice": DetectDeviceKey, + "PodImage": PodImageKey, + "DriverImage": DriverImageKey, + "PodCheckingInterval": PodCheckingIntervalKey, + "ImagePullPolicy": ImagePullPolicyKey, + "DefaultBatchSize": DefaultBatchSizeKey, + "MaxBatchSize": MaxBatchSizeKey, + "DetectDevice": DetectDeviceKey, } labelFilterPrefixes = []string{} annotationFilterPrefixes = []string{} ) -type TLSMode int - -const ( - TLSMode_None TLSMode = 0 - TLSMode_TLS TLSMode = 1 - TLSMode_mTLS TLSMode = 2 -) +// maintain a list of key-time pair data. +// provide a function to add the key and update the time +// atomitcally and return a reconcile requeue event +// if needed. +type syncedMap4Reconciler struct { + data map[string]time.Time + mutex sync.Mutex +} // LMEvalJobReconciler reconciles a LMEvalJob object type LMEvalJobReconciler struct { client.Client - Scheme *runtime.Scheme - Recorder record.EventRecorder - ConfigMap string - Namespace string - options *ServiceOptions + Scheme *runtime.Scheme + Recorder record.EventRecorder + ConfigMap string + Namespace string + options *ServiceOptions + restConfig *rest.Config + restClient rest.Interface + pullingJobs *syncedMap4Reconciler } type ServiceOptions struct { - PodImage string - DriverImage string - DriverReportInterval time.Duration - PodCheckingInterval time.Duration - ImagePullPolicy corev1.PullPolicy - GrpcPort int - GrpcService string - GrpcServerSecret string - GrpcClientSecret string - MaxBatchSize int - DefaultBatchSize int - DetectDevice bool - grpcTLSMode TLSMode + PodImage string + DriverImage string + PodCheckingInterval time.Duration + ImagePullPolicy corev1.PullPolicy + MaxBatchSize int + DefaultBatchSize int + DetectDevice bool } +// The registered function to set up LMES controller func ControllerSetUp(mgr manager.Manager, ns, configmap string, recorder record.EventRecorder) error { + clientset, err := kubernetes.NewForConfig(mgr.GetConfig()) + if err != nil { + return err + } + return (&LMEvalJobReconciler{ - ConfigMap: configmap, - Namespace: ns, - Client: mgr.GetClient(), - Scheme: mgr.GetScheme(), - Recorder: mgr.GetEventRecorderFor("lm-eval-service-controller"), + ConfigMap: configmap, + Namespace: ns, + Client: mgr.GetClient(), + Scheme: mgr.GetScheme(), + Recorder: mgr.GetEventRecorderFor("lm-eval-service-controller"), + restConfig: mgr.GetConfig(), + restClient: clientset.CoreV1().RESTClient(), + pullingJobs: newSyncedMap4Reconciler(), }).SetupWithManager(mgr) } +func newSyncedMap4Reconciler() *syncedMap4Reconciler { + return &syncedMap4Reconciler{data: make(map[string]time.Time)} +} + +// check if the paired time of the key is passed. if yes, update the time and +// return a requeue result. otherwise an empty result +func (q *syncedMap4Reconciler) addOrUpdate(key string, after time.Duration) reconcile.Result { + q.mutex.Lock() + defer q.mutex.Unlock() + + v, ok := q.data[key] + if ok && time.Now().Before(v) { + // no need to requeue since there is an existing one + return reconcile.Result{} + } + value := time.Now().Add(after) + q.data[key] = value + return reconcile.Result{Requeue: true, RequeueAfter: after} +} + +// remove the key from the list +func (q *syncedMap4Reconciler) remove(key string) { + q.mutex.Lock() + defer q.mutex.Unlock() + delete(q.data, key) +} + // +kubebuilder:rbac:groups=trustyai.opendatahub.io,resources=lmevaljobs,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=trustyai.opendatahub.io,resources=lmevaljobs/status,verbs=get;update;patch // +kubebuilder:rbac:groups=trustyai.opendatahub.io,resources=lmevaljobs/finalizers,verbs=update // +kubebuilder:rbac:groups="",resources=pods,verbs=get;list;watch;create;delete +// +kubebuilder:rbac:groups="",resources=pods/exec,verbs=get;list;watch;create;delete // +kubebuilder:rbac:groups="",resources=configmaps,verbs=get;watch;list // +kubebuilder:rbac:groups="",resources=secrets,verbs=get;watch;list @@ -192,19 +226,6 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error { return err } - if err := r.checkSecrets(ctx); err != nil { - // if mTLS/TLS is not enable, then we are good. Otherwise error out - if viper.IsSet(GrpcServerCertEnv) || - viper.IsSet(GrpcServerKeyEnv) || - viper.IsSet(GrpcClientCaEnv) { - return fmt.Errorf("TLS or mTLS is enabled for GRPC server but secrets don't exist") - } - } - - // ideally, this should be call in the main.go, but GRPC server depends on the - // constructOptionsFromConfigMap to get the settings. - go StartGrpcServer(ctx, r) - return nil })); err != nil { return err @@ -241,54 +262,27 @@ func (r *LMEvalJobReconciler) SetupWithManager(mgr ctrl.Manager) error { Complete(r) } -func (r *LMEvalJobReconciler) checkSecrets(ctx context.Context) error { - - var isSecretExists = func(name string) bool { - var secret corev1.Secret - err := r.Get(ctx, types.NamespacedName{Namespace: r.Namespace, Name: name}, &secret) - return err == nil - } - - if !isSecretExists(r.options.GrpcServerSecret) { - return fmt.Errorf("secret %s not found", r.options.GrpcServerSecret) - } - if !isSecretExists(r.options.GrpcClientSecret) { - return fmt.Errorf("secret %s not found", r.options.GrpcServerSecret) - } - return nil -} - -func (r *LMEvalJobReconciler) updateStatus(ctx context.Context, newStatus *backendv1beta1.JobStatus) (err error) { - log := log.FromContext(ctx) - - if strings.Trim(newStatus.GetJobName(), " ") == "" || - strings.Trim(newStatus.GetJobNamespace(), " ") == "" { - - return fmt.Errorf("JobName or JobNameSpace is empty") +func (r *LMEvalJobReconciler) updateStatus(ctx context.Context, log logr.Logger, job *lmesv1alpha1.LMEvalJob) error { + stdin, _, err := r.remoteCommand(ctx, job, fmt.Sprintf("%s %s", DestDriverPath, "--get-status")) + if err != nil { + return err } - - job := &lmesv1alpha1.LMEvalJob{} - if err = r.Get(ctx, types.NamespacedName{ - Namespace: newStatus.JobNamespace, - Name: newStatus.JobName, - }, job); err != nil { - log.Info("unable to fetch LMEvalJob") + newStatus := lmesv1alpha1.LMEvalJobStatus{} + if err = json.Unmarshal(stdin, &newStatus); err != nil { return err } - newJobStatus := job.Status.DeepCopy() - newJobStatus.State = lmesv1alpha1.JobState(newStatus.GetState()) - newJobStatus.Reason = lmesv1alpha1.Reason(newStatus.GetReason()) + // driver only provides updates for these fields + if newStatus.State != job.Status.State || + newStatus.Message != job.Status.Message || + newStatus.Reason != job.Status.Reason || + newStatus.Results != job.Status.Results { - if newStatus.GetStatusMessage() != "" { - newJobStatus.Message = newStatus.GetStatusMessage() - } - if newStatus.Results != nil { - newJobStatus.Results = newStatus.GetResults() - } + job.Status.State = newStatus.State + job.Status.Message = newStatus.Message + job.Status.Reason = newStatus.Reason + job.Status.Results = newStatus.Results - if !reflect.DeepEqual(job.Status, newJobStatus) { - job.Status = *newJobStatus err = r.Status().Update(ctx, job) if err != nil { log.Error(err, "failed to update status") @@ -297,21 +291,49 @@ func (r *LMEvalJobReconciler) updateStatus(ctx context.Context, newStatus *backe return err } +func (r *LMEvalJobReconciler) shutdownDriver(ctx context.Context, job *lmesv1alpha1.LMEvalJob) error { + _, _, err := r.remoteCommand(ctx, job, fmt.Sprintf("%s %s", DestDriverPath, "--shutdown")) + return err +} + +func (r *LMEvalJobReconciler) remoteCommand(ctx context.Context, job *lmesv1alpha1.LMEvalJob, command string) ([]byte, []byte, error) { + request := r.restClient.Post(). + Namespace(job.GetNamespace()). + Resource("pods"). + Name(job.GetName()). + SubResource("exec"). + VersionedParams(&corev1.PodExecOptions{ + Command: []string{"/bin/sh", "-c", command}, + Stdin: false, + Stdout: true, + Stderr: true, + }, scheme.ParameterCodec) + + outBuff := &bytes.Buffer{} + errBuf := &bytes.Buffer{} + exec, err := remotecommand.NewSPDYExecutor(r.restConfig, "POST", request.URL()) + if err != nil { + return nil, nil, err + } + if err = exec.StreamWithContext(ctx, remotecommand.StreamOptions{ + Stdout: outBuff, + Stderr: errBuf, + }); err != nil { + return nil, nil, err + } + return outBuff.Bytes(), errBuf.Bytes(), nil +} + func (r *LMEvalJobReconciler) constructOptionsFromConfigMap( ctx context.Context, configmap *corev1.ConfigMap) error { r.options = &ServiceOptions{ - DriverImage: DefaultDriverImage, - PodImage: DefaultPodImage, - DriverReportInterval: driver.DefaultDriverReportInterval, - PodCheckingInterval: DefaultPodCheckingInterval, - ImagePullPolicy: DefaultImagePullPolicy, - GrpcPort: DefaultGrpcPort, - GrpcService: DefaultGrpcService, - GrpcServerSecret: DefaultGrpcServerSecret, - GrpcClientSecret: DefaultGrpcClientSecret, - MaxBatchSize: DefaultMaxBatchSize, - DetectDevice: DefaultDetectDevice, - DefaultBatchSize: DefaultBatchSize, + DriverImage: DefaultDriverImage, + PodImage: DefaultPodImage, + PodCheckingInterval: DefaultPodCheckingInterval, + ImagePullPolicy: DefaultImagePullPolicy, + MaxBatchSize: DefaultMaxBatchSize, + DetectDevice: DefaultDetectDevice, + DefaultBatchSize: DefaultBatchSize, } log := log.FromContext(ctx) @@ -374,6 +396,8 @@ func (r *LMEvalJobReconciler) constructOptionsFromConfigMap( } func (r *LMEvalJobReconciler) handleDeletion(ctx context.Context, job *lmesv1alpha1.LMEvalJob, log logr.Logger) (reconcile.Result, error) { + defer r.pullingJobs.remove(string(job.GetUID())) + if controllerutil.ContainsFinalizer(job, lmesv1alpha1.FinalizerName) { // delete the correspondling pod if needed // remove our finalizer from the list and update it. @@ -459,7 +483,7 @@ func (r *LMEvalJobReconciler) handleNewCR(ctx context.Context, log logr.Logger, job.Namespace)) log.Info("Successfully create a Pod for the Job") // Check the pod after the config interval - return ctrl.Result{Requeue: true, RequeueAfter: r.options.PodCheckingInterval}, nil + return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil } func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Logger, job *lmesv1alpha1.LMEvalJob) (ctrl.Result, error) { @@ -482,38 +506,31 @@ func (r *LMEvalJobReconciler) checkScheduledPod(ctx context.Context, log logr.Lo return ctrl.Result{}, err } - if pod.Status.ContainerStatuses == nil { - // wait for the pod to initialize and run the containers - return ctrl.Result{Requeue: true, RequeueAfter: r.options.PodCheckingInterval}, nil - } - - mainIndex := slices.IndexFunc(pod.Status.ContainerStatuses, func(s corev1.ContainerStatus) bool { - return s.Name == "main" - }) - - if mainIndex == -1 || pod.Status.ContainerStatuses[mainIndex].State.Terminated == nil { - // wait for the main container to finish - return ctrl.Result{Requeue: true, RequeueAfter: r.options.PodCheckingInterval}, nil - } - - // main container finished. update status - job.Status.State = lmesv1alpha1.CompleteJobState - if pod.Status.ContainerStatuses[mainIndex].State.Terminated.ExitCode == 0 { - job.Status.Reason = lmesv1alpha1.SucceedReason - } else { + if mainIdx := getContainerByName(&pod.Status, "main"); mainIdx == -1 { + // waiting for the main container to be up + return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + } else if podFailed, msg := isContainerFailed(&pod.Status.ContainerStatuses[mainIdx]); podFailed { + job.Status.State = lmesv1alpha1.CompleteJobState job.Status.Reason = lmesv1alpha1.FailedReason - job.Status.Message = pod.Status.ContainerStatuses[mainIndex].State.Terminated.Reason + job.Status.Message = msg + if err := r.Status().Update(ctx, job); err != nil { + log.Error(err, "unable to update LMEvalJob status for pod failure") + } + log.Info("detect an error on the job's pod. marked the job as done", "name", job.Name) + return ctrl.Result{}, err + } else if pod.Status.ContainerStatuses[mainIdx].State.Running == nil { + return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil } - err = r.Status().Update(ctx, job) + // pull status from the driver + if err = r.updateStatus(ctx, log, job); err == nil && job.Status.State == lmesv1alpha1.CompleteJobState { + // the update will trigger another reconcile + return ctrl.Result{}, nil + } if err != nil { - log.Error(err, "unable to update LMEvalJob status", "state", job.Status.State) + log.Error(err, "unable to retrieve the status from the job's pod. retry after the pulling interval") } - r.Recorder.Event(job, "Normal", "PodCompleted", - fmt.Sprintf("The pod for the LMEvalJob %s in namespace %s has completed", - job.Name, - job.Namespace)) - return ctrl.Result{}, err + return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil } func (r *LMEvalJobReconciler) getPod(ctx context.Context, job *lmesv1alpha1.LMEvalJob) (*corev1.Pod, error) { @@ -555,17 +572,36 @@ func (r *LMEvalJobReconciler) deleteJobPod(ctx context.Context, job *lmesv1alpha func (r *LMEvalJobReconciler) handleComplete(ctx context.Context, log logr.Logger, job *lmesv1alpha1.LMEvalJob) (ctrl.Result, error) { if job.Status.CompleteTime == nil { + // make sure the pod is in the complete state. if not, run the shutdown command + pod, err := r.getPod(ctx, job) + if err == nil { + if getRunningContainerByName(&pod.Status, "main") != -1 { + // send shutdown command if the main container is running + if err := r.shutdownDriver(ctx, job); err != nil { + log.Error(err, "failed to shutdown the job pod. retry after the pulling interval") + return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), nil + } + } + } else { + // the pod is gone ?? + log.Error(err, "LMEvalJob is marked as Complete but the pod is gone") + } + r.Recorder.Event(job, "Normal", "JobCompleted", fmt.Sprintf("The LMEvalJob %s in namespace %s has completed", job.Name, job.Namespace)) - // TODO: final wrap up/clean up + + // record the CompleteTime current := v1.Now() job.Status.CompleteTime = ¤t if err := r.Status().Update(ctx, job); err != nil { log.Error(err, "failed to update status for completion") } } + + // make sure to clean up the pullingJobs + r.pullingJobs.remove(string(job.GetUID())) return ctrl.Result{}, nil } @@ -582,7 +618,7 @@ func (r *LMEvalJobReconciler) handleCancel(ctx context.Context, log logr.Logger, if err := r.deleteJobPod(ctx, job); err != nil { // leave the state as is and retry again log.Error(err, "failed to delete pod. scheduled a retry", "interval", r.options.PodCheckingInterval.String()) - return ctrl.Result{Requeue: true, RequeueAfter: r.options.PodCheckingInterval}, err + return r.pullingJobs.addOrUpdate(string(job.GetUID()), r.options.PodCheckingInterval), err } } @@ -594,6 +630,7 @@ func (r *LMEvalJobReconciler) handleCancel(ctx context.Context, log logr.Logger, fmt.Sprintf("The LMEvalJob %s in namespace %s has cancelled and changed its state to Complete", job.Name, job.Namespace)) + r.pullingJobs.remove(string(job.GetUID())) return ctrl.Result{}, err } @@ -644,7 +681,6 @@ func (r *LMEvalJobReconciler) createPod(job *lmesv1alpha1.LMEvalJob, log logr.Lo }, } - envVars, volumes, volumeMounts = r.patch4GrpcTLS(envVars, volumes, volumeMounts, r.options.grpcTLSMode) volumes = append(volumes, job.Spec.Pod.GetVolumes()...) volumeMounts = append(volumeMounts, job.Spec.Pod.GetContainer().GetVolumMounts()...) labels := getPodLabels(job.Labels, log) @@ -840,12 +876,7 @@ func (r *LMEvalJobReconciler) generateCmd(job *lmesv1alpha1.LMEvalJob) []string } cmds := []string{ DestDriverPath, - "--job-namespace", job.Namespace, - "--job-name", job.Name, - "--grpc-service", fmt.Sprintf("%s.%s.svc", r.options.GrpcService, r.Namespace), - "--grpc-port", strconv.Itoa(r.options.GrpcPort), "--output-path", "/opt/app-root/src/output", - "--report-interval", r.options.DriverReportInterval.String(), } if r.options.DetectDevice { @@ -883,96 +914,32 @@ func argsToString(args []lmesv1alpha1.Arg) string { return strings.Join(equalForms, ",") } -func (r *LMEvalJobReconciler) patch4GrpcTLS( - envVars []corev1.EnvVar, - volumes []corev1.Volume, - volumeMounts []corev1.VolumeMount, - tlsMode TLSMode) ([]corev1.EnvVar, []corev1.Volume, []corev1.VolumeMount) { - - var secretMode int32 = 420 - - if tlsMode == TLSMode_mTLS { - envVars = append(envVars, - corev1.EnvVar{ - Name: driver.GrpcClientKeyEnv, - Value: "/tmp/k8s-grpc-client/certs/tls.key", - }, - corev1.EnvVar{ - Name: driver.GrpcClientCertEnv, - Value: "/tmp/k8s-grpc-client/certs/tls.crt", - }, - corev1.EnvVar{ - Name: driver.GrpcServerCaEnv, - Value: "/tmp/k8s-grpc-server/certs/ca.crt", - }, - ) - - volumeMounts = append(volumeMounts, - corev1.VolumeMount{ - Name: "client-cert", - MountPath: "/tmp/k8s-grpc-client/certs", - ReadOnly: true, - }, - corev1.VolumeMount{ - Name: "server-cert", - MountPath: "/tmp/k8s-grpc-server/certs", - ReadOnly: true, - }, - ) - - volumes = append(volumes, - corev1.Volume{ - Name: "client-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: r.options.GrpcClientSecret, - DefaultMode: &secretMode, - }, - }, - }, - corev1.Volume{ - Name: "server-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: r.options.GrpcServerSecret, - DefaultMode: &secretMode, - Items: []corev1.KeyToPath{ - {Key: "ca.crt", Path: "ca.crt"}, - }, - }, - }, - }, - ) - } else if tlsMode == TLSMode_TLS { - envVars = append(envVars, - corev1.EnvVar{ - Name: driver.GrpcServerCaEnv, - Value: "/tmp/k8s-grpc-server/certs/ca.crt", - }, - ) +func isContainerFailed(status *corev1.ContainerStatus) (bool, string) { + if status.State.Waiting != nil && + status.State.Waiting.Reason != "PodInitializing" { + return true, status.State.Waiting.Reason + } + if status.State.Terminated != nil && + status.State.Terminated.Reason != "Complete" { + return true, status.State.Terminated.Reason + } + return false, "" +} - volumeMounts = append(volumeMounts, - corev1.VolumeMount{ - Name: "server-cert", - MountPath: "/tmp/k8s-grpc-server/certs", - }, - ) - - volumes = append(volumes, - corev1.Volume{ - Name: "server-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: r.options.GrpcServerSecret, - DefaultMode: &secretMode, - Items: []corev1.KeyToPath{ - {Key: "ca.crt", Path: "ca.crt"}, - }, - }, - }, - }, - ) +// return the index of the container which is in running state and with the specified name +// otherwise return -1 +func getRunningContainerByName(status *corev1.PodStatus, name string) int { + if idx := getContainerByName(status, name); idx != -1 && status.ContainerStatuses[idx].State.Running != nil { + return idx } + return -1 +} - return envVars, volumes, volumeMounts +func getContainerByName(status *corev1.PodStatus, name string) int { + if status.ContainerStatuses == nil { + return -1 + } + return slices.IndexFunc(status.ContainerStatuses, func(s corev1.ContainerStatus) bool { + return s.Name == name + }) } diff --git a/controllers/lmes/lmevaljob_controller_test.go b/controllers/lmes/lmevaljob_controller_test.go index 8139032..b04b21d 100644 --- a/controllers/lmes/lmevaljob_controller_test.go +++ b/controllers/lmes/lmevaljob_controller_test.go @@ -34,7 +34,6 @@ var ( allowPrivilegeEscalation = false runAsNonRootUser = true runAsUser int64 = 1001030000 - secretMode int32 = 420 ) func Test_SimplePod(t *testing.T) { @@ -45,8 +44,6 @@ func Test_SimplePod(t *testing.T) { PodImage: "podimage:latest", DriverImage: "driver:latest", ImagePullPolicy: corev1.PullAlways, - GrpcPort: 8088, - GrpcService: "grpc-service", }, } var job = &lmesv1alpha1.LMEvalJob{ @@ -169,8 +166,6 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { PodImage: "podimage:latest", DriverImage: "driver:latest", ImagePullPolicy: corev1.PullAlways, - GrpcPort: 8088, - GrpcService: "grpc-service", }, } var job = &lmesv1alpha1.LMEvalJob{ @@ -357,191 +352,14 @@ func Test_WithLabelsAnnotationsResourcesVolumes(t *testing.T) { assert.Equal(t, expect, newPod) } -func Test_GrpcMTlsPod(t *testing.T) { - log := log.FromContext(context.Background()) - lmevalRec := LMEvalJobReconciler{ - Namespace: "test", - options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - GrpcPort: 8088, - GrpcService: "grpc-service", - GrpcServerSecret: "server-secret", - GrpcClientSecret: "client-secret", - grpcTLSMode: TLSMode_mTLS, - }, - } - var job = &lmesv1alpha1.LMEvalJob{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - UID: "for-testing", - }, - TypeMeta: metav1.TypeMeta{ - Kind: lmesv1alpha1.KindName, - APIVersion: lmesv1alpha1.Version, - }, - Spec: lmesv1alpha1.LMEvalJobSpec{ - Model: "test", - ModelArgs: []lmesv1alpha1.Arg{ - {Name: "arg1", Value: "value1"}, - }, - TaskList: lmesv1alpha1.TaskList{ - TaskNames: []string{"task1", "task2"}, - }, - }, - } - - expect := &corev1.Pod{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - Labels: map[string]string{ - "app.kubernetes.io/name": "ta-lmes", - }, - OwnerReferences: []metav1.OwnerReference{ - { - APIVersion: lmesv1alpha1.Version, - Kind: lmesv1alpha1.KindName, - Name: "test", - Controller: &isController, - UID: "for-testing", - }, - }, - }, - TypeMeta: metav1.TypeMeta{ - Kind: "Pod", - APIVersion: "v1", - }, - Spec: corev1.PodSpec{ - InitContainers: []corev1.Container{ - { - Name: "driver", - Image: lmevalRec.options.DriverImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, - Command: []string{DriverPath, "--copy", DestDriverPath}, - SecurityContext: &corev1.SecurityContext{ - AllowPrivilegeEscalation: &allowPrivilegeEscalation, - RunAsUser: &runAsUser, - Capabilities: &corev1.Capabilities{ - Drop: []corev1.Capability{ - "ALL", - }, - }, - }, - VolumeMounts: []corev1.VolumeMount{ - { - Name: "shared", - MountPath: "/opt/app-root/src/bin", - }, - }, - }, - }, - Containers: []corev1.Container{ - { - Name: "main", - Image: lmevalRec.options.PodImage, - ImagePullPolicy: lmevalRec.options.ImagePullPolicy, - Env: []corev1.EnvVar{ - { - Name: "GRPC_CLIENT_KEY", - Value: "/tmp/k8s-grpc-client/certs/tls.key", - }, - { - Name: "GRPC_CLIENT_CERT", - Value: "/tmp/k8s-grpc-client/certs/tls.crt", - }, - { - Name: "GRPC_SERVER_CA", - Value: "/tmp/k8s-grpc-server/certs/ca.crt", - }, - }, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), - SecurityContext: &corev1.SecurityContext{ - AllowPrivilegeEscalation: &allowPrivilegeEscalation, - RunAsUser: &runAsUser, - Capabilities: &corev1.Capabilities{ - Drop: []corev1.Capability{ - "ALL", - }, - }, - }, - VolumeMounts: []corev1.VolumeMount{ - { - Name: "shared", - MountPath: "/opt/app-root/src/bin", - }, - { - Name: "client-cert", - MountPath: "/tmp/k8s-grpc-client/certs", - ReadOnly: true, - }, - { - Name: "server-cert", - MountPath: "/tmp/k8s-grpc-server/certs", - ReadOnly: true, - }, - }, - }, - }, - SecurityContext: &corev1.PodSecurityContext{ - RunAsNonRoot: &runAsNonRootUser, - SeccompProfile: &corev1.SeccompProfile{ - Type: corev1.SeccompProfileTypeRuntimeDefault, - }, - }, - Volumes: []corev1.Volume{ - { - Name: "shared", VolumeSource: corev1.VolumeSource{ - EmptyDir: &corev1.EmptyDirVolumeSource{}, - }, - }, - { - Name: "client-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: lmevalRec.options.GrpcClientSecret, - DefaultMode: &secretMode, - }, - }, - }, - { - Name: "server-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: lmevalRec.options.GrpcServerSecret, - DefaultMode: &secretMode, - Items: []corev1.KeyToPath{ - {Key: "ca.crt", Path: "ca.crt"}, - }, - }, - }, - }, - }, - RestartPolicy: corev1.RestartPolicyNever, - }, - } - - newPod := lmevalRec.createPod(job, log) - - assert.Equal(t, expect, newPod) -} - func Test_EnvSecretsPod(t *testing.T) { log := log.FromContext(context.Background()) lmevalRec := LMEvalJobReconciler{ Namespace: "test", options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - GrpcPort: 8088, - GrpcService: "grpc-service", - GrpcServerSecret: "server-secret", - GrpcClientSecret: "client-secret", - grpcTLSMode: TLSMode_mTLS, + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, }, } var job = &lmesv1alpha1.LMEvalJob{ @@ -644,18 +462,6 @@ func Test_EnvSecretsPod(t *testing.T) { }, }, }, - { - Name: "GRPC_CLIENT_KEY", - Value: "/tmp/k8s-grpc-client/certs/tls.key", - }, - { - Name: "GRPC_CLIENT_CERT", - Value: "/tmp/k8s-grpc-client/certs/tls.crt", - }, - { - Name: "GRPC_SERVER_CA", - Value: "/tmp/k8s-grpc-server/certs/ca.crt", - }, }, Command: lmevalRec.generateCmd(job), Args: lmevalRec.generateArgs(job, log), @@ -673,16 +479,6 @@ func Test_EnvSecretsPod(t *testing.T) { Name: "shared", MountPath: "/opt/app-root/src/bin", }, - { - Name: "client-cert", - MountPath: "/tmp/k8s-grpc-client/certs", - ReadOnly: true, - }, - { - Name: "server-cert", - MountPath: "/tmp/k8s-grpc-server/certs", - ReadOnly: true, - }, }, }, }, @@ -698,27 +494,6 @@ func Test_EnvSecretsPod(t *testing.T) { EmptyDir: &corev1.EmptyDirVolumeSource{}, }, }, - { - Name: "client-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: lmevalRec.options.GrpcClientSecret, - DefaultMode: &secretMode, - }, - }, - }, - { - Name: "server-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: lmevalRec.options.GrpcServerSecret, - DefaultMode: &secretMode, - Items: []corev1.KeyToPath{ - {Key: "ca.crt", Path: "ca.crt"}, - }, - }, - }, - }, }, RestartPolicy: corev1.RestartPolicyNever, }, @@ -734,14 +509,9 @@ func Test_FileSecretsPod(t *testing.T) { lmevalRec := LMEvalJobReconciler{ Namespace: "test", options: &ServiceOptions{ - PodImage: "podimage:latest", - DriverImage: "driver:latest", - ImagePullPolicy: corev1.PullAlways, - GrpcPort: 8088, - GrpcService: "grpc-service", - GrpcServerSecret: "server-secret", - GrpcClientSecret: "client-secret", - grpcTLSMode: TLSMode_mTLS, + PodImage: "podimage:latest", + DriverImage: "driver:latest", + ImagePullPolicy: corev1.PullAlways, }, } var job = &lmesv1alpha1.LMEvalJob{ @@ -842,22 +612,8 @@ func Test_FileSecretsPod(t *testing.T) { Name: "main", Image: lmevalRec.options.PodImage, ImagePullPolicy: lmevalRec.options.ImagePullPolicy, - Env: []corev1.EnvVar{ - { - Name: "GRPC_CLIENT_KEY", - Value: "/tmp/k8s-grpc-client/certs/tls.key", - }, - { - Name: "GRPC_CLIENT_CERT", - Value: "/tmp/k8s-grpc-client/certs/tls.crt", - }, - { - Name: "GRPC_SERVER_CA", - Value: "/tmp/k8s-grpc-server/certs/ca.crt", - }, - }, - Command: lmevalRec.generateCmd(job), - Args: lmevalRec.generateArgs(job, log), + Command: lmevalRec.generateCmd(job), + Args: lmevalRec.generateArgs(job, log), SecurityContext: &corev1.SecurityContext{ AllowPrivilegeEscalation: &allowPrivilegeEscalation, RunAsUser: &runAsUser, @@ -872,16 +628,6 @@ func Test_FileSecretsPod(t *testing.T) { Name: "shared", MountPath: "/opt/app-root/src/bin", }, - { - Name: "client-cert", - MountPath: "/tmp/k8s-grpc-client/certs", - ReadOnly: true, - }, - { - Name: "server-cert", - MountPath: "/tmp/k8s-grpc-server/certs", - ReadOnly: true, - }, { Name: "secVol1", MountPath: "the_path", @@ -902,27 +648,6 @@ func Test_FileSecretsPod(t *testing.T) { EmptyDir: &corev1.EmptyDirVolumeSource{}, }, }, - { - Name: "client-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: lmevalRec.options.GrpcClientSecret, - DefaultMode: &secretMode, - }, - }, - }, - { - Name: "server-cert", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: lmevalRec.options.GrpcServerSecret, - DefaultMode: &secretMode, - Items: []corev1.KeyToPath{ - {Key: "ca.crt", Path: "ca.crt"}, - }, - }, - }, - }, { Name: "secVol1", VolumeSource: corev1.VolumeSource{ @@ -955,9 +680,6 @@ func Test_GenerateArgBatchSize(t *testing.T) { PodImage: "podimage:latest", DriverImage: "driver:latest", ImagePullPolicy: corev1.PullAlways, - GrpcPort: 8088, - GrpcService: "grpc-service", - grpcTLSMode: TLSMode_None, MaxBatchSize: 24, DefaultBatchSize: 8, }, @@ -1014,9 +736,6 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { PodImage: "podimage:latest", DriverImage: "driver:latest", ImagePullPolicy: corev1.PullAlways, - GrpcPort: 8088, - GrpcService: "grpc-service", - grpcTLSMode: TLSMode_None, DefaultBatchSize: DefaultBatchSize, MaxBatchSize: DefaultMaxBatchSize, }, @@ -1063,12 +782,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", - "--job-namespace", "default", - "--job-name", "test", - "--grpc-service", "grpc-service.test.svc", - "--grpc-port", "8088", "--output-path", "/opt/app-root/src/output", - "--report-interval", "0s", "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", }, lmevalRec.generateCmd(job)) @@ -1093,12 +807,7 @@ func Test_GenerateArgCmdTaskRecipes(t *testing.T) { assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", - "--job-namespace", "default", - "--job-name", "test", - "--grpc-service", "grpc-service.test.svc", - "--grpc-port", "8088", "--output-path", "/opt/app-root/src/output", - "--report-interval", "0s", "--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10", "--", @@ -1113,9 +822,6 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { PodImage: "podimage:latest", DriverImage: "driver:latest", ImagePullPolicy: corev1.PullAlways, - GrpcPort: 8088, - GrpcService: "grpc-service", - grpcTLSMode: TLSMode_None, DefaultBatchSize: DefaultBatchSize, MaxBatchSize: DefaultMaxBatchSize, }, @@ -1163,12 +869,7 @@ func Test_GenerateArgCmdCustomCard(t *testing.T) { assert.Equal(t, []string{ "/opt/app-root/src/bin/driver", - "--job-namespace", "default", - "--job-name", "test", - "--grpc-service", "grpc-service.test.svc", - "--grpc-port", "8088", "--output-path", "/opt/app-root/src/output", - "--report-interval", "0s", "--task-recipe", "card=cards.custom_0,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10", "--custom-card", `{ "__type__": "task_card", "loader": { "__type__": "load_hf", "path": "wmt16", "name": "de-en" }, "preprocess_steps": [ { "__type__": "copy", "field": "translation/en", "to_field": "text" }, { "__type__": "copy", "field": "translation/de", "to_field": "translation" }, { "__type__": "set", "fields": { "source_language": "english", "target_language": "deutch" } } ], "task": "tasks.translation.directed", "templates": "templates.translation.directed.all" }`, "--", @@ -1183,8 +884,6 @@ func Test_CustomCardValidation(t *testing.T) { PodImage: "podimage:latest", DriverImage: "driver:latest", ImagePullPolicy: corev1.PullAlways, - GrpcPort: 8088, - GrpcService: "grpc-service", }, } var job = &lmesv1alpha1.LMEvalJob{ diff --git a/go.mod b/go.mod index 9e8ed7c..2bd5302 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/moby/spdystream v0.2.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect diff --git a/go.sum b/go.sum index c99eba9..4075533 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,7 @@ github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbt github.com/PuerkitoBio/urlesc v0.0.0-20160726150825-5bd2802263f2/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/aws/aws-sdk-go v1.44.264 h1:5klL62ebn6uv3oJ0ixF7K12hKItj8lV3QqWeQPlkFSs= @@ -252,6 +253,7 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/moby/spdystream v0.2.0 h1:cjW1zVyyoiM0T7b6UoySUFqzXMoqRckQtXwGPiBhOM8= github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=