From 67d01e39bebc3fb1e4b8a2477ee7343f87ebb517 Mon Sep 17 00:00:00 2001 From: Sammy Kerata Oina <44265300+SammyOina@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:38:03 +0300 Subject: [PATCH] COCOS-155 - Add python algo support (#178) * * feat(algorithm.go): add support for algorithm type context * feat(python.go): implement Python algorithm runtime * fix(cocos_defconfig): add IPTABLES package Signed-off-by: SammyOina * update proto Signed-off-by: Sammy Oina * small fixes Signed-off-by: Sammy Oina * add metadata Signed-off-by: Sammy Oina * debug Signed-off-by: Sammy Oina * debug Signed-off-by: Sammy Oina * chunk logger Signed-off-by: Sammy Oina * debug logger Signed-off-by: Sammy Oina * test lock Signed-off-by: Sammy Oina * add req file Signed-off-by: SammyOina * stream result Signed-off-by: SammyOina * test with venv Signed-off-by: Sammy Oina * fix missing requirements file Signed-off-by: Sammy Oina * result stream Signed-off-by: Sammy Oina * modify test server Signed-off-by: Sammy Oina * remove debugging and cleaning up Signed-off-by: Sammy Oina * original repo Signed-off-by: Sammy Oina * add missing header Signed-off-by: Sammy Oina * downgrade protoc Signed-off-by: Sammy Oina --------- Signed-off-by: SammyOina Signed-off-by: Sammy Oina --- .env | 11 +-- .github/workflows/checkproto.yaml | 2 +- agent/agent.pb.go | 80 ++++++++++--------- agent/agent.proto | 3 +- agent/agent_grpc.pb.go | 83 +++++++++++++------- agent/algorithm/algorithm.go | 25 ++++++ agent/algorithm/binary/binary.go | 7 +- agent/algorithm/python/python.go | 126 ++++++++++++++++++++++++++++++ agent/api/grpc/endpoint.go | 2 +- agent/api/grpc/interceptors.go | 7 ++ agent/api/grpc/requests.go | 3 +- agent/api/grpc/server.go | 38 +++++++-- agent/computations.go | 7 +- agent/service.go | 54 +++++++++---- agent/state.go | 22 +++--- cli/algorithms.go | 39 ++++++++- hal/linux/configs/cocos_defconfig | 15 ++++ internal/logger/protohandler.go | 52 +++++++++--- pkg/manager/manager.pb.go | 2 +- pkg/manager/manager_grpc.pb.go | 2 +- pkg/progressbar/progressbar.go | 54 ++++++++++++- pkg/sdk/agent.go | 25 +++++- test/computations/main.go | 24 +++--- test/manual/algo/requirements.txt | 3 + 24 files changed, 537 insertions(+), 149 deletions(-) create mode 100644 agent/algorithm/python/python.go create mode 100644 test/manual/algo/requirements.txt diff --git a/.env b/.env index 7d7ad6cc..35bb0201 100644 --- a/.env +++ b/.env @@ -11,15 +11,11 @@ COCOS_JAEGER_OLTP_HTTP_PORT=4318 ## Core Services ### Manager -MANAGER_HTTP_HOST="cocos-manager" -MANAGER_HTTP_PORT=9021 -MANAGER_HTTP_SERVER_CERT="" -MANAGER_HTTP_SERVER_KEY="" -MANAGER_GRPC_HOST="cocos-manager" +MANAGER_GRPC_HOST="" MANAGER_GRPC_PORT=7003 MANAGER_GRPC_SERVER_CERT="" MANAGER_GRPC_SERVER_KEY="" -AGENT_GRPC_URL="192.168.100.4:7002" +AGENT_GRPC_URL="localhost:7002" AGENT_GRPC_TIMEOUT="" AGENT_GRPC_CA_CERTS="" AGENT_GRPC_CLIENT_TLS="" @@ -30,6 +26,3 @@ MANAGER_QEMU_ENABLE_SEV=false MANAGER_QEMU_SEV_CBITPOS=51 MANAGER_QEMU_OVMF_CODE_FILE=/usr/share/OVMF/OVMF_CODE.fd MANAGER_QEMU_OVMF_VARS_FILE=/usr/share/OVMF/OVMF_VARS.fd - -# Docker image tag -COCOS_RELEASE_TAG=latest diff --git a/.github/workflows/checkproto.yaml b/.github/workflows/checkproto.yaml index 51be8213..798625d5 100644 --- a/.github/workflows/checkproto.yaml +++ b/.github/workflows/checkproto.yaml @@ -33,7 +33,7 @@ jobs: - name: Set up protoc run: | - PROTOC_VERSION=27.2 + PROTOC_VERSION=25.3 PROTOC_GEN_VERSION=v1.34.2 PROTOC_GRPC_VERSION=v1.4.0 diff --git a/agent/agent.pb.go b/agent/agent.pb.go index 3a8940b2..492a75d4 100644 --- a/agent/agent.pb.go +++ b/agent/agent.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.34.2 -// protoc v5.27.2 +// protoc v4.25.3 // source: agent/agent.proto package agent @@ -28,7 +28,8 @@ type AlgoRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Algorithm []byte `protobuf:"bytes,1,opt,name=algorithm,proto3" json:"algorithm,omitempty"` + Algorithm []byte `protobuf:"bytes,1,opt,name=algorithm,proto3" json:"algorithm,omitempty"` + Requirements []byte `protobuf:"bytes,2,opt,name=requirements,proto3" json:"requirements,omitempty"` } func (x *AlgoRequest) Reset() { @@ -70,6 +71,13 @@ func (x *AlgoRequest) GetAlgorithm() []byte { return nil } +func (x *AlgoRequest) GetRequirements() []byte { + if x != nil { + return x.Requirements + } + return nil +} + type AlgoResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -376,41 +384,43 @@ var File_agent_agent_proto protoreflect.FileDescriptor var file_agent_agent_proto_rawDesc = []byte{ 0x0a, 0x11, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x22, 0x2b, 0x0a, 0x0b, 0x41, 0x6c, + 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x22, 0x4f, 0x0a, 0x0b, 0x41, 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6c, 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x61, 0x6c, - 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x22, 0x0e, 0x0a, 0x0c, 0x41, 0x6c, 0x67, 0x6f, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x27, 0x0a, 0x0b, 0x44, 0x61, 0x74, 0x61, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, - 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x64, 0x61, 0x74, 0x61, 0x73, 0x65, 0x74, - 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x0f, 0x0a, 0x0d, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x22, 0x24, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0x35, 0x0a, 0x12, 0x41, 0x74, 0x74, 0x65, 0x73, - 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1f, 0x0a, - 0x0b, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0c, 0x52, 0x0a, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x44, 0x61, 0x74, 0x61, 0x22, 0x29, - 0x0a, 0x13, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xf9, 0x01, 0x0a, 0x0c, 0x41, 0x67, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c, - 0x67, 0x6f, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, - 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, - 0x33, 0x0a, 0x04, 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, - 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, - 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x28, 0x01, 0x12, 0x37, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x14, - 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, - 0x0b, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x2e, 0x61, - 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, - 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x67, 0x6f, 0x72, 0x69, 0x74, 0x68, 0x6d, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x71, 0x75, 0x69, + 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0c, 0x72, + 0x65, 0x71, 0x75, 0x69, 0x72, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x22, 0x0e, 0x0a, 0x0c, 0x41, + 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x27, 0x0a, 0x0b, 0x44, + 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x61, + 0x74, 0x61, 0x73, 0x65, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x64, 0x61, 0x74, + 0x61, 0x73, 0x65, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0f, 0x0a, 0x0d, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x24, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0x35, 0x0a, 0x12, 0x41, + 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x1f, 0x0a, 0x0b, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x64, 0x61, 0x74, 0x61, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0a, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x44, 0x61, + 0x74, 0x61, 0x22, 0x29, 0x0a, 0x13, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfb, 0x01, + 0x0a, 0x0c, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, + 0x0a, 0x04, 0x41, 0x6c, 0x67, 0x6f, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, + 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, + 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x28, 0x01, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x61, 0x67, + 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, 0x39, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, + 0x6c, 0x74, 0x12, 0x14, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, + 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, + 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x30, 0x01, 0x12, 0x46, 0x0a, 0x0b, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x12, 0x19, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, + 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, + 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x2e, + 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/agent/agent.proto b/agent/agent.proto index 305fa9c1..551dbc0b 100644 --- a/agent/agent.proto +++ b/agent/agent.proto @@ -10,12 +10,13 @@ option go_package = "./agent"; service AgentService { rpc Algo(stream AlgoRequest) returns (AlgoResponse) {} rpc Data(stream DataRequest) returns (DataResponse) {} - rpc Result(ResultRequest) returns (ResultResponse) {} + rpc Result(ResultRequest) returns (stream ResultResponse) {} rpc Attestation(AttestationRequest) returns (AttestationResponse) {} } message AlgoRequest { bytes algorithm = 1; + bytes requirements = 2; } message AlgoResponse {} diff --git a/agent/agent_grpc.pb.go b/agent/agent_grpc.pb.go index c666126f..17358955 100644 --- a/agent/agent_grpc.pb.go +++ b/agent/agent_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.4.0 -// - protoc v5.27.2 +// - protoc v4.25.3 // source: agent/agent.proto package agent @@ -34,7 +34,7 @@ const ( type AgentServiceClient interface { Algo(ctx context.Context, opts ...grpc.CallOption) (AgentService_AlgoClient, error) Data(ctx context.Context, opts ...grpc.CallOption) (AgentService_DataClient, error) - Result(ctx context.Context, in *ResultRequest, opts ...grpc.CallOption) (*ResultResponse, error) + Result(ctx context.Context, in *ResultRequest, opts ...grpc.CallOption) (AgentService_ResultClient, error) Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (*AttestationResponse, error) } @@ -116,14 +116,37 @@ func (x *agentServiceDataClient) CloseAndRecv() (*DataResponse, error) { return m, nil } -func (c *agentServiceClient) Result(ctx context.Context, in *ResultRequest, opts ...grpc.CallOption) (*ResultResponse, error) { +func (c *agentServiceClient) Result(ctx context.Context, in *ResultRequest, opts ...grpc.CallOption) (AgentService_ResultClient, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(ResultResponse) - err := c.cc.Invoke(ctx, AgentService_Result_FullMethodName, in, out, cOpts...) + stream, err := c.cc.NewStream(ctx, &AgentService_ServiceDesc.Streams[2], AgentService_Result_FullMethodName, cOpts...) if err != nil { return nil, err } - return out, nil + x := &agentServiceResultClient{ClientStream: stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type AgentService_ResultClient interface { + Recv() (*ResultResponse, error) + grpc.ClientStream +} + +type agentServiceResultClient struct { + grpc.ClientStream +} + +func (x *agentServiceResultClient) Recv() (*ResultResponse, error) { + m := new(ResultResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil } func (c *agentServiceClient) Attestation(ctx context.Context, in *AttestationRequest, opts ...grpc.CallOption) (*AttestationResponse, error) { @@ -142,7 +165,7 @@ func (c *agentServiceClient) Attestation(ctx context.Context, in *AttestationReq type AgentServiceServer interface { Algo(AgentService_AlgoServer) error Data(AgentService_DataServer) error - Result(context.Context, *ResultRequest) (*ResultResponse, error) + Result(*ResultRequest, AgentService_ResultServer) error Attestation(context.Context, *AttestationRequest) (*AttestationResponse, error) mustEmbedUnimplementedAgentServiceServer() } @@ -157,8 +180,8 @@ func (UnimplementedAgentServiceServer) Algo(AgentService_AlgoServer) error { func (UnimplementedAgentServiceServer) Data(AgentService_DataServer) error { return status.Errorf(codes.Unimplemented, "method Data not implemented") } -func (UnimplementedAgentServiceServer) Result(context.Context, *ResultRequest) (*ResultResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method Result not implemented") +func (UnimplementedAgentServiceServer) Result(*ResultRequest, AgentService_ResultServer) error { + return status.Errorf(codes.Unimplemented, "method Result not implemented") } func (UnimplementedAgentServiceServer) Attestation(context.Context, *AttestationRequest) (*AttestationResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Attestation not implemented") @@ -228,22 +251,25 @@ func (x *agentServiceDataServer) Recv() (*DataRequest, error) { return m, nil } -func _AgentService_Result_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(ResultRequest) - if err := dec(in); err != nil { - return nil, err +func _AgentService_Result_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(ResultRequest) + if err := stream.RecvMsg(m); err != nil { + return err } - if interceptor == nil { - return srv.(AgentServiceServer).Result(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: AgentService_Result_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(AgentServiceServer).Result(ctx, req.(*ResultRequest)) - } - return interceptor(ctx, in, info, handler) + return srv.(AgentServiceServer).Result(m, &agentServiceResultServer{ServerStream: stream}) +} + +type AgentService_ResultServer interface { + Send(*ResultResponse) error + grpc.ServerStream +} + +type agentServiceResultServer struct { + grpc.ServerStream +} + +func (x *agentServiceResultServer) Send(m *ResultResponse) error { + return x.ServerStream.SendMsg(m) } func _AgentService_Attestation_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { @@ -271,10 +297,6 @@ var AgentService_ServiceDesc = grpc.ServiceDesc{ ServiceName: "agent.AgentService", HandlerType: (*AgentServiceServer)(nil), Methods: []grpc.MethodDesc{ - { - MethodName: "Result", - Handler: _AgentService_Result_Handler, - }, { MethodName: "Attestation", Handler: _AgentService_Attestation_Handler, @@ -291,6 +313,11 @@ var AgentService_ServiceDesc = grpc.ServiceDesc{ Handler: _AgentService_Data_Handler, ClientStreams: true, }, + { + StreamName: "Result", + Handler: _AgentService_Result_Handler, + ServerStreams: true, + }, }, Metadata: "agent/agent.proto", } diff --git a/agent/algorithm/algorithm.go b/agent/algorithm/algorithm.go index 38d9072a..1e6bd9b0 100644 --- a/agent/algorithm/algorithm.go +++ b/agent/algorithm/algorithm.go @@ -2,8 +2,33 @@ // SPDX-License-Identifier: Apache-2.0 package algorithm +import ( + "context" + + "google.golang.org/grpc/metadata" +) + +type AlgorithType string + +const ( + AlgoTypeBin AlgorithType = "bin" + AlgoTypePython AlgorithType = "python" + AlgoTypeKey = "algo_type" +) + +func AlgorithmTypeToContext(ctx context.Context, algoType string) context.Context { + return metadata.AppendToOutgoingContext(ctx, AlgoTypeKey, algoType) +} + +func AlgorithmTypeFromContext(ctx context.Context) string { + return metadata.ValueFromIncomingContext(ctx, AlgoTypeKey)[0] +} + // Algorithm is an interface that specifies the API for an algorithm. type Algorithm interface { // Run executes the algorithm and returns the result. Run() ([]byte, error) + + // Add dataset to algorithm. + AddDataset(dataset string) } diff --git a/agent/algorithm/binary/binary.go b/agent/algorithm/binary/binary.go index a89cc618..d487dd60 100644 --- a/agent/algorithm/binary/binary.go +++ b/agent/algorithm/binary/binary.go @@ -26,16 +26,19 @@ type binary struct { stdout io.Writer } -func New(logger *slog.Logger, eventsSvc events.Service, algoFile string, datasets ...string) algorithm.Algorithm { +func New(logger *slog.Logger, eventsSvc events.Service, algoFile string) algorithm.Algorithm { return &binary{ algoFile: algoFile, - datasets: datasets, logger: logger, stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc}, stdout: &algorithm.Stdout{Logger: logger}, } } +func (b *binary) AddDataset(dataset string) { + b.datasets = append(b.datasets, dataset) +} + func (b *binary) Run() ([]byte, error) { defer os.Remove(b.algoFile) defer func() { diff --git a/agent/algorithm/python/python.go b/agent/algorithm/python/python.go new file mode 100644 index 00000000..0fd88313 --- /dev/null +++ b/agent/algorithm/python/python.go @@ -0,0 +1,126 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 +package python + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + + "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/agent/events" + "github.com/ultravioletrs/cocos/pkg/socket" + "google.golang.org/grpc/metadata" +) + +const ( + socketPath = "unix_socket" + PyRuntime = "python3" + pyRuntimeKey = "python_runtime" +) + +func PythonRunTimeToContext(ctx context.Context, runtime string) context.Context { + return metadata.AppendToOutgoingContext(ctx, pyRuntimeKey, runtime) +} + +func PythonRunTimeFromContext(ctx context.Context) string { + return metadata.ValueFromIncomingContext(ctx, pyRuntimeKey)[0] +} + +var _ algorithm.Algorithm = (*python)(nil) + +type python struct { + algoFile string + datasets []string + logger *slog.Logger + stderr io.Writer + stdout io.Writer + runtime string + requirementsFile string +} + +func New(logger *slog.Logger, eventsSvc events.Service, runtime, requirementsFile, algoFile string) algorithm.Algorithm { + p := &python{ + algoFile: algoFile, + logger: logger, + stderr: &algorithm.Stderr{Logger: logger, EventSvc: eventsSvc}, + stdout: &algorithm.Stdout{Logger: logger}, + requirementsFile: requirementsFile, + } + if runtime != "" { + p.runtime = runtime + } else { + p.runtime = PyRuntime + } + return p +} + +func (p *python) AddDataset(dataset string) { + p.datasets = append(p.datasets, dataset) +} + +func (p *python) Run() ([]byte, error) { + venvPath := "venv" + createVenvCmd := exec.Command(p.runtime, "-m", "venv", venvPath) + createVenvCmd.Stderr = p.stderr + createVenvCmd.Stdout = p.stdout + if err := createVenvCmd.Run(); err != nil { + return nil, fmt.Errorf("error creating virtual environment: %v", err) + } + + pythonPath := filepath.Join(venvPath, "bin", "python") + + if p.requirementsFile != "" { + rcmd := exec.Command(pythonPath, "-m", "pip", "install", "-r", p.requirementsFile) + rcmd.Stderr = p.stderr + rcmd.Stdout = p.stdout + if err := rcmd.Run(); err != nil { + return nil, fmt.Errorf("error installing requirements: %v", err) + } + } + + defer os.Remove(p.algoFile) + defer func() { + for _, file := range p.datasets { + os.Remove(file) + } + }() + defer os.RemoveAll(venvPath) + + listener, err := socket.StartUnixSocketServer(socketPath) + if err != nil { + return nil, fmt.Errorf("error creating stdout pipe: %v", err) + } + defer listener.Close() + + dataChannel := make(chan []byte) + errorChannel := make(chan error) + + var result []byte + + go socket.AcceptConnection(listener, dataChannel, errorChannel) + + args := append([]string{p.algoFile, socketPath}, p.datasets...) + cmd := exec.Command(pythonPath, args...) + cmd.Stderr = p.stderr + cmd.Stdout = p.stdout + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("error starting algorithm: %v", err) + } + + if err := cmd.Wait(); err != nil { + return nil, fmt.Errorf("algorithm execution error: %v", err) + } + + select { + case result = <-dataChannel: + return result, nil + case err = <-errorChannel: + return nil, fmt.Errorf("error receiving data: %v", err) + } +} diff --git a/agent/api/grpc/endpoint.go b/agent/api/grpc/endpoint.go index 4641d8b8..9d7fd123 100644 --- a/agent/api/grpc/endpoint.go +++ b/agent/api/grpc/endpoint.go @@ -17,7 +17,7 @@ func algoEndpoint(svc agent.Service) endpoint.Endpoint { return algoRes{}, err } - algo := agent.Algorithm{Algorithm: req.Algorithm} + algo := agent.Algorithm{Algorithm: req.Algorithm, Requirements: req.Requirements} err := svc.Algo(ctx, algo) if err != nil { diff --git a/agent/api/grpc/interceptors.go b/agent/api/grpc/interceptors.go index 8085caf4..97250c7b 100644 --- a/agent/api/grpc/interceptors.go +++ b/agent/api/grpc/interceptors.go @@ -45,6 +45,13 @@ func (s *authInterceptor) AuthStreamInterceptor() grpc.StreamServerInterceptor { } wrapped := &wrappedServerStream{ServerStream: stream, ctx: ctx} return handler(srv, wrapped) + case agent.AgentService_Result_FullMethodName: + ctx, err := s.auth.AuthenticateUser(stream.Context(), auth.ConsumerRole) + if err != nil { + return status.Errorf(codes.Unauthenticated, err.Error()) + } + wrapped := &wrappedServerStream{ServerStream: stream, ctx: ctx} + return handler(srv, wrapped) default: return handler(srv, stream) } diff --git a/agent/api/grpc/requests.go b/agent/api/grpc/requests.go index 9f5e83e7..eecc22a2 100644 --- a/agent/api/grpc/requests.go +++ b/agent/api/grpc/requests.go @@ -7,7 +7,8 @@ import ( ) type algoReq struct { - Algorithm []byte `protobuf:"bytes,1,opt,name=algorithm,proto3" json:"algorithm,omitempty"` + Algorithm []byte `protobuf:"bytes,1,opt,name=algorithm,proto3" json:"algorithm,omitempty"` + Requirements []byte } func (req algoReq) validate() error { diff --git a/agent/api/grpc/server.go b/agent/api/grpc/server.go index 11c6ab3c..f52c0bbb 100644 --- a/agent/api/grpc/server.go +++ b/agent/api/grpc/server.go @@ -3,6 +3,7 @@ package grpc import ( + "bytes" "context" "errors" "io" @@ -13,6 +14,8 @@ import ( "google.golang.org/grpc/status" ) +const bufferSize = 1024 * 1024 + var _ agent.AgentServiceServer = (*grpcServer)(nil) type grpcServer struct { @@ -53,7 +56,8 @@ func decodeAlgoRequest(_ context.Context, grpcReq interface{}) (interface{}, err req := grpcReq.(*agent.AlgoRequest) return algoReq{ - Algorithm: req.Algorithm, + Algorithm: req.Algorithm, + Requirements: req.Requirements, }, nil } @@ -101,7 +105,7 @@ func encodeAttestationResponse(_ context.Context, response interface{}) (interfa // Algo implements agent.AgentServiceServer. func (s *grpcServer) Algo(stream agent.AgentService_AlgoServer) error { - var algoFile []byte + var algoFile, reqFile []byte for { algoChunk, err := stream.Recv() if err == io.EOF { @@ -111,8 +115,9 @@ func (s *grpcServer) Algo(stream agent.AgentService_AlgoServer) error { return status.Error(codes.Internal, err.Error()) } algoFile = append(algoFile, algoChunk.Algorithm...) + reqFile = append(reqFile, algoChunk.Requirements...) } - _, res, err := s.algo.ServeGRPC(stream.Context(), &agent.AlgoRequest{Algorithm: algoFile}) + _, res, err := s.algo.ServeGRPC(stream.Context(), &agent.AlgoRequest{Algorithm: algoFile, Requirements: reqFile}) if err != nil { return err } @@ -141,13 +146,32 @@ func (s *grpcServer) Data(stream agent.AgentService_DataServer) error { return stream.SendAndClose(ar) } -func (s *grpcServer) Result(ctx context.Context, req *agent.ResultRequest) (*agent.ResultResponse, error) { - _, res, err := s.result.ServeGRPC(ctx, req) +func (s *grpcServer) Result(req *agent.ResultRequest, stream agent.AgentService_ResultServer) error { + _, res, err := s.result.ServeGRPC(stream.Context(), req) if err != nil { - return nil, err + return err } rr := res.(*agent.ResultResponse) - return rr, nil + + reusltBuffer := bytes.NewBuffer(rr.File) + + buf := make([]byte, bufferSize) + + for { + n, err := reusltBuffer.Read(buf) + if err == io.EOF { + break + } + if err != nil { + return status.Error(codes.Internal, err.Error()) + } + + if err := stream.Send(&agent.ResultResponse{File: buf[:n]}); err != nil { + return status.Error(codes.Internal, err.Error()) + } + } + + return nil } func (s *grpcServer) Attestation(ctx context.Context, req *agent.AttestationRequest) (*agent.AttestationResponse, error) { diff --git a/agent/computations.go b/agent/computations.go index 67f09f42..22d966f5 100644 --- a/agent/computations.go +++ b/agent/computations.go @@ -52,9 +52,10 @@ type Dataset struct { type Datasets []Dataset type Algorithm struct { - Algorithm []byte `json:"-"` - Hash [32]byte `json:"hash,omitempty"` - UserKey []byte `json:"user_key,omitempty"` + Algorithm []byte `json:"-"` + Hash [32]byte `json:"hash,omitempty"` + UserKey []byte `json:"user_key,omitempty"` + Requirements []byte `json:"-"` } type ManifestIndexKey struct{} diff --git a/agent/service.go b/agent/service.go index 5d11d2bc..004f29d1 100644 --- a/agent/service.go +++ b/agent/service.go @@ -13,7 +13,9 @@ import ( "slices" "github.com/google/go-sev-guest/client" + "github.com/ultravioletrs/cocos/agent/algorithm" "github.com/ultravioletrs/cocos/agent/algorithm/binary" + "github.com/ultravioletrs/cocos/agent/algorithm/python" "github.com/ultravioletrs/cocos/agent/events" "golang.org/x/crypto/sha3" ) @@ -59,13 +61,12 @@ type Service interface { } type agentService struct { - computation Computation // Holds the current computation request details. - algorithm string // Filepath to the algorithm received for the computation. - datasets []string // Filepath to the datasets received for the computation. - result []byte // Stores the result of the computation. - sm *StateMachine // Manages the state transitions of the agent service. - runError error // Stores any error encountered during the computation run. - eventSvc events.Service // Service for publishing events related to computation. + computation Computation // Holds the current computation request details. + algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. + result []byte // Stores the result of the computation. + sm *StateMachine // Manages the state transitions of the agent service. + runError error // Stores any error encountered during the computation run. + eventSvc events.Service // Service for publishing events related to computation. } var _ Service = (*agentService)(nil) @@ -92,15 +93,15 @@ func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, cmp return svc } -func (as *agentService) Algo(ctx context.Context, algorithm Algorithm) error { +func (as *agentService) Algo(ctx context.Context, algo Algorithm) error { if as.sm.GetState() != receivingAlgorithm { return ErrStateNotReady } - if as.algorithm != "" { + if as.algorithm != nil { return ErrAllManifestItemsReceived } - hash := sha3.Sum256(algorithm.Algorithm) + hash := sha3.Sum256(algo.Algorithm) if hash != as.computation.Algorithm.Hash { return ErrHashMismatch @@ -111,7 +112,7 @@ func (as *agentService) Algo(ctx context.Context, algorithm Algorithm) error { return fmt.Errorf("error creating algorithm file: %v", err) } - if _, err := f.Write(algorithm.Algorithm); err != nil { + if _, err := f.Write(algo.Algorithm); err != nil { return fmt.Errorf("error writing algorithm to file: %v", err) } @@ -123,9 +124,31 @@ func (as *agentService) Algo(ctx context.Context, algorithm Algorithm) error { return fmt.Errorf("error closing file: %v", err) } - as.algorithm = f.Name() + algoType := algorithm.AlgorithmTypeFromContext(ctx) + if algoType == "" { + algoType = string(algorithm.AlgoTypeBin) + } + + switch algoType { + case string(algorithm.AlgoTypeBin): + as.algorithm = binary.New(as.sm.logger, as.eventSvc, f.Name()) + case string(algorithm.AlgoTypePython): + fr, err := os.CreateTemp("", "requirements.txt") + if err != nil { + return fmt.Errorf("error creating requirments file: %v", err) + } + + if _, err := fr.Write(algo.Requirements); err != nil { + return fmt.Errorf("error writing requirements to file: %v", err) + } + if err := fr.Close(); err != nil { + return fmt.Errorf("error closing file: %v", err) + } + runtime := python.PythonRunTimeFromContext(ctx) + as.algorithm = python.New(as.sm.logger, as.eventSvc, runtime, fr.Name(), f.Name()) + } - if as.algorithm != "" { + if as.algorithm != nil { as.sm.SendEvent(algorithmReceived) } @@ -164,7 +187,7 @@ func (as *agentService) Data(ctx context.Context, dataset Dataset) error { return fmt.Errorf("error closing file: %v", err) } - as.datasets = append(as.datasets, f.Name()) + as.algorithm.AddDataset(f.Name()) if len(as.computation.Datasets) == 0 { as.sm.SendEvent(dataReceived) @@ -211,8 +234,7 @@ func (as *agentService) runComputation() { as.sm.logger.Debug("computation run started") defer as.sm.SendEvent(runComplete) as.publishEvent("in-progress", json.RawMessage{})() - algorithm := binary.New(as.sm.logger, as.eventSvc, as.algorithm, as.datasets...) - result, err := algorithm.Run() + result, err := as.algorithm.Run() if err != nil { as.runError = err as.sm.logger.Warn(fmt.Sprintf("computation failed with error: %s", err.Error())) diff --git a/agent/state.go b/agent/state.go index ddd5b46a..f61a978d 100644 --- a/agent/state.go +++ b/agent/state.go @@ -88,16 +88,19 @@ func (sm *StateMachine) Start(ctx context.Context) { for { select { case event := <-sm.EventChan: - nextState, valid := sm.Transitions[sm.GetState()][event] + sm.mu.Lock() + nextState, valid := sm.Transitions[sm.State][event] if valid { - sm.mu.Lock() sm.State = nextState - sm.mu.Unlock() - sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", sm.GetState(), nextState)) + sm.logger.Debug(fmt.Sprintf("Transition: %v -> %v\n", sm.State, nextState)) } else { - sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.GetState())) + sm.logger.Error(fmt.Sprintf("Invalid transition: %v -> ???\n", sm.State)) } - stateFunc, exists := sm.StateFunctions[sm.GetState()] + sm.mu.Unlock() + + sm.mu.Lock() + stateFunc, exists := sm.StateFunctions[sm.State] + sm.mu.Unlock() if exists { go stateFunc() } @@ -114,13 +117,12 @@ func (sm *StateMachine) SendEvent(event event) { func (sm *StateMachine) GetState() state { sm.mu.Lock() - state := sm.State - sm.mu.Unlock() - return state + defer sm.mu.Unlock() + return sm.State } func (sm *StateMachine) SetState(state state) { sm.mu.Lock() + defer sm.mu.Unlock() sm.State = state - sm.mu.Unlock() } diff --git a/cli/algorithms.go b/cli/algorithms.go index 328bc78b..ec262661 100644 --- a/cli/algorithms.go +++ b/cli/algorithms.go @@ -3,16 +3,26 @@ package cli import ( + "context" "encoding/pem" "log" "os" "github.com/spf13/cobra" "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/agent/algorithm" + "github.com/ultravioletrs/cocos/agent/algorithm/python" + "google.golang.org/grpc/metadata" +) + +var ( + pythonRuntime string + algoType string + requirementsFile string ) func (cli *CLI) NewAlgorithmCmd() *cobra.Command { - return &cobra.Command{ + cmd := &cobra.Command{ Use: "algo", Short: "Upload an algorithm binary", Example: "algo ", @@ -27,8 +37,17 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { log.Fatalf("Error reading algorithm file: %v", err) } + var req []byte + if requirementsFile != "" { + req, err = os.ReadFile(requirementsFile) + if err != nil { + log.Fatalf("Error reading requirments file: %v", err) + } + } + algoReq := agent.Algorithm{ - Algorithm: algorithm, + Algorithm: algorithm, + Requirements: req, } privKeyFile, err := os.ReadFile(args[1]) @@ -40,11 +59,25 @@ func (cli *CLI) NewAlgorithmCmd() *cobra.Command { privKey := decodeKey(pemBlock) - if err := cli.agentSDK.Algo(cmd.Context(), algoReq, privKey); err != nil { + ctx := metadata.NewOutgoingContext(cmd.Context(), metadata.New(make(map[string]string))) + + if err := cli.agentSDK.Algo(addAlgoMetadata(ctx), algoReq, privKey); err != nil { log.Fatalf("Error uploading algorithm with error: %v", err) } log.Println("Successfully uploaded algorithm") }, } + + cmd.Flags().StringVarP(&algoType, "algorithm", "a", string(algorithm.AlgoTypeBin), "Algorithm type to run") + cmd.Flags().StringVar(&pythonRuntime, "python-runtime", python.PyRuntime, "Python runtime to use") + cmd.Flags().StringVarP(&requirementsFile, "requirements", "r", "", "Python requirements file") + + return cmd +} + +func addAlgoMetadata(ctx context.Context) context.Context { + ctx = algorithm.AlgorithmTypeToContext(ctx, algoType) + ctx = python.PythonRunTimeToContext(ctx, pythonRuntime) + return ctx } diff --git a/hal/linux/configs/cocos_defconfig b/hal/linux/configs/cocos_defconfig index 04830604..c80f25e6 100644 --- a/hal/linux/configs/cocos_defconfig +++ b/hal/linux/configs/cocos_defconfig @@ -53,3 +53,18 @@ BR2_PACKAGE_DOCKER_COMPOSE=y BR2_PACKAGE_DOCKER_ENGINE=y BR2_PACKAGE_CONTAINERD=y BR2_PACKAGE_RUNC=y +BR2_PACKAGE_IPTABLES=y + +# Python +BR2_PACKAGE_PYTHON3=y +BR2_PACKAGE_PYTHON_PIP=y +BR2_PACKAGE_BZIP2=y +BR2_PACKAGE_XZ=y +BR2_PACKAGE_ZIP=y +BR2_PACKAGE_PYTHON3_ZLIB=y +BR2_PACKAGE_PYTHON3_XZ=y +BR2_PACKAGE_PYTHON3_BZIP2=y +BR2_INSTALL_LIBSTDCPP=y +BR2_TOOLCHAIN_BUILDROOT_CXX=y +BR2_PACKAGE_HOST_GCC_TARGET=y +BR2_TOOLCHAIN_BUILDROOT_LIBSTDCPP=y diff --git a/internal/logger/protohandler.go b/internal/logger/protohandler.go index bfb7003e..db0b9a5e 100644 --- a/internal/logger/protohandler.go +++ b/internal/logger/protohandler.go @@ -40,19 +40,47 @@ func (h *handler) Enabled(_ context.Context, l slog.Level) bool { // Handle implements slog.Handler. func (h *handler) Handle(_ context.Context, r slog.Record) error { - agentLog := manager.ClientStreamMessage{Message: &manager.ClientStreamMessage_AgentLog{AgentLog: &manager.AgentLog{ - Timestamp: timestamppb.New(r.Time), - Message: r.Message, - Level: r.Level.String(), - }}} - - b, err := proto.Marshal(&agentLog) - if err != nil { - return err - } - if _, err := h.w.Write(b); err != nil { - return err + message := r.Message + timestamp := timestamppb.New(r.Time) + level := r.Level.String() + + // Calculate the number of chunks + chunkSize := 500 + numChunks := (len(message) + chunkSize - 1) / chunkSize + + for i := 0; i < numChunks; i++ { + start := i * chunkSize + end := start + chunkSize + if end > len(message) { + end = len(message) + } + + // Create a chunk of the message + chunk := message[start:end] + + // Create the agent log with the chunk + agentLog := manager.ClientStreamMessage{ + Message: &manager.ClientStreamMessage_AgentLog{ + AgentLog: &manager.AgentLog{ + Timestamp: timestamp, + Message: chunk, + Level: level, + }, + }, + } + + // Marshal the chunk to protobuf + b, err := proto.Marshal(&agentLog) + if err != nil { + return err + } + + // Write the chunk to the writer + if _, err := h.w.Write(b); err != nil { + return err + } } + return nil } diff --git a/pkg/manager/manager.pb.go b/pkg/manager/manager.pb.go index 8ceaf2b4..64bbca75 100644 --- a/pkg/manager/manager.pb.go +++ b/pkg/manager/manager.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.34.2 -// protoc v5.27.2 +// protoc v4.25.3 // source: manager/manager.proto package manager diff --git a/pkg/manager/manager_grpc.pb.go b/pkg/manager/manager_grpc.pb.go index ee6b7b52..2aee96a5 100644 --- a/pkg/manager/manager_grpc.pb.go +++ b/pkg/manager/manager_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.4.0 -// - protoc v5.27.2 +// - protoc v4.25.3 // source: manager/manager.proto package manager diff --git a/pkg/progressbar/progressbar.go b/pkg/progressbar/progressbar.go index 20ddc654..c809ee1e 100644 --- a/pkg/progressbar/progressbar.go +++ b/pkg/progressbar/progressbar.go @@ -79,10 +79,32 @@ func New() *ProgressBar { return &ProgressBar{} } -func (p *ProgressBar) SendAlgorithm(description string, buffer *bytes.Buffer, stream *agent.AgentService_AlgoClient) error { - return p.sendData(description, buffer, &algoClientWrapper{client: stream}, func(data []byte) interface{} { +func (p *ProgressBar) SendAlgorithm(description string, algobuffer, reqBuffer *bytes.Buffer, stream *agent.AgentService_AlgoClient) error { + totalSize := algobuffer.Len() + reqBuffer.Len() + p.reset(description, totalSize) + + wrapper := &algoClientWrapper{client: stream} + + // Send reqBuffer first + if err := p.sendBuffer(reqBuffer, wrapper, func(data []byte) interface{} { + return &agent.AlgoRequest{Requirements: data} + }); err != nil { + return err + } + + // Then send algobuffer + if err := p.sendBuffer(algobuffer, wrapper, func(data []byte) interface{} { return &agent.AlgoRequest{Algorithm: data} - }) + }); err != nil { + return err + } + + if _, err := io.WriteString(os.Stdout, "\n"); err != nil { + return err + } + + _, err := wrapper.CloseAndRecv() + return err } func (p *ProgressBar) SendData(description string, buffer *bytes.Buffer, stream *agent.AgentService_DataClient) error { @@ -123,6 +145,32 @@ func (p *ProgressBar) sendData(description string, buffer *bytes.Buffer, stream return err } +func (p *ProgressBar) sendBuffer(buffer *bytes.Buffer, stream streamSender, createRequest func([]byte) interface{}) error { + buf := make([]byte, bufferSize) + + for { + n, err := buffer.Read(buf) + if err == io.EOF { + break + } + if err != nil { + return err + } + + p.updateProgress(n) + + if err := stream.Send(createRequest(buf[:n])); err != nil { + return err + } + + if err := p.renderProgressBar(); err != nil { + return err + } + } + + return nil +} + func (p *ProgressBar) reset(description string, totalBytes int) { p.currentUploadedBytes = 0 p.currentUploadPercentage = 0 diff --git a/pkg/sdk/agent.go b/pkg/sdk/agent.go index 8a21b50d..fdc812ac 100644 --- a/pkg/sdk/agent.go +++ b/pkg/sdk/agent.go @@ -13,6 +13,7 @@ import ( "crypto/sha256" "encoding/base64" "errors" + "io" "log/slog" "github.com/ultravioletrs/cocos/agent" @@ -53,16 +54,20 @@ func (sdk *agentSDK) Algo(ctx context.Context, algorithm agent.Algorithm, privKe return err } - ctx = metadata.NewOutgoingContext(ctx, md) + for k, v := range md { + ctx = metadata.AppendToOutgoingContext(ctx, k, v[0]) + } + stream, err := sdk.client.Algo(ctx) if err != nil { sdk.logger.Error("Failed to call Algo RPC") return err } algoBuffer := bytes.NewBuffer(algorithm.Algorithm) + reqBuffer := bytes.NewBuffer(algorithm.Requirements) pb := progressbar.New() - if err := pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, &stream); err != nil { + if err := pb.SendAlgorithm(algoProgressBarDescription, algoBuffer, reqBuffer, &stream); err != nil { sdk.logger.Error("Failed to send Algorithm") return err } @@ -104,13 +109,25 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any) ([]byte, error) { } ctx = metadata.NewOutgoingContext(ctx, md) - response, err := sdk.client.Result(ctx, request) + stream, err := sdk.client.Result(ctx, request) if err != nil { sdk.logger.Error("Failed to call Result RPC") return nil, err } - return response.File, nil + var result []byte + for { + response, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + result = append(result, response.File...) + } + + return result, nil } func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte) ([]byte, error) { diff --git a/test/computations/main.go b/test/computations/main.go index 38ce280e..281e4444 100644 --- a/test/computations/main.go +++ b/test/computations/main.go @@ -33,7 +33,7 @@ const ( var ( algoPath = "./test/manual/algo/lin_reg.py" - dataPath = "./test/manual/data/iris.csv" + dataPaths []string attestedTLS = false pubKeyFile string ) @@ -57,16 +57,15 @@ func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, au } pubPem, _ := pem.Decode(pubKey) - var dataset []*manager.Dataset - if dataPath != "" { + var datasets []*manager.Dataset + for _, dataPath := range dataPaths { data, err := os.ReadFile(dataPath) if err != nil { s.logger.Error(fmt.Sprintf("failed to read data file: %s", err)) return } dataHash := sha3.Sum256(data) - - dataset = []*manager.Dataset{{Hash: dataHash[:], UserKey: pubPem.Bytes}} + datasets = append(datasets, &manager.Dataset{Hash: dataHash[:], UserKey: pubPem.Bytes}) } algoHash := sha3.Sum256(algo) @@ -76,7 +75,7 @@ func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, au Id: "1", Name: "sample computation", Description: "sample descrption", - Datasets: dataset, + Datasets: datasets, Algorithm: &manager.Algorithm{Hash: algoHash[:], UserKey: pubPem.Bytes}, ResultConsumers: []*manager.ResultConsumer{{UserKey: pubPem.Bytes}}, AgentConfig: &manager.AgentConfig{ @@ -91,17 +90,20 @@ func (s *svc) Run(ipAdress string, reqChan chan *manager.ServerStreamMessage, au func main() { if len(os.Args) < 5 { - log.Fatalf("usage: %s ", os.Args[0]) + log.Fatalf("usage: %s ", os.Args[0]) } - dataPath = os.Args[1] - algoPath = os.Args[2] - pubKeyFile = os.Args[3] - attestedTLSParam, err := strconv.ParseBool(os.Args[4]) + algoPath = os.Args[1] + pubKeyFile = os.Args[2] + attestedTLSParam, err := strconv.ParseBool(os.Args[3]) if err != nil { log.Fatalf("usage: %s , must be a bool value", os.Args[0]) } attestedTLS = attestedTLSParam + for i := 4; i < len(os.Args); i++ { + dataPaths = append(dataPaths, os.Args[i]) + } + ctx, cancel := context.WithCancel(context.Background()) g, ctx := errgroup.WithContext(ctx) incomingChan := make(chan *manager.ClientStreamMessage) diff --git a/test/manual/algo/requirements.txt b/test/manual/algo/requirements.txt new file mode 100644 index 00000000..3bc94b02 --- /dev/null +++ b/test/manual/algo/requirements.txt @@ -0,0 +1,3 @@ +pandas +scikit-learn +joblib