diff --git a/.github/workflows/checkproto.yaml b/.github/workflows/checkproto.yaml index e00dd358..bc0b2804 100644 --- a/.github/workflows/checkproto.yaml +++ b/.github/workflows/checkproto.yaml @@ -33,8 +33,8 @@ jobs: - name: Set up protoc run: | - PROTOC_VERSION=29.0 - PROTOC_GEN_VERSION=v1.36.0 + PROTOC_VERSION=29.3 + PROTOC_GEN_VERSION=v1.36.4 PROTOC_GRPC_VERSION=v1.5.1 # Download and install protoc @@ -55,7 +55,7 @@ jobs: - name: Set up Cocos-AI run: | # Rename .pb.go files to .pb.go.tmp to prevent conflicts - for p in $(ls pkg/manager/*.pb.go); do + for p in $(ls manager/*.pb.go); do mv $p $p.tmp done @@ -67,7 +67,7 @@ jobs: make protoc # Compare generated Go files with the original ones - for p in $(ls pkg/manager/*.pb.go); do + for p in $(ls manager/*.pb.go); do if ! cmp -s $p $p.tmp; then echo "Proto file and generated Go file $p are out of sync!" exit 1 diff --git a/agent/agent.pb.go b/agent/agent.pb.go index 8a090d6b..19daa5b6 100644 --- a/agent/agent.pb.go +++ b/agent/agent.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.0 -// protoc v5.29.0 +// protoc-gen-go v1.36.4 +// protoc v5.29.3 // source: agent/agent.proto package agent @@ -14,6 +14,7 @@ import ( protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -281,7 +282,9 @@ func (x *ResultResponse) GetFile() []byte { type AttestationRequest struct { state protoimpl.MessageState `protogen:"open.v1"` - ReportData []byte `protobuf:"bytes,1,opt,name=report_data,json=reportData,proto3" json:"report_data,omitempty"` // Should be of length 64. + TeeNonce []byte `protobuf:"bytes,1,opt,name=teeNonce,proto3" json:"teeNonce,omitempty"` // Should be less or equal 64 bytes. + VtpmNonce []byte `protobuf:"bytes,2,opt,name=vtpmNonce,proto3" json:"vtpmNonce,omitempty"` // Should be less or equal 32 bytes. + Type int32 `protobuf:"varint,3,opt,name=type,proto3" json:"type,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -316,13 +319,27 @@ func (*AttestationRequest) Descriptor() ([]byte, []int) { return file_agent_agent_proto_rawDescGZIP(), []int{6} } -func (x *AttestationRequest) GetReportData() []byte { +func (x *AttestationRequest) GetTeeNonce() []byte { if x != nil { - return x.ReportData + return x.TeeNonce } return nil } +func (x *AttestationRequest) GetVtpmNonce() []byte { + if x != nil { + return x.VtpmNonce + } + return nil +} + +func (x *AttestationRequest) GetType() int32 { + if x != nil { + return x.Type + } + return 0 +} + type AttestationResponse struct { state protoimpl.MessageState `protogen:"open.v1"` File []byte `protobuf:"bytes,1,opt,name=file,proto3" json:"file,omitempty"` @@ -369,7 +386,7 @@ func (x *AttestationResponse) GetFile() []byte { var File_agent_agent_proto protoreflect.FileDescriptor -var file_agent_agent_proto_rawDesc = []byte{ +var file_agent_agent_proto_rawDesc = string([]byte{ 0x0a, 0x11, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2f, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x22, 0x4f, 0x0a, 0x0b, 0x41, 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6c, 0x67, @@ -386,40 +403,43 @@ var file_agent_agent_proto_rawDesc = []byte{ 0x22, 0x0f, 0x0a, 0x0d, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x24, 0x0a, 0x0e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0x35, 0x0a, 0x12, 0x41, 0x74, 0x74, 0x65, 0x73, - 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1f, 0x0a, - 0x0b, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0c, 0x52, 0x0a, 0x72, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x44, 0x61, 0x74, 0x61, 0x22, 0x29, - 0x0a, 0x13, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfd, 0x01, 0x0a, 0x0c, 0x41, 0x67, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c, - 0x67, 0x6f, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, - 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, - 0x33, 0x0a, 0x04, 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, - 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, - 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x28, 0x01, 0x12, 0x39, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x14, - 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, - 0x48, 0x0a, 0x0b, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, - 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e, - 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61, - 0x67, 0x65, 0x6e, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} + 0x0c, 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x22, 0x62, 0x0a, 0x12, 0x41, 0x74, 0x74, 0x65, 0x73, + 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, + 0x08, 0x74, 0x65, 0x65, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, + 0x08, 0x74, 0x65, 0x65, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x76, 0x74, 0x70, + 0x6d, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x76, 0x74, + 0x70, 0x6d, 0x4e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x22, 0x29, 0x0a, 0x13, 0x41, + 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, + 0x52, 0x04, 0x66, 0x69, 0x6c, 0x65, 0x32, 0xfd, 0x01, 0x0a, 0x0c, 0x41, 0x67, 0x65, 0x6e, 0x74, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x33, 0x0a, 0x04, 0x41, 0x6c, 0x67, 0x6f, 0x12, + 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x6c, 0x67, 0x6f, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x12, 0x33, 0x0a, 0x04, + 0x44, 0x61, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x61, 0x74, + 0x61, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x13, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, + 0x2e, 0x44, 0x61, 0x74, 0x61, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, + 0x01, 0x12, 0x39, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x14, 0x2e, 0x61, 0x67, + 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x15, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x48, 0x0a, 0x0b, + 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x2e, 0x61, 0x67, + 0x65, 0x6e, 0x74, 0x2e, 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x2e, 0x41, + 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x42, 0x09, 0x5a, 0x07, 0x2e, 0x2f, 0x61, 0x67, 0x65, 0x6e, + 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +}) var ( file_agent_agent_proto_rawDescOnce sync.Once - file_agent_agent_proto_rawDescData = file_agent_agent_proto_rawDesc + file_agent_agent_proto_rawDescData []byte ) func file_agent_agent_proto_rawDescGZIP() []byte { file_agent_agent_proto_rawDescOnce.Do(func() { - file_agent_agent_proto_rawDescData = protoimpl.X.CompressGZIP(file_agent_agent_proto_rawDescData) + file_agent_agent_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_agent_agent_proto_rawDesc), len(file_agent_agent_proto_rawDesc))) }) return file_agent_agent_proto_rawDescData } @@ -460,7 +480,7 @@ func file_agent_agent_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_agent_agent_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_agent_agent_proto_rawDesc), len(file_agent_agent_proto_rawDesc)), NumEnums: 0, NumMessages: 8, NumExtensions: 0, @@ -471,7 +491,6 @@ func file_agent_agent_proto_init() { MessageInfos: file_agent_agent_proto_msgTypes, }.Build() File_agent_agent_proto = out.File - file_agent_agent_proto_rawDesc = nil file_agent_agent_proto_goTypes = nil file_agent_agent_proto_depIdxs = nil } diff --git a/agent/agent.proto b/agent/agent.proto index f07b8f02..d9877426 100644 --- a/agent/agent.proto +++ b/agent/agent.proto @@ -36,7 +36,9 @@ message ResultResponse { } message AttestationRequest { - bytes report_data = 1; // Should be of length 64. + bytes teeNonce = 1; // Should be less or equal 64 bytes. + bytes vtpmNonce = 2; // Should be less or equal 32 bytes. + int32 type = 3; } message AttestationResponse { diff --git a/agent/agent_grpc.pb.go b/agent/agent_grpc.pb.go index 6459ecec..80f9211a 100644 --- a/agent/agent_grpc.pb.go +++ b/agent/agent_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.5.1 -// - protoc v5.29.0 +// - protoc v5.29.3 // source: agent/agent.proto package agent diff --git a/agent/api/grpc/endpoint.go b/agent/api/grpc/endpoint.go index da41add6..6f1a4f61 100644 --- a/agent/api/grpc/endpoint.go +++ b/agent/api/grpc/endpoint.go @@ -70,7 +70,7 @@ func attestationEndpoint(svc agent.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return attestationRes{}, err } - file, err := svc.Attestation(ctx, req.ReportData) + file, err := svc.Attestation(ctx, req.TeeNonce, req.VtpmNonce, req.AttType) if err != nil { return attestationRes{}, err } diff --git a/agent/api/grpc/endpoint_test.go b/agent/api/grpc/endpoint_test.go index 4ab3494d..086bf9bc 100644 --- a/agent/api/grpc/endpoint_test.go +++ b/agent/api/grpc/endpoint_test.go @@ -141,11 +141,11 @@ func TestAttestationEndpoint(t *testing.T) { }{ { name: "Success", - req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))}, + req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: 0}, }, { name: "Service Error", - req: attestationReq{ReportData: sha3.Sum512([]byte("report data"))}, + req: attestationReq{TeeNonce: sha3.Sum512([]byte("report data")), VtpmNonce: sha3.Sum256([]byte("vtpm nonce")), AttType: 0}, expectedErr: true, }, } @@ -153,9 +153,9 @@ func TestAttestationEndpoint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.name == svcErr { - svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, errors.New("")).Once() + svc.On("Attestation", context.Background(), tt.req.TeeNonce, tt.req.VtpmNonce, tt.req.AttType).Return([]byte{}, errors.New("")).Once() } else { - svc.On("Attestation", context.Background(), tt.req.ReportData).Return([]byte{}, nil).Once() + svc.On("Attestation", context.Background(), tt.req.TeeNonce, tt.req.VtpmNonce, tt.req.AttType).Return([]byte{}, nil).Once() } endpoint := attestationEndpoint(svc) res, err := endpoint(context.Background(), tt.req) diff --git a/agent/api/grpc/requests.go b/agent/api/grpc/requests.go index 4d7d2206..37f1a7ef 100644 --- a/agent/api/grpc/requests.go +++ b/agent/api/grpc/requests.go @@ -38,7 +38,9 @@ func (req resultReq) validate() error { } type attestationReq struct { - ReportData [64]byte + TeeNonce [64]byte + VtpmNonce [32]byte + AttType int32 } func (req attestationReq) validate() error { diff --git a/agent/api/grpc/server.go b/agent/api/grpc/server.go index 2ac9a387..ae51c393 100644 --- a/agent/api/grpc/server.go +++ b/agent/api/grpc/server.go @@ -11,6 +11,7 @@ import ( "github.com/go-kit/kit/transport/grpc" "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" @@ -96,10 +97,20 @@ func encodeResultResponse(_ context.Context, response interface{}) (interface{}, func decodeAttestationRequest(_ context.Context, grpcReq interface{}) (interface{}, error) { req := grpcReq.(*agent.AttestationRequest) - if len(req.ReportData) != agent.ReportDataSize { - return nil, errors.New("malformed report data, expect 64 bytes") + var reportData [agent.Nonce]byte + var nonce [vtpm.Nonce]byte + + if len(req.TeeNonce) > agent.Nonce { + return nil, errors.New("malformed report data, expect less or equal to 64 bytes") + } + + if len(req.VtpmNonce) > vtpm.Nonce { + return nil, errors.New("malformed vTPM nonce, expect less or equal to 32 bytes") } - return attestationReq{ReportData: [agent.ReportDataSize]byte(req.ReportData)}, nil + + copy(reportData[:], req.TeeNonce) + copy(nonce[:], req.VtpmNonce) + return attestationReq{TeeNonce: reportData, VtpmNonce: nonce, AttType: req.Type}, nil } func encodeAttestationResponse(_ context.Context, response interface{}) (interface{}, error) { diff --git a/agent/api/grpc/server_test.go b/agent/api/grpc/server_test.go index 8f07db99..df3c7bbc 100644 --- a/agent/api/grpc/server_test.go +++ b/agent/api/grpc/server_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/agent/mocks" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -151,10 +152,12 @@ func TestAttestation(t *testing.T) { mockStream := &MockAgentService_AttestationServer{ctx: context.Background()} mockStream.On("Send", mock.AnythingOfType("*agent.AttestationResponse")).Return(nil) - reportData := [agent.ReportDataSize]byte{} - mockService.On("Attestation", mock.Anything, reportData).Return([]byte("attestation data"), nil) + reportData := [agent.Nonce]byte{} + vtpmNonce := [vtpm.Nonce]byte{} + attestationType := 0 + mockService.On("Attestation", mock.Anything, reportData, vtpmNonce, int32(attestationType)).Return([]byte("attestation data"), nil) - err := server.Attestation(&agent.AttestationRequest{ReportData: reportData[:]}, mockStream) + err := server.Attestation(&agent.AttestationRequest{TeeNonce: reportData[:]}, mockStream) assert.NoError(t, err) mockService.AssertExpectations(t) @@ -199,11 +202,11 @@ func TestEncodeResultResponse(t *testing.T) { } func TestDecodeAttestationRequest(t *testing.T) { - reportData := [agent.ReportDataSize]byte{} - req := &agent.AttestationRequest{ReportData: reportData[:]} + nonce := [agent.Nonce]byte{} + req := &agent.AttestationRequest{TeeNonce: nonce[:]} decoded, err := decodeAttestationRequest(context.Background(), req) assert.NoError(t, err) - assert.Equal(t, attestationReq{ReportData: reportData}, decoded) + assert.Equal(t, attestationReq{TeeNonce: nonce}, decoded) } func TestEncodeAttestationResponse(t *testing.T) { diff --git a/agent/api/logging.go b/agent/api/logging.go index 60f65d5f..dd22d2d9 100644 --- a/agent/api/logging.go +++ b/agent/api/logging.go @@ -13,6 +13,7 @@ import ( "time" "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) var _ agent.Service = (*loggingMiddleware)(nil) @@ -103,7 +104,7 @@ func (lm *loggingMiddleware) Result(ctx context.Context) (response []byte, err e return lm.svc.Result(ctx) } -func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [agent.ReportDataSize]byte) (response []byte, err error) { +func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [agent.Nonce]byte, nonce [vtpm.Nonce]byte, attType int32) (response []byte, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method Attestation took %s to complete", time.Since(begin)) if err != nil { @@ -113,5 +114,5 @@ func (lm *loggingMiddleware) Attestation(ctx context.Context, reportData [agent. lm.logger.Info(fmt.Sprintf("%s without errors", message)) }(time.Now()) - return lm.svc.Attestation(ctx, reportData) + return lm.svc.Attestation(ctx, reportData, nonce, attType) } diff --git a/agent/api/metrics.go b/agent/api/metrics.go index de4fdb91..f7a77703 100644 --- a/agent/api/metrics.go +++ b/agent/api/metrics.go @@ -12,6 +12,7 @@ import ( "github.com/go-kit/kit/metrics" "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) var _ agent.Service = (*metricsMiddleware)(nil) @@ -89,11 +90,11 @@ func (ms *metricsMiddleware) Result(ctx context.Context) ([]byte, error) { return ms.svc.Result(ctx) } -func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [agent.ReportDataSize]byte) ([]byte, error) { +func (ms *metricsMiddleware) Attestation(ctx context.Context, reportData [agent.Nonce]byte, nonce [vtpm.Nonce]byte, attType int32) ([]byte, error) { defer func(begin time.Time) { ms.counter.With("method", "attestation").Add(1) ms.latency.With("method", "attestation").Observe(time.Since(begin).Seconds()) }(time.Now()) - return ms.svc.Attestation(ctx, reportData) + return ms.svc.Attestation(ctx, reportData, nonce, attType) } diff --git a/agent/cvms/server/cvm.go b/agent/cvms/server/cvm.go index fadca8f7..fa4d33e9 100644 --- a/agent/cvms/server/cvm.go +++ b/agent/cvms/server/cvm.go @@ -71,7 +71,7 @@ func (as *agentServer) Start(cfg agent.AgentConfig, cmp agent.Computation) error return err } - qp, err := quoteprovider.GetQuoteProvider() + qp, err := quoteprovider.GetLeveledQuoteProvider() if err != nil { as.logger.Error(fmt.Sprintf("failed to create quote provider %s", err.Error())) return err diff --git a/agent/mocks/agent.go b/agent/mocks/agent.go index fc717fdd..62635e8f 100644 --- a/agent/mocks/agent.go +++ b/agent/mocks/agent.go @@ -73,9 +73,9 @@ func (_c *Service_Algo_Call) RunAndReturn(run func(context.Context, agent.Algori return _c } -// Attestation provides a mock function with given fields: ctx, reportData -func (_m *Service) Attestation(ctx context.Context, reportData [64]byte) ([]byte, error) { - ret := _m.Called(ctx, reportData) +// Attestation provides a mock function with given fields: ctx, reportData, nonce, attType +func (_m *Service) Attestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType int32) ([]byte, error) { + ret := _m.Called(ctx, reportData, nonce, attType) if len(ret) == 0 { panic("no return value specified for Attestation") @@ -83,19 +83,19 @@ func (_m *Service) Attestation(ctx context.Context, reportData [64]byte) ([]byte var r0 []byte var r1 error - if rf, ok := ret.Get(0).(func(context.Context, [64]byte) ([]byte, error)); ok { - return rf(ctx, reportData) + if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, int32) ([]byte, error)); ok { + return rf(ctx, reportData, nonce, attType) } - if rf, ok := ret.Get(0).(func(context.Context, [64]byte) []byte); ok { - r0 = rf(ctx, reportData) + if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, int32) []byte); ok { + r0 = rf(ctx, reportData, nonce, attType) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) } } - if rf, ok := ret.Get(1).(func(context.Context, [64]byte) error); ok { - r1 = rf(ctx, reportData) + if rf, ok := ret.Get(1).(func(context.Context, [64]byte, [32]byte, int32) error); ok { + r1 = rf(ctx, reportData, nonce, attType) } else { r1 = ret.Error(1) } @@ -111,13 +111,15 @@ type Service_Attestation_Call struct { // Attestation is a helper method to define mock.On call // - ctx context.Context // - reportData [64]byte -func (_e *Service_Expecter) Attestation(ctx interface{}, reportData interface{}) *Service_Attestation_Call { - return &Service_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData)} +// - nonce [32]byte +// - attType int32 +func (_e *Service_Expecter) Attestation(ctx interface{}, reportData interface{}, nonce interface{}, attType interface{}) *Service_Attestation_Call { + return &Service_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, nonce, attType)} } -func (_c *Service_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte)) *Service_Attestation_Call { +func (_c *Service_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, nonce [32]byte, attType int32)) *Service_Attestation_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].([64]byte)) + run(args[0].(context.Context), args[1].([64]byte), args[2].([32]byte), args[3].(int32)) }) return _c } @@ -127,7 +129,7 @@ func (_c *Service_Attestation_Call) Return(_a0 []byte, _a1 error) *Service_Attes return _c } -func (_c *Service_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte) ([]byte, error)) *Service_Attestation_Call { +func (_c *Service_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, [32]byte, int32) ([]byte, error)) *Service_Attestation_Call { _c.Call.Return(run) return _c } diff --git a/agent/service.go b/agent/service.go index 40871b0f..87563a70 100644 --- a/agent/service.go +++ b/agent/service.go @@ -23,6 +23,7 @@ import ( "github.com/ultravioletrs/cocos/agent/events" "github.com/ultravioletrs/cocos/agent/statemachine" "github.com/ultravioletrs/cocos/internal" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "golang.org/x/crypto/sha3" ) @@ -55,6 +56,14 @@ const ( RunFailed ) +type AttestationType int + +const ( + SNP AttestationType = iota + VTPM + SNPvTPM +) + //go:generate stringer -type=Status type Status uint8 @@ -70,8 +79,9 @@ const ( const ( // ReportDataSize is the size of the report data expected by the attestation service. - ReportDataSize = 64 + Nonce = 64 algoFilePermission = 0o700 + VMPL = 2 ) var ( @@ -99,6 +109,8 @@ var ( ErrAllResultsConsumed = errors.New("all results have been consumed by declared consumers") // ErrAttestationFailed attestation failed. ErrAttestationFailed = errors.New("failed to get raw quote") + // ErrAttType indicates that the attestation type that is requested does not exist or is not supported. + ErrAttestationType = errors.New("attestation type does not exist or is not supported") ) // Service specifies an API that must be fullfiled by the domain service @@ -109,28 +121,28 @@ type Service interface { Algo(ctx context.Context, algorithm Algorithm) error Data(ctx context.Context, dataset Dataset) error Result(ctx context.Context) ([]byte, error) - Attestation(ctx context.Context, reportData [ReportDataSize]byte) ([]byte, error) + Attestation(ctx context.Context, reportData [Nonce]byte, nonce [vtpm.Nonce]byte, attType int32) ([]byte, error) State() string } type agentService struct { mu sync.Mutex - computation Computation // Holds the current computation request details. - algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. - result []byte // Stores the result of the computation. - sm statemachine.StateMachine // Manages the state transitions of the agent service. - runError error // Stores any error encountered during the computation run. - eventSvc events.Service // Service for publishing events related to computation. - quoteProvider client.QuoteProvider // Provider for generating attestation quotes. - logger *slog.Logger // Logger for the agent service. - resultsConsumed bool // Indicates if the results have been consumed. - cancel context.CancelFunc // Cancels the computation context. + computation Computation // Holds the current computation request details. + algorithm algorithm.Algorithm // Filepath to the algorithm received for the computation. + result []byte // Stores the result of the computation. + sm statemachine.StateMachine // Manages the state transitions of the agent service. + runError error // Stores any error encountered during the computation run. + eventSvc events.Service // Service for publishing events related to computation. + quoteProvider client.LeveledQuoteProvider // Provider for generating attestation quotes. + logger *slog.Logger // Logger for the agent service. + resultsConsumed bool // Indicates if the results have been consumed. + cancel context.CancelFunc // Cancels the computation context. } var _ Service = (*agentService)(nil) // New instantiates the agent service implementation. -func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quoteProvider client.QuoteProvider) Service { +func New(ctx context.Context, logger *slog.Logger, eventSvc events.Service, quoteProvider client.LeveledQuoteProvider) Service { sm := statemachine.NewStateMachine(Idle) ctx, cancel := context.WithCancel(ctx) svc := &agentService{ @@ -397,13 +409,32 @@ func (as *agentService) Result(ctx context.Context) ([]byte, error) { return as.result, as.runError } -func (as *agentService) Attestation(ctx context.Context, reportData [ReportDataSize]byte) ([]byte, error) { - rawQuote, err := as.quoteProvider.GetRawQuote(reportData) - if err != nil { - return []byte{}, err +func (as *agentService) Attestation(ctx context.Context, reportData [Nonce]byte, nonce [vtpm.Nonce]byte, attType int32) ([]byte, error) { + switch AttestationType(attType) { + case SNP: + fmt.Println("SEV") + rawQuote, err := as.quoteProvider.GetRawQuoteAtLevel(reportData, VMPL) + if err != nil { + return []byte{}, err + } + return rawQuote, nil + case VTPM: + fmt.Println("vTPM") + vTPMQuote, err := vtpm.Attest(reportData[:], nonce[:], false) + if err != nil { + return []byte{}, err + } + return vTPMQuote, nil + case SNPvTPM: + fmt.Println("SEV and vTPM") + vTPMQuote, err := vtpm.Attest(reportData[:], nonce[:], true) + if err != nil { + return []byte{}, err + } + return vTPMQuote, nil + default: + return []byte{}, ErrAttestationType } - - return rawQuote, nil } func (as *agentService) runComputation(state statemachine.State) { diff --git a/agent/service_test.go b/agent/service_test.go index 48e21a88..1030072b 100644 --- a/agent/service_test.go +++ b/agent/service_test.go @@ -22,6 +22,7 @@ import ( smmocks "github.com/ultravioletrs/cocos/agent/statemachine/mocks" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" mocks2 "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "golang.org/x/crypto/sha3" "google.golang.org/grpc/metadata" ) @@ -35,7 +36,7 @@ var ( const datasetFile = "iris.csv" func TestAlgo(t *testing.T) { - qp, err := quoteprovider.GetQuoteProvider() + qp, err := quoteprovider.GetLeveledQuoteProvider() require.NoError(t, err) algo, err := os.ReadFile(algoPath) @@ -139,7 +140,7 @@ func TestAlgo(t *testing.T) { } func TestData(t *testing.T) { - qp, err := quoteprovider.GetQuoteProvider() + qp, err := quoteprovider.GetLeveledQuoteProvider() require.NoError(t, err) algo, err := os.ReadFile(algoPath) @@ -240,7 +241,7 @@ func TestData(t *testing.T) { } func TestResult(t *testing.T) { - qp, err := quoteprovider.GetQuoteProvider() + qp, err := quoteprovider.GetLeveledQuoteProvider() require.NoError(t, err) cases := []struct { @@ -323,23 +324,26 @@ func TestResult(t *testing.T) { } func TestAttestation(t *testing.T) { - qp := new(mocks2.QuoteProvider) + qp := new(mocks2.LeveledQuoteProvider) cases := []struct { name string - reportData [ReportDataSize]byte + reportData [Nonce]byte + nonce [vtpm.Nonce]byte rawQuote []uint8 err error }{ { name: "Test attestation successful", reportData: generateReportData(), + nonce: [32]byte{}, rawQuote: make([]uint8, 0), err: nil, }, { name: "Test attestation failed", reportData: generateReportData(), + nonce: [32]byte{}, rawQuote: nil, err: ErrAttestationFailed, }, @@ -355,22 +359,22 @@ func TestAttestation(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() - getQuote := qp.On("GetRawQuote", mock.Anything).Return(tc.rawQuote, tc.err) + getQuote := qp.On("GetRawQuoteAtLevel", mock.Anything, mock.Anything).Return(tc.rawQuote, tc.err) if tc.err != ErrAttestationFailed { - getQuote = qp.On("GetRawQuote", mock.Anything).Return(tc.reportData, nil) + getQuote = qp.On("GetRawQuoteAtLevel", mock.Anything, mock.Anything).Return(tc.nonce, nil) } defer getQuote.Unset() svc := New(ctx, mglog.NewMock(), events, qp) time.Sleep(300 * time.Millisecond) - _, err := svc.Attestation(ctx, tc.reportData) + _, err := svc.Attestation(ctx, tc.reportData, tc.nonce, 0) assert.True(t, errors.Contains(err, tc.err), "expected %v, got %v", tc.err, err) }) } } -func generateReportData() [ReportDataSize]byte { - bytes := make([]byte, ReportDataSize) +func generateReportData() [Nonce]byte { + bytes := make([]byte, Nonce) _, err := rand.Read(bytes) if err != nil { log.Fatalf("Failed to generate random bytes: %v", err) diff --git a/cli/attestation.go b/cli/attestation.go index 38762919..a34776f6 100644 --- a/cli/attestation.go +++ b/cli/attestation.go @@ -27,6 +27,7 @@ import ( "github.com/spf13/pflag" "github.com/ultravioletrs/cocos/agent" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" @@ -109,33 +110,39 @@ const ( } } ` + SNP = "snp" + VTPM = "vtpm" + SNPvTPM = "snp-vtpm" ) var ( - mode string - cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}} - cfgString string - timeout time.Duration - maxRetryDelay time.Duration - platformInfo string - stepping string - trustedAuthorKeys []string - trustedAuthorHashes []string - trustedIdKeys []string - trustedIdKeyHashes []string - attestationFile string - tpmAttestationFile string - attestation []byte - empty16 = [size16]byte{} - empty32 = [size32]byte{} - empty64 = [size64]byte{} - defaultReportIdMa = []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255} - getJsonAttestation bool - errReportSize = errors.New("attestation contents too small") - output string - nonce []byte - format string - teeNonce []byte + mode string + cfg = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}} + cfgString string + timeout time.Duration + maxRetryDelay time.Duration + platformInfo string + stepping string + trustedAuthorKeys []string + trustedAuthorHashes []string + trustedIdKeys []string + trustedIdKeyHashes []string + attestationFile string + tpmAttestationFile string + attestation []byte + empty16 = [size16]byte{} + empty32 = [size32]byte{} + empty64 = [size64]byte{} + defaultReportIdMa = []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255} + errReportSize = errors.New("attestation contents too small") + ErrBadType = errors.New("bad type provided to the CLI attestation command") + ErrBadAttestation = errors.New("attestation file is corrupted or in wrong format") + output string + nonce []byte + format string + teeNonce []byte + attestationType string + getTextProtoAttestation bool ) var errEmptyFile = errors.New("input file is empty") @@ -179,30 +186,65 @@ func (cli *CLI) NewAttestationCmd() *cobra.Command { func (cli *CLI) NewGetAttestationCmd() *cobra.Command { cmd := &cobra.Command{ Use: "get", - Short: "Retrieve attestation information from agent. Report data expected in hex enoded string of length 64 bytes.", - Example: "get ", - Args: cobra.ExactArgs(1), + Short: "Retrieve attestation information from agent. The nonce parameter needs to be a hex encoded string.", + Example: "get ", Run: func(cmd *cobra.Command, args []string) { if cli.connectErr != nil { printError(cmd, "Failed to connect to agent: %v ❌ ", cli.connectErr) return } - cmd.Println("Getting attestation") + attType := agent.SNP + switch attestationType { + case SNP: + cmd.Println("Fetching SEV-SNP attestation report") + case VTPM: + cmd.Println("Fetching vTPM report") + attType = agent.VTPM + case SNPvTPM: + cmd.Println("Fetching SEV-SNP and vTPM report") + attType = agent.SNPvTPM + default: + printError(cmd, "Possible attestation types are snp, vtpm and snp-vtpm: %v ❌ ", ErrBadType) + return + } - reportData, err := hex.DecodeString(args[0]) - if err != nil { - printError(cmd, "Error decoding report data: %v ❌ ", err) + if (attType == agent.VTPM || attType == agent.SNPvTPM) && len(nonce) == 0 { + msg := color.New(color.FgRed).Sprint("vTPM nonce must be defined for vTPM attestation ❌ ") + cmd.Println(msg) return } - if len(reportData) != agent.ReportDataSize { - msg := color.New(color.FgRed).Sprintf("report data must be a hex encoded string of length %d bytes ❌ ", agent.ReportDataSize) + + if (attType == agent.SNP || attType == agent.SNPvTPM) && len(teeNonce) == 0 { + msg := color.New(color.FgRed).Sprint("TEE nonce must be defined for SEV-SNP attestation ❌ ") cmd.Println(msg) return } + var fixedReportData [agent.Nonce]byte + if attType != agent.VTPM { + if len(teeNonce) > agent.Nonce { + msg := color.New(color.FgRed).Sprintf("nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", agent.Nonce) + cmd.Println(msg) + return + } + + copy(fixedReportData[:], teeNonce) + } + + var fixedVtpmNonceByte [vtpm.Nonce]byte + if attType != agent.SNP { + if len(nonce) > vtpm.Nonce { + msg := color.New(color.FgRed).Sprintf("vTPM nonce must be a hex encoded string of length lesser or equal %d bytes ❌ ", vtpm.Nonce) + cmd.Println(msg) + return + } + + copy(fixedVtpmNonceByte[:], nonce) + } + filename := attestationFilePath - if getJsonAttestation { + if getTextProtoAttestation { filename = attestationJson } @@ -212,7 +254,7 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { return } - if err := cli.agentSDK.Attestation(cmd.Context(), [agent.ReportDataSize]byte(reportData), attestationFile); err != nil { + if err := cli.agentSDK.Attestation(cmd.Context(), fixedReportData, fixedVtpmNonceByte, int(attType), attestationFile); err != nil { printError(cmd, "Failed to get attestation due to error: %v ❌ ", err) return } @@ -222,16 +264,35 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { return } - if getJsonAttestation { + if getTextProtoAttestation { result, err := os.ReadFile(filename) if err != nil { printError(cmd, "Error reading attestation file: %v ❌ ", err) return } - result, err = attesationToJSON(result) + switch attestationType { + case SNP: + result, err = attesationToJSON(result) + case VTPM, SNPvTPM: + marshalOptions := prototext.MarshalOptions{ + Multiline: true, + EmitASCII: true, + } + var attvTPM tpmAttest.Attestation + err = proto.Unmarshal(result, &attvTPM) + if err != nil { + printError(cmd, "failed to unmarshal the attestation report: %v ❌ ", ErrBadAttestation) + } + + result = []byte(marshalOptions.Format(&attvTPM)) + default: + printError(cmd, "Possible attestation types are snp, vtpm and snp-vtpm: %v ❌ ", ErrBadType) + return + } + if err != nil { - printError(cmd, "Error converting attestation to json: %v ❌ ", err) + printError(cmd, "Error converting attestation to textproto: %v ❌ ", err) return } @@ -245,7 +306,10 @@ func (cli *CLI) NewGetAttestationCmd() *cobra.Command { }, } - cmd.Flags().BoolVarP(&getJsonAttestation, "json", "j", false, "Get attestation in json format") + cmd.Flags().BoolVarP(&getTextProtoAttestation, "textproto", "p", false, "Get attestation in textproto format") + cmd.Flags().BytesHexVarP(&teeNonce, "tee", "e", []byte{}, "Define the nonce for the SNP attestation report (must be used with attestation type snp and snp-vtpm)") + cmd.Flags().BytesHexVarP(&nonce, "vtpmnonce", "n", []byte{}, "Define the nonce for the vTPM attestation report (must be used with attestation type vtpm and snp-vtpm)") + cmd.Flags().StringVarP(&attestationType, "type", "t", "", "Get SEV or/and vTPM attestation report (snp, vtpm or snp-vtpm)") return cmd } @@ -585,7 +649,12 @@ func sevsnpverify(cmd *cobra.Command, args []string) error { return fmt.Errorf("error validating input: %v ❌ ", err) } - if err := quoteprovider.VerifyAndValidate(attestation, &cfg); err != nil { + attestationPB, err := abi.ReportCertsToProto(attestation) + if err != nil { + return fmt.Errorf("failed to convert attestation bytes to struct %v ❌ ", err) + } + + if err := quoteprovider.VerifyAndValidate(attestationPB, &cfg); err != nil { return fmt.Errorf("attestation validation and verification failed with error: %v ❌ ", err) } cmd.Println("Attestation validation and verification is successful!") diff --git a/cli/attestation_test.go b/cli/attestation_test.go index e8542568..8a9a9e63 100644 --- a/cli/attestation_test.go +++ b/cli/attestation_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/sdk/mocks" ) @@ -35,8 +36,8 @@ func TestNewAttestationCmd(t *testing.T) { cmd.SetOutput(&buf) - reportData := bytes.Repeat([]byte{0x01}, agent.ReportDataSize) - mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(reportData), mock.Anything).Return(nil) + reportData := bytes.Repeat([]byte{0x01}, agent.Nonce) + mockSDK.On("Attestation", mock.Anything, [agent.Nonce]byte(reportData), mock.Anything).Return(nil) cmd.SetArgs([]string{hex.EncodeToString(reportData)}) err := cmd.Execute() @@ -47,6 +48,10 @@ func TestNewAttestationCmd(t *testing.T) { func TestNewGetAttestationCmd(t *testing.T) { validattestation, err := os.ReadFile("../attestation.bin") require.NoError(t, err) + + teeNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, agent.Nonce)) + vtpmNonce := hex.EncodeToString(bytes.Repeat([]byte{0x00}, vtpm.Nonce)) + testCases := []struct { name string args []string @@ -56,57 +61,85 @@ func TestNewGetAttestationCmd(t *testing.T) { expectedOut string }{ { - name: "successful attestation retrieval", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))}, + name: "successful SNP attestation retrieval", + args: []string{"--tee", teeNonce, "-t", "snp"}, mockResponse: []byte("mock attestation"), mockError: nil, expectedOut: "Attestation result retrieved and saved successfully!", }, { - name: "invalid report data (decoding error)", - args: []string{"invalid"}, + name: "successful vTPM attestation retrieval", + args: []string{"--vtpmnonce", vtpmNonce, "-t", "vtpm"}, + mockResponse: []byte("mock attestation"), + mockError: nil, + expectedOut: "Attestation result retrieved and saved successfully!", + }, + { + name: "successful SNP-vTPM attestation retrieval", + args: []string{"--tee", teeNonce, "--vtpmnonce", vtpmNonce, "-t", "vtpm"}, + mockResponse: []byte("mock attestation"), + mockError: nil, + expectedOut: "Attestation result retrieved and saved successfully!", + }, + { + name: "missing vTPM nonce", + args: []string{"--tee", teeNonce, "-t", "snp-vtpm"}, + mockResponse: []byte("mock attestation"), + mockError: nil, + expectedOut: "vTPM nonce must be defined for vTPM attestation", + }, + { + name: "missing TEE nonce", + args: []string{"--vtpmnonce", vtpmNonce, "-t", "snp-vtpm"}, + mockResponse: []byte("mock attestation"), + mockError: nil, + expectedOut: "TEE nonce must be defined for SEV-SNP attestation", + }, + { + name: "invalid report data size", + args: []string{"-e", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 65)), "-t", "snp"}, mockResponse: nil, mockError: errors.New("error"), - expectedErr: "Error decoding report data", + expectedErr: "nonce must be a hex encoded string of length lesser or equal 64 bytes", }, { - name: "invalid report data size", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, 32))}, + name: "invalid vTPM data size", + args: []string{"-n", hex.EncodeToString(bytes.Repeat([]byte{0x00}, 33)), "-t", "vtpm"}, mockResponse: nil, mockError: errors.New("error"), - expectedErr: "report data must be a hex encoded string of length 64 bytes", + expectedErr: "vTPM nonce must be a hex encoded string of length lesser or equal 32 bytes", }, { - name: "invalid report data hex", + name: "invalid arguments", args: []string{"invalid"}, mockResponse: nil, mockError: errors.New("error"), - expectedErr: "Error decoding report data", + expectedErr: "Possible attestation types are snp, vtpm and snp-vtpm: bad type provided to the CLI attestation command", }, { name: "failed to get attestation", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))}, + args: []string{"-e", teeNonce, "-t", "snp"}, mockResponse: nil, mockError: errors.New("error"), expectedErr: "Failed to get attestation due to error", }, { - name: "JSON report error", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), "--json"}, + name: "Textproto report error", + args: []string{"-e", teeNonce, "-t", "snp", "--textproto"}, mockResponse: []byte("mock attestation"), mockError: nil, - expectedErr: "Error converting attestation to json", + expectedErr: "Error converting attestation to textproto", }, { - name: "successful JSON report", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), "--json"}, + name: "successful Textproto report", + args: []string{"-e", teeNonce, "-t", "snp", "--textproto"}, mockResponse: validattestation, mockError: nil, expectedOut: "Attestation result retrieved and saved successfully!", }, { name: "connection error", - args: []string{hex.EncodeToString(bytes.Repeat([]byte{0x01}, agent.ReportDataSize))}, + args: []string{"-e", teeNonce, "-t", "snp"}, mockResponse: nil, mockError: errors.New("failed to connect to agent"), expectedErr: "Failed to connect to agent", @@ -128,8 +161,8 @@ func TestNewGetAttestationCmd(t *testing.T) { var buf bytes.Buffer cmd.SetOutput(&buf) - mockSDK.On("Attestation", mock.Anything, [agent.ReportDataSize]byte(bytes.Repeat([]byte{0x01}, agent.ReportDataSize)), mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) { - _, err := args.Get(2).(*os.File).Write(tc.mockResponse) + mockSDK.On("Attestation", mock.Anything, [agent.Nonce]byte(bytes.Repeat([]byte{0x00}, agent.Nonce)), [vtpm.Nonce]byte(bytes.Repeat([]byte{0x00}, vtpm.Nonce)), mock.Anything, mock.Anything).Return(tc.mockError).Run(func(args mock.Arguments) { + _, err := args.Get(4).(*os.File).Write(tc.mockResponse) require.NoError(t, err) }) diff --git a/cmd/agent/main.go b/cmd/agent/main.go index 419805a7..389172dd 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -72,7 +72,7 @@ func main() { return } - qp, err := quoteprovider.GetQuoteProvider() + qp, err := quoteprovider.GetLeveledQuoteProvider() if err != nil { logger.Error(fmt.Sprintf("failed to create quote provider %s", err.Error())) exitCode = 1 @@ -150,7 +150,7 @@ func main() { } } -func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, qp client.QuoteProvider) agent.Service { +func newService(ctx context.Context, logger *slog.Logger, eventSvc events.Service, qp client.LeveledQuoteProvider) agent.Service { svc := agent.New(ctx, logger, eventSvc, qp) svc = api.LoggingMiddleware(svc, logger) diff --git a/hal/linux/configs/cocos_defconfig b/hal/linux/configs/cocos_defconfig index 42dd193e..0162c941 100644 --- a/hal/linux/configs/cocos_defconfig +++ b/hal/linux/configs/cocos_defconfig @@ -27,7 +27,7 @@ BR2_ROOTFS_POST_SCRIPT_ARGS="$(BR2_DEFCONFIG)" # Linux headers same as kernel BR2_PACKAGE_HOST_LINUX_HEADERS_CUSTOM_6_11=y BR2_TOOLCHAIN_HEADERS_LATEST=y -BR2_TOOLCHAIN_HEADERS_AT_LEAST="6.12-rc6" +BR2_TOOLCHAIN_HEADERS_AT_LEAST="6.11-rc7" # Kernel BR2_LINUX_KERNEL=y diff --git a/internal/server/grpc/grpc.go b/internal/server/grpc/grpc.go index e7009b7e..bff7ac49 100644 --- a/internal/server/grpc/grpc.go +++ b/internal/server/grpc/grpc.go @@ -25,6 +25,7 @@ import ( "github.com/ultravioletrs/cocos/agent/auth" "github.com/ultravioletrs/cocos/internal/server" "github.com/ultravioletrs/cocos/pkg/atls" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -51,7 +52,7 @@ type Server struct { server.BaseServer server *grpc.Server registerService serviceRegister - quoteProvider client.QuoteProvider + quoteProvider client.LeveledQuoteProvider authSvc auth.Authenticator health *health.Server } @@ -60,7 +61,7 @@ type serviceRegister func(srv *grpc.Server) var _ server.Server = (*Server)(nil) -func New(ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, registerService serviceRegister, logger *slog.Logger, qp client.QuoteProvider, authSvc auth.Authenticator) server.Server { +func New(ctx context.Context, cancel context.CancelFunc, name string, config server.ServerConfiguration, registerService serviceRegister, logger *slog.Logger, qp client.LeveledQuoteProvider, authSvc auth.Authenticator) server.Server { base := config.GetBaseConfig() listenFullAddress := fmt.Sprintf("%s:%s", base.Host, base.Port) return &Server{ @@ -301,5 +302,19 @@ func generateCertificatesForATLS() ([]byte, []byte, error) { Bytes: privateKeyBytes, }) + cert, err := x509.ParseCertificate(certDERBytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + pubKeyDER, err := x509.MarshalPKIXPublicKey(cert.PublicKey) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal public key to DER format: %w", err) + } + + if err := vtpm.ExtendPCR(vtpm.PCR15, pubKeyDER); err != nil { + return nil, nil, fmt.Errorf("failed to extend vTPM PCR with public key: %w", err) + } + return certBytes, keyBytes, nil } diff --git a/internal/server/grpc/grpc_test.go b/internal/server/grpc/grpc_test.go index b5bcbaeb..5f721ae6 100644 --- a/internal/server/grpc/grpc_test.go +++ b/internal/server/grpc/grpc_test.go @@ -22,6 +22,7 @@ import ( authmocks "github.com/ultravioletrs/cocos/agent/auth/mocks" "github.com/ultravioletrs/cocos/internal/server" "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider/mocks" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "google.golang.org/grpc" "google.golang.org/grpc/test/bufconn" ) @@ -30,6 +31,14 @@ const bufSize = 1024 * 1024 var lis *bufconn.Listener +// type BufferRW struct { +// *bytes.Buffer +// } + +// func (b *BufferRW) Close() error { +// return nil +// } + func init() { lis = bufconn.Listen(bufSize) } @@ -47,7 +56,7 @@ func TestNew(t *testing.T) { }, } logger := slog.Default() - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) @@ -97,7 +106,7 @@ func TestServerStartWithTLSFile(t *testing.T) { logBuffer := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) @@ -144,7 +153,7 @@ func TestServerStartWithmTLSFile(t *testing.T) { logBuffer := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) @@ -184,7 +193,7 @@ func TestServerStop(t *testing.T) { } buf := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(buf, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", config, func(srv *grpc.Server) {}, logger, qp, authSvc) @@ -259,6 +268,8 @@ func (b *ThreadSafeBuffer) String() string { } func TestServerInitializationAndStartup(t *testing.T) { + vtpm.ExternalTPM = &vtpm.DummyRWC{} + testCases := []struct { name string config server.AgentConfig @@ -374,7 +385,7 @@ func TestServerInitializationAndStartup(t *testing.T) { logBuffer := &ThreadSafeBuffer{} logger := slog.New(slog.NewTextHandler(logBuffer, &slog.HandlerOptions{Level: slog.LevelDebug})) - qp := new(mocks.QuoteProvider) + qp := new(mocks.LeveledQuoteProvider) authSvc := new(authmocks.Authenticator) srv := New(ctx, cancel, "TestServer", tc.config, func(srv *grpc.Server) {}, logger, qp, authSvc) diff --git a/manager/attestation_policy.go b/manager/attestation_policy.go index 12d74c1e..ce5fe016 100644 --- a/manager/attestation_policy.go +++ b/manager/attestation_policy.go @@ -76,8 +76,8 @@ func (ms *managerService) FetchAttestationPolicy(_ context.Context, computationI attestationPolicy.Policy.Measurement = measurement } - if vmi.Config.HostData != "" { - hostData, err := base64.StdEncoding.DecodeString(vmi.Config.HostData) + if vmi.Config.SevConfig.EnableHostData { + hostData, err := base64.StdEncoding.DecodeString(vmi.Config.SevConfig.HostData) if err != nil { return nil, err } diff --git a/manager/attestation_policy_test.go b/manager/attestation_policy_test.go index 54f87b03..a5e9230a 100644 --- a/manager/attestation_policy_test.go +++ b/manager/attestation_policy_test.go @@ -57,9 +57,10 @@ func TestFetchAttestationPolicy(t *testing.T) { binaryBehavior: "success", vmConfig: qemu.VMInfo{ Config: qemu.Config{ - EnableSEV: true, - SMPCount: 2, - CPU: "EPYC", + EnableSEV: true, + EnableSEVSNP: false, + SMPCount: 2, + CPU: "EPYC", OVMFCodeConfig: qemu.OVMFCodeConfig{ File: "/path/to/OVMF_CODE.fd", }, @@ -68,23 +69,6 @@ func TestFetchAttestationPolicy(t *testing.T) { }, expectedError: "open /path/to/OVMF_CODE.fd: no such file or directory", }, - { - name: "Valid SEV-SNP configuration", - computationId: "sev-snp-computation", - binaryBehavior: "success", - vmConfig: qemu.VMInfo{ - Config: qemu.Config{ - EnableSEVSNP: true, - SMPCount: 4, - CPU: "EPYC-v2", - OVMFCodeConfig: qemu.OVMFCodeConfig{ - File: "/path/to/OVMF_CODE_SNP.fd", - }, - }, - LaunchTCB: 0, - }, - expectedError: "open /path/to/OVMF_CODE_SNP.fd: no such file or director", - }, { name: "Invalid computation ID", computationId: "non-existent", diff --git a/manager/manager.pb.go b/manager/manager.pb.go index a1a6da7f..34f4aec2 100644 --- a/manager/manager.pb.go +++ b/manager/manager.pb.go @@ -3,8 +3,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.0 -// protoc v5.29.0 +// protoc-gen-go v1.36.4 +// protoc v5.29.3 // source: manager/manager.proto package manager @@ -15,6 +15,7 @@ import ( emptypb "google.golang.org/protobuf/types/known/emptypb" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -422,7 +423,7 @@ func (x *SVMInfoReq) GetId() string { var File_manager_manager_proto protoreflect.FileDescriptor -var file_manager_manager_proto_rawDesc = []byte{ +var file_manager_manager_proto_rawDesc = string([]byte{ 0x0a, 0x15, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x1a, 0x1b, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, @@ -488,16 +489,16 @@ var file_manager_manager_proto_rawDesc = []byte{ 0x41, 0x74, 0x74, 0x65, 0x73, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x6c, 0x69, 0x63, 0x79, 0x52, 0x65, 0x73, 0x22, 0x00, 0x42, 0x0b, 0x5a, 0x09, 0x2e, 0x2f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +}) var ( file_manager_manager_proto_rawDescOnce sync.Once - file_manager_manager_proto_rawDescData = file_manager_manager_proto_rawDesc + file_manager_manager_proto_rawDescData []byte ) func file_manager_manager_proto_rawDescGZIP() []byte { file_manager_manager_proto_rawDescOnce.Do(func() { - file_manager_manager_proto_rawDescData = protoimpl.X.CompressGZIP(file_manager_manager_proto_rawDescData) + file_manager_manager_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_manager_manager_proto_rawDesc), len(file_manager_manager_proto_rawDesc))) }) return file_manager_manager_proto_rawDescData } @@ -538,7 +539,7 @@ func file_manager_manager_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_manager_manager_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_manager_manager_proto_rawDesc), len(file_manager_manager_proto_rawDesc)), NumEnums: 0, NumMessages: 7, NumExtensions: 0, @@ -549,7 +550,6 @@ func file_manager_manager_proto_init() { MessageInfos: file_manager_manager_proto_msgTypes, }.Build() File_manager_manager_proto = out.File - file_manager_manager_proto_rawDesc = nil file_manager_manager_proto_goTypes = nil file_manager_manager_proto_depIdxs = nil } diff --git a/manager/manager_grpc.pb.go b/manager/manager_grpc.pb.go index b8111ce4..77a8fd26 100644 --- a/manager/manager_grpc.pb.go +++ b/manager/manager_grpc.pb.go @@ -4,7 +4,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.5.1 -// - protoc v5.29.0 +// - protoc v5.29.3 // source: manager/manager.proto package manager diff --git a/manager/qemu/config.go b/manager/qemu/config.go index 86267353..fcc0c3d4 100644 --- a/manager/qemu/config.go +++ b/manager/qemu/config.go @@ -56,7 +56,13 @@ type SevConfig struct { ID string `env:"SEV_ID" envDefault:"sev0"` CBitPos int `env:"SEV_CBITPOS" envDefault:"51"` ReducedPhysBits int `env:"SEV_REDUCED_PHYS_BITS" envDefault:"1"` - HostData string `env:"HOST_DATA" envDefault:""` + EnableHostData bool `env:"ENABLE_HOST_DATA" envDefault:"false"` + HostData string `env:"HOST_DATA" envDefault:""` +} + +type IGVMConfig struct { + ID string `env:"IGVM_ID" envDefault:"igvm0"` + File string `env:"IGVM_FILE" envDefault:"/root/coconut-qemu.igvm"` } type VSockConfig struct { @@ -80,9 +86,6 @@ type Config struct { MemID string `env:"MEM_ID" envDefault:"ram1"` MemoryConfig - // Kernel hash - KernelHash bool `env:"KERNEL_HASH" envDefault:"false"` - // OVMF OVMFCodeConfig OVMFVarsConfig @@ -100,6 +103,9 @@ type Config struct { // SEV SevConfig + // vTPM + IGVMConfig + // display NoGraphic bool `env:"NO_GRAPHIC" envDefault:"true"` Monitor string `env:"MONITOR" envDefault:"pty"` @@ -173,40 +179,39 @@ func (config Config) ConstructQemuArgs() []string { // SEV if config.EnableSEV || config.EnableSEVSNP { sevType := "sev-guest" - kernelHash := "" hostData := "" args = append(args, "-machine", - fmt.Sprintf("confidential-guest-support=%s,memory-backend=%s", + fmt.Sprintf("confidential-guest-support=%s,memory-backend=%s,igvm-cfg=%s", config.SevConfig.ID, - config.MemID)) + config.MemID, + config.IGVMConfig.ID)) if config.EnableSEVSNP { - args = append(args, "-bios", config.OVMFCodeConfig.File) sevType = "sev-snp-guest" - if config.SevConfig.HostData != "" { + if config.SevConfig.EnableHostData { hostData = fmt.Sprintf(",host-data=%s", config.SevConfig.HostData) } } - if config.KernelHash { - kernelHash = ",kernel-hashes=on" - } - args = append(args, "-object", fmt.Sprintf("memory-backend-memfd,id=%s,size=%s,share=true,prealloc=false", config.MemID, config.MemoryConfig.Size)) args = append(args, "-object", - fmt.Sprintf("%s,id=%s,cbitpos=%d,reduced-phys-bits=%d%s%s", + fmt.Sprintf("%s,id=%s,cbitpos=%d,reduced-phys-bits=%d%s", sevType, config.SevConfig.ID, config.SevConfig.CBitPos, config.SevConfig.ReducedPhysBits, - kernelHash, hostData)) + + args = append(args, "-object", + fmt.Sprintf("igvm-cfg,id=%s,file=%s", + config.IGVMConfig.ID, + config.IGVMConfig.File)) } args = append(args, "-kernel", config.DiskImgConfig.KernelFile) diff --git a/manager/qemu/config_test.go b/manager/qemu/config_test.go index 67d4d2af..e21e2124 100644 --- a/manager/qemu/config_test.go +++ b/manager/qemu/config_test.go @@ -132,6 +132,10 @@ func TestConstructQemuArgs(t *testing.T) { CBitPos: 51, ReducedPhysBits: 1, }, + IGVMConfig: IGVMConfig{ + ID: "igvm0", + File: "/test/path/cocos-igvm.igvm", + }, NoGraphic: true, Monitor: "pty", }, @@ -144,10 +148,10 @@ func TestConstructQemuArgs(t *testing.T) { "-netdev", "user,id=vmnic,hostfwd=tcp::7020-:7002", "-device", "virtio-net-pci,disable-legacy=on,iommu_platform=true,netdev=vmnic,addr=0x2,romfile=", "-device", "vhost-vsock-pci,id=vhost-vsock-pci0,guest-cid=3", - "-machine", "confidential-guest-support=sev0,memory-backend=ram1", - "-bios", "/usr/share/OVMF/OVMF_CODE.fd", + "-machine", "confidential-guest-support=sev0,memory-backend=ram1,igvm-cfg=igvm0", "-object", "memory-backend-memfd,id=ram1,size=2048M,share=true,prealloc=false", "-object", "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1", + "-object", "igvm-cfg,id=igvm0,file=/test/path/cocos-igvm.igvm", "-kernel", "img/bzImage", "-append", "\"quiet console=null\"", "-initrd", "img/rootfs.cpio.gz", @@ -167,37 +171,6 @@ func TestConstructQemuArgs(t *testing.T) { } } -func TestConstructQemuArgs_KernelHash(t *testing.T) { - config := Config{ - EnableSEVSNP: true, - KernelHash: true, - SevConfig: SevConfig{ - ID: "sev0", - CBitPos: 51, - ReducedPhysBits: 1, - }, - } - - result := config.ConstructQemuArgs() - - expected := "-object" - expectedValue := "sev-snp-guest,id=sev0,cbitpos=51,reduced-phys-bits=1,kernel-hashes=on" - - found := false - for i, arg := range result { - if arg == expected && i+1 < len(result) { - if result[i+1] == expectedValue { - found = true - break - } - } - } - - if !found { - t.Errorf("ConstructQemuArgs() did not contain expected SEV-SNP configuration with kernel hashes enabled") - } -} - func TestConstructQemuArgs_HostData(t *testing.T) { config := Config{ EnableSEVSNP: true, @@ -205,6 +178,7 @@ func TestConstructQemuArgs_HostData(t *testing.T) { ID: "sev0", CBitPos: 51, ReducedPhysBits: 1, + EnableHostData: true, HostData: "test-host-data", }, } diff --git a/manager/qemu/vm.go b/manager/qemu/vm.go index c41e48c0..2c55689f 100644 --- a/manager/qemu/vm.go +++ b/manager/qemu/vm.go @@ -59,7 +59,7 @@ func (v *qemuVM) Start() (err error) { v.vmi.Config.NetDevConfig.ID = fmt.Sprintf("%s-%s", v.vmi.Config.NetDevConfig.ID, id) v.vmi.Config.SevConfig.ID = fmt.Sprintf("%s-%s", v.vmi.Config.SevConfig.ID, id) - if !v.vmi.Config.KernelHash { + if !v.vmi.Config.EnableSEVSNP { // Copy firmware vars file. srcFile := v.vmi.Config.OVMFVarsConfig.File dstFile := fmt.Sprintf("%s/%s-%s.fd", tmpDir, firmwareVars, id) diff --git a/mockery.yml b/mockery.yml index 87f0bd08..e9a304f5 100644 --- a/mockery.yml +++ b/mockery.yml @@ -107,7 +107,7 @@ packages: mockname: "{{.InterfaceName}}" github.com/google/go-sev-guest/client: interfaces: - QuoteProvider: + LeveledQuoteProvider: config: dir: "./pkg/attestation/quoteprovider/mocks" filename: "QuoteProvider.go" diff --git a/pkg/atls/atlsListener.go b/pkg/atls/atlsListener.go index 0e059294..74482f14 100644 --- a/pkg/atls/atlsListener.go +++ b/pkg/atls/atlsListener.go @@ -21,7 +21,7 @@ import ( "github.com/absmach/magistrala/pkg/errors" "github.com/ultravioletrs/cocos/agent" - "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" ) const ( @@ -51,8 +51,8 @@ var ( errConnCreate = errors.New("could not create connection") ) -type ValidationVerification func(data1, data2 []byte) error -type FetchAttestation func(data1 []byte) ([]byte, error) +type ValidationVerification func(data1, data2, data3, data4 []byte) error +type FetchAttestation func(data1, data2, data3 []byte) ([]byte, error) func registerFetchAttestation(callback FetchAttestation) uintptr { handle := cgo.NewHandle(callback) @@ -70,7 +70,7 @@ func validationVerificationCallback(teeType C.int) uintptr { case NoTee: return uintptr(0) case AmdSevSnp: - return registerValidationVerification(quoteprovider.VerifyAttestationReportTLS) + return registerValidationVerification(vtpm.VTPMVerify) default: return uintptr(0) } @@ -82,22 +82,24 @@ func fetchAttestationCallback(teeType C.int) uintptr { case NoTee: return uintptr(0) case AmdSevSnp: - return registerFetchAttestation(quoteprovider.FetchAttestation) + return registerFetchAttestation(vtpm.FetchATLSQuote) default: return uintptr(0) } } //export callVerificationValidationCallback -func callVerificationValidationCallback(callbackHandle uintptr, attReport *C.uchar, attReportSize C.int, repData *C.uchar) C.int { +func callVerificationValidationCallback(callbackHandle uintptr, pubKey *C.uchar, pubKeyLen C.int, quote *C.uchar, quoteSize C.int, teeNonce *C.uchar, nonce *C.uchar) C.int { handle := cgo.Handle(callbackHandle) defer handle.Delete() callback := handle.Value().(ValidationVerification) - attestationReport := C.GoBytes(unsafe.Pointer(attReport), attReportSize) - reportData := C.GoBytes(unsafe.Pointer(repData), agent.ReportDataSize) + pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen) + attestationReport := C.GoBytes(unsafe.Pointer(quote), quoteSize) + teeData := C.GoBytes(unsafe.Pointer(teeNonce), agent.Nonce) + nonceData := C.GoBytes(unsafe.Pointer(nonce), vtpm.Nonce) - err := callback(attestationReport, reportData) + err := callback(attestationReport, pubKeyCert, teeData, nonceData) if err != nil { fmt.Fprintf(os.Stderr, "callback failed %v", err) return C.int(-1) @@ -107,20 +109,22 @@ func callVerificationValidationCallback(callbackHandle uintptr, attReport *C.uch } //export callFetchAttestationCallback -func callFetchAttestationCallback(callbackHandle uintptr, reportDataByte *C.uchar, outlen *C.int) *C.uchar { +func callFetchAttestationCallback(callbackHandle uintptr, pubKey *C.uchar, pubKeyLen C.int, teeNonceByte *C.uchar, vTPMNonceByte *C.uchar, outlen *C.ulong) *C.uchar { handle := cgo.Handle(callbackHandle) defer handle.Delete() callback := handle.Value().(FetchAttestation) - reportData := C.GoBytes(unsafe.Pointer(reportDataByte), agent.ReportDataSize) + pubKeyCert := C.GoBytes(unsafe.Pointer(pubKey), pubKeyLen) + teeNonceData := C.GoBytes(unsafe.Pointer(teeNonceByte), agent.Nonce) + vTPMNonce := C.GoBytes(unsafe.Pointer(vTPMNonceByte), vtpm.Nonce) - quote, err := callback(reportData) + quote, err := callback(pubKeyCert, teeNonceData, vTPMNonce) if err != nil { fmt.Fprintf(os.Stderr, "attestation callback returned nil") return nil } - *outlen = C.int(len(quote)) + *outlen = C.ulong(len(quote)) resultC := C.malloc(C.size_t(len(quote))) if resultC == nil { fmt.Fprintf(os.Stderr, "could not allocate memory for fetch attestation callback") diff --git a/pkg/atls/extensions.c b/pkg/atls/extensions.c index 6e27f82d..57511026 100644 --- a/pkg/atls/extensions.c +++ b/pkg/atls/extensions.c @@ -7,30 +7,27 @@ #include #include -extern int callVerificationValidationCallback(uintptr_t callbackHandle, const u_char* attReport, int attReportSize, const u_char* repData); -extern u_char* callFetchAttestationCallback(uintptr_t callbackHandle, const u_char* reportDataByte, int* outlen); +extern int callVerificationValidationCallback(uintptr_t callbackHandle, const u_char* pubKey, int pubKeyLen, const u_char* quote, int quoteSize, const u_char* teeNonce, const u_char* nonce); +extern u_char* callFetchAttestationCallback(uintptr_t callbackHandle, const u_char* pubKey, int pubKeyLen, const u_char* teeNonceByte, const u_char* vTPMNonceByte, unsigned long* outlen); extern uintptr_t validationVerificationCallback(int teeType); extern uintptr_t fetchAttestationCallback(int teeType); -int triggerVerificationValidationCallback(uintptr_t callbackHandle, u_char *attestationReport, int reportSize, u_char *reportData) { - if (attestationReport == NULL || reportData == NULL) { - fprintf(stderr, "attestation data and report data cannot be NULL\n"); +int triggerVerificationValidationCallback(uintptr_t callbackHandle, u_char* pub_key, int pub_key_len, u_char *quote, int quote_size, u_char *tee_nonce, u_char *nonce) { + if (quote == NULL || nonce == NULL || tee_nonce == NULL || pub_key == NULL) { + fprintf(stderr, "attestation and noce and public key cannot be NULL\n"); return -1; } - - return callVerificationValidationCallback(callbackHandle, attestationReport, reportSize, reportData); + return callVerificationValidationCallback(callbackHandle, pub_key, pub_key_len, quote, quote_size, tee_nonce, nonce); } -u_char* triggerFetchAttestationCallback(uintptr_t callbackHandle, char *reportData) { - int outlen = REPORT_DATA_SIZE; - - if(reportData == NULL) { +u_char* triggerFetchAttestationCallback(uintptr_t callback_handle, u_char* pub_key, int pub_key_len, char *tee_nonce, char *vtpm_nonce, unsigned long *outlen) { + if(tee_nonce == NULL || vtpm_nonce == NULL) { fprintf(stderr, "Report data cannot be NULL"); return NULL; } - return callFetchAttestationCallback(callbackHandle, reportData, &outlen); + return callFetchAttestationCallback(callback_handle, pub_key, pub_key_len, tee_nonce, vtpm_nonce, outlen); } int check_sev_snp() { @@ -47,46 +44,6 @@ int check_sev_snp() { return 1; } -int compute_sha256_of_public_key_nonce(X509 *cert, u_char *nonce, u_char *hash) { - EVP_PKEY *pkey = NULL; - u_char *pubkey_buf = NULL; - u_char *concatinated = NULL; - int pubkey_len = 0; - int totla_len = 0; - - pkey = X509_get_pubkey(cert); - if (pkey == NULL) { - fprintf(stderr, "Failed to extract public key from certificate\n"); - return 0; - } - - pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf); - if (pubkey_len <= 0) { - fprintf(stderr, "Failed to convert public key to DER format\n"); - EVP_PKEY_free(pkey); - return -1; - } - - totla_len = pubkey_len + CLIENT_RANDOM_SIZE; - concatinated = (u_char*)malloc(totla_len); - if (concatinated == NULL) { - perror("failed to allocate memory"); - return -1; - } - memcpy(concatinated, nonce, CLIENT_RANDOM_SIZE); - memcpy(concatinated + CLIENT_RANDOM_SIZE, pubkey_buf, pubkey_len); - - // Compute the SHA-512 hash of the DER-encoded public key and the random nonce - SHA512(concatinated, totla_len, hash); - - // Clean up - EVP_PKEY_free(pkey); - OPENSSL_free(pubkey_buf); - free(concatinated); - - return 0; // Success -} - /* Evidence request extension - Contains a random nonce that goes into the attestation report @@ -121,9 +78,14 @@ int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type, } if (ext_data != NULL) { - if (RAND_bytes(ext_data->er.data, CLIENT_RANDOM_SIZE) != 1) { - perror("could not generate random bytes, will use SSL client random"); - SSL_get_client_random(s, ext_data->er.data, CLIENT_RANDOM_SIZE); + if (RAND_bytes(ext_data->er.vtpm_nonce, CLIENT_RANDOM_SIZE) != 1) { + perror("could not generate random bytes for vtpm nonce, will use SSL client random"); + SSL_get_client_random(s, ext_data->er.vtpm_nonce, CLIENT_RANDOM_SIZE); + } + + if (RAND_bytes(ext_data->er.tee_nonce, CLIENT_RANDOM_SIZE) != 1) { + perror("could not generate random bytes for tee nonce, will use SSL client random"); + SSL_get_client_random(s, ext_data->er.tee_nonce, CLIENT_RANDOM_SIZE); } } else { fprintf(stderr, "add_arg is NULL\n"); @@ -132,7 +94,8 @@ int evidence_request_ext_add_cb(SSL *s, unsigned int ext_type, return -1; } - memcpy(er->data, ext_data->er.data, CLIENT_RANDOM_SIZE); + memcpy(er->vtpm_nonce, ext_data->er.vtpm_nonce, CLIENT_RANDOM_SIZE); + memcpy(er->tee_nonce, ext_data->er.tee_nonce, CLIENT_RANDOM_SIZE); er->tee_type = AMD_TEE; ext_data->er.tee_type = AMD_TEE; @@ -201,7 +164,8 @@ int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type, evidence_request *er = (evidence_request*)in; if (ext_data != NULL) { - memcpy(ext_data->er.data, er->data, CLIENT_RANDOM_SIZE); + memcpy(ext_data->er.vtpm_nonce, er->vtpm_nonce, CLIENT_RANDOM_SIZE); + memcpy(ext_data->er.tee_nonce, er->tee_nonce, CLIENT_RANDOM_SIZE); ext_data->er.tee_type = er->tee_type; } else { fprintf(stderr, "parse_arg is NULL\n"); @@ -238,7 +202,7 @@ int evidence_request_ext_parse_cb(SSL *s, unsigned int ext_type, /* Attestation Certificate extension - Contains the attestation report - - The attestation report contains the hash of the nonce and the Public Key of the x.509 Agent certificate + - The attestation report contains the hash of the nonce, the Public Key of the x.509 Agent certificate, and the vTPM AK */ void attestation_certificate_ext_free_cb(SSL *s, unsigned int ext_type, unsigned int context, @@ -263,40 +227,48 @@ int attestation_certificate_ext_add_cb(SSL *s, unsigned int ext_type, { tls_extension_data *ext_data = (tls_extension_data*)add_arg; if (ext_data != NULL) { - u_char *attestation_report; - u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char)); - - if (hash == NULL) { - perror("could not allocate memory"); - *al = SSL_AD_INTERNAL_ERROR; - return -1; - } + u_char *quote; + size_t len = 0; + EVP_PKEY *pkey = NULL; + u_char *pubkey_buf = NULL; + int pubkey_len = 0; + if (x != NULL) { - int ret = compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash); - if (ret != 0) { - fprintf(stderr, "error while calculating hash\n"); - free(hash); - *al = SSL_AD_INTERNAL_ERROR; + pkey = X509_get_pubkey(x); + if (pkey == NULL) { + fprintf(stderr, "Failed to extract public key from certificate\n"); + return -1; + } + + pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf); + if (pubkey_len <= 0) { + fprintf(stderr, "Failed to convert public key to DER format\n"); + EVP_PKEY_free(pkey); return -1; } } else { fprintf(stderr, "agent certificate must be used for aTLS\n"); - free(hash); *al = SSL_AD_INTERNAL_ERROR; return -1; } - attestation_report = triggerFetchAttestationCallback(ext_data->fetch_attestation_handler, hash); - if (attestation_report == NULL) { + quote = triggerFetchAttestationCallback(ext_data->fetch_attestation_handler, pubkey_buf, pubkey_len, ext_data->er.tee_nonce, ext_data->er.vtpm_nonce, &len); + if (quote == NULL) { fprintf(stderr, "attestation report is NULL\n"); *al = SSL_AD_INTERNAL_ERROR; + EVP_PKEY_free(pkey); + OPENSSL_free(pubkey_buf); return -1; } - free(hash); - *out = attestation_report; - *outlen = ATTESTATION_REPORT_SIZE; + EVP_PKEY_free(pkey); + OPENSSL_free(pubkey_buf); + + fprintf(stderr, "QUOTE_SIZE: %ld\n", strlen(quote)); + + *out = quote; + *outlen = len; return 1; } else { fprintf(stderr, "add_arg is NULL\n"); @@ -329,34 +301,41 @@ int attestation_certificate_ext_parse_cb(SSL *s, unsigned int ext_type, tls_extension_data *ext_data = (tls_extension_data*)parse_arg; if (ext_data != NULL) { - char *attestation_report = (char*)malloc(ATTESTATION_REPORT_SIZE*sizeof(char)); - u_char *hash = (u_char*)malloc(REPORT_DATA_SIZE*sizeof(u_char)); + char *quote = (char*)malloc(inlen*sizeof(char)); + EVP_PKEY *pkey = NULL; + u_char *pubkey_buf = NULL; + int pubkey_len = 0; int res = 0; - if (hash == NULL || attestation_report == NULL) { + if (quote == NULL) { perror("could not allocate memory"); - - if (hash != NULL) free(hash); - if (attestation_report != NULL) free(attestation_report); - return 0; } - if (compute_sha256_of_public_key_nonce(x, ext_data->er.data, hash) != 0) { - fprintf(stderr, "calculating hash failed\n"); - free(attestation_report); - free(hash); - return 0; + pkey = X509_get_pubkey(x); + if (pkey == NULL) { + fprintf(stderr, "Failed to extract public key from certificate\n"); + return -1; } - - memcpy(attestation_report, in, inlen); - - res = triggerVerificationValidationCallback(ext_data->verification_validation_handler, - attestation_report, - ATTESTATION_REPORT_SIZE, - hash); - free(attestation_report); - free(hash); + + pubkey_len = i2d_PUBKEY(pkey, &pubkey_buf); + if (pubkey_len <= 0) { + fprintf(stderr, "Failed to convert public key to DER format\n"); + EVP_PKEY_free(pkey); + return -1; + } + memcpy(quote, in, inlen); + + res = triggerVerificationValidationCallback(ext_data->verification_validation_handler, + pubkey_buf, + pubkey_len, + quote, + inlen, + (u_char*)&ext_data->er.tee_nonce, + (u_char*)&ext_data->er.vtpm_nonce); + free(quote); + EVP_PKEY_free(pkey); + OPENSSL_free(pubkey_buf); if (res != 0) { fprintf(stderr, "verification and validation failed, aborting connection\n"); diff --git a/pkg/atls/extensions.h b/pkg/atls/extensions.h index 5fb0d230..ae8f93d4 100644 --- a/pkg/atls/extensions.h +++ b/pkg/atls/extensions.h @@ -6,7 +6,6 @@ #define EVIDENCE_REQUEST_HELLO_EXTENSION_TYPE 65 #define ATTESTATION_CERTIFICATE_EXTENSION_TYPE 66 -#define ATTESTATION_REPORT_SIZE 0x4A0 #define REPORT_DATA_SIZE 64 #define CLIENT_RANDOM_SIZE 32 #define TLS_CLIENT_CTX 0 @@ -19,7 +18,8 @@ typedef struct evidence_request { int tee_type; - char data[CLIENT_RANDOM_SIZE]; + char vtpm_nonce[CLIENT_RANDOM_SIZE]; + char tee_nonce[CLIENT_RANDOM_SIZE]; } evidence_request; typedef struct tls_extension_data @@ -32,7 +32,7 @@ typedef struct tls_extension_data typedef struct tls_server_connection { int server_fd; - char* cert; + char* cert; int cert_len; char* key; int key_len; diff --git a/pkg/attestation/quoteprovider/embed.go b/pkg/attestation/quoteprovider/embed.go index c92f797d..7b368c1a 100644 --- a/pkg/attestation/quoteprovider/embed.go +++ b/pkg/attestation/quoteprovider/embed.go @@ -22,12 +22,12 @@ var _ client.QuoteProvider = (*embeddedQuoteProvider)(nil) type embeddedQuoteProvider struct { } -func GetQuoteProvider() (client.QuoteProvider, error) { +func GetLeveledQuoteProvider() (client.QuoteProvider, error) { return &embeddedQuoteProvider{}, nil } -// GetQuote returns the SEV quote for the given report data. -func (e *embeddedQuoteProvider) GetRawQuote(reportData [64]byte) ([]byte, error) { +// GetRawQuoteAtLevel returns the SEV quote for the given report data and VMPL. +func (e *embeddedQuoteProvider) GetRawQuoteAtLevel(reportData [64]byte, vmpl uint) ([]byte, error) { return cocosai.EmbeddedAttestation, nil } diff --git a/pkg/attestation/quoteprovider/mocks/QuoteProvider.go b/pkg/attestation/quoteprovider/mocks/QuoteProvider.go index 179e2e01..0c636159 100644 --- a/pkg/attestation/quoteprovider/mocks/QuoteProvider.go +++ b/pkg/attestation/quoteprovider/mocks/QuoteProvider.go @@ -10,42 +10,42 @@ import ( mock "github.com/stretchr/testify/mock" ) -// QuoteProvider is an autogenerated mock type for the QuoteProvider type -type QuoteProvider struct { +// LeveledQuoteProvider is an autogenerated mock type for the LeveledQuoteProvider type +type LeveledQuoteProvider struct { mock.Mock } -type QuoteProvider_Expecter struct { +type LeveledQuoteProvider_Expecter struct { mock *mock.Mock } -func (_m *QuoteProvider) EXPECT() *QuoteProvider_Expecter { - return &QuoteProvider_Expecter{mock: &_m.Mock} +func (_m *LeveledQuoteProvider) EXPECT() *LeveledQuoteProvider_Expecter { + return &LeveledQuoteProvider_Expecter{mock: &_m.Mock} } -// GetRawQuote provides a mock function with given fields: reportData -func (_m *QuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) { - ret := _m.Called(reportData) +// GetRawQuoteAtLevel provides a mock function with given fields: reportData, vmpl +func (_m *LeveledQuoteProvider) GetRawQuoteAtLevel(reportData [64]byte, vmpl uint) ([]uint8, error) { + ret := _m.Called(reportData, vmpl) if len(ret) == 0 { - panic("no return value specified for GetRawQuote") + panic("no return value specified for GetRawQuoteAtLevel") } var r0 []uint8 var r1 error - if rf, ok := ret.Get(0).(func([64]byte) ([]uint8, error)); ok { - return rf(reportData) + if rf, ok := ret.Get(0).(func([64]byte, uint) ([]uint8, error)); ok { + return rf(reportData, vmpl) } - if rf, ok := ret.Get(0).(func([64]byte) []uint8); ok { - r0 = rf(reportData) + if rf, ok := ret.Get(0).(func([64]byte, uint) []uint8); ok { + r0 = rf(reportData, vmpl) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]uint8) } } - if rf, ok := ret.Get(1).(func([64]byte) error); ok { - r1 = rf(reportData) + if rf, ok := ret.Get(1).(func([64]byte, uint) error); ok { + r1 = rf(reportData, vmpl) } else { r1 = ret.Error(1) } @@ -53,36 +53,37 @@ func (_m *QuoteProvider) GetRawQuote(reportData [64]byte) ([]uint8, error) { return r0, r1 } -// QuoteProvider_GetRawQuote_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRawQuote' -type QuoteProvider_GetRawQuote_Call struct { +// LeveledQuoteProvider_GetRawQuoteAtLevel_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRawQuoteAtLevel' +type LeveledQuoteProvider_GetRawQuoteAtLevel_Call struct { *mock.Call } -// GetRawQuote is a helper method to define mock.On call +// GetRawQuoteAtLevel is a helper method to define mock.On call // - reportData [64]byte -func (_e *QuoteProvider_Expecter) GetRawQuote(reportData interface{}) *QuoteProvider_GetRawQuote_Call { - return &QuoteProvider_GetRawQuote_Call{Call: _e.mock.On("GetRawQuote", reportData)} +// - vmpl uint +func (_e *LeveledQuoteProvider_Expecter) GetRawQuoteAtLevel(reportData interface{}, vmpl interface{}) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call { + return &LeveledQuoteProvider_GetRawQuoteAtLevel_Call{Call: _e.mock.On("GetRawQuoteAtLevel", reportData, vmpl)} } -func (_c *QuoteProvider_GetRawQuote_Call) Run(run func(reportData [64]byte)) *QuoteProvider_GetRawQuote_Call { +func (_c *LeveledQuoteProvider_GetRawQuoteAtLevel_Call) Run(run func(reportData [64]byte, vmpl uint)) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].([64]byte)) + run(args[0].([64]byte), args[1].(uint)) }) return _c } -func (_c *QuoteProvider_GetRawQuote_Call) Return(_a0 []uint8, _a1 error) *QuoteProvider_GetRawQuote_Call { +func (_c *LeveledQuoteProvider_GetRawQuoteAtLevel_Call) Return(_a0 []uint8, _a1 error) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *QuoteProvider_GetRawQuote_Call) RunAndReturn(run func([64]byte) ([]uint8, error)) *QuoteProvider_GetRawQuote_Call { +func (_c *LeveledQuoteProvider_GetRawQuoteAtLevel_Call) RunAndReturn(run func([64]byte, uint) ([]uint8, error)) *LeveledQuoteProvider_GetRawQuoteAtLevel_Call { _c.Call.Return(run) return _c } // IsSupported provides a mock function with given fields: -func (_m *QuoteProvider) IsSupported() bool { +func (_m *LeveledQuoteProvider) IsSupported() bool { ret := _m.Called() if len(ret) == 0 { @@ -99,35 +100,35 @@ func (_m *QuoteProvider) IsSupported() bool { return r0 } -// QuoteProvider_IsSupported_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsSupported' -type QuoteProvider_IsSupported_Call struct { +// LeveledQuoteProvider_IsSupported_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsSupported' +type LeveledQuoteProvider_IsSupported_Call struct { *mock.Call } // IsSupported is a helper method to define mock.On call -func (_e *QuoteProvider_Expecter) IsSupported() *QuoteProvider_IsSupported_Call { - return &QuoteProvider_IsSupported_Call{Call: _e.mock.On("IsSupported")} +func (_e *LeveledQuoteProvider_Expecter) IsSupported() *LeveledQuoteProvider_IsSupported_Call { + return &LeveledQuoteProvider_IsSupported_Call{Call: _e.mock.On("IsSupported")} } -func (_c *QuoteProvider_IsSupported_Call) Run(run func()) *QuoteProvider_IsSupported_Call { +func (_c *LeveledQuoteProvider_IsSupported_Call) Run(run func()) *LeveledQuoteProvider_IsSupported_Call { _c.Call.Run(func(args mock.Arguments) { run() }) return _c } -func (_c *QuoteProvider_IsSupported_Call) Return(_a0 bool) *QuoteProvider_IsSupported_Call { +func (_c *LeveledQuoteProvider_IsSupported_Call) Return(_a0 bool) *LeveledQuoteProvider_IsSupported_Call { _c.Call.Return(_a0) return _c } -func (_c *QuoteProvider_IsSupported_Call) RunAndReturn(run func() bool) *QuoteProvider_IsSupported_Call { +func (_c *LeveledQuoteProvider_IsSupported_Call) RunAndReturn(run func() bool) *LeveledQuoteProvider_IsSupported_Call { _c.Call.Return(run) return _c } // Product provides a mock function with given fields: -func (_m *QuoteProvider) Product() *sevsnp.SevProduct { +func (_m *LeveledQuoteProvider) Product() *sevsnp.SevProduct { ret := _m.Called() if len(ret) == 0 { @@ -146,40 +147,40 @@ func (_m *QuoteProvider) Product() *sevsnp.SevProduct { return r0 } -// QuoteProvider_Product_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Product' -type QuoteProvider_Product_Call struct { +// LeveledQuoteProvider_Product_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Product' +type LeveledQuoteProvider_Product_Call struct { *mock.Call } // Product is a helper method to define mock.On call -func (_e *QuoteProvider_Expecter) Product() *QuoteProvider_Product_Call { - return &QuoteProvider_Product_Call{Call: _e.mock.On("Product")} +func (_e *LeveledQuoteProvider_Expecter) Product() *LeveledQuoteProvider_Product_Call { + return &LeveledQuoteProvider_Product_Call{Call: _e.mock.On("Product")} } -func (_c *QuoteProvider_Product_Call) Run(run func()) *QuoteProvider_Product_Call { +func (_c *LeveledQuoteProvider_Product_Call) Run(run func()) *LeveledQuoteProvider_Product_Call { _c.Call.Run(func(args mock.Arguments) { run() }) return _c } -func (_c *QuoteProvider_Product_Call) Return(_a0 *sevsnp.SevProduct) *QuoteProvider_Product_Call { +func (_c *LeveledQuoteProvider_Product_Call) Return(_a0 *sevsnp.SevProduct) *LeveledQuoteProvider_Product_Call { _c.Call.Return(_a0) return _c } -func (_c *QuoteProvider_Product_Call) RunAndReturn(run func() *sevsnp.SevProduct) *QuoteProvider_Product_Call { +func (_c *LeveledQuoteProvider_Product_Call) RunAndReturn(run func() *sevsnp.SevProduct) *LeveledQuoteProvider_Product_Call { _c.Call.Return(run) return _c } -// NewQuoteProvider creates a new instance of QuoteProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// NewLeveledQuoteProvider creates a new instance of LeveledQuoteProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewQuoteProvider(t interface { +func NewLeveledQuoteProvider(t interface { mock.TestingT Cleanup(func()) -}) *QuoteProvider { - mock := &QuoteProvider{} +}) *LeveledQuoteProvider { + mock := &LeveledQuoteProvider{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/pkg/attestation/quoteprovider/sev.go b/pkg/attestation/quoteprovider/sev.go index fc8bf482..dc63e08a 100644 --- a/pkg/attestation/quoteprovider/sev.go +++ b/pkg/attestation/quoteprovider/sev.go @@ -14,7 +14,6 @@ import ( "time" "github.com/absmach/magistrala/pkg/errors" - "github.com/google/go-sev-guest/abi" "github.com/google/go-sev-guest/client" "github.com/google/go-sev-guest/proto/check" "github.com/google/go-sev-guest/proto/sevsnp" @@ -32,6 +31,7 @@ const ( reportDataSize = 64 sevProductNameMilan = "Milan" sevProductNameGenoa = "Genoa" + sevVMPL = 2 ) var ( @@ -42,7 +42,6 @@ var ( var ( errProductLine = errors.New(fmt.Sprintf("product name must be %s or %s", sevProductNameMilan, sevProductNameGenoa)) - errReportSize = errors.New("attestation report size mismatch") errAttVerification = errors.New("attestation verification failed") errAttValidation = errors.New("attestation validation failed") ) @@ -138,38 +137,28 @@ func validateReport(attestationPB *sevsnp.Attestation, cfg *check.Config) error return nil } -func GetQuoteProvider() (client.QuoteProvider, error) { - return client.GetQuoteProvider() +func GetLeveledQuoteProvider() (client.LeveledQuoteProvider, error) { + return client.GetLeveledQuoteProvider() } -func VerifyAttestationReportTLS(attestationBytes []byte, reportData []byte) error { +func VerifyAttestationReportTLS(attestationPB *sevsnp.Attestation, reportData []byte) error { config, err := copyConfig(&AttConfigurationSEVSNP) if err != nil { return errors.Wrap(fmt.Errorf("failed to create a copy of attestation policy"), err) } config.Policy.ReportData = reportData[:] - return VerifyAndValidate(attestationBytes, config) + return VerifyAndValidate(attestationPB, config) } -func VerifyAndValidate(attestationReport []byte, cfg *check.Config) error { +func VerifyAndValidate(attestationPB *sevsnp.Attestation, cfg *check.Config) error { logger.Init("", false, false, io.Discard) - if len(attestationReport) < attestationReportSize { - return errReportSize - } - attestationBytes := attestationReport[:attestationReportSize] - - attestationPB, err := abi.ReportCertsToProto(attestationBytes) - if err != nil { - return fmt.Errorf("failed to convert attestation bytes to struct %v", errors.Wrap(errAttVerification, err)) - } - - if err = verifyReport(attestationPB, cfg); err != nil { + if err := verifyReport(attestationPB, cfg); err != nil { return err } - if err = validateReport(attestationPB, cfg); err != nil { + if err := validateReport(attestationPB, cfg); err != nil { return err } @@ -179,7 +168,7 @@ func VerifyAndValidate(attestationReport []byte, cfg *check.Config) error { func FetchAttestation(reportDataSlice []byte) ([]byte, error) { var reportData [reportDataSize]byte - qp, err := GetQuoteProvider() + qp, err := GetLeveledQuoteProvider() if err != nil { return []byte{}, fmt.Errorf("could not get quote provider") } @@ -189,7 +178,7 @@ func FetchAttestation(reportDataSlice []byte) ([]byte, error) { } copy(reportData[:], reportDataSlice) - rawQuote, err := qp.GetRawQuote(reportData) + rawQuote, err := qp.GetRawQuoteAtLevel(reportData, sevVMPL) if err != nil { return []byte{}, fmt.Errorf("failed to get raw quote") } diff --git a/pkg/attestation/quoteprovider/sev_test.go b/pkg/attestation/quoteprovider/sev_test.go index 6f2e4250..e5b6c667 100644 --- a/pkg/attestation/quoteprovider/sev_test.go +++ b/pkg/attestation/quoteprovider/sev_test.go @@ -20,11 +20,6 @@ import ( "google.golang.org/protobuf/encoding/protojson" ) -const ( - measurementOffset = 0x90 - signatureOffset = 0x2A0 -) - func TestFillInAttestationLocal(t *testing.T) { tempDir, err := os.MkdirTemp("", "test_home") require.NoError(t, err) @@ -76,18 +71,18 @@ func TestFillInAttestationLocal(t *testing.T) { } func TestVerifyAttestationReportSuccess(t *testing.T) { - file, reportData := prepareForTestVerifyAttestationReport(t) + attestationPB, reportData := prepareForTestVerifyAttestationReport(t) tests := []struct { name string - attestationReport []byte + attestationReport *sevsnp.Attestation reportData []byte goodProduct int err error }{ { name: "Valid attestation, validation and verification is performed succsessfully", - attestationReport: file, + attestationReport: attestationPB, reportData: reportData, goodProduct: 1, err: nil, @@ -103,20 +98,20 @@ func TestVerifyAttestationReportSuccess(t *testing.T) { } func TestVerifyAttestationReportMalformedSignature(t *testing.T) { - file, reportData := prepareForTestVerifyAttestationReport(t) + attestationPB, reportData := prepareForTestVerifyAttestationReport(t) // Change random data so in the signature so the signature failes - file[signatureOffset] = file[signatureOffset] ^ 0x01 + attestationPB.Report.Signature[0] = attestationPB.Report.Signature[0] ^ 0x01 tests := []struct { name string - attestationReport []byte + attestationReport *sevsnp.Attestation reportData []byte err error }{ { name: "Valid attestation, distorted signature", - attestationReport: file, + attestationReport: attestationPB, reportData: reportData, err: errAttVerification, }, @@ -131,17 +126,17 @@ func TestVerifyAttestationReportMalformedSignature(t *testing.T) { } func TestVerifyAttestationReportUnknownProduct(t *testing.T) { - file, reportData := prepareForTestVerifyAttestationReport(t) + attestationPB, reportData := prepareForTestVerifyAttestationReport(t) tests := []struct { name string - attestationReport []byte + attestationReport *sevsnp.Attestation reportData []byte err error }{ { name: "Valid attestation, unknown product", - attestationReport: file, + attestationReport: attestationPB, reportData: reportData, err: errProductLine, }, @@ -158,20 +153,20 @@ func TestVerifyAttestationReportUnknownProduct(t *testing.T) { } func TestVerifyAttestationReportMalformedPolicy(t *testing.T) { - file, reportData := prepareForTestVerifyAttestationReport(t) + attestationPB, reportData := prepareForTestVerifyAttestationReport(t) // Change random data in the measurement so the measurement does not match - file[measurementOffset] = file[measurementOffset] ^ 0x01 + attestationPB.Report.Measurement[0] = attestationPB.Report.Measurement[0] ^ 0x01 tests := []struct { name string - attestationReport []byte + attestationReport *sevsnp.Attestation reportData []byte err error }{ { name: "Valid attestation, malformed policy (measurement)", - attestationReport: file, + attestationReport: attestationPB, reportData: reportData, err: errAttVerification, }, @@ -185,17 +180,17 @@ func TestVerifyAttestationReportMalformedPolicy(t *testing.T) { } } -func prepareForTestVerifyAttestationReport(t *testing.T) ([]byte, []byte) { +func prepareForTestVerifyAttestationReport(t *testing.T) (*sevsnp.Attestation, []byte) { file, err := os.ReadFile("../../../attestation.bin") require.NoError(t, err) - rr, err := abi.ReportCertsToProto(file) - require.NoError(t, err) - if len(file) < attestationReportSize { file = append(file, make([]byte, attestationReportSize-len(file))...) } + rr, err := abi.ReportCertsToProto(file) + require.NoError(t, err) + AttConfigurationSEVSNP = check.Config{Policy: &check.Policy{}, RootOfTrust: &check.RootOfTrust{}} attestationPolicyFile, err := os.ReadFile("../../../scripts/attestation_policy/attestation_policy.json") @@ -212,5 +207,5 @@ func prepareForTestVerifyAttestationReport(t *testing.T) ([]byte, []byte) { AttConfigurationSEVSNP.Policy.ReportIdMa = rr.Report.ReportIdMa AttConfigurationSEVSNP.RootOfTrust.ProductLine = sevProductNameMilan - return file, rr.Report.ReportData + return rr, rr.Report.ReportData } diff --git a/pkg/attestation/vtpm/dummy.go b/pkg/attestation/vtpm/dummy.go new file mode 100644 index 00000000..39f13f73 --- /dev/null +++ b/pkg/attestation/vtpm/dummy.go @@ -0,0 +1,27 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package vtpm + +type DummyRWC struct{} + +// Read fills p with byte(len(p)) and returns len(p). +func (l *DummyRWC) Read(p []byte) (int, error) { + n := len(p) + // Fill each byte in p with the value of n as a byte. + for i := range p { + p[i] = byte(n) + } + return n, nil +} + +// Write simply returns len(p) indicating that all bytes were written. +func (l *DummyRWC) Write(p []byte) (int, error) { + // In this simple implementation, we ignore the data. + return len(p), nil +} + +// Close does nothing. +func (l *DummyRWC) Close() error { + return nil +} diff --git a/pkg/attestation/vtpm/vtpm.go b/pkg/attestation/vtpm/vtpm.go new file mode 100644 index 00000000..8bb4e138 --- /dev/null +++ b/pkg/attestation/vtpm/vtpm.go @@ -0,0 +1,230 @@ +// Copyright (c) Ultraviolet +// SPDX-License-Identifier: Apache-2.0 + +package vtpm + +import ( + "crypto" + "crypto/x509" + "fmt" + "io" + "os" + + "github.com/google/go-sev-guest/abi" + "github.com/google/go-tpm-tools/client" + pb "github.com/google/go-tpm-tools/proto/attest" + "github.com/google/go-tpm-tools/server" + "github.com/google/go-tpm/legacy/tpm2" + "github.com/google/go-tpm/tpmutil" + "github.com/ultravioletrs/cocos/pkg/attestation/quoteprovider" + "golang.org/x/crypto/sha3" + "google.golang.org/protobuf/proto" +) + +const ( + eventLog = "/sys/kernel/security/tpm0/binary_bios_measurements" + Nonce = 32 + PCR15 = 15 +) + +var ExternalTPM io.ReadWriteCloser + +type tpmWrapper struct { + io.ReadWriteCloser +} + +func (et tpmWrapper) EventLog() ([]byte, error) { + return os.ReadFile(eventLog) +} + +func OpenTpm() (io.ReadWriteCloser, error) { + if ExternalTPM != nil { + return tpmWrapper{ExternalTPM}, nil + } + + tw := tpmWrapper{} + var err error + + tw.ReadWriteCloser, err = tpm2.OpenTPM("/dev/tpmrm0") + if os.IsNotExist(err) { + tw.ReadWriteCloser, err = tpm2.OpenTPM("/dev/tpm0") + } + + return tw, err +} + +func ExtendPCR(pcrIndex int, value []byte) error { + rwc, err := OpenTpm() + if err != nil { + return err + } + defer rwc.Close() + + if err := tpm2.PCRExtend(rwc, tpmutil.Handle(pcrIndex), tpm2.AlgSHA256, value, ""); err != nil { + return err + } + + if err := tpm2.PCRExtend(rwc, tpmutil.Handle(pcrIndex), tpm2.AlgSHA384, value, ""); err != nil { + return err + } + + return nil +} + +func Attest(teeNonce []byte, vTPMNonce []byte, teeAttestaion bool) ([]byte, error) { + attestation, err := fetchVTPMQuote(vTPMNonce) + if err != nil { + return []byte{}, err + } + + if teeAttestaion { + attestation, err = addTEEAttestation(attestation, teeNonce) + if err != nil { + return []byte{}, err + } + } + + return marshalQuote(attestation) +} + +func FetchATLSQuote(pubKey, teeNonce, vTPMNonce []byte) ([]byte, error) { + attestation, err := fetchVTPMQuote(vTPMNonce) + if err != nil { + return []byte{}, err + } + + reportData, err := createTEEAttestationReportNonce(pubKey, attestation.GetAkPub(), teeNonce) + if err != nil { + return []byte{}, err + } + + attestation, err = addTEEAttestation(attestation, reportData) + if err != nil { + return []byte{}, err + } + + return marshalQuote(attestation) +} + +func VTPMVerify(quote []byte, pubKeyTLS []byte, teeNonce []byte, vtpmNonce []byte) error { + attestation := &pb.Attestation{} + + err := proto.Unmarshal(quote, attestation) + if err != nil { + return fmt.Errorf("fail to unmarshal quote: %v", err) + } + + ak := attestation.GetAkPub() + pub, err := tpm2.DecodePublic(ak) + if err != nil { + return err + } + + cryptoPub, err := pub.Key() + if err != nil { + return err + } + + reportData, err := createTEEAttestationReportNonce(pubKeyTLS, ak, teeNonce) + if err != nil { + return fmt.Errorf("fail to calculate report data: %v", err) + } + + if err := quoteprovider.VerifyAttestationReportTLS(attestation.GetSevSnpAttestation(), reportData); err != nil { + return fmt.Errorf("failed to verify TEE attestation report: %v", err) + } + + _, err = server.VerifyAttestation(attestation, server.VerifyOpts{Nonce: vtpmNonce, TrustedAKs: []crypto.PublicKey{cryptoPub}}) + if err != nil { + return fmt.Errorf("verifying attestation: %w", err) + } + + return nil +} + +func publicKeyToBytes(pubKey interface{}) ([]byte, error) { + derBytes, err := x509.MarshalPKIXPublicKey(pubKey) + if err != nil { + return nil, err + } + return derBytes, nil +} + +func createTEEAttestationReportNonce(pubKeyTLS []byte, ak []byte, nonce []byte) ([]byte, error) { + pub, err := tpm2.DecodePublic(ak) + if err != nil { + return []byte{}, err + } + + cryptoPub, err := pub.Key() + if err != nil { + return []byte{}, err + } + + pubKeyBytes, err := publicKeyToBytes(cryptoPub) + if err != nil { + return []byte{}, err + } + + reportData := append(append(pubKeyTLS, pubKeyBytes...), nonce...) + hash := sha3.Sum512(reportData) + + return hash[:], nil +} + +func marshalQuote(attestation *pb.Attestation) ([]byte, error) { + out, err := proto.Marshal(attestation) + if err != nil { + return []byte{}, fmt.Errorf("failed to marshal vTPM attestation report: %v", err) + } + + return out, nil +} + +func fetchVTPMQuote(nonce []byte) (*pb.Attestation, error) { + rwc, err := OpenTpm() + if err != nil { + return nil, err + } + defer rwc.Close() + + attestationKey, err := client.AttestationKeyRSA(rwc) + if err != nil { + return nil, fmt.Errorf("failed to create attestation key: %v", err) + } + defer attestationKey.Close() + + var fixedNonce [Nonce]byte + copy(fixedNonce[:], nonce) + attestOpts := client.AttestOpts{} + attestOpts.Nonce = fixedNonce[:] + + attestOpts.TCGEventLog, err = client.GetEventLog(rwc) + if err != nil { + return nil, fmt.Errorf("failed to retrieve TCG Event Log: %w", err) + } + + attestation, err := attestationKey.Attest(attestOpts) + if err != nil { + return nil, fmt.Errorf("failed to collect attestation report: %v", err) + } + + return attestation, nil +} + +func addTEEAttestation(attestation *pb.Attestation, nonce []byte) (*pb.Attestation, error) { + rawTeeAttestation, err := quoteprovider.FetchAttestation(nonce) + if err != nil { + return attestation, fmt.Errorf("failed to fetch TEE attestation report: %v", err) + } + + extReport, err := abi.ReportCertsToProto(rawTeeAttestation) + if err != nil { + return attestation, fmt.Errorf("failed to export the TEE report: %v", err) + } + attestation.TeeAttestation = &pb.Attestation_SevSnpAttestation{ + SevSnpAttestation: extReport, + } + + return attestation, nil +} diff --git a/pkg/sdk/agent.go b/pkg/sdk/agent.go index 7b8cbc68..5639c29a 100644 --- a/pkg/sdk/agent.go +++ b/pkg/sdk/agent.go @@ -26,11 +26,12 @@ type SDK interface { Algo(ctx context.Context, algorithm, requirements *os.File, privKey any) error Data(ctx context.Context, dataset *os.File, filename string, privKey any) error Result(ctx context.Context, privKey any, resultFile *os.File) error - Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error + Attestation(ctx context.Context, reportData [size64]byte, nonce [size32]byte, attType int, attestationFile *os.File) error } const ( size64 = 64 + size32 = 32 algoProgressBarDescription = "Uploading algorithm" dataProgressBarDescription = "Uploading data" resultProgressDescription = "Downloading result" @@ -120,9 +121,11 @@ func (sdk *agentSDK) Result(ctx context.Context, privKey any, resultFile *os.Fil return pb.ReceiveResult(resultProgressDescription, fileSize, stream, resultFile) } -func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, attestationFile *os.File) error { +func (sdk *agentSDK) Attestation(ctx context.Context, reportData [size64]byte, nonce [size32]byte, attType int, attestationFile *os.File) error { request := &agent.AttestationRequest{ - ReportData: reportData[:], + TeeNonce: reportData[:], + VtpmNonce: nonce[:], + Type: int32(attType), } stream, err := sdk.client.Attestation(ctx, request) diff --git a/pkg/sdk/agent_test.go b/pkg/sdk/agent_test.go index c3379722..c73b7cec 100644 --- a/pkg/sdk/agent_test.go +++ b/pkg/sdk/agent_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/ultravioletrs/cocos/agent" + "github.com/ultravioletrs/cocos/pkg/attestation/vtpm" "github.com/ultravioletrs/cocos/pkg/sdk" "golang.org/x/crypto/sha3" "google.golang.org/grpc" @@ -364,6 +365,7 @@ func TestAttestation(t *testing.T) { resultConsumer1Key, _ := generateKeys(t, "ed25519") reportData := make([]byte, 64) + nonce := make([]byte, 64) report := []byte{ 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, @@ -385,7 +387,8 @@ func TestAttestation(t *testing.T) { cases := []struct { name string userKey any - reportData [agent.ReportDataSize]byte + reportData [agent.Nonce]byte + nonce [vtpm.Nonce]byte response *agent.AttestationResponse svcRes []byte err error @@ -393,7 +396,8 @@ func TestAttestation(t *testing.T) { { name: "fetch attestation report successfully", userKey: resultConsumerKey, - reportData: [agent.ReportDataSize]byte(reportData), + reportData: [agent.Nonce]byte(reportData), + nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: report, }, @@ -403,7 +407,8 @@ func TestAttestation(t *testing.T) { { name: "fetch attestation report with different key type", userKey: resultConsumer1Key, - reportData: [agent.ReportDataSize]byte(reportData), + reportData: [agent.Nonce]byte(reportData), + nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: report, }, @@ -413,7 +418,8 @@ func TestAttestation(t *testing.T) { { name: "failed to fetch attestation report", userKey: resultConsumerKey, - reportData: [agent.ReportDataSize]byte(reportData), + reportData: [agent.Nonce]byte(reportData), + nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: []byte{}, }, @@ -422,7 +428,8 @@ func TestAttestation(t *testing.T) { { name: "invalid report data", userKey: resultConsumerKey, - reportData: [agent.ReportDataSize]byte{}, + reportData: [agent.Nonce]byte{}, + nonce: [vtpm.Nonce]byte(nonce), response: &agent.AttestationResponse{ File: []byte{}, }, @@ -442,7 +449,7 @@ func TestAttestation(t *testing.T) { os.Remove(file.Name()) }) - err = sdk.Attestation(context.Background(), tc.reportData, file) + err = sdk.Attestation(context.Background(), tc.reportData, tc.nonce, 0, file) require.NoError(t, file.Close()) diff --git a/pkg/sdk/mocks/sdk.go b/pkg/sdk/mocks/sdk.go index 47e5909a..1c0d93ea 100644 --- a/pkg/sdk/mocks/sdk.go +++ b/pkg/sdk/mocks/sdk.go @@ -74,17 +74,17 @@ func (_c *SDK_Algo_Call) RunAndReturn(run func(context.Context, *os.File, *os.Fi return _c } -// Attestation provides a mock function with given fields: ctx, reportData, attestationFile -func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, attestationFile *os.File) error { - ret := _m.Called(ctx, reportData, attestationFile) +// Attestation provides a mock function with given fields: ctx, reportData, nonce, attType, attestationFile +func (_m *SDK) Attestation(ctx context.Context, reportData [64]byte, nonce [32]byte, attType int, attestationFile *os.File) error { + ret := _m.Called(ctx, reportData, nonce, attType, attestationFile) if len(ret) == 0 { panic("no return value specified for Attestation") } var r0 error - if rf, ok := ret.Get(0).(func(context.Context, [64]byte, *os.File) error); ok { - r0 = rf(ctx, reportData, attestationFile) + if rf, ok := ret.Get(0).(func(context.Context, [64]byte, [32]byte, int, *os.File) error); ok { + r0 = rf(ctx, reportData, nonce, attType, attestationFile) } else { r0 = ret.Error(0) } @@ -100,14 +100,16 @@ type SDK_Attestation_Call struct { // Attestation is a helper method to define mock.On call // - ctx context.Context // - reportData [64]byte +// - nonce [32]byte +// - attType int // - attestationFile *os.File -func (_e *SDK_Expecter) Attestation(ctx interface{}, reportData interface{}, attestationFile interface{}) *SDK_Attestation_Call { - return &SDK_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, attestationFile)} +func (_e *SDK_Expecter) Attestation(ctx interface{}, reportData interface{}, nonce interface{}, attType interface{}, attestationFile interface{}) *SDK_Attestation_Call { + return &SDK_Attestation_Call{Call: _e.mock.On("Attestation", ctx, reportData, nonce, attType, attestationFile)} } -func (_c *SDK_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, attestationFile *os.File)) *SDK_Attestation_Call { +func (_c *SDK_Attestation_Call) Run(run func(ctx context.Context, reportData [64]byte, nonce [32]byte, attType int, attestationFile *os.File)) *SDK_Attestation_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].([64]byte), args[2].(*os.File)) + run(args[0].(context.Context), args[1].([64]byte), args[2].([32]byte), args[3].(int), args[4].(*os.File)) }) return _c } @@ -117,7 +119,7 @@ func (_c *SDK_Attestation_Call) Return(_a0 error) *SDK_Attestation_Call { return _c } -func (_c *SDK_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, *os.File) error) *SDK_Attestation_Call { +func (_c *SDK_Attestation_Call) RunAndReturn(run func(context.Context, [64]byte, [32]byte, int, *os.File) error) *SDK_Attestation_Call { _c.Call.Return(run) return _c }