diff --git a/.github/workflows/checkproto.yaml b/.github/workflows/checkproto.yaml index 798625d5..51be8213 100644 --- a/.github/workflows/checkproto.yaml +++ b/.github/workflows/checkproto.yaml @@ -33,7 +33,7 @@ jobs: - name: Set up protoc run: | - PROTOC_VERSION=25.3 + PROTOC_VERSION=27.2 PROTOC_GEN_VERSION=v1.34.2 PROTOC_GRPC_VERSION=v1.4.0 diff --git a/.gitignore b/.gitignore index f34a3e2e..457ef1d8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,5 @@ cmd/manager/tmp *.pem dist/ -result.bin +result.zip *.spec diff --git a/agent/agent.pb.go b/agent/agent.pb.go index 492a75d4..959991e8 100644 --- a/agent/agent.pb.go +++ b/agent/agent.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.34.2 -// protoc v4.25.3 +// protoc v5.27.2 // source: agent/agent.proto package agent diff --git a/agent/agent_grpc.pb.go b/agent/agent_grpc.pb.go index 17358955..ac2bab2e 100644 --- a/agent/agent_grpc.pb.go +++ b/agent/agent_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.4.0 -// - protoc v4.25.3 +// - protoc v5.27.2 // source: agent/agent.proto package agent diff --git a/agent/algorithm/algorithm.go b/agent/algorithm/algorithm.go index 1e6bd9b0..9c4b5319 100644 --- a/agent/algorithm/algorithm.go +++ b/agent/algorithm/algorithm.go @@ -13,7 +13,11 @@ type AlgorithType string const ( AlgoTypeBin AlgorithType = "bin" AlgoTypePython AlgorithType = "python" + AlgoTypeWasm AlgorithType = "wasm" AlgoTypeKey = "algo_type" + + ResultsDir = "results" + DatasetsDir = "datasets" ) func AlgorithmTypeToContext(ctx context.Context, algoType string) context.Context { @@ -27,8 +31,5 @@ func AlgorithmTypeFromContext(ctx context.Context) string { // Algorithm is an interface that specifies the API for an algorithm. type Algorithm interface { // Run executes the algorithm and returns the result. - Run() ([]byte, error) - - // Add dataset to algorithm. - AddDataset(dataset string) + Run() error } diff --git a/agent/algorithm/binary/binary.go b/agent/algorithm/binary/binary.go index d487dd60..df6766c1 100644 --- a/agent/algorithm/binary/binary.go +++ b/agent/algorithm/binary/binary.go @@ -6,77 +6,40 @@ import ( "fmt" "io" "log/slog" - "os" "os/exec" "github.com/ultravioletrs/cocos/agent/algorithm" "github.com/ultravioletrs/cocos/agent/events" - "github.com/ultravioletrs/cocos/pkg/socket" ) -const socketPath = "unix_socket" - var _ algorithm.Algorithm = (*binary)(nil) type binary struct { algoFile string - datasets []string - logger *slog.Logger stderr io.Writer stdout io.Writer } -func New(logger *slog.Logger, eventsSvc events.Service, algoFile string) algorithm.Algorithm { +func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string) algorithm.Algorithm { return &binary{ algoFile: algoFile, - logger: logger, stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc}, stdout: &algorithm.Stdout{Logger: logger}, } } -func (b *binary) AddDataset(dataset string) { - b.datasets = append(b.datasets, dataset) -} - -func (b *binary) Run() ([]byte, error) { - defer os.Remove(b.algoFile) - defer func() { - for _, file := range b.datasets { - os.Remove(file) - } - }() - listener, err := socket.StartUnixSocketServer(socketPath) - if err != nil { - return nil, fmt.Errorf("error creating stdout pipe: %v", err) - } - defer listener.Close() - - // Create channels for received data and errors - dataChannel := make(chan []byte) - errorChannel := make(chan error) - - var result []byte - - go socket.AcceptConnection(listener, dataChannel, errorChannel) - - args := append([]string{socketPath}, b.datasets...) - cmd := exec.Command(b.algoFile, args...) +func (b *binary) Run() error { + cmd := exec.Command(b.algoFile) cmd.Stderr = b.stderr cmd.Stdout = b.stdout if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("error starting algorithm: %v", err) + return fmt.Errorf("error starting algorithm: %v", err) } if err := cmd.Wait(); err != nil { - return nil, fmt.Errorf("algorithm execution error: %v", err) + return fmt.Errorf("algorithm execution error: %v", err) } - select { - case result = <-dataChannel: - return result, nil - case err = <-errorChannel: - return nil, fmt.Errorf("error receiving data: %v", err) - } + return nil } diff --git a/agent/algorithm/python/python.go b/agent/algorithm/python/python.go index 0fd88313..a2a2afdb 100644 --- a/agent/algorithm/python/python.go +++ b/agent/algorithm/python/python.go @@ -13,12 +13,10 @@ import ( "github.com/ultravioletrs/cocos/agent/algorithm" "github.com/ultravioletrs/cocos/agent/events" - "github.com/ultravioletrs/cocos/pkg/socket" "google.golang.org/grpc/metadata" ) const ( - socketPath = "unix_socket" PyRuntime = "python3" pyRuntimeKey = "python_runtime" ) @@ -35,18 +33,15 @@ var _ algorithm.Algorithm = (*python)(nil) type python struct { algoFile string - datasets []string - logger *slog.Logger stderr io.Writer stdout io.Writer runtime string requirementsFile string } -func New(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string) algorithm.Algorithm { +func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string) algorithm.Algorithm { p := &python{ algoFile: algoFile, - logger: logger, stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc}, stdout: &algorithm.Stdout{Logger: logger}, requirementsFile: requirementsFile, @@ -59,17 +54,13 @@ func New(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFil return p } -func (p *python) AddDataset(dataset string) { - p.datasets = append(p.datasets, dataset) -} - -func (p *python) Run() ([]byte, error) { +func (p *python) Run() error { venvPath := "venv" createVenvCmd := exec.Command(p.runtime, "-m", "venv", venvPath) createVenvCmd.Stderr = p.stderr createVenvCmd.Stdout = p.stdout if err := createVenvCmd.Run(); err != nil { - return nil, fmt.Errorf("error creating virtual environment: %v", err) + return fmt.Errorf("error creating virtual environment: %v", err) } pythonPath := filepath.Join(venvPath, "bin", "python") @@ -79,48 +70,25 @@ func (p *python) Run() ([]byte, error) { rcmd.Stderr = p.stderr rcmd.Stdout = p.stdout if err := rcmd.Run(); err != nil { - return nil, fmt.Errorf("error installing requirements: %v", err) - } - } - - defer os.Remove(p.algoFile) - defer func() { - for _, file := range p.datasets { - os.Remove(file) + return fmt.Errorf("error installing requirements: %v", err) } - }() - defer os.RemoveAll(venvPath) - - listener, err := socket.StartUnixSocketServer(socketPath) - if err != nil { - return nil, fmt.Errorf("error creating stdout pipe: %v", err) } - defer listener.Close() - - dataChannel := make(chan []byte) - errorChannel := make(chan error) - - var result []byte - go socket.AcceptConnection(listener, dataChannel, errorChannel) - - args := append([]string{p.algoFile, socketPath}, p.datasets...) - cmd := exec.Command(pythonPath, args...) + cmd := exec.Command(pythonPath, p.algoFile) cmd.Stderr = p.stderr cmd.Stdout = p.stdout if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("error starting algorithm: %v", err) + return fmt.Errorf("error starting algorithm: %v", err) } if err := cmd.Wait(); err != nil { - return nil, fmt.Errorf("algorithm execution error: %v", err) + return fmt.Errorf("algorithm execution error: %v", err) } - select { - case result = <-dataChannel: - return result, nil - case err = <-errorChannel: - return nil, fmt.Errorf("error receiving data: %v", err) + if err := os.RemoveAll(venvPath); err != nil { + return fmt.Errorf("error removing virtual environment: %v", err) } + + return nil } diff --git a/agent/algorithm/results.go b/agent/algorithm/results.go new file mode 100644 index 00000000..114cfe05 --- /dev/null +++ b/agent/algorithm/results.go @@ -0,0 +1,59 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package algorithm + +import ( + "archive/zip" + "bytes" + "fmt" + "io" + "os" + "path/filepath" +) + +// ZipDirectory zips a directory and returns the zipped bytes. +func ZipDirectory() ([]byte, error) { + buf := new(bytes.Buffer) + zipWriter := zip.NewWriter(buf) + + err := filepath.Walk(ResultsDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("error walking the path %q: %v", path, err) + } + + if info.IsDir() { + return nil + } + + relPath, err := filepath.Rel(ResultsDir, path) + if err != nil { + return fmt.Errorf("error getting relative path for %q: %v", path, err) + } + + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("error opening file %q: %v", path, err) + } + defer file.Close() + + zipFile, err := zipWriter.Create(relPath) + if err != nil { + return fmt.Errorf("error creating zip file for %q: %v", path, err) + } + + if _, err = io.Copy(zipFile, file); err != nil { + return fmt.Errorf("error copying file %q to zip: %v", path, err) + } + + return err + }) + if err != nil { + return nil, err + } + + if err = zipWriter.Close(); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} diff --git a/agent/algorithm/results_test.go b/agent/algorithm/results_test.go new file mode 100644 index 00000000..c9745624 --- /dev/null +++ b/agent/algorithm/results_test.go @@ -0,0 +1,81 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package algorithm_test + +import ( + "os" + "testing" + + "github.com/ultravioletrs/cocos/agent/algorithm" +) + +func TestZipDirectory(t *testing.T) { + cases := []struct { + name string + directories []string + files []string + expected []string + }{ + { + name: "empty directory", + directories: []string{"testdata"}, + }, + { + name: "single file", + files: []string{"file1.txt"}, + }, + { + name: "directory with single file", + directories: []string{"testdata"}, + expected: []string{"testdata/file1.txt"}, + }, + { + name: "directory with multiple files", + directories: []string{"testdata"}, + expected: []string{ + "testdata/file1.txt", + "testdata/file2.txt", + "testdata/file3.txt", + }, + }, + { + name: "nested directories", + directories: []string{"testdata", "testdata/nested"}, + expected: []string{ + "testdata/nested/file1.txt", + "testdata/nested/file2.txt", + "testdata/nested/file3.txt", + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil { + t.Fatalf("error creating results directory: %s", err.Error()) + } + defer func() { + if err := os.RemoveAll(algorithm.ResultsDir); err != nil { + t.Fatalf("error removing results directory and its contents: %s", err.Error()) + } + }() + + for _, dir := range tc.directories { + if dir != "" { + if err := os.Mkdir(algorithm.ResultsDir+"/"+dir, 0o755); err != nil { + t.Fatalf("error creating test directory: %s", err.Error()) + } + } + } + for _, file := range tc.files { + if _, err := os.Create(algorithm.ResultsDir + "/" + file); err != nil { + t.Fatalf("error creating test file: %s", err.Error()) + } + } + + if _, err := algorithm.ZipDirectory(); err != nil { + t.Errorf("ZipDirectory() error = %v", err) + } + }) + } +} diff --git a/agent/algorithm/wasm/wasm.go b/agent/algorithm/wasm/wasm.go new file mode 100644 index 00000000..9d8bc721 --- /dev/null +++ b/agent/algorithm/wasm/wasm.go @@ -0,0 +1,50 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package wasm + +import ( + "fmt" + "io" + "log/slog" + "os/exec" + + "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/agent/events" +) + +const wasmRuntime = "wasmedge" + +var mapDirOption = []string{"--dir", ".:" + algorithm.ResultsDir} + +var _ algorithm.Algorithm = (*wasm)(nil) + +type wasm struct { + algoFile string + stderr io.Writer + stdout io.Writer +} + +func NewAlgorithm(logger *slog.Logger, eventsSvc events.Service, algoFile string) algorithm.Algorithm { + return &wasm{ + algoFile: algoFile, + stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc}, + stdout: &algorithm.Stdout{Logger: logger}, + } +} + +func (w *wasm) Run() error { + args := append(mapDirOption, w.algoFile) + cmd := exec.Command(wasmRuntime, args...) + cmd.Stderr = w.stderr + cmd.Stdout = w.stdout + + if err := cmd.Start(); err != nil { + return fmt.Errorf("error starting algorithm: %v", err) + } + + if err := cmd.Wait(); err != nil { + return fmt.Errorf("algorithm execution error: %v", err) + } + + return nil +} diff --git a/agent/service.go b/agent/service.go index 004f29d1..bc105ed6 100644 --- a/agent/service.go +++ b/agent/service.go @@ -16,6 +16,7 @@ import ( "github.com/ultravioletrs/cocos/agent/algorithm" "github.com/ultravioletrs/cocos/agent/algorithm/binary" "github.com/ultravioletrs/cocos/agent/algorithm/python" + "github.com/ultravioletrs/cocos/agent/algorithm/wasm" "github.com/ultravioletrs/cocos/agent/events" "golang.org/x/crypto/sha3" ) @@ -89,6 +90,7 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp svc.sm.StateFunctions[running] = svc.runComputation svc.computation = cmp + svc.sm.SendEvent(manifestReceived) return svc } @@ -131,21 +133,31 @@ func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { switch algoType { case string(algorithm.AlgoTypeBin): - as.algorithm = binary.New(as.sm.logger, as.eventSvc, f.Name()) + as.algorithm = binary.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name()) case string(algorithm.AlgoTypePython): - fr, err := os.CreateTemp("", "requirements.txt") - if err != nil { - return fmt.Errorf("error creating requirments file: %v", err) - } - - if _, err := fr.Write(algo.Requirements); err != nil { - return fmt.Errorf("error writing requirements to file: %v", err) - } - if err := fr.Close(); err != nil { - return fmt.Errorf("error closing file: %v", err) + var requirementsFile string + if len(algo.Requirements) > 0 { + fr, err := os.CreateTemp("", "requirements.txt") + if err != nil { + return fmt.Errorf("error creating requirments file: %v", err) + } + + if _, err := fr.Write(algo.Requirements); err != nil { + return fmt.Errorf("error writing requirements to file: %v", err) + } + if err := fr.Close(); err != nil { + return fmt.Errorf("error closing file: %v", err) + } + requirementsFile = fr.Name() } runtime := python.PythonRunTimeFromContext(ctx) - as.algorithm = python.New(as.sm.logger, as.eventSvc, runtime, fr.Name(), f.Name()) + as.algorithm = python.NewAlgorithm(as.sm.logger, as.eventSvc, runtime, requirementsFile, f.Name()) + case string(algorithm.AlgoTypeWasm): + as.algorithm = wasm.NewAlgorithm(as.sm.logger, as.eventSvc, f.Name()) + } + + if err := os.Mkdir(algorithm.DatasetsDir, 0o755); err != nil { + return fmt.Errorf("error creating datasets directory: %v", err) } if as.algorithm != nil { @@ -175,7 +187,7 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { } as.computation.Datasets = slices.Delete(as.computation.Datasets, index, index+1) - f, err := os.CreateTemp("", fmt.Sprintf("dataset-%d", index)) + f, err := os.Create(fmt.Sprintf("%s/dataset-%d", algorithm.DatasetsDir, index)) if err != nil { return fmt.Errorf("error creating dataset file: %v", err) } @@ -187,8 +199,6 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { return fmt.Errorf("error closing file: %v", err) } - as.algorithm.AddDataset(f.Name()) - if len(as.computation.Datasets) == 0 { as.sm.SendEvent(dataReceived) } @@ -233,16 +243,42 @@ func (as *agentService) runComputation() { as.publishEvent("starting", json.RawMessage{})() as.sm.logger.Debug("computation run started") defer as.sm.SendEvent(runComplete) + + if err := os.Mkdir(algorithm.ResultsDir, 0o755); err != nil { + as.runError = fmt.Errorf("error creating results directory: %s", err.Error()) + as.sm.logger.Warn(as.runError.Error()) + as.publishEvent("failed", json.RawMessage{})() + return + } + + defer func() { + if err := os.RemoveAll(algorithm.ResultsDir); err != nil { + as.sm.logger.Warn(fmt.Sprintf("error removing results directory and its contents: %s", err.Error())) + } + if err := os.RemoveAll(algorithm.DatasetsDir); err != nil { + as.sm.logger.Warn(fmt.Sprintf("error removing datasets directory and its contents: %s", err.Error())) + } + }() + as.publishEvent("in-progress", json.RawMessage{})() - result, err := as.algorithm.Run() + if err := as.algorithm.Run(); err != nil { + as.runError = err + as.sm.logger.Warn(fmt.Sprintf("failed to run computation: %s", err.Error())) + as.publishEvent("failed", json.RawMessage{})() + return + } + + results, err := algorithm.ZipDirectory() if err != nil { as.runError = err - as.sm.logger.Warn(fmt.Sprintf("computation failed with error: %s", err.Error())) + as.sm.logger.Warn(fmt.Sprintf("failed to zip results: %s", err.Error())) as.publishEvent("failed", json.RawMessage{})() return } + as.publishEvent("complete", json.RawMessage{})() - as.result = result + + as.result = results } func (as *agentService) publishEvent(status string, details json.RawMessage) func() { diff --git a/cli/result.go b/cli/result.go index a264c667..73898cfb 100644 --- a/cli/result.go +++ b/cli/result.go @@ -10,7 +10,7 @@ import ( "github.com/spf13/cobra" ) -const resultFilePath = "result.bin" +const resultFilePath = "result.zip" func (cli *CLI) NewResultsCmd() *cobra.Command { return &cobra.Command{ diff --git a/hal/linux/Config.in b/hal/linux/Config.in index cf84594f..03271381 100644 --- a/hal/linux/Config.in +++ b/hal/linux/Config.in @@ -1,2 +1,2 @@ source "$BR2_EXTERNAL_COCOS_PATH/package/agent/Config.in" -source "$BR2_EXTERNAL_COCOS_PATH/package/wasmtime/Config.in" +source "$BR2_EXTERNAL_COCOS_PATH/package/wasmedge/Config.in" diff --git a/hal/linux/package/wasmedge/Config.in b/hal/linux/package/wasmedge/Config.in new file mode 100644 index 00000000..c0a7be4e --- /dev/null +++ b/hal/linux/package/wasmedge/Config.in @@ -0,0 +1,6 @@ +config BR2_PACKAGE_WASMEDGE + bool "wasmedge" + default y + help + Wasmedge is a standalone runtime for WebAssembly. + https://wasmedge.org/docs/ diff --git a/hal/linux/package/wasmedge/wasmedge.mk b/hal/linux/package/wasmedge/wasmedge.mk new file mode 100644 index 00000000..9d62740d --- /dev/null +++ b/hal/linux/package/wasmedge/wasmedge.mk @@ -0,0 +1,8 @@ +WASMEDGE_DOWNLOAD_URL = https://raw.githubusercontent.com/WasmEdge/WasmEdge/master/utils/install.sh + +define WASMEDGE_INSTALL_TARGET_CMDS + curl -sSf $(WASMEDGE_DOWNLOAD_URL) | bash -s -- -p $(TARGET_DIR)/usr + echo "source /usr/env" >> $(TARGET_DIR)/etc/profile +endef + +$(eval $(generic-package)) diff --git a/hal/linux/package/wasmtime/Config.in b/hal/linux/package/wasmtime/Config.in deleted file mode 100644 index 06cbc61f..00000000 --- a/hal/linux/package/wasmtime/Config.in +++ /dev/null @@ -1,5 +0,0 @@ -config BR2_PACKAGE_WASMTIME - bool "wasmtime" - help - Wasmtime is a standalone runtime for WebAssembly. - https://github.com/bytecodealliance/wasmtime diff --git a/hal/linux/package/wasmtime/wasmtime.mk b/hal/linux/package/wasmtime/wasmtime.mk deleted file mode 100644 index 1ec16882..00000000 --- a/hal/linux/package/wasmtime/wasmtime.mk +++ /dev/null @@ -1,11 +0,0 @@ -WASMTIME_SITE = https://wasmtime.dev/install.sh - -define WASMTIME_BUILD_CMDS - curl $(WASMTIME_SITE) -sSf | bash -endef - -define WASMTIME_INSTALL_TARGET_CMDS - $(INSTALL) -D -m 0755 ~/.wasmtime/bin/wasmtime $(TARGET_DIR)/usr/bin/wasmtime -endef - -$(eval $(generic-package)) diff --git a/pkg/manager/manager.pb.go b/pkg/manager/manager.pb.go index 1c8a3478..1aa78cbd 100644 --- a/pkg/manager/manager.pb.go +++ b/pkg/manager/manager.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.34.2 -// protoc v4.25.3 +// protoc v5.27.2 // source: manager/manager.proto package manager diff --git a/pkg/manager/manager_grpc.pb.go b/pkg/manager/manager_grpc.pb.go index 2aee96a5..ee6b7b52 100644 --- a/pkg/manager/manager_grpc.pb.go +++ b/pkg/manager/manager_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.4.0 -// - protoc v4.25.3 +// - protoc v5.27.2 // source: manager/manager.proto package manager diff --git a/test/manual/algo/README.md b/test/manual/algo/README.md index 60ee1efb..a9cb79f6 100644 --- a/test/manual/algo/README.md +++ b/test/manual/algo/README.md @@ -1,17 +1,105 @@ # Algorithm -Agent accepts binaries programs. To use the python program you need to bundle or compile it. -In this example we'll use [pyinstaller](https://pypi.org/project/pyinstaller/) +Agent accepts binaries programs, python scripts, and wasm files. It runs them in a sandboxed environment and returns the output. -```shell -pip install pandas scikit-learn -pip install -U pyinstaller -pyinstaller --onefile lin_reg.py +## Python Example + +To test this examples work on your local machine, you need to install the following dependencies: + +```bash +pip install -r requirements.txt +``` + +This can be done in a virtual environment. + +```bash +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +To run the example, you can use the following command: + +```bash +python3 test/manual/algo/addition.py +``` + +The addition example is a simple algorithm to demonstrate you can run an algorithm without any external dependencies and input arguments. It returns the sum of two numbers. + +```bash +python3 test/manual/algo/lin_reg.py ``` -Make the binary static: +The linear regression example is a more complex algorithm that requires external dependencies.It returns a linear regression model trained on the iris dataset found [here](../data/) for demonstration purposes. + +```bash +python3 test/manual/algo/lin_reg.py predict result.zip test/manual/data +``` + +This will make inference on the results of the linear regression model. + +To run the examples in the agent, you can use the following command: + +```bash +go run ./test/computations/main.go ./test/manual/algo/lin_reg.py public.pem false ./test/manual/data/iris.csv +``` + +This command is run from the root directory of the project. This will start the computation server. + +In another window, you can run the following command: + +```bash +sudo MANAGER_QEMU_SMP_MAXCPUS=4 MANAGER_GRPC_URL=localhost:7001 MANAGER_LOG_LEVEL=debug MANAGER_QEMU_USE_SUDO=false MANAGER_QEMU_ENABLE_SEV=false MANAGER_QEMU_SEV_CBITPOS=51 MANAGER_QEMU_ENABLE_SEV_SNP=false MANAGER_QEMU_OVMF_CODE_FILE=/usr/share/edk2/x64/OVMF_CODE.fd MANAGER_QEMU_OVMF_VARS_FILE=/usr/share/edk2/x64/OVMF_VARS.fd go run main.go +``` + +This command is run from the [manager main directory](../../../cmd/manager/). This will start the manager. Make sure you have already built the [qemu image](../../../hal/linux/README.md). + +In another window, you can run the following command: -```shell -pip install staticx -staticx +```bash +./build/cocos-cli algo ./test/manual/algo/lin_reg.py ./private.pem -a python -r ./test/manual/algo/requirements.txt ``` + +make sure you have built the cocos-cli. This will upload the algorithm and the requirements file. + +Next we need to upload the dataset + +```bash +./build/cocos-cli data ./test/manual/data/iris.csv ./private.pem +``` + +After some time when the results are ready, you can run the following command to get the results: + +```bash +./build/cocos-cli results ./private.pem +``` + +This will return the results of the algorithm. + +To make inference on the results, you can use the following command: + +```bash +python3 test/manual/algo/lin_reg.py predict result.zip test/manual/data +``` + +For addition example, you can use the following command: + +```bash +go run ./test/computations/main.go ./test/manual/algo/addition.py public.pem false +``` + +```bash +./build/cocos-cli algo ./test/manual/algo/addition.py ./private.pem -a python +``` + +```bash +./build/cocos-cli results ./private.pem +``` + +## Wasm Example + +More information on how to run wasm files can be found [here](https://github.com/ultravioletrs/ai/tree/main/burn-algorithms). + +## Binary Example + +More information on how to run binary files can be found [here](https://github.com/ultravioletrs/ai/tree/main/burn-algorithms). diff --git a/test/manual/algo/addition.py b/test/manual/algo/addition.py index 7dfc42ae..9a4467b9 100644 --- a/test/manual/algo/addition.py +++ b/test/manual/algo/addition.py @@ -1,9 +1,14 @@ -import sys, io -import joblib -import socket +import os +import sys +import zipfile + +RESULTS_DIR = "results" +RESULTS_FILE = "result.txt" + class Computation: result = 0 + def __init__(self): """ Initializes a new instance of the Computation class. @@ -16,45 +21,35 @@ def compute(self, a, b): """ self.result = a + b - def send_result(self, socket_path): + def save_result(self): """ - Sends the result to a socket. + Sends the result to a file. """ - buffer = io.BytesIO() - try: - joblib.dump(self.result, buffer) - except Exception as e: - print("Failed to dump the result to the buffer: ", e) - return + os.makedirs(RESULTS_DIR) + except FileExistsError: + pass - data = buffer.getvalue() + with open(RESULTS_DIR + os.sep + RESULTS_FILE, "w") as f: + f.write(str(self.result)) - client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - try: - try: - client.connect(socket_path) - except Exception as e: - print("Failed to connect to the socket: ", e) - return - try: - client.send(data) - except Exception as e: - print("Failed to send data to the socket: ", e) - return - finally: - client.close() - def read_results_from_file(self, results_file): """ Reads the results from a file. """ - try: - results = joblib.load(results_file) - print("Results: ", results) - except Exception as e: - print("Failed to load results from file: ", e) - return + if results_file.endswith(".zip"): + try: + os.makedirs(RESULTS_DIR) + except FileExistsError: + pass + with zipfile.ZipFile(results_file, "r") as zip_ref: + zip_ref.extractall(RESULTS_DIR) + with open(RESULTS_FILE, "r") as f: + print(f.read()) + else: + with open(results_file, "r") as f: + print(f.read()) + if __name__ == "__main__": a = 5 @@ -62,15 +57,10 @@ def read_results_from_file(self, results_file): computation = Computation() if len(sys.argv) == 1: - print("Please provide a socket path or a file path") - exit(1) - - if sys.argv[1] == "test" and len(sys.argv) == 3: - computation.read_results_from_file(sys.argv[2]) - elif len(sys.argv) == 2: computation.compute(a, b) - computation.send_result(sys.argv[1]) + computation.save_result() + elif len(sys.argv) == 3 and sys.argv[1] == "test": + computation.read_results_from_file(sys.argv[2]) else: print("Invalid arguments") exit(1) - diff --git a/test/manual/algo/lin_reg.py b/test/manual/algo/lin_reg.py index 191acaa4..701e0e18 100644 --- a/test/manual/algo/lin_reg.py +++ b/test/manual/algo/lin_reg.py @@ -1,47 +1,116 @@ -import sys, io +import os +import sys import joblib -import socket - import pandas as pd from sklearn.model_selection import train_test_split from sklearn.linear_model import LogisticRegression +import zipfile +from sklearn import metrics + +DATA_DIR = "datasets" +RESULTS_DIR = "results" +RESULTS_FILE = "model.bin" + + +class Computation: + model = None + + def __init__(self): + """ + Initializes a new instance of the Computation class. + """ + pass + + def _read_csv(self, data_path=""): + """ + Reads the CSV file. + """ + files = os.listdir(data_path) + if len(files) != 1: + print("No files found in the directory") + exit(1) + csv_file_path = data_path + os.sep + files[0] + return pd.read_csv(csv_file_path) + + def compute(self): + """ + Trains a logistic regression model. + """ + iris = self._read_csv(DATA_DIR) + + # Droping the Species since we only need the measurements + X = iris.drop(["Species"], axis=1) + + # converting into numpy array and assigning petal length and petal width + X = X.to_numpy()[:, (3, 4)] + y = iris["Species"] + + X_train, _, y_train, _ = train_test_split(X, y, test_size=0.5, random_state=42) -csv_file_path = sys.argv[2] -iris = pd.read_csv(csv_file_path) + log_reg = LogisticRegression() + log_reg.fit(X_train, y_train) + self.model = log_reg -# Droping the Species since we only need the measurements -X = iris.drop(['Species'], axis=1) + def save_result(self): + """ + Sends the result to a file. + """ + try: + os.makedirs(RESULTS_DIR) + except FileExistsError: + pass -# converting into numpy array and assigning petal length and petal width -X = X.to_numpy()[:, (3,4)] -y = iris['Species'] + results_file = RESULTS_DIR + os.sep + RESULTS_FILE + joblib.dump(self.model, results_file) -# Splitting into train and test -X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.5, random_state=42) + def read_results_from_file(self, results_file): + """ + Reads the results from a file. + """ + if results_file.endswith(".zip"): + try: + os.makedirs(RESULTS_DIR) + except FileExistsError: + pass + with zipfile.ZipFile(results_file, "r") as zip_ref: + zip_ref.extractall(RESULTS_DIR) + self.model = joblib.load(RESULTS_DIR + os.sep + RESULTS_FILE) + else: + self.model = joblib.load(results_file) -log_reg = LogisticRegression() -log_reg.fit(X_train,y_train) + def predict(self, data_path=""): + iris = self._read_csv(data_path) -# Serialize the trained model to a byte buffer -model_buffer = io.BytesIO() -joblib.dump(log_reg, model_buffer) + # Droping the Species since we only need the measurements + X = iris.drop(["Species"], axis=1) -# Get the serialized model as a bytes object -model_bytes = model_buffer.getvalue() + # converting into numpy array and assigning petal length and petal width + X = X.to_numpy()[:, (3, 4)] + y = iris["Species"] -# Define the path for the Unix domain socket -socket_path = sys.argv[1] + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.5, random_state=42 + ) -# Create a Unix domain socket client -client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + training_prediction = self.model.predict(X_train) + test_prediction = self.model.predict(X_test) -try: - # Connect to the server - client.connect(socket_path) + print("Precision, Recall, Confusion matrix, in training\n") + print(metrics.classification_report(y_train, training_prediction, digits=3)) + print(metrics.confusion_matrix(y_train, training_prediction)) + print("Precision, Recall, Confusion matrix, in testing\n") + print(metrics.classification_report(y_test, test_prediction, digits=3)) + print(metrics.confusion_matrix(y_test, test_prediction)) - # Send the serialized model over the socket - client.send(model_bytes) -finally: - # Close the socket - client.close() +if __name__ == "__main__": + computation = Computation() + if len(sys.argv) == 1: + computation.compute() + computation.save_result() + elif len(sys.argv) == 4 and sys.argv[1] == "predict": + computation.read_results_from_file(sys.argv[2]) + computation.predict(sys.argv[3]) + else: + print("Invalid arguments") + exit(1) diff --git a/test/manual/algo/lin_reg_test.py b/test/manual/algo/lin_reg_test.py deleted file mode 100644 index 3165873c..00000000 --- a/test/manual/algo/lin_reg_test.py +++ /dev/null @@ -1,51 +0,0 @@ -import pandas as pd - -from sklearn.model_selection import train_test_split -from sklearn import metrics -import joblib - -import sys - -import warnings -warnings.filterwarnings("ignore", category=DeprecationWarning) -warnings.filterwarnings("ignore", category=UserWarning) - -csv_file_path = sys.argv[1] -model_filename = sys.argv[2] - -# Load the CSV file into a Pandas DataFrame -iris = pd.read_csv(csv_file_path) - -log_reg = joblib.load(model_filename) - -# Now you have the Iris dataset loaded into the iris_df DataFrame -print(iris.head()) # Display the first few rows of the DataFrame - -# Droping the Species since we only need the measurements -X = iris.drop(['Species'], axis=1) - -# converting into numpy array and assigning petal length and petal width -X = X.to_numpy()[:, (3,4)] -y = iris['Species'] - -# Splitting into train and test -X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.5, random_state=42) - -training_prediction = log_reg.predict(X_train) -test_prediction = log_reg.predict(X_test) - -print("Precision, Recall, Confusion matrix, in training\n") - -# Precision Recall scores -print(metrics.classification_report(y_train, training_prediction, digits=3)) - -# Confusion matrix -print(metrics.confusion_matrix(y_train, training_prediction)) - -print("Precision, Recall, Confusion matrix, in testing\n") - -# Precision Recall scores -print(metrics.classification_report(y_test, test_prediction, digits=3)) - -# Confusion matrix -print(metrics.confusion_matrix(y_test, test_prediction))