Skip to content

Commit

Permalink
More test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jsha committed Jan 9, 2025
1 parent 4708117 commit 824e6e2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 28 deletions.
41 changes: 13 additions & 28 deletions sa/sa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3391,6 +3391,13 @@ func TestGetRevokedCertsWithShard(t *testing.T) {
if err != nil {
t.Fatalf("adding cert: %s", err)
}
status, err := sa.GetCertificateStatus(ctx, &sapb.Serial{Serial: core.SerialToString(cert.SerialNumber)})
if err != nil {
t.Fatalf("GetCertificateStatus: %s", err)
}
if status.Status != string(core.OCSPStatusGood) {
t.Fatalf("GetCertificateStatus for new cert: got %s, want %s", status.Status, core.OCSPStatusGood)
}
return cert
}

Expand All @@ -3407,18 +3414,6 @@ func TestGetRevokedCertsWithShard(t *testing.T) {
fc.Add(2 * 24 * time.Hour)
eeCert3 := makeCert()

// Check that it worked.
for _, c := range []*x509.Certificate{eeCert1, eeCert2, eeCert3} {
status, err := sa.GetCertificateStatus(
ctx, &sapb.Serial{Serial: core.SerialToString(c.SerialNumber)})
if err != nil {
t.Fatalf("GetCertificateStatus: %s", err)
}
if status.Status != string(core.OCSPStatusGood) {
t.Fatalf("GetCertificateStatus for new cert: got %s, want %s", status.Status, core.OCSPStatusGood)
}
}

// Here's a little helper func we'll use to call GetRevokedCerts and return
// a sorted list of serials.
getRevokedCerts := func(req *sapb.GetRevokedCertsRequest) []string {
Expand All @@ -3436,11 +3431,11 @@ func TestGetRevokedCertsWithShard(t *testing.T) {
if err != nil {
t.Fatalf("GetRevokedCerts(%+v): %s", req, err)
}
sort.Strings(serials)
return serials
}

// The basic request covers a time range and shard that should include both certificates.
// The basic request covers a time range that includes eeCert1's and eeCert2's NotAfter,
// but excludes eeCert3's NotAfter.
// The ExpiresBefore field is set based on the 6-day lifetime of certs from test.ThrowAwayCert
basicRequest := &sapb.GetRevokedCertsRequest{
IssuerNameID: issuerNameID,
Expand All @@ -3450,7 +3445,7 @@ func TestGetRevokedCertsWithShard(t *testing.T) {
RevokedBefore: mustTimestamp("2023-04-01 00:00"),
}

// Nothing's been revoked yet. Count should be zero.
// Nothing's been revoked yet. Should get no results.
serials := getRevokedCerts(basicRequest)
if len(serials) > 0 {
t.Errorf("GetRevokedCerts (before revocations) = %s, want []", serials)
Expand Down Expand Up @@ -3479,26 +3474,16 @@ func TestGetRevokedCertsWithShard(t *testing.T) {
// But note that the temporal shard is different from the other two.
revoke(eeCert3, 97)

// Check that it worked in the most basic way.
query := "SELECT count(*) FROM revokedCertificates where shardIdx = 97;"
c, err := sa.dbMap.SelectNullInt(ctx, query)
if err != nil {
t.Fatalf("query %q: %s", query, err)
}
if !c.Valid {
t.Fatalf("query %q: no results", query)
}
if c.Int64 != 2 {
t.Fatalf("query %q: got %d results, want %d", query, c.Int64, 2)
}

// expectSerials registers an error if the provided serials don't match the serials
// of the provded certs (after sorting).
expectSerials := func(message string, serials []string, certs ...*x509.Certificate) {
t.Helper()
var expectedSerials []string
for _, c := range certs {
expectedSerials = append(expectedSerials, core.SerialToString(c.SerialNumber))
}
sort.Strings(expectedSerials)
sort.Strings(serials)
if !reflect.DeepEqual(serials, expectedSerials) {
t.Errorf("%s: want %s, got %s", message, expectedSerials, serials)
}
Expand Down
3 changes: 3 additions & 0 deletions sa/saro.go
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,9 @@ func (cd crlDeduper) Send(crl *corepb.CRLEntry) error {
// notAfter date are included), but the ending timestamp is exclusive (certs
// with exactly that notAfter date are *not* included).
func (ssa *SQLStorageAuthorityRO) GetRevokedCerts(req *sapb.GetRevokedCertsRequest, stream grpc.ServerStreamingServer[corepb.CRLEntry]) error {
if core.IsAnyNilOrZero(req.IssuerNameID, req.ExpiresAfter, req.ExpiresBefore, req.RevokedBefore) {
return errors.New("incomplete request for GetRevokedCerts")
}
crlDeduper := crlDeduper{
ServerStreamingServer: stream,
seen: make(map[string]bool),
Expand Down

0 comments on commit 824e6e2

Please sign in to comment.