Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sa: GetRevokedCerts returns explicit shards too #7918

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 133 additions & 73 deletions sa/sa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"os"
"reflect"
"slices"
"sort"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -3353,120 +3354,179 @@ func TestGetRevokedCerts(t *testing.T) {
test.AssertEquals(t, count, 0)
}

func TestGetRevokedCertsByShard(t *testing.T) {
sa, _, cleanUp := initSA(t)
func TestGetRevokedCertsWithShard(t *testing.T) {
sa, fc, cleanUp := initSA(t)
defer cleanUp()

// Add a cert to the DB to test with. We use AddPrecertificate because it sets
// up the certificateStatus row we need. This particular cert has a notAfter
// date of Mar 6 2023, and we lie about its IssuerNameID to make things easy.
reg := createWorkingRegistration(t, sa)
eeCert, err := core.LoadCert("../test/hierarchy/ee-e1.cert.pem")
test.AssertNotError(t, err, "failed to load test cert")
_, err = sa.AddSerial(ctx, &sapb.AddSerialRequest{
RegID: reg.Id,
Serial: core.SerialToString(eeCert.SerialNumber),
Created: timestamppb.New(eeCert.NotBefore),
Expires: timestamppb.New(eeCert.NotAfter),
})
test.AssertNotError(t, err, "failed to add test serial")
_, err = sa.AddPrecertificate(ctx, &sapb.AddCertificateRequest{
Der: eeCert.Raw,
RegID: reg.Id,
Issued: timestamppb.New(eeCert.NotBefore),
IssuerNameID: 1,
})
test.AssertNotError(t, err, "failed to add test cert")

// Check that it worked.
status, err := sa.GetCertificateStatus(
ctx, &sapb.Serial{Serial: core.SerialToString(eeCert.SerialNumber)})
test.AssertNotError(t, err, "GetCertificateStatus failed")
test.AssertEquals(t, core.OCSPStatus(status.Status), core.OCSPStatusGood)
fc.Set(mustTime("2023-03-01 00:00"))

// Here's a little helper func we'll use to call GetRevokedCerts and count
// how many results it returned.
countRevokedCerts := func(req *sapb.GetRevokedCertsRequest) (int, error) {
// Make up an IssuerNameID to make things simpler.
issuerNameID := int64(834)

// Create a certificate and add it to the tables we need.
makeCert := func() *x509.Certificate {
_, cert := test.ThrowAwayCert(t, fc)
// We depend on specifics of the lifetime set by test.ThrowAwayCert, so verify.
lifetime := cert.NotAfter.Sub(cert.NotBefore)
if lifetime != 6*24*time.Hour {
t.Fatalf("cert lifetime: got %s, want 6 days", lifetime)
}
_, err := sa.AddSerial(ctx, &sapb.AddSerialRequest{
RegID: reg.Id,
Serial: core.SerialToString(cert.SerialNumber),
Created: timestamppb.New(cert.NotBefore),
Expires: timestamppb.New(cert.NotAfter),
})
if err != nil {
t.Fatalf("adding serial: %s", err)
}
_, err = sa.AddPrecertificate(ctx, &sapb.AddCertificateRequest{
Der: cert.Raw,
RegID: reg.Id,
Issued: timestamppb.New(cert.NotBefore),
IssuerNameID: issuerNameID,
})
if err != nil {
t.Fatalf("adding cert: %s", err)
}
status, err := sa.GetCertificateStatus(ctx, &sapb.Serial{Serial: core.SerialToString(cert.SerialNumber)})
if err != nil {
t.Fatalf("GetCertificateStatus: %s", err)
}
if status.Status != string(core.OCSPStatusGood) {
t.Fatalf("GetCertificateStatus for new cert: got %s, want %s", status.Status, core.OCSPStatusGood)
}
return cert
}

// Two certs issued at the same time, with the same expiration.
// eeCert1 will be revoked without an explicit ShardIdx.
// eeCert2 will be revoked _with_ an explicit ShardIdx.
eeCert1 := makeCert()
eeCert2 := makeCert()

// eeCert3 is issued two days after the others and will be revoked
// with the same explicit ShardIdx. It will show up in a different
// temporal shard than eeCert1 and eeCert2, because we are querying
// as if the shard width for CRLs is one day.
fc.Add(2 * 24 * time.Hour)
eeCert3 := makeCert()

// Here's a little helper func we'll use to call GetRevokedCerts and return
// a sorted list of serials.
getRevokedCerts := func(req *sapb.GetRevokedCertsRequest) []string {
stream := make(chan *corepb.CRLEntry)
mockServerStream := &fakeServerStream[corepb.CRLEntry]{output: stream}
var err error
go func() {
err = sa.GetRevokedCerts(req, mockServerStream)
close(stream)
}()
entriesReceived := 0
for range stream {
entriesReceived++
var serials []string
for e := range stream {
serials = append(serials, e.Serial)
}
return entriesReceived, err
if err != nil {
t.Fatalf("GetRevokedCerts(%+v): %s", req, err)
}
return serials
}

// The basic request covers a time range and shard that should include this certificate.
// The basic request covers a time range that includes eeCert1's and eeCert2's NotAfter,
// but excludes eeCert3's NotAfter.
// The ExpiresBefore field is set based on the 6-day lifetime of certs from test.ThrowAwayCert
basicRequest := &sapb.GetRevokedCertsRequest{
IssuerNameID: 1,
ShardIdx: 9,
IssuerNameID: issuerNameID,
ShardIdx: 97,
ExpiresAfter: mustTimestamp("2023-03-01 00:00"),
ExpiresBefore: mustTimestamp("2023-03-08 00:00"),
RevokedBefore: mustTimestamp("2023-04-01 00:00"),
}

// Nothing's been revoked yet. Count should be zero.
count, err := countRevokedCerts(basicRequest)
test.AssertNotError(t, err, "zero rows shouldn't result in error")
test.AssertEquals(t, count, 0)

// Revoke the certificate, providing the ShardIdx so it gets written into
// both the certificateStatus and revokedCertificates tables.
_, err = sa.RevokeCertificate(context.Background(), &sapb.RevokeCertificateRequest{
IssuerID: 1,
Serial: core.SerialToString(eeCert.SerialNumber),
Date: mustTimestamp("2023-01-01 00:00"),
Reason: 1,
Response: []byte{1, 2, 3},
ShardIdx: 9,
})
test.AssertNotError(t, err, "failed to revoke test cert")
// Nothing's been revoked yet. Should get no results.
serials := getRevokedCerts(basicRequest)
if len(serials) > 0 {
t.Errorf("GetRevokedCerts (before revocations) = %s, want []", serials)
}

// Check that it worked in the most basic way.
c, err := sa.dbMap.SelectNullInt(
ctx, "SELECT count(*) FROM revokedCertificates")
test.AssertNotError(t, err, "SELECT from revokedCertificates failed")
test.Assert(t, c.Valid, "SELECT from revokedCertificates got no result")
test.AssertEquals(t, c.Int64, int64(1))
revoke := func(cert *x509.Certificate, shardIdx int64) {
t.Logf("revoking %x with shardIdx %d", cert.SerialNumber, shardIdx)
_, err := sa.RevokeCertificate(context.Background(), &sapb.RevokeCertificateRequest{
IssuerID: issuerNameID,
Serial: core.SerialToString(cert.SerialNumber),
Date: mustTimestamp("2023-03-04 00:00"),
Reason: 1,
Response: []byte{1, 2, 3},
ShardIdx: shardIdx,
})
if err != nil {
t.Fatalf("sa.RevokeCertificate %s", err)
}
}

// Asking for revoked certs now should return one result.
count, err = countRevokedCerts(basicRequest)
test.AssertNotError(t, err, "normal usage shouldn't result in error")
test.AssertEquals(t, count, 1)
// First certificate: revoke without ShardIdx
revoke(eeCert1, 0)
// Second certificate: revoke with ShardIdx = 97.
revoke(eeCert2, 97)
// Third certificate: revoke with ShardIdx = 97.
// But note that the temporal shard is different from the other two.
revoke(eeCert3, 97)

// expectSerials registers an error if the provided serials don't match the serials
// of the provided certs (after sorting).
expectSerials := func(message string, serials []string, certs ...*x509.Certificate) {
t.Helper()
var expectedSerials []string
for _, c := range certs {
expectedSerials = append(expectedSerials, core.SerialToString(c.SerialNumber))
}
sort.Strings(expectedSerials)
sort.Strings(serials)
if !reflect.DeepEqual(serials, expectedSerials) {
t.Errorf("%s: want %s, got %s", message, expectedSerials, serials)
}
}
serials = getRevokedCerts(basicRequest)
expectSerials("GetRevokedCerts (after revocations)", serials, eeCert1, eeCert2, eeCert3)

// Asking for revoked certs from a different issuer should return zero results.
count, err = countRevokedCerts(&sapb.GetRevokedCertsRequest{
serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{
IssuerNameID: 5678,
ShardIdx: basicRequest.ShardIdx,
ExpiresAfter: basicRequest.ExpiresAfter,
ExpiresBefore: basicRequest.ExpiresBefore,
RevokedBefore: basicRequest.RevokedBefore,
})
test.AssertNotError(t, err, "zero rows shouldn't result in error")
test.AssertEquals(t, count, 0)
expectSerials("GetRevokedCerts with nonexistent issuer", serials)

// Asking for revoked certs from a different shard should return zero results.
count, err = countRevokedCerts(&sapb.GetRevokedCertsRequest{
serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{
IssuerNameID: basicRequest.IssuerNameID,
ShardIdx: 0,
ExpiresAfter: basicRequest.ExpiresAfter,
ExpiresBefore: basicRequest.ExpiresBefore,
RevokedBefore: basicRequest.RevokedBefore,
})
expectSerials("GetRevokedCerts with no shardIdx specified (temporal sharding only)", serials, eeCert1, eeCert2)

serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{
IssuerNameID: basicRequest.IssuerNameID,
ShardIdx: 8,
ExpiresAfter: basicRequest.ExpiresAfter,
ExpiresBefore: basicRequest.ExpiresBefore,
RevokedBefore: basicRequest.RevokedBefore,
})
test.AssertNotError(t, err, "zero rows shouldn't result in error")
test.AssertEquals(t, count, 0)
expectSerials("GetRevokedCerts for explicit shard with no revocations (temporal sharding only)", serials, eeCert1, eeCert2)

// Asking for revoked certs with an old RevokedBefore should return no results.
count, err = countRevokedCerts(&sapb.GetRevokedCertsRequest{
serials = getRevokedCerts(&sapb.GetRevokedCertsRequest{
IssuerNameID: basicRequest.IssuerNameID,
ShardIdx: basicRequest.ShardIdx,
ExpiresAfter: basicRequest.ExpiresAfter,
ExpiresBefore: basicRequest.ExpiresBefore,
RevokedBefore: mustTimestamp("2020-03-01 00:00"),
})
test.AssertNotError(t, err, "zero rows shouldn't result in error")
test.AssertEquals(t, count, 0)
expectSerials("GetRevokedCerts for old RevokedBefore", serials)
}

func TestGetMaxExpiration(t *testing.T) {
Expand Down
48 changes: 42 additions & 6 deletions sa/saro.go
Original file line number Diff line number Diff line change
Expand Up @@ -1050,18 +1050,54 @@ func (ssa *SQLStorageAuthorityRO) SerialsForIncident(req *sapb.SerialsForInciden
})
}

// GetRevokedCerts gets a request specifying an issuer and a period of time,
// and writes to the output stream the set of all certificates issued by that
// issuer which expire during that period of time and which have been revoked.
// crlDeduper implements grpc.ServerStreamingServer[corepb.CRLEntry].
//
// It passes CRLEntry's to the inner ServerStreamingServer, with the
// exception that it omits any CRLEntry with the same serial as a previously
// sent one.
type crlDeduper struct {
grpc.ServerStreamingServer[corepb.CRLEntry]

seen map[string]bool
}

func (cd crlDeduper) Send(crl *corepb.CRLEntry) error {
if !cd.seen[crl.Serial] {
cd.seen[crl.Serial] = true
return cd.ServerStreamingServer.Send(crl)
}
return nil
}

// GetRevokedCerts returns a stream of revoked certificates for a single CRL shard.
//
// If ShardIdx is zero, GetRevokedCerts calculates shard membership based
// solely on temporal sharding.
//
// If ShardIdx is nonzero, GetRevokedCerts calculates shard membership based
// on temporal sharding _and_ explicit sharding (that is, sharding based on
// the shardIdx field of the revokedCertificates table). Most revoked certificates
// will be present in two shards: one based on explicit sharding and one based
// on temporal sharding (a few will have the same shard for both).
//
// The starting timestamp is treated as inclusive (certs with exactly that
// notAfter date are included), but the ending timestamp is exclusive (certs
// with exactly that notAfter date are *not* included).
func (ssa *SQLStorageAuthorityRO) GetRevokedCerts(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error {
if core.IsAnyNilOrZero(req.IssuerNameID, req.ExpiresAfter, req.ExpiresBefore, req.RevokedBefore) {
return errors.New("incomplete request for GetRevokedCerts")
}
crlDeduper := crlDeduper{
ServerStreamingServer: stream,
seen: make(map[string]bool),
}
if req.ShardIdx != 0 {
return ssa.getRevokedCertsFromRevokedCertificatesTable(req, stream)
} else {
return ssa.getRevokedCertsFromCertificateStatusTable(req, stream)
err := ssa.getRevokedCertsFromRevokedCertificatesTable(req, crlDeduper)
if err != nil {
return err
}
}
return ssa.getRevokedCertsFromCertificateStatusTable(req, crlDeduper)
}

// getRevokedCertsFromRevokedCertificatesTable uses the new revokedCertificates
Expand Down
6 changes: 5 additions & 1 deletion test/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,18 @@ func LoadSigner(filename string) (crypto.Signer, error) {
// ThrowAwayCert is a small test helper function that creates a self-signed
// certificate with one SAN. It returns the parsed certificate and its serial
// in string form for convenience.
//
// The certificate returned from this function is the bare minimum needed for
// most tests and isn't a robust example of a complete end entity certificate.
//
// Returned certificates have NotBefore == clk.Now(), and NotBefore 6 days in the
// future.
func ThrowAwayCert(t *testing.T, clk clock.Clock) (string, *x509.Certificate) {
var nameBytes [3]byte
_, _ = rand.Read(nameBytes[:])
name := fmt.Sprintf("%s.example.com", hex.EncodeToString(nameBytes[:]))

var serialBytes [16]byte
var serialBytes [18]byte
_, _ = rand.Read(serialBytes[:])
serial := big.NewInt(0).SetBytes(serialBytes[:])

Expand Down
Loading