diff --git a/sa/sa_test.go b/sa/sa_test.go index 45afe52f711..65bbd376740 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -3391,6 +3391,13 @@ func TestGetRevokedCertsWithShard(t *testing.T) { if err != nil { t.Fatalf("adding cert: %s", err) } + status, err := sa.GetCertificateStatus(ctx, &sapb.Serial{Serial: core.SerialToString(cert.SerialNumber)}) + if err != nil { + t.Fatalf("GetCertificateStatus: %s", err) + } + if status.Status != string(core.OCSPStatusGood) { + t.Fatalf("GetCertificateStatus for new cert: got %s, want %s", status.Status, core.OCSPStatusGood) + } return cert } @@ -3407,18 +3414,6 @@ func TestGetRevokedCertsWithShard(t *testing.T) { fc.Add(2 * 24 * time.Hour) eeCert3 := makeCert() - // Check that it worked. - for _, c := range []*x509.Certificate{eeCert1, eeCert2, eeCert3} { - status, err := sa.GetCertificateStatus( - ctx, &sapb.Serial{Serial: core.SerialToString(c.SerialNumber)}) - if err != nil { - t.Fatalf("GetCertificateStatus: %s", err) - } - if status.Status != string(core.OCSPStatusGood) { - t.Fatalf("GetCertificateStatus for new cert: got %s, want %s", status.Status, core.OCSPStatusGood) - } - } - // Here's a little helper func we'll use to call GetRevokedCerts and return // a sorted list of serials. getRevokedCerts := func(req *sapb.GetRevokedCertsRequest) []string { @@ -3436,11 +3431,11 @@ func TestGetRevokedCertsWithShard(t *testing.T) { if err != nil { t.Fatalf("GetRevokedCerts(%+v): %s", req, err) } - sort.Strings(serials) return serials } - // The basic request covers a time range and shard that should include both certificates. + // The basic request covers a time range that includes eeCert1's and eeCert2's NotAfter, + // but excludes eeCert3's NotAfter. // The ExpiresBefore field is set based on the 6-day lifetime of certs from test.ThrowAwayCert basicRequest := &sapb.GetRevokedCertsRequest{ IssuerNameID: issuerNameID, @@ -3450,7 +3445,7 @@ func TestGetRevokedCertsWithShard(t *testing.T) { RevokedBefore: mustTimestamp("2023-04-01 00:00"), } - // Nothing's been revoked yet. Count should be zero. + // Nothing's been revoked yet. Should get no results. serials := getRevokedCerts(basicRequest) if len(serials) > 0 { t.Errorf("GetRevokedCerts (before revocations) = %s, want []", serials) @@ -3479,19 +3474,8 @@ func TestGetRevokedCertsWithShard(t *testing.T) { // But note that the temporal shard is different from the other two. revoke(eeCert3, 97) - // Check that it worked in the most basic way. - query := "SELECT count(*) FROM revokedCertificates where shardIdx = 97;" - c, err := sa.dbMap.SelectNullInt(ctx, query) - if err != nil { - t.Fatalf("query %q: %s", query, err) - } - if !c.Valid { - t.Fatalf("query %q: no results", query) - } - if c.Int64 != 2 { - t.Fatalf("query %q: got %d results, want %d", query, c.Int64, 2) - } - + // expectSerials registers an error if the provided serials don't match the serials + // of the provded certs (after sorting). expectSerials := func(message string, serials []string, certs ...*x509.Certificate) { t.Helper() var expectedSerials []string @@ -3499,6 +3483,7 @@ func TestGetRevokedCertsWithShard(t *testing.T) { expectedSerials = append(expectedSerials, core.SerialToString(c.SerialNumber)) } sort.Strings(expectedSerials) + sort.Strings(serials) if !reflect.DeepEqual(serials, expectedSerials) { t.Errorf("%s: want %s, got %s", message, expectedSerials, serials) } diff --git a/sa/saro.go b/sa/saro.go index ad2f83fd8cb..7b53888d278 100644 --- a/sa/saro.go +++ b/sa/saro.go @@ -1084,6 +1084,9 @@ func (cd crlDeduper) Send(crl *corepb.CRLEntry) error { // notAfter date are included), but the ending timestamp is exclusive (certs // with exactly that notAfter date are *not* included). func (ssa *SQLStorageAuthorityRO) GetRevokedCerts(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error { + if core.IsAnyNilOrZero(req.IssuerNameID, req.ExpiresAfter, req.ExpiresBefore, req.RevokedBefore) { + return errors.New("incomplete request for GetRevokedCerts") + } crlDeduper := crlDeduper{ ServerStreamingServer: stream, seen: make(map[string]bool),