diff --git a/cert/ca_pool.go b/cert/ca_pool.go index 8d71b85b6..e61a38daf 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -3,6 +3,8 @@ package cert import ( "errors" "fmt" + "net/netip" + "slices" "strings" "time" ) @@ -12,7 +14,7 @@ type CAPool struct { certBlocklist map[string]struct{} } -// NewCAPool creates a CAPool +// NewCAPool creates an empty CAPool func NewCAPool() *CAPool { ca := CAPool{ CAs: make(map[string]*CachedCertificate), @@ -51,7 +53,7 @@ func NewCAPoolFromPEM(caPEMs []byte) (*CAPool, error) { return pool, nil } -// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool +// AddCAFromPEM verifies a Nebula CA certificate and adds it to the pool. // Only the first pem encoded object will be consumed, any remaining bytes are returned. // Parsed certificates will be verified and must be a CA func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) { @@ -68,7 +70,7 @@ func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) { return pemBytes, nil } -// TODO: +// AddCA verifies a Nebula CA certificate and adds it to the pool. func (ncp *CAPool) AddCA(c Certificate) error { if !c.IsCA() { return fmt.Errorf("%s: %w", c.Name(), ErrNotCA) @@ -78,7 +80,7 @@ func (ncp *CAPool) AddCA(c Certificate) error { return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned) } - sum, err := c.Sha256Sum() + sum, err := c.Fingerprint() if err != nil { return fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Name()) } @@ -112,9 +114,10 @@ func (ncp *CAPool) ResetCertBlocklist() { ncp.certBlocklist = make(map[string]struct{}) } -// TODO: -func (ncp *CAPool) IsBlocklisted(sha string) bool { - if _, ok := ncp.certBlocklist[sha]; ok { +// IsBlocklisted tests the provided fingerprint against the pools blocklist. +// Returns true if the fingerprint is blocked. +func (ncp *CAPool) IsBlocklisted(fingerprint string) bool { + if _, ok := ncp.certBlocklist[fingerprint]; ok { return true } @@ -125,7 +128,7 @@ func (ncp *CAPool) IsBlocklisted(sha string) bool { // If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts // to increase performance. func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) { - sha, err := c.Sha256Sum() + sha, err := c.Fingerprint() if err != nil { return nil, fmt.Errorf("could not calculate shasum to verify: %w", err) } @@ -141,7 +144,7 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti ShaSum: sha, signerShaSum: signer.ShaSum, } - + for _, g := range c.Groups() { cc.InvertedGroups[g] = struct{}{} } @@ -149,6 +152,8 @@ func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCerti return &cc, nil } +// VerifyCachedCertificate is the same as VerifyCertificate other than it operates on a pre-verified structure and +// is a cheaper operation to perform as a result. func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error { _, err := ncp.verify(c.Certificate, now, c.ShaSum, c.signerShaSum) return err @@ -174,6 +179,9 @@ func (ncp *CAPool) verify(c Certificate, now time.Time, certSha string, signerSh // If we are checking a cached certificate then we can bail early here // Either the root is no longer trusted or everything is fine + //TODO: this is slightly different than v1.9.3 and earlier where we were matching the public key + // The reason to switch is that the public key can be reused and the constraints in the ca can change + // but there may be history here, double check if len(signerSha) > 0 { if signerSha != signer.ShaSum { return nil, ErrSignatureMismatch @@ -184,7 +192,7 @@ func (ncp *CAPool) verify(c Certificate, now time.Time, certSha string, signerSh return nil, ErrSignatureMismatch } - err = c.CheckRootConstraints(signer.Certificate) + err = CheckCAConstraints(signer.Certificate, c) if err != nil { return nil, err } @@ -219,3 +227,70 @@ func (ncp *CAPool) GetFingerprints() []string { return fp } + +// CheckCAConstraints returns an error if the sub certificate violates constraints present in the signer certificate. +func CheckCAConstraints(signer Certificate, sub Certificate) error { + return checkCAConstraints(signer, sub.NotAfter(), sub.NotBefore(), sub.Groups(), sub.Networks(), sub.UnsafeNetworks()) +} + +// checkCAConstraints is a very generic function allowing both Certificates and TBSCertificates to be tested. +func checkCAConstraints(signer Certificate, notBefore, notAfter time.Time, groups []string, networks, unsafeNetworks []netip.Prefix) error { + // Make sure this cert wasn't valid before the root + if signer.NotAfter().Before(notAfter) { + return fmt.Errorf("certificate expires after signing certificate") + } + + // Make sure this cert isn't valid after the root + if signer.NotBefore().After(notBefore) { + return fmt.Errorf("certificate is valid before the signing certificate") + } + + // If the signer has a limited set of groups make sure the cert only contains a subset + signerGroups := signer.Groups() + if len(signerGroups) > 0 { + for _, g := range groups { + if !slices.Contains(signerGroups, g) { + //TODO: since we no longer pre-compute the inverted groups then this is kind of slow + return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) + } + } + } + + // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset + signingNetworks := signer.Networks() + if len(signingNetworks) > 0 { + for _, subNetwork := range networks { + found := false + for _, signingNetwork := range signingNetworks { + if signingNetwork.Contains(subNetwork.Addr()) && signingNetwork.Bits() <= subNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", subNetwork.String()) + } + } + } + + // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset + signingUnsafeNetworks := signer.UnsafeNetworks() + if len(signingUnsafeNetworks) > 0 { + for _, subUnsafeNetwork := range unsafeNetworks { + found := false + for _, caNetwork := range signingUnsafeNetworks { + if caNetwork.Contains(subUnsafeNetwork.Addr()) && caNetwork.Bits() <= subUnsafeNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", subUnsafeNetwork.String()) + } + } + } + + return nil +} diff --git a/cert/cert.go b/cert/cert.go index 3a312b677..4e41c4349 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -13,7 +13,9 @@ const ( ) type Certificate interface { - //TODO: describe this + // Version defines the underlying certificate structure and wire protocol version + // Version1 certificates are ipv4 only and uses protobuf serialization + // Version2 certificates are ipv4 or ipv6 and uses asn.1 serialization Version() Version // Name is the human-readable name that identifies this certificate. @@ -65,20 +67,14 @@ type Certificate interface { // computed signature. A true result means this certificate has not been tampered with. CheckSignature(signingPublicKey []byte) bool - // Sha256Sum returns the hex encoded sha256 sum of the certificate. + // Fingerprint returns the hex encoded sha256 sum of the certificate. // This acts as a unique fingerprint and can be used to blocklist certificates. - Sha256Sum() (string, error) + Fingerprint() (string, error) // Expired tests if the certificate is valid for the provided time. Expired(t time.Time) bool - // CheckRootConstraints tests if the certificate meets all constraints in the - // signing certificate, returning the first violated constraint or nil if the - // certificate conforms to all constraints. - //TODO: feels better to have this on the CAPool I think - CheckRootConstraints(signer Certificate) error - - //TODO + // VerifyPrivateKey returns an error if the private key is not a pair with the certificates public key. VerifyPrivateKey(curve Curve, privateKey []byte) error // Marshal will return the byte representation of this certificate @@ -88,10 +84,9 @@ type Certificate interface { // MarshalForHandshakes prepares the bytes needed to use directly in a handshake MarshalForHandshakes() ([]byte, error) - // MarshalToPEM will return a PEM encoded representation of this certificate + // MarshalPEM will return a PEM encoded representation of this certificate // This is primarily the format stored on disk - //TODO: MarshalPEM? - MarshalToPEM() ([]byte, error) + MarshalPEM() ([]byte, error) // MarshalJSON will return the json representation of this certificate MarshalJSON() ([]byte, error) @@ -99,7 +94,7 @@ type Certificate interface { // String will return a human-readable representation of this certificate String() string - //TODO + // Copy creates a copy of the certificate Copy() Certificate } @@ -112,7 +107,7 @@ type CachedCertificate struct { signerShaSum string } -// TODO: +// UnmarshalCertificate will attempt to unmarshal a wire protocol level certificate. func UnmarshalCertificate(b []byte) (Certificate, error) { c, err := unmarshalCertificateV1(b, true) if err != nil { @@ -121,7 +116,9 @@ func UnmarshalCertificate(b []byte) (Certificate, error) { return c, nil } -// TODO: +// UnmarshalCertificateFromHandshake will attempt to unmarshal a certificate received in a handshake. +// Handshakes save space by placing the peers public key in a different part of the packet, we have to +// reassemble the actual certificate structure with that in mind. func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) { c, err := unmarshalCertificateV1(b, false) if err != nil { diff --git a/cert/cert_test.go b/cert/cert_test.go index 462d6e8c4..8e90aabab 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -212,7 +212,7 @@ func TestNebulaCertificate_Verify(t *testing.T) { caPool := NewCAPool() assert.NoError(t, caPool.AddCA(ca)) - f, err := c.Sha256Sum() + f, err := c.Fingerprint() assert.Nil(t, err) caPool.BlocklistFingerprint(f) @@ -235,7 +235,7 @@ func TestNebulaCertificate_Verify(t *testing.T) { ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) assert.Nil(t, err) - caPem, err := ca.MarshalToPEM() + caPem, err := ca.MarshalPEM() assert.Nil(t, err) caPool = NewCAPool() @@ -264,7 +264,7 @@ func TestNebulaCertificate_VerifyP256(t *testing.T) { caPool := NewCAPool() assert.NoError(t, caPool.AddCA(ca)) - f, err := c.Sha256Sum() + f, err := c.Fingerprint() assert.Nil(t, err) caPool.BlocklistFingerprint(f) @@ -287,7 +287,7 @@ func TestNebulaCertificate_VerifyP256(t *testing.T) { ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) assert.Nil(t, err) - caPem, err := ca.MarshalToPEM() + caPem, err := ca.MarshalPEM() assert.Nil(t, err) caPool = NewCAPool() @@ -312,7 +312,7 @@ func TestNebulaCertificate_Verify_IPs(t *testing.T) { ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) assert.Nil(t, err) - caPem, err := ca.MarshalToPEM() + caPem, err := ca.MarshalPEM() assert.Nil(t, err) caPool := NewCAPool() @@ -385,7 +385,7 @@ func TestNebulaCertificate_Verify_Subnets(t *testing.T) { ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) assert.Nil(t, err) - caPem, err := ca.MarshalToPEM() + caPem, err := ca.MarshalPEM() assert.Nil(t, err) caPool := NewCAPool() @@ -643,7 +643,7 @@ func newTestCaCertP256(before, after time.Time, ips, subnets []netip.Prefix, gro } func newTestCert(ca *certificateV1, key []byte, before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*certificateV1, []byte, []byte, error) { - issuer, err := ca.Sha256Sum() + issuer, err := ca.Fingerprint() if err != nil { return nil, nil, nil, err } diff --git a/cert/cert_v1.go b/cert/cert_v1.go index 3ae078e40..4740e726c 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -16,7 +16,6 @@ import ( "math/big" "net" "net/netip" - "slices" "time" "golang.org/x/crypto/curve25519" @@ -94,139 +93,16 @@ func (nc *certificateV1) UnsafeNetworks() []netip.Prefix { return nc.details.Subnets } -func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) { - pubKey := nc.details.PublicKey - nc.details.PublicKey = nil - rawCertNoKey, err := nc.Marshal() - if err != nil { - return nil, err - } - nc.details.PublicKey = pubKey - return rawCertNoKey, nil -} - -// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert -func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) { - if len(b) == 0 { - return nil, fmt.Errorf("nil byte array") - } - var rc RawNebulaCertificate - err := proto.Unmarshal(b, &rc) - if err != nil { - return nil, err - } - - if rc.Details == nil { - return nil, fmt.Errorf("encoded Details was nil") - } - - if len(rc.Details.Ips)%2 != 0 { - return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") - } - - if len(rc.Details.Subnets)%2 != 0 { - return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") - } - - nc := certificateV1{ - details: detailsV1{ - Name: rc.Details.Name, - Groups: make([]string, len(rc.Details.Groups)), - Ips: make([]netip.Prefix, len(rc.Details.Ips)/2), - Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2), - NotBefore: time.Unix(rc.Details.NotBefore, 0), - NotAfter: time.Unix(rc.Details.NotAfter, 0), - PublicKey: make([]byte, len(rc.Details.PublicKey)), - IsCA: rc.Details.IsCA, - Curve: rc.Details.Curve, - }, - signature: make([]byte, len(rc.Signature)), - } - - copy(nc.signature, rc.Signature) - copy(nc.details.Groups, rc.Details.Groups) - nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) - - if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey { - return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) - } - copy(nc.details.PublicKey, rc.Details.PublicKey) - - var ip netip.Addr - for i, rawIp := range rc.Details.Ips { - if i%2 == 0 { - ip = int2addr(rawIp) - } else { - ones, _ := net.IPMask(int2ip(rawIp)).Size() - nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones) - } - } - - for i, rawIp := range rc.Details.Subnets { - if i%2 == 0 { - ip = int2addr(rawIp) - } else { - ones, _ := net.IPMask(int2ip(rawIp)).Size() - nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones) - } - } - - return &nc, nil -} - -func signV1(t *TBSCertificate, curve Curve, key []byte) (*certificateV1, error) { - c := &certificateV1{ - details: detailsV1{ - Name: t.Name, - Ips: t.Networks, - Subnets: t.UnsafeNetworks, - Groups: t.Groups, - NotBefore: t.NotBefore, - NotAfter: t.NotAfter, - PublicKey: t.PublicKey, - IsCA: t.IsCA, - Curve: t.Curve, - Issuer: t.issuer, - }, - } - b, err := proto.Marshal(c.getRawDetails()) +func (nc *certificateV1) Fingerprint() (string, error) { + b, err := nc.Marshal() if err != nil { - return nil, err - } - - var sig []byte - - switch curve { - case Curve_CURVE25519: - signer := ed25519.PrivateKey(key) - sig = ed25519.Sign(signer, b) - case Curve_P256: - signer := &ecdsa.PrivateKey{ - PublicKey: ecdsa.PublicKey{ - Curve: elliptic.P256(), - }, - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 - D: new(big.Int).SetBytes(key), - } - // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 - signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) - - // We need to hash first for ECDSA - // - https://pkg.go.dev/crypto/ecdsa#SignASN1 - hashed := sha256.Sum256(b) - sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) - if err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("invalid curve: %s", c.details.Curve) + return "", err } - c.signature = sig - return c, nil + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]), nil } -// CheckSignature verifies the signature against the provided public key func (nc *certificateV1) CheckSignature(key []byte) bool { b, err := proto.Marshal(nc.getRawDetails()) if err != nil { @@ -245,75 +121,10 @@ func (nc *certificateV1) CheckSignature(key []byte) bool { } } -// Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false func (nc *certificateV1) Expired(t time.Time) bool { return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t) } -// CheckRootConstraints returns an error if the certificate violates constraints set on the root (groups, ips, subnets) -// TODO: we could use cachedcert here to make it better, maybe who cares cuz its only once? Or move this entirely to caPool -func (nc *certificateV1) CheckRootConstraints(signer Certificate) error { - // Make sure this cert wasn't valid before the root - if signer.NotAfter().Before(nc.details.NotAfter) { - return fmt.Errorf("certificate expires after signing certificate") - } - - // Make sure this cert isn't valid after the root - if signer.NotBefore().After(nc.details.NotBefore) { - return fmt.Errorf("certificate is valid before the signing certificate") - } - - // If the signer has a limited set of groups make sure the cert only contains a subset - groups := signer.Groups() - if len(groups) > 0 { - for _, g := range nc.details.Groups { - if !slices.Contains(groups, g) { - //TODO: since we no longer pre-compute the inverted groups then this is kind of slow - return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) - } - } - } - - // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset - networks := signer.Networks() - if len(networks) > 0 { - for _, cNetwork := range nc.details.Ips { - found := false - for _, caNetwork := range networks { - if caNetwork.Contains(cNetwork.Addr()) && caNetwork.Bits() <= cNetwork.Bits() { - found = true - break - } - } - - if !found { - return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", cNetwork.String()) - } - } - } - - // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset - unsafeNetworks := signer.UnsafeNetworks() - if len(unsafeNetworks) > 0 { - for _, cNetwork := range nc.details.Subnets { - found := false - for _, caNetwork := range unsafeNetworks { - if caNetwork.Contains(cNetwork.Addr()) && caNetwork.Bits() <= cNetwork.Bits() { - found = true - break - } - } - - if !found { - return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", cNetwork.String()) - } - } - } - - return nil -} - -// VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { if curve != nc.details.Curve { return fmt.Errorf("curve in cert and private key supplied don't match") @@ -368,7 +179,36 @@ func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { return nil } -// String will return a pretty printed representation of a nebula cert +// getRawDetails marshals the raw details into protobuf ready struct +func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails { + rd := &RawNebulaCertificateDetails{ + Name: nc.details.Name, + Groups: nc.details.Groups, + NotBefore: nc.details.NotBefore.Unix(), + NotAfter: nc.details.NotAfter.Unix(), + PublicKey: make([]byte, len(nc.details.PublicKey)), + IsCA: nc.details.IsCA, + Curve: nc.details.Curve, + } + + for _, ipNet := range nc.details.Ips { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) + } + + for _, ipNet := range nc.details.Subnets { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) + } + + copy(rd.PublicKey, nc.details.PublicKey[:]) + + // I know, this is terrible + rd.Issuer, _ = hex.DecodeString(nc.details.Issuer) + + return rd +} + func (nc *certificateV1) String() string { if nc == nil { return "Certificate {}\n" @@ -415,7 +255,7 @@ func (nc *certificateV1) String() string { s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey) s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve) s += "\t}\n" - fp, err := nc.Sha256Sum() + fp, err := nc.Fingerprint() if err == nil { s += fmt.Sprintf("\tFingerprint: %s\n", fp) } @@ -425,37 +265,17 @@ func (nc *certificateV1) String() string { return s } -// getRawDetails marshals the raw details into protobuf ready struct -func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails { - rd := &RawNebulaCertificateDetails{ - Name: nc.details.Name, - Groups: nc.details.Groups, - NotBefore: nc.details.NotBefore.Unix(), - NotAfter: nc.details.NotAfter.Unix(), - PublicKey: make([]byte, len(nc.details.PublicKey)), - IsCA: nc.details.IsCA, - Curve: nc.details.Curve, - } - - for _, ipNet := range nc.details.Ips { - mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) - rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) - } - - for _, ipNet := range nc.details.Subnets { - mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) - rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) +func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) { + pubKey := nc.details.PublicKey + nc.details.PublicKey = nil + rawCertNoKey, err := nc.Marshal() + if err != nil { + return nil, err } - - copy(rd.PublicKey, nc.details.PublicKey[:]) - - // I know, this is terrible - rd.Issuer, _ = hex.DecodeString(nc.details.Issuer) - - return rd + nc.details.PublicKey = pubKey + return rawCertNoKey, nil } -// Marshal will marshal a nebula cert into a protobuf byte array func (nc *certificateV1) Marshal() ([]byte, error) { rc := RawNebulaCertificate{ Details: nc.getRawDetails(), @@ -465,8 +285,7 @@ func (nc *certificateV1) Marshal() ([]byte, error) { return proto.Marshal(&rc) } -// MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result -func (nc *certificateV1) MarshalToPEM() ([]byte, error) { +func (nc *certificateV1) MarshalPEM() ([]byte, error) { b, err := nc.Marshal() if err != nil { return nil, err @@ -474,19 +293,8 @@ func (nc *certificateV1) MarshalToPEM() ([]byte, error) { return pem.EncodeToMemory(&pem.Block{Type: CertificateBanner, Bytes: b}), nil } -// Sha256Sum calculates a sha-256 sum of the marshaled certificate -func (nc *certificateV1) Sha256Sum() (string, error) { - b, err := nc.Marshal() - if err != nil { - return "", err - } - - sum := sha256.Sum256(b) - return hex.EncodeToString(sum[:]), nil -} - func (nc *certificateV1) MarshalJSON() ([]byte, error) { - fp, _ := nc.Sha256Sum() + fp, _ := nc.Fingerprint() jc := m{ "details": m{ "name": nc.details.Name, @@ -506,64 +314,157 @@ func (nc *certificateV1) MarshalJSON() ([]byte, error) { return json.Marshal(jc) } -// TODO: func (nc *certificateV1) Copy() Certificate { - // r, err := nc.Marshal() - // if err != nil { - // //TODO - // return nil - // } - // - // c, err := UnmarshalNebulaCertificate(r) - // return c - return nc + c := &certificateV1{ + details: detailsV1{ + Name: nc.details.Name, + Groups: make([]string, len(nc.details.Groups)), + Ips: make([]netip.Prefix, len(nc.details.Ips)), + Subnets: make([]netip.Prefix, len(nc.details.Subnets)), + NotBefore: nc.details.NotBefore, + NotAfter: nc.details.NotAfter, + PublicKey: make([]byte, len(nc.details.PublicKey)), + IsCA: nc.details.IsCA, + Issuer: nc.details.Issuer, + }, + signature: make([]byte, len(nc.signature)), + } + + copy(c.signature, nc.signature) + copy(c.details.Groups, nc.details.Groups) + copy(c.details.PublicKey, nc.details.PublicKey) + + for i, p := range nc.details.Ips { + c.details.Ips[i] = p + } + + for i, p := range nc.details.Subnets { + c.details.Subnets[i] = p + } + + return c +} + +// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert +func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) { + if len(b) == 0 { + return nil, fmt.Errorf("nil byte array") + } + var rc RawNebulaCertificate + err := proto.Unmarshal(b, &rc) + if err != nil { + return nil, err + } + + if rc.Details == nil { + return nil, fmt.Errorf("encoded Details was nil") + } + + if len(rc.Details.Ips)%2 != 0 { + return nil, fmt.Errorf("encoded IPs should be in pairs, an odd number was found") + } + + if len(rc.Details.Subnets)%2 != 0 { + return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") + } + + nc := certificateV1{ + details: detailsV1{ + Name: rc.Details.Name, + Groups: make([]string, len(rc.Details.Groups)), + Ips: make([]netip.Prefix, len(rc.Details.Ips)/2), + Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2), + NotBefore: time.Unix(rc.Details.NotBefore, 0), + NotAfter: time.Unix(rc.Details.NotAfter, 0), + PublicKey: make([]byte, len(rc.Details.PublicKey)), + IsCA: rc.Details.IsCA, + Curve: rc.Details.Curve, + }, + signature: make([]byte, len(rc.Signature)), + } + + copy(nc.signature, rc.Signature) + copy(nc.details.Groups, rc.Details.Groups) + nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) + + if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey { + return nil, fmt.Errorf("public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) + } + copy(nc.details.PublicKey, rc.Details.PublicKey) + + var ip netip.Addr + for i, rawIp := range rc.Details.Ips { + if i%2 == 0 { + ip = int2addr(rawIp) + } else { + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones) + } + } + + for i, rawIp := range rc.Details.Subnets { + if i%2 == 0 { + ip = int2addr(rawIp) + } else { + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones) + } + } + + return &nc, nil } -//func (nc *certificateV1) Copy() *certificateV1 { -// c := &certificateV1{ -// Details: detailsV1{ -// Name: nc.Details.Name, -// Groups: make([]string, len(nc.Details.Groups)), -// Ips: make([]*net.IPNet, len(nc.Details.Ips)), -// Subnets: make([]*net.IPNet, len(nc.Details.Subnets)), -// NotBefore: nc.Details.NotBefore, -// NotAfter: nc.Details.NotAfter, -// PublicKey: make([]byte, len(nc.Details.PublicKey)), -// IsCA: nc.Details.IsCA, -// Issuer: nc.Details.Issuer, -// InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)), -// }, -// Signature: make([]byte, len(nc.Signature)), -// } -// -// copy(c.Signature, nc.Signature) -// copy(c.Details.Groups, nc.Details.Groups) -// copy(c.Details.PublicKey, nc.Details.PublicKey) -// -// for i, p := range nc.Details.Ips { -// c.Details.Ips[i] = &net.IPNet{ -// IP: make(net.IP, len(p.IP)), -// Mask: make(net.IPMask, len(p.Mask)), -// } -// copy(c.Details.Ips[i].IP, p.IP) -// copy(c.Details.Ips[i].Mask, p.Mask) -// } -// -// for i, p := range nc.Details.Subnets { -// c.Details.Subnets[i] = &net.IPNet{ -// IP: make(net.IP, len(p.IP)), -// Mask: make(net.IPMask, len(p.Mask)), -// } -// copy(c.Details.Subnets[i].IP, p.IP) -// copy(c.Details.Subnets[i].Mask, p.Mask) -// } -// -// for g := range nc.Details.InvertedGroups { -// c.Details.InvertedGroups[g] = struct{}{} -// } -// -// return c -//} +func signV1(t *TBSCertificate, curve Curve, key []byte) (*certificateV1, error) { + c := &certificateV1{ + details: detailsV1{ + Name: t.Name, + Ips: t.Networks, + Subnets: t.UnsafeNetworks, + Groups: t.Groups, + NotBefore: t.NotBefore, + NotAfter: t.NotAfter, + PublicKey: t.PublicKey, + IsCA: t.IsCA, + Curve: t.Curve, + Issuer: t.issuer, + }, + } + b, err := proto.Marshal(c.getRawDetails()) + if err != nil { + return nil, err + } + + var sig []byte + + switch curve { + case Curve_CURVE25519: + signer := ed25519.PrivateKey(key) + sig = ed25519.Sign(signer, b) + case Curve_P256: + signer := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: elliptic.P256(), + }, + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L95 + D: new(big.Int).SetBytes(key), + } + // ref: https://github.com/golang/go/blob/go1.19/src/crypto/x509/sec1.go#L119 + signer.X, signer.Y = signer.Curve.ScalarBaseMult(key) + + // We need to hash first for ECDSA + // - https://pkg.go.dev/crypto/ecdsa#SignASN1 + hashed := sha256.Sum256(b) + sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("invalid curve: %s", c.details.Curve) + } + + c.signature = sig + return c, nil +} func ip2int(ip []byte) uint32 { if len(ip) == 16 { diff --git a/cert/sign.go b/cert/sign.go index 430c944bf..ff32fc3e3 100644 --- a/cert/sign.go +++ b/cert/sign.go @@ -6,6 +6,8 @@ import ( "time" ) +// TBSCertificate represents a certificate intended to be signed. +// It is invalid to use this structure as a Certificate. type TBSCertificate struct { Version Version Name string @@ -20,18 +22,25 @@ type TBSCertificate struct { issuer string } -// TODO: +// Sign will create a sealed certificate using details provided by the TBSCertificate as long as those +// details do not violate constraints of the signing certificate. +// If the TBSCertificate is a CA then signer must be nil. func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { if curve != t.Curve { return nil, fmt.Errorf("curve in cert and private key supplied don't match") } - //TODO: signer should assert its constraints on the TBSCertificate, once you do nebula-cert sign needs to not double do it if signer != nil { if t.IsCA { return nil, fmt.Errorf("can not sign a CA certificate with another") } - issuer, err := signer.Sha256Sum() + + err := checkCAConstraints(signer, t.NotAfter, t.NotBefore, t.Groups, t.Networks, t.UnsafeNetworks) + if err != nil { + return nil, err + } + + issuer, err := signer.Fingerprint() if err != nil { return nil, fmt.Errorf("error computing issuer: %v", err) } diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 3f3785ab6..c0a856b63 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -230,7 +230,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("error while writing out-key: %s", err) } - b, err = c.MarshalToPEM() + b, err = c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index 77edc01a8..a62c22338 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -66,7 +66,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { } if *pf.outQRPath != "" { - b, err := c.MarshalToPEM() + b, err := c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling cert to PEM: %s", err) } diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index 2c2047a3c..e50f0d6e2 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -234,7 +234,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - b, err := c.MarshalToPEM() + b, err := c.MarshalPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } diff --git a/connection_manager_test.go b/connection_manager_test.go index d377722b4..c1398c4f9 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -388,11 +388,11 @@ func (d *dummyCert) Marshal() ([]byte, error) { return nil, nil } -func (d *dummyCert) MarshalToPEM() ([]byte, error) { +func (d *dummyCert) MarshalPEM() ([]byte, error) { return nil, nil } -func (d *dummyCert) Sha256Sum() (string, error) { +func (d *dummyCert) Fingerprint() (string, error) { return "", nil } diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 7c1ac4c4d..8ba2723e2 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -541,7 +541,7 @@ func TestRehandshakingRelays(t *testing.T) { r.Log("Renew relay certificate and spin until me and them sees it") _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -645,7 +645,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { r.Log("Renew relay certificate and spin until me and them sees it") _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -740,7 +740,7 @@ func TestRehandshaking(t *testing.T) { r.Log("Renew my certificate and spin until their sees it") _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } @@ -841,7 +841,7 @@ func TestRehandshakingLoser(t *testing.T) { r.Log("Renew their certificate and spin until mine sees it") _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"}) - caB, err := ca.MarshalToPEM() + caB, err := ca.MarshalPEM() if err != nil { panic(err) } diff --git a/e2e/helpers.go b/e2e/helpers.go index 3c5d5ad17..c0893aca2 100644 --- a/e2e/helpers.go +++ b/e2e/helpers.go @@ -38,7 +38,7 @@ func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Pre panic(err) } - pem, err := c.MarshalToPEM() + pem, err := c.MarshalPEM() if err != nil { panic(err) } @@ -75,7 +75,7 @@ func NewTestCert(ca cert.Certificate, key []byte, name string, before, after tim panic(err) } - pem, err := c.MarshalToPEM() + pem, err := c.MarshalPEM() if err != nil { panic(err) } diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 885ba6a7c..77996f3da 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -46,7 +46,7 @@ func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNe } _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{}) - caB, err := caCrt.MarshalToPEM() + caB, err := caCrt.MarshalPEM() if err != nil { panic(err) } diff --git a/ssh.go b/ssh.go index 93f162efc..881ee4696 100644 --- a/ssh.go +++ b/ssh.go @@ -825,7 +825,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit } if args.Raw { - b, err := cert.MarshalToPEM() + b, err := cert.MarshalPEM() if err != nil { //TODO: handle it return nil