From 899bfb0ec5c14f37432ef1fb4246c8126fd46612 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Wed, 21 Aug 2024 12:54:52 +0300 Subject: [PATCH] COCOS-151 - Add compression/decompression option for CLI/Agent (#200) * on the fly compression Signed-off-by: Sammy Oina * rename file-hash to checksum Signed-off-by: Sammy Oina * check error properly Signed-off-by: Sammy Oina * fix lint Signed-off-by: Sammy Oina * fix connection handling Signed-off-by: Sammy Oina --------- Signed-off-by: Sammy Oina --- agent/algorithm/results.go | 59 ----------------- agent/algorithm/results_test.go | 3 +- agent/computations.go | 17 +++++ agent/service.go | 30 +++++---- cli/README.md | 32 ++++++++-- cli/{file_hash.go => checksum.go} | 18 ++---- cli/datasets.go | 39 ++++++++++-- internal/file.go | 37 +++++++++++ internal/zip.go | 102 ++++++++++++++++++++++++++++++ pkg/sdk/agent.go | 5 +- test/computations/main.go | 19 +++--- test/manual/agent-config/main.go | 58 +++++++++-------- 12 files changed, 287 insertions(+), 132 deletions(-) delete mode 100644 agent/algorithm/results.go rename cli/{file_hash.go => checksum.go} (57%) create mode 100644 internal/zip.go diff --git a/agent/algorithm/results.go b/agent/algorithm/results.go deleted file mode 100644 index 114cfe05..00000000 --- a/agent/algorithm/results.go +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Ultraviolet -// SPDX-License-Identifier: Apache-2.0 -package algorithm - -import ( - "archive/zip" - "bytes" - "fmt" - "io" - "os" - "path/filepath" -) - -// ZipDirectory zips a directory and returns the zipped bytes. -func ZipDirectory() ([]byte, error) { - buf := new(bytes.Buffer) - zipWriter := zip.NewWriter(buf) - - err := filepath.Walk(ResultsDir, func(path string, info os.FileInfo, err error) error { - if err != nil { - return fmt.Errorf("error walking the path %q: %v", path, err) - } - - if info.IsDir() { - return nil - } - - relPath, err := filepath.Rel(ResultsDir, path) - if err != nil { - return fmt.Errorf("error getting relative path for %q: %v", path, err) - } - - file, err := os.Open(path) - if err != nil { - return fmt.Errorf("error opening file %q: %v", path, err) - } - defer file.Close() - - zipFile, err := zipWriter.Create(relPath) - if err != nil { - return fmt.Errorf("error creating zip file for %q: %v", path, err) - } - - if _, err = io.Copy(zipFile, file); err != nil { - return fmt.Errorf("error copying file %q to zip: %v", path, err) - } - - return err - }) - if err != nil { - return nil, err - } - - if err = zipWriter.Close(); err != nil { - return nil, err - } - - return buf.Bytes(), nil -} diff --git a/agent/algorithm/results_test.go b/agent/algorithm/results_test.go index c9745624..594086ab 100644 --- a/agent/algorithm/results_test.go +++ b/agent/algorithm/results_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/internal" ) func TestZipDirectory(t *testing.T) { @@ -73,7 +74,7 @@ func TestZipDirectory(t *testing.T) { } } - if _, err := algorithm.ZipDirectory(); err != nil { + if _, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir); err != nil { t.Errorf("ZipDirectory() error = %v", err) } }) diff --git a/agent/computations.go b/agent/computations.go index 79d30a8c..88f2daf0 100644 --- a/agent/computations.go +++ b/agent/computations.go @@ -6,6 +6,8 @@ import ( "context" "encoding/json" "fmt" + + "google.golang.org/grpc/metadata" ) var _ fmt.Stringer = (*Datasets)(nil) @@ -69,3 +71,18 @@ func IndexFromContext(ctx context.Context) (int, bool) { index, ok := ctx.Value(ManifestIndexKey{}).(int) return index, ok } + +const DecompressKey = "decompress" + +func DecompressFromContext(ctx context.Context) bool { + vals := metadata.ValueFromIncomingContext(ctx, DecompressKey) + if len(vals) == 0 { + return false + } + + return vals[0] == "true" +} + +func DecompressToContext(ctx context.Context, decompress bool) context.Context { + return metadata.AppendToOutgoingContext(ctx, DecompressKey, fmt.Sprintf("%t", decompress)) +} diff --git a/agent/service.go b/agent/service.go index e8463033..0f1211ef 100644 --- a/agent/service.go +++ b/agent/service.go @@ -18,6 +18,7 @@ import ( "github.com/ultravioletrs/cocos/agent/algorithm/python" "github.com/ultravioletrs/cocos/agent/algorithm/wasm" "github.com/ultravioletrs/cocos/agent/events" + "github.com/ultravioletrs/cocos/internal" "golang.org/x/crypto/sha3" ) @@ -191,16 +192,22 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { as.computation.Datasets = slices.Delete(as.computation.Datasets, i, i+1) - f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, dataset.Filename)) - if err != nil { - return fmt.Errorf("error creating dataset file: %v", err) - } - - if _, err := f.Write(dataset.Dataset); err != nil { - return fmt.Errorf("error writing dataset to file: %v", err) - } - if err := f.Close(); err != nil { - return fmt.Errorf("error closing file: %v", err) + if DecompressFromContext(ctx) { + if err := internal.UnzipFromMemory(dataset.Dataset, algorithm.DatasetsDir); err != nil { + return fmt.Errorf("error decompressing dataset: %v", err) + } + } else { + f, err := os.Create(fmt.Sprintf("%s/%s", algorithm.DatasetsDir, dataset.Filename)) + if err != nil { + return fmt.Errorf("error creating dataset file: %v", err) + } + + if _, err := f.Write(dataset.Dataset); err != nil { + return fmt.Errorf("error writing dataset to file: %v", err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("error closing file: %v", err) + } } matched = true @@ -212,7 +219,6 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { return ErrUndeclaredDataset } - // Check if all datasets have been received if len(as.computation.Datasets) == 0 { as.sm.SendEvent(dataReceived) } @@ -288,7 +294,7 @@ func (as *agentService) runComputation() { return } - results, err := algorithm.ZipDirectory() + results, err := internal.ZipDirectoryToMemory(algorithm.ResultsDir) if err != nil { as.runError = err as.sm.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error())) diff --git a/cli/README.md b/cli/README.md index 2fb59f1b..14d1d419 100644 --- a/cli/README.md +++ b/cli/README.md @@ -16,14 +16,14 @@ make cli Retrieves attestation information from the SEV guest and saves it to a file. To retrieve attestation from agent, use the following command: ```bash -./build/cocos-cli agent attestation get '' +./build/cocos-cli attestation get '' ``` #### Validate attestation Validates the retrieved attestation information against a specified policy and checks its authenticity. To validate and verify attestation from agent, use the following command: ```bash -./build/cocos-cli agent attestation validate '' --report_data '' +./build/cocos-cli attestation validate '' --report_data '' ``` ##### Flags - --config: Path to a JSON file containing the validation configuration. This can be used to override individual flags. @@ -62,21 +62,41 @@ To validate and verify attestation from agent, use the following command: To upload an algorithm, use the following command: ```bash -./build/cocos-cli agent algo /path/to/algorithm +./build/cocos-cli algo /path/to/algorithm ``` +##### Flags +- -a, --algorithm string Algorithm type to run (default "bin") +- --python-runtime string Python runtime to use (default "python3") +- -r, --requirements string Python requirements file + + #### Upload Dataset To upload a dataset, use the following command: ```bash -./build/cocos-cli agent data /path/to/dataset.csv +./build/cocos-cli data /path/to/dataset.csv ``` +Users can also upload directories which will be compressed on transit. Once received by agent they will be stored as compressed files or decompressed if the user passed the decompression argument. + +##### Flags +- -d, --decompress Decompress the dataset on agent + + + #### Retrieve result To retrieve the computation result, use the following command: ```bash -./build/cocos-cli agent result -``` \ No newline at end of file +./build/cocos-cli result +``` + +#### Checksum +When defining the manifest dataset and algorithm checksums are required. This can be done as below: + +```bash +./build/cocos-cli checksum +``` diff --git a/cli/file_hash.go b/cli/checksum.go similarity index 57% rename from cli/file_hash.go rename to cli/checksum.go index e786c264..c9b2cb0c 100644 --- a/cli/file_hash.go +++ b/cli/checksum.go @@ -3,32 +3,26 @@ package cli import ( - "encoding/hex" "log" - "os" "github.com/spf13/cobra" - "golang.org/x/crypto/sha3" + "github.com/ultravioletrs/cocos/internal" ) func (cli *CLI) NewFileHashCmd() *cobra.Command { return &cobra.Command{ - Use: "file-hash", + Use: "checksum", Short: "Compute the sha3-256 hash of a file", - Example: "file-hash ", + Example: "checksum ", Args: cobra.ExactArgs(1), Run: func(cmd *cobra.Command, args []string) { - fileName := args[0] + path := args[0] - file, err := os.ReadFile(fileName) + hash, err := internal.ChecksumHex(path) if err != nil { - log.Fatalf("Error reading dataset file: %v", err) + log.Fatalf("Error computing hash: %v", err) } - hashBytes := sha3.Sum256(file) - - hash := hex.EncodeToString(hashBytes[:]) - log.Println("Hash of file:", hash) }, } diff --git a/cli/datasets.go b/cli/datasets.go index 1bbe4058..875f1aa2 100644 --- a/cli/datasets.go +++ b/cli/datasets.go @@ -3,6 +3,7 @@ package cli import ( + "context" "crypto/x509" "encoding/pem" "log" @@ -11,27 +12,45 @@ import ( "github.com/spf13/cobra" "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/internal" + "google.golang.org/grpc/metadata" ) +var decompressDataset bool + func (cli *CLI) NewDatasetsCmd() *cobra.Command { - return &cobra.Command{ + cmd := &cobra.Command{ Use: "data", Short: "Upload a dataset", Example: "data ", Args: cobra.ExactArgs(2), Run: func(cmd *cobra.Command, args []string) { - datasetFile := args[0] + datasetPath := args[0] - log.Println("Uploading dataset:", datasetFile) + log.Println("Uploading dataset:", datasetPath) - dataset, err := os.ReadFile(datasetFile) + f, err := os.Stat(datasetPath) if err != nil { log.Fatalf("Error reading dataset file: %v", err) } + var dataset []byte + + if f.IsDir() { + dataset, err = internal.ZipDirectoryToMemory(datasetPath) + if err != nil { + log.Fatalf("Error zipping dataset directory: %v", err) + } + } else { + dataset, err = os.ReadFile(datasetPath) + if err != nil { + log.Fatalf("Error reading dataset file: %v", err) + } + } + dataReq := agent.Dataset{ Dataset: dataset, - Filename: path.Base(datasetFile), + Filename: path.Base(datasetPath), } privKeyFile, err := os.ReadFile(args[1]) @@ -43,13 +62,17 @@ func (cli *CLI) NewDatasetsCmd() *cobra.Command { privKey := decodeKey(pemBlock) - if err := cli.agentSDK.Data(cmd.Context(), dataReq, privKey); err != nil { + ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) + if err := cli.agentSDK.Data(addDatasetMetadata(ctx), dataReq, privKey); err != nil { log.Fatalf("Error uploading dataset: %v", err) } log.Println("Successfully uploaded dataset") }, } + + cmd.Flags().BoolVarP(&decompressDataset, "decompress", "d", false, "Decompress the dataset on agent") + return cmd } func decodeKey(b *pem.Block) interface{} { @@ -74,3 +97,7 @@ func decodeKey(b *pem.Block) interface{} { return nil } } + +func addDatasetMetadata(ctx context.Context) context.Context { + return agent.DecompressToContext(ctx, decompressDataset) +} diff --git a/internal/file.go b/internal/file.go index 50416f51..100dfbb9 100644 --- a/internal/file.go +++ b/internal/file.go @@ -3,9 +3,12 @@ package internal import ( + "encoding/hex" "io" "os" "path/filepath" + + "golang.org/x/crypto/sha3" ) // CopyFile copies a file from srcPath to dstPath. @@ -46,3 +49,37 @@ func DeleteFilesInDir(dirPath string) error { return nil } + +// Checksum calculates the SHA3-256 checksum of the file or directory at path. +func Checksum(path string) ([]byte, error) { + file, err := os.Stat(path) + if err != nil { + return nil, err + } + + if file.IsDir() { + f, err := ZipDirectoryToMemory(path) + if err != nil { + return nil, err + } + sum := sha3.Sum256(f) + return sum[:], nil + } else { + f, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + sum := sha3.Sum256(f) + return sum[:], nil + } +} + +// ChecksumHex calculates the SHA3-256 checksum of the file or directory at path and returns it as a hex-encoded string. +func ChecksumHex(path string) (string, error) { + sum, err := Checksum(path) + if err != nil { + return "", err + } + return hex.EncodeToString(sum), nil +} diff --git a/internal/zip.go b/internal/zip.go new file mode 100644 index 00000000..25bce054 --- /dev/null +++ b/internal/zip.go @@ -0,0 +1,102 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package internal + +import ( + "archive/zip" + "bytes" + "io" + "os" + "path/filepath" +) + +func ZipDirectoryToMemory(sourceDir string) ([]byte, error) { + buf := new(bytes.Buffer) + zipWriter := zip.NewWriter(buf) + + err := filepath.Walk(sourceDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if info.IsDir() { + return nil + } + + relPath, err := filepath.Rel(sourceDir, path) + if err != nil { + return err + } + + zipHeader, err := zip.FileInfoHeader(info) + if err != nil { + return err + } + zipHeader.Name = relPath + + zipWriterEntry, err := zipWriter.CreateHeader(zipHeader) + if err != nil { + return err + } + + fileToZip, err := os.Open(path) + if err != nil { + return err + } + defer fileToZip.Close() + + _, err = io.Copy(zipWriterEntry, fileToZip) + return err + }) + if err != nil { + zipWriter.Close() + return nil, err + } + + if err := zipWriter.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func UnzipFromMemory(zipData []byte, targetDir string) error { + reader := bytes.NewReader(zipData) + zipReader, err := zip.NewReader(reader, int64(len(zipData))) + if err != nil { + return err + } + + for _, file := range zipReader.File { + filePath := filepath.Join(targetDir, file.Name) + + if file.FileInfo().IsDir() { + if err := os.MkdirAll(filePath, os.ModePerm); err != nil { + return err + } + continue + } + + if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil { + return err + } + + srcFile, err := file.Open() + if err != nil { + return err + } + defer srcFile.Close() + + dstFile, err := os.Create(filePath) + if err != nil { + return err + } + defer dstFile.Close() + + if _, err := io.Copy(dstFile, srcFile); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sdk/agent.go b/pkg/sdk/agent.go index fd2d75a8..a1ad99de 100644 --- a/pkg/sdk/agent.go +++ b/pkg/sdk/agent.go @@ -82,7 +82,10 @@ func (sdk *agentSDK) Data(ctx context.Context, dataset agent.Dataset, privKey an return err } - ctx = metadata.NewOutgoingContext(ctx, md) + for k, v := range md { + ctx = metadata.AppendToOutgoingContext(ctx, k, v[0]) + } + stream, err := sdk.client.Data(ctx) if err != nil { sdk.logger.Error("Failed to call Data RPC") diff --git a/test/computations/main.go b/test/computations/main.go index e3006451..bcb422ee 100644 --- a/test/computations/main.go +++ b/test/computations/main.go @@ -13,11 +13,11 @@ import ( mglog "github.com/absmach/magistrala/logger" "github.com/caarlos0/env/v11" + "github.com/ultravioletrs/cocos/internal" "github.com/ultravioletrs/cocos/internal/server" grpcserver "github.com/ultravioletrs/cocos/internal/server/grpc" managergrpc "github.com/ultravioletrs/cocos/manager/api/grpc" "github.com/ultravioletrs/cocos/pkg/manager" - "golang.org/x/crypto/sha3" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -44,11 +44,6 @@ type svc struct { func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, auth credentials.AuthInfo) { s.logger.Debug(fmt.Sprintf("received who am on ip address %s", ipAdress)) - algo, err := os.ReadFile(algoPath) - if err != nil { - s.logger.Error(fmt.Sprintf("failed to read algorithm file: %s", err)) - return - } pubKey, err := os.ReadFile(pubKeyFile) if err != nil { @@ -63,16 +58,20 @@ func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, au s.logger.Error(fmt.Sprintf("data file does not exist: %s", dataPath)) return } - data, err := os.ReadFile(dataPath) + dataHash, err := internal.Checksum(dataPath) if err != nil { - s.logger.Error(fmt.Sprintf("failed to read data file: %s", err)) + s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err)) return } - dataHash := sha3.Sum256(data) + datasets = append(datasets, &manager.Dataset{Hash: dataHash[:], UserKey: pubPem.Bytes}) } - algoHash := sha3.Sum256(algo) + algoHash, err := internal.Checksum(algoPath) + if err != nil { + s.logger.Error(fmt.Sprintf("failed to calculate checksum: %s", err)) + return + } reqChan <- &manager.ServerStreamMessage{ Message: &manager.ServerStreamMessage_RunReq{ diff --git a/test/manual/agent-config/main.go b/test/manual/agent-config/main.go index 27ecd1fe..4a286891 100644 --- a/test/manual/agent-config/main.go +++ b/test/manual/agent-config/main.go @@ -10,15 +10,16 @@ import ( "encoding/pem" "fmt" "log" + "net" "os" "strconv" "github.com/mdlayher/vsock" "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/internal" "github.com/ultravioletrs/cocos/manager" "github.com/ultravioletrs/cocos/manager/qemu" pkgmanager "github.com/ultravioletrs/cocos/pkg/manager" - "golang.org/x/crypto/sha3" "google.golang.org/protobuf/proto" ) @@ -35,21 +36,19 @@ func main() { } attestedTLS := attestedTLSParam - algo, err := os.ReadFile(algoPath) + pubKey, err := os.ReadFile(pubKeyFile) if err != nil { - log.Fatalf(fmt.Sprintf("failed to read algorithm file: %s", err)) + log.Fatalf(fmt.Sprintf("failed to read public key file: %s", err)) } - data, err := os.ReadFile(dataPath) + pubPem, _ := pem.Decode(pubKey) + algoHash, err := internal.Checksum(algoPath) if err != nil { - log.Fatalf(fmt.Sprintf("failed to read data file: %s", err)) + log.Fatalf(fmt.Sprintf("failed to calculate checksum: %s", err)) } - pubKey, err := os.ReadFile(pubKeyFile) + dataHash, err := internal.Checksum(dataPath) if err != nil { - log.Fatalf(fmt.Sprintf("failed to read public key file: %s", err)) + log.Fatalf(fmt.Sprintf("failed to calculate checksum: %s", err)) } - pubPem, _ := pem.Decode(pubKey) - algoHash := sha3.Sum256(algo) - dataHash := sha3.Sum256(data) l, err := vsock.Listen(manager.ManagerVsockPort, nil) if err != nil { @@ -57,8 +56,8 @@ func main() { } ac := agent.Computation{ ID: "123", - Datasets: agent.Datasets{agent.Dataset{Hash: dataHash, UserKey: pubPem.Bytes}}, - Algorithm: agent.Algorithm{Hash: algoHash, UserKey: pubPem.Bytes}, + Datasets: agent.Datasets{agent.Dataset{Hash: [32]byte(dataHash), UserKey: pubPem.Bytes}}, + Algorithm: agent.Algorithm{Hash: [32]byte(algoHash), UserKey: pubPem.Bytes}, ResultConsumers: []agent.ResultConsumer{{UserKey: pubPem.Bytes}}, AgentConfig: agent.AgentConfig{ LogLevel: "debug", @@ -66,7 +65,9 @@ func main() { AttestedTls: attestedTLS, }, } - fmt.Println(SendAgentConfig(3, ac)) + if err := SendAgentConfig(3, ac); err != nil { + log.Fatal(err) + } for { conn, err := l.Accept() @@ -74,18 +75,7 @@ func main() { log.Println(err) continue } - b := make([]byte, 1024) - n, err := conn.Read(b) - if err != nil { - log.Println(err) - continue - } - conn.Close() - var mes pkgmanager.ClientStreamMessage - if err := proto.Unmarshal(b[:n], &mes); err != nil { - log.Println(err) - } - fmt.Println(mes.String()) + go handleConnections(conn) } } @@ -109,3 +99,21 @@ func SendAgentConfig(cid uint32, ac agent.Computation) error { } return nil } + +func handleConnections(conn net.Conn) { + defer conn.Close() + for { + b := make([]byte, 1024) + n, err := conn.Read(b) + if err != nil { + log.Println(err) + return + } + var message pkgmanager.ClientStreamMessage + if err := proto.Unmarshal(b[:n], &message); err != nil { + log.Println(err) + return + } + fmt.Println(message.String()) + } +}