Skip to content

Commit

Permalink
feat: support unitxt recipes
Browse files Browse the repository at this point in the history
Add new fields in the CRD to support unitxt recipes and
leverage the driver to create corresponding yaml files
of the unitxt recipes.

Signed-off-by: Yihong Wang <[email protected]>
  • Loading branch information
yhwang committed Sep 23, 2024
1 parent a626cf8 commit cf23c97
Show file tree
Hide file tree
Showing 11 changed files with 581 additions and 69 deletions.
10 changes: 6 additions & 4 deletions Dockerfile.lmes-job
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ USER default
WORKDIR /opt/app-root/src
RUN mkdir /opt/app-root/src/hf_home && chmod g+rwx /opt/app-root/src/hf_home
RUN mkdir /opt/app-root/src/output && chmod g+rwx /opt/app-root/src/output
RUN mkdir /opt/app-root/src/my_tasks && chmod g+rwx /opt/app-root/src/my_tasks
RUN mkdir /opt/app-root/src/.cache
ENV PATH="/opt/app-root/bin:/opt/app-root/src/.local/bin/:/opt/app-root/src/bin:/usr/local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"

RUN pip install --no-cache-dir --user --upgrade ibm-generative-ai[lm-eval]
COPY --chown=1001:0 patch /opt/app-root/src/patch
# Clone the Git repository and install the Python package
# Clone the Git repository, check out v0.4.4 and install the Python package
RUN git clone https://github.com/opendatahub-io/lm-evaluation-harness.git && \
cd lm-evaluation-harness && git checkout 568af943e315100af3f00937bfd6947844769ab8 && \
cd lm-evaluation-harness && git checkout 543617fef9ba885e87f8db8930fbbff1d4e2ca49 && \
curl --output lm_eval/models/bam.py https://raw.githubusercontent.com/IBM/ibm-generative-ai/main/src/genai/extensions/lm_eval/model.py && \
git apply /opt/app-root/src/patch/lmes/models.patch && pip install --no-cache-dir --user -e .[unitxt] && \
pip install --no-cache-dir --user -e .[openai]
git apply /opt/app-root/src/patch/lmes/models.patch && pip install --no-cache-dir --user -e .[api]

RUN python -c 'from lm_eval.tasks.unitxt import task; import os.path; print("class: !function " + task.__file__.replace("task.py", "task.Unitxt"))' > ./my_tasks/unitxt

ENV PYTHONPATH=/opt/app-root/src/.local/lib/python3.11/site-packages:/opt/app-root/src/lm-evaluation-harness:/opt/app-root/src:/opt/app-root/src/server
ENV HF_HOME=/opt/app-root/src/hf_home
Expand Down
66 changes: 64 additions & 2 deletions api/lmes/v1alpha1/lmevaljob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ limitations under the License.
package v1alpha1

