diff --git a/pkg/attest/attest.go b/pkg/attest/attest.go index 5dd82c48..56df0b9d 100644 --- a/pkg/attest/attest.go +++ b/pkg/attest/attest.go @@ -124,30 +124,22 @@ func (certState *CertState) Attest(maa MAA, runtimeDataBytes []byte, uvmInformat } if SNPReport.ReportedTCB != certState.Tcbm { - // TCB values still don't match, try retrieving the SNP report again - SNPReportBytes, inittimeDataBytes, err = GetSNPReport(uvmInformation.EncodedSecurityPolicy, runtimeDataBytes) + SNPReportBytes, inittimeDataBytes, vcekCertChain, err = certState.refreshSNPReportAndCertChain(uvmInformation.EncodedSecurityPolicy, runtimeDataBytes) if err != nil { - return "", errors.Wrapf(err, "failed to retrieve new attestation report") - } - - if err = SNPReport.DeserializeReport(SNPReportBytes); err != nil { - return "", errors.Wrapf(err, "failed to deserialize new attestation report") + return "", err } - - // refresh certs again - vcekCertChain, err = certState.RefreshCertChain(SNPReport) + } + } else { + //In case the initialCerts are not properly configured, we should fall back on fetching the vcekCert remotely + if uvmInformation.InitialCerts.VcekCert == "" || uvmInformation.InitialCerts.CertificateChain == "" { + SNPReportBytes, inittimeDataBytes, vcekCertChain, err = certState.refreshSNPReportAndCertChain(uvmInformation.EncodedSecurityPolicy, runtimeDataBytes) if err != nil { return "", err } - - // if no match after refreshing certs and attestation report, fail - if SNPReport.ReportedTCB != certState.Tcbm { - return "", errors.New(fmt.Sprintf("SNP reported TCB value: %d doesn't match Certificate TCB value: %d", SNPReport.ReportedTCB, certState.Tcbm)) - } + } else { + certString := uvmInformation.InitialCerts.VcekCert + uvmInformation.InitialCerts.CertificateChain + vcekCertChain = []byte(certString) } - } else { - certString := uvmInformation.InitialCerts.VcekCert + uvmInformation.InitialCerts.CertificateChain - vcekCertChain = []byte(certString) } uvmReferenceInfoBytes, err := base64.StdEncoding.DecodeString(uvmInformation.EncodedUvmReferenceInfo) @@ -164,3 +156,28 @@ func (certState *CertState) Attest(maa MAA, runtimeDataBytes []byte, uvmInformat return maaToken, nil } + +func (certState *CertState) refreshSNPReportAndCertChain(securityPolicy string, runtimeDataBytes []byte) ([]byte, []byte, []byte, error) { + var SNPReport SNPAttestationReport + SNPReportBytes, inittimeDataBytes, err := GetSNPReport(securityPolicy, runtimeDataBytes) + if err != nil { + return nil, nil, nil, errors.Wrapf(err, "failed to retrieve new attestation report") + } + + if err = SNPReport.DeserializeReport(SNPReportBytes); err != nil { + return nil, nil, nil, errors.Wrapf(err, "failed to deserialize new attestation report") + } + + // refresh certs again + vcekCertChain, err := certState.RefreshCertChain(SNPReport) + if err != nil { + return nil, nil, nil, err + } + + // if no match after refreshing certs and attestation report, fail + if SNPReport.ReportedTCB != certState.Tcbm { + return nil, nil, nil, errors.New(fmt.Sprintf("SNP reported TCB value: %d doesn't match Certificate TCB value: %d", SNPReport.ReportedTCB, certState.Tcbm)) + } + + return SNPReportBytes, inittimeDataBytes, vcekCertChain, nil +}