diff --git a/data/common.go b/data/common.go index f5fe830..090eee4 100644 --- a/data/common.go +++ b/data/common.go @@ -1,8 +1,9 @@ // This module contains Flyte CoPilot related code. // Currently it only has 2 utilities - downloader and an uploader. // Usage Downloader: -// downloader := NewDownloader(...) -// downloader.DownloadInputs(...) // will recursively download all inputs +// +// downloader := NewDownloader(...) +// downloader.DownloadInputs(...) // will recursively download all inputs // // Usage uploader: // uploader := NewUploader(...) diff --git a/data/upload.go b/data/upload.go index 34049c1..adb0c3d 100644 --- a/data/upload.go +++ b/data/upload.go @@ -9,6 +9,7 @@ import ( "path" "path/filepath" "reflect" + "strings" "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" "github.com/flyteorg/flytestdlib/futures" @@ -56,6 +57,39 @@ func (u Uploader) handleSimpleType(_ context.Context, t core.SimpleType, filePat return coreutils.MakeLiteralForSimpleType(t, string(b)) } +func (u Uploader) handleCollectionType(_ context.Context, t *core.LiteralType, filePath string) (*core.Literal, error) { + fpath, info, err := IsFileReadable(filePath, true) + if err != nil { + return nil, err + } + if info.IsDir() { + return nil, fmt.Errorf("expected file for type [%s], found dir at path [%s]", t.String(), filePath) + } + if info.Size() > maxPrimitiveSize { + return nil, fmt.Errorf("maximum allowed filesize is [%d], but found [%d]", maxPrimitiveSize, info.Size()) + } + b, err := ioutil.ReadFile(fpath) + if err != nil { + return nil, err + } + literalString := strings.Split(strings.ReplaceAll(string(b), " ", ""), ",") + literals := make([]*core.Literal, 0, len(literalString)) + for _, val := range literalString { + lv, err := coreutils.MakeLiteralForType(t.GetCollectionType(), val) + if err != nil { + return nil, err + } + literals = append(literals, lv) + } + res := &core.Literal{} + res.Value = &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: literals, + }, + } + return res, nil +} + func (u Uploader) handleBlobType(ctx context.Context, localPath string, toPath storage.DataReference) (*core.Literal, error) { fpath, info, err := IsFileReadable(localPath, true) if err != nil { @@ -158,6 +192,10 @@ func (u Uploader) RecursiveUpload(ctx context.Context, vars *core.VariableMap, f varFutures[varName] = futures.NewAsyncFuture(childCtx, func(ctx2 context.Context) (interface{}, error) { return u.handleSimpleType(ctx2, varType.GetSimple(), varPath) }) + case *core.LiteralType_CollectionType: + varFutures[varName] = futures.NewAsyncFuture(childCtx, func(ctx2 context.Context) (interface{}, error) { + return u.handleCollectionType(ctx2, varType, varPath) + }) default: return fmt.Errorf("currently CoPilot uploader does not support [%s], system error", varType) } diff --git a/data/upload_test.go b/data/upload_test.go index 2e3c84e..1e781ec 100644 --- a/data/upload_test.go +++ b/data/upload_test.go @@ -60,4 +60,41 @@ func TestUploader_RecursiveUpload(t *testing.T) { assert.NoError(t, err) assert.Equal(t, string(data), string(b), "content dont match") }) + + t.Run("upload-collection", func(t *testing.T) { + tmpDir, err := ioutil.TempDir(tmpFolderLocation, tmpPrefix) + assert.NoError(t, err) + defer func() { + assert.NoError(t, os.RemoveAll(tmpDir)) + }() + + lt := core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}} + vmap := &core.VariableMap{ + Variables: map[string]*core.Variable{ + "y": { + Type: &core.LiteralType{Type: &core.LiteralType_CollectionType{CollectionType: <}}, + }, + }, + } + + data := []byte("1, 2, 3, 4") + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "y"), data, os.ModePerm)) + fmt.Printf("Written to %s ", path.Join(tmpDir, "y")) + + store, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + outputRef := storage.DataReference("output") + rawRef := storage.DataReference("raw") + u := NewUploader(context.TODO(), store, core.DataLoadingConfig_JSON, core.IOStrategy_UPLOAD_ON_EXIT, "error") + assert.NoError(t, u.RecursiveUpload(context.TODO(), vmap, tmpDir, outputRef, rawRef)) + + outputs := &core.LiteralMap{} + assert.NoError(t, store.ReadProtobuf(context.TODO(), outputRef, outputs)) + assert.Len(t, outputs.Literals, 1) + assert.NotNil(t, outputs.Literals["y"]) + assert.NotNil(t, outputs.Literals["y"].GetCollection()) + assert.NotNil(t, outputs.Literals["y"].GetCollection().GetLiterals()) + assert.NotNil(t, outputs.Literals["y"].GetCollection().GetLiterals()[0].GetScalar()) + }) }