import (
"fmt"
"strings"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
Expand Down Expand Up @@ -78,6 +81,65 @@ type FileSecret struct {
MountPath string `json:"mountPath"`
}

// Use a task recipe to form a custom task. It maps to the Unitxt Recipe
// Find details of the Unitxt Recipe here:
// https://www.unitxt.ai/en/latest/unitxt.standard.html#unitxt.standard.StandardRecipe
type TaskRecipe struct {
// The Unitxt dataset card
Card string `json:"card"`
// The Unitxt template
Template string `json:"template"`
// The Unitxt Task
// +optional
Task *string `json:"task,omitempty"`
// Metrics
// +optional
Metrics []string `json:"metrics,omitempty"`
// The Unitxt format
// +optional
Format *string `json:"format,omitempty"`
// A limit number of records to load
// +optional
LoaderLimit *int `json:"loaderLimit,omitempty"`
// Number of fewshot
// +optional
NumDemos *int `json:"numDemos,omitempty"`
// The pool size for the fewshot
// +optional
DemosPoolSize *int `json:"demosPoolSize,omitempty"`
}

type TaskList struct {
// TaskNames from lm-eval's task list
TaskNames []string `json:"taskNames,omitempty"`
// Task Recipes specifically for Unitxt
TaskRecipes []TaskRecipe `json:"taskRecipes,omitempty"`
}

func (t *TaskRecipe) String() string {
var b strings.Builder
b.WriteString(fmt.Sprintf("card=%s,template=%s", t.Card, t.Template))
if t.Task != nil {
b.WriteString(fmt.Sprintf(",task=%s", *t.Task))
}
if len(t.Metrics) > 0 {
b.WriteString(fmt.Sprintf(",metrics=[%s]", strings.Join(t.Metrics, ",")))
}
if t.Format != nil {
b.WriteString(fmt.Sprintf(",format=%s", *t.Format))
}
if t.LoaderLimit != nil {
b.WriteString(fmt.Sprintf(",loader_limit=%d", *t.LoaderLimit))
}
if t.NumDemos != nil {
b.WriteString(fmt.Sprintf(",num_demos=%d", *t.NumDemos))
}
if t.DemosPoolSize != nil {
b.WriteString(fmt.Sprintf(",demos_pool_size=%d", *t.DemosPoolSize))
}
return b.String()
}

// LMEvalJobSpec defines the desired state of LMEvalJob
type LMEvalJobSpec struct {
// INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
Expand All @@ -88,8 +150,8 @@ type LMEvalJobSpec struct {
// Args for the model
// +optional
ModelArgs []Arg `json:"modelArgs,omitempty"`
// Evaluation tasks
Tasks []string `json:"tasks"`
// Evaluation task list
TaskList TaskList `json:"taskList"`
// Sets the number of few-shot examples to place in context
// +optional
NumFewShot *int `json:"numFewShot,omitempty"`
Expand Down
78 changes: 73 additions & 5 deletions api/lmes/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 20 additions & 1 deletion cmd/lmes_driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"io"
"os"
"strings"
"time"

ctrl "sigs.k8s.io/controller-runtime"
Expand All @@ -36,18 +37,35 @@ const (
OutputPath = "/opt/app-root/src/output"
)

type taskRecipeArg []string

func (t *taskRecipeArg) Set(value string) error {
*t = append(*t, value)
return nil
}

func (t *taskRecipeArg) String() string {
// supposedly, use ":" as the separator for task recipe should be safe
return strings.Join(*t, ":")
}

var (
taskRecipes taskRecipeArg
copy = flag.String("copy", "", "copy this binary to specified destination path")
jobNameSpace = flag.String("job-namespace", "", "Job's namespace ")
jobName = flag.String("job-name", "", "Job's name")
grpcService = flag.String("grpc-service", "", "grpc service name")
grpcPort = flag.Int("grpc-port", 8082, "grpc port")
outputPath = flag.String("output-path", OutputPath, "output path")
detectDevice = flag.Bool("detect-device", true, "detect available device(s), CUDA or CPU")
detectDevice = flag.Bool("detect-device", false, "detect available device(s), CUDA or CPU")
reportInterval = flag.Duration("report-interval", time.Second*10, "specify the druation interval to report the progress")
driverLog = ctrl.Log.WithName("driver")
)

func init() {
flag.Var(&taskRecipes, "task-recipe", "task recipe")
}

func main() {
opts := zap.Options{
Development: true,
Expand Down Expand Up @@ -86,6 +104,7 @@ func main() {
GrpcPort: *grpcPort,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
Args: args,
ReportInterval: *reportInterval,
}
Expand Down
90 changes: 90 additions & 0 deletions cmd/lmes_driver/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
Copyright 2024.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"context"
"flag"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/trustyai-explainability/trustyai-service-operator/controllers/lmes/driver"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
)

func Test_ArgParsing(t *testing.T) {
os.Args = []string{
"/opt/app-root/src/bin/driver",
"--job-namespace", "default",
"--job-name", "test",
"--grpc-service", "grpc-service.test.svc",
"--grpc-port", "8088",
"--output-path", "/opt/app-root/src/output",
"--detect-device",
"--report-interval", "10s",
"--task-recipe", "card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
"--task-recipe", "card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
"--",
"sh", "-c", "python",
}

opts := zap.Options{
Development: true,
}
opts.BindFlags(flag.CommandLine)

flag.Parse()

args := flag.Args()

assert.Equal(t, "default", *jobNameSpace)
assert.Equal(t, "test", *jobName)
assert.Equal(t, "grpc-service.test.svc", *grpcService)
assert.Equal(t, 8088, *grpcPort)
assert.Equal(t, "/opt/app-root/src/output", *outputPath)
assert.Equal(t, true, *detectDevice)
assert.Equal(t, time.Second*10, *reportInterval)
assert.Equal(t, taskRecipeArg{
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
"card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
}, taskRecipes)

dOption := driver.DriverOption{
Context: context.Background(),
JobNamespace: *jobNameSpace,
JobName: *jobName,
OutputPath: *outputPath,
GrpcService: *grpcService,
GrpcPort: *grpcPort,
DetectDevice: *detectDevice,
Logger: driverLog,
TaskRecipes: taskRecipes,
Args: args,
ReportInterval: *reportInterval,
}

assert.Equal(t, []string{
"card=unitxt.card1,template=unitxt.template,metrics=[unitxt.metric1,unitxt.metric2],format=unitxt.format,num_demos=5,demos_pool_size=10",
"card=unitxt.card2,template=unitxt.template2,metrics=[unitxt.metric3,unitxt.metric4],format=unitxt.format,num_demos=5,demos_pool_size=10",
}, dOption.TaskRecipes)

assert.Equal(t, []string{
"sh", "-c", "python",
}, dOption.Args)
}
Loading

0 comments on commit cf23c97

Please sign in to comment.