diff --git a/cert/ca_pool.go b/cert/ca_pool.go index d1231fd04..8d71b85b6 100644 --- a/cert/ca_pool.go +++ b/cert/ca_pool.go @@ -8,14 +8,14 @@ import ( ) type CAPool struct { - CAs map[string]*NebulaCertificate + CAs map[string]*CachedCertificate certBlocklist map[string]struct{} } // NewCAPool creates a CAPool func NewCAPool() *CAPool { ca := CAPool{ - CAs: make(map[string]*NebulaCertificate), + CAs: make(map[string]*CachedCertificate), certBlocklist: make(map[string]struct{}), } @@ -60,25 +60,46 @@ func (ncp *CAPool) AddCAFromPEM(pemBytes []byte) ([]byte, error) { return pemBytes, err } - if !c.Details.IsCA { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotCA) + err = ncp.AddCA(c) + if err != nil { + return pemBytes, err + } + + return pemBytes, nil +} + +// TODO: +func (ncp *CAPool) AddCA(c Certificate) error { + if !c.IsCA() { + return fmt.Errorf("%s: %w", c.Name(), ErrNotCA) } - if !c.CheckSignature(c.Details.PublicKey) { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrNotSelfSigned) + if !c.CheckSignature(c.PublicKey()) { + return fmt.Errorf("%s: %w", c.Name(), ErrNotSelfSigned) } sum, err := c.Sha256Sum() if err != nil { - return pemBytes, fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Details.Name) + return fmt.Errorf("could not calculate shasum for provided CA; error: %s; %s", err, c.Name()) + } + + cc := &CachedCertificate{ + Certificate: c, + ShaSum: sum, + InvertedGroups: make(map[string]struct{}), } - ncp.CAs[sum] = c + for _, g := range c.Groups() { + cc.InvertedGroups[g] = struct{}{} + } + + ncp.CAs[sum] = cc + if c.Expired(time.Now()) { - return pemBytes, fmt.Errorf("%s: %w", c.Details.Name, ErrExpired) + return fmt.Errorf("%s: %w", c.Name(), ErrExpired) } - return pemBytes, nil + return nil } // BlocklistFingerprint adds a cert fingerprint to the blocklist @@ -91,34 +112,94 @@ func (ncp *CAPool) ResetCertBlocklist() { ncp.certBlocklist = make(map[string]struct{}) } -// NOTE: This uses an internal cache for Sha256Sum() that will not be invalidated -// automatically if you manually change any fields in the NebulaCertificate. -func (ncp *CAPool) IsBlocklisted(c *NebulaCertificate) bool { - return ncp.isBlocklistedWithCache(c, false) +// TODO: +func (ncp *CAPool) IsBlocklisted(sha string) bool { + if _, ok := ncp.certBlocklist[sha]; ok { + return true + } + + return false } -// IsBlocklisted returns true if the fingerprint fails to generate or has been explicitly blocklisted -func (ncp *CAPool) isBlocklistedWithCache(c *NebulaCertificate, useCache bool) bool { - h, err := c.sha256SumWithCache(useCache) +// VerifyCertificate verifies the certificate is valid and is signed by a trusted CA in the pool. +// If the certificate is valid then the returned CachedCertificate can be used in subsequent verification attempts +// to increase performance. +func (ncp *CAPool) VerifyCertificate(now time.Time, c Certificate) (*CachedCertificate, error) { + sha, err := c.Sha256Sum() if err != nil { - return true + return nil, fmt.Errorf("could not calculate shasum to verify: %w", err) } - if _, ok := ncp.certBlocklist[h]; ok { - return true + signer, err := ncp.verify(c, now, sha, "") + if err != nil { + return nil, err } - return false + cc := CachedCertificate{ + Certificate: c, + InvertedGroups: make(map[string]struct{}), + ShaSum: sha, + signerShaSum: signer.ShaSum, + } + + for _, g := range c.Groups() { + cc.InvertedGroups[g] = struct{}{} + } + + return &cc, nil +} + +func (ncp *CAPool) VerifyCachedCertificate(now time.Time, c *CachedCertificate) error { + _, err := ncp.verify(c.Certificate, now, c.ShaSum, c.signerShaSum) + return err +} + +func (ncp *CAPool) verify(c Certificate, now time.Time, certSha string, signerSha string) (*CachedCertificate, error) { + if ncp.IsBlocklisted(certSha) { + return nil, ErrBlockListed + } + + signer, err := ncp.GetCAForCert(c) + if err != nil { + return nil, err + } + + if signer.Certificate.Expired(now) { + return nil, ErrRootExpired + } + + if c.Expired(now) { + return nil, ErrExpired + } + + // If we are checking a cached certificate then we can bail early here + // Either the root is no longer trusted or everything is fine + if len(signerSha) > 0 { + if signerSha != signer.ShaSum { + return nil, ErrSignatureMismatch + } + return signer, nil + } + if !c.CheckSignature(signer.Certificate.PublicKey()) { + return nil, ErrSignatureMismatch + } + + err = c.CheckRootConstraints(signer.Certificate) + if err != nil { + return nil, err + } + + return signer, nil } // GetCAForCert attempts to return the signing certificate for the provided certificate. // No signature validation is performed -func (ncp *CAPool) GetCAForCert(c *NebulaCertificate) (*NebulaCertificate, error) { - if c.Details.Issuer == "" { +func (ncp *CAPool) GetCAForCert(c Certificate) (*CachedCertificate, error) { + if c.Issuer() == "" { return nil, fmt.Errorf("no issuer in certificate") } - signer, ok := ncp.CAs[c.Details.Issuer] + signer, ok := ncp.CAs[c.Issuer()] if ok { return signer, nil } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index c22b4b4aa..053640d98 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -61,49 +61,49 @@ IBNWYMep3ysx9zCgknfG5dKtwGTaqF++BWKDYdyl34KX -----END NEBULA CERTIFICATE----- ` - rootCA := NebulaCertificate{ - Details: NebulaCertificateDetails{ + rootCA := certificateV1{ + details: detailsV1{ Name: "nebula root ca", }, } - rootCA01 := NebulaCertificate{ - Details: NebulaCertificateDetails{ + rootCA01 := certificateV1{ + details: detailsV1{ Name: "nebula root ca 01", }, } - rootCAP256 := NebulaCertificate{ - Details: NebulaCertificateDetails{ + rootCAP256 := certificateV1{ + details: detailsV1{ Name: "nebula P256 test", }, } p, err := NewCAPoolFromPEM([]byte(noNewLines)) assert.Nil(t, err) - assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) + assert.Equal(t, p.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) + assert.Equal(t, p.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) assert.Nil(t, err) - assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) + assert.Equal(t, pp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) + assert.Equal(t, pp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) // expired cert, no valid certs ppp, err := NewCAPoolFromPEM([]byte(expired)) assert.Equal(t, ErrExpired, err) - assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") + assert.Equal(t, ppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") // expired cert, with valid certs pppp, err := NewCAPoolFromPEM(append([]byte(expired), noNewLines...)) assert.Equal(t, ErrExpired, err) - assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Details.Name, rootCA.Details.Name) - assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Details.Name, rootCA01.Details.Name) - assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Details.Name, "expired") + assert.Equal(t, pppp.CAs[string("c9bfaf7ce8e84b2eeda2e27b469f4b9617bde192efd214b68891ecda6ed49522")].Certificate.Name(), rootCA.details.Name) + assert.Equal(t, pppp.CAs[string("5c9c3f23e7ee7fe97637cbd3a0a5b854154d1d9aaaf7b566a51f4a88f76b64cd")].Certificate.Name(), rootCA01.details.Name) + assert.Equal(t, pppp.CAs[string("152070be6bb19bc9e3bde4c2f0e7d8f4ff5448b4c9856b8eccb314fade0229b0")].Certificate.Name(), "expired") assert.Equal(t, len(pppp.CAs), 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) assert.Nil(t, err) - assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Details.Name, rootCAP256.Details.Name) + assert.Equal(t, ppppp.CAs[string("a7938893ec8c4ef769b06d7f425e5e46f7a7f5ffa49c3bcf4a86b608caba9159")].Certificate.Name(), rootCAP256.details.Name) assert.Equal(t, len(ppppp.CAs), 1) } diff --git a/cert/cert.go b/cert/cert.go new file mode 100644 index 000000000..3a312b677 --- /dev/null +++ b/cert/cert.go @@ -0,0 +1,132 @@ +package cert + +import ( + "net/netip" + "time" +) + +type Version int + +const ( + Version1 Version = 1 + Version2 Version = 2 +) + +type Certificate interface { + //TODO: describe this + Version() Version + + // Name is the human-readable name that identifies this certificate. + Name() string + + // Networks is a list of ip addresses and network sizes assigned to this certificate. + // If IsCA is true then certificates signed by this CA can only have ip addresses and + // networks that are contained by an entry in this list. + Networks() []netip.Prefix + + // UnsafeNetworks is a list of networks that this host can act as an unsafe router for. + // If IsCA is true then certificates signed by this CA can only have networks that are + // contained by an entry in this list. + UnsafeNetworks() []netip.Prefix + + // Groups is a list of identities that can be used to write more general firewall rule + // definitions. + // If IsCA is true then certificates signed by this CA can only use groups that are + // in this list. + Groups() []string + + // IsCA signifies if this is a certificate authority (true) or a host certificate (false). + // It is invalid to use a CA certificate as a host certificate. + IsCA() bool + + // NotBefore is the time at which this certificate becomes valid. + // If IsCA is true then certificate signed by this CA can not have a time before this. + NotBefore() time.Time + + // NotAfter is the time at which this certificate becomes invalid. + // If IsCA is true then certificate signed by this CA can not have a time after this. + NotAfter() time.Time + + // Issuer is the fingerprint of the CA that signed this certificate. + // If IsCA is true then this will be empty. + Issuer() string //TODO: string or bytes? + + // PublicKey is the raw bytes to be used in asymmetric cryptographic operations. + PublicKey() []byte + + // Curve identifies which curve was used for the PublicKey and Signature. + Curve() Curve + + // Signature is the cryptographic seal for all the details of this certificate. + // CheckSignature can be used to verify that the details of this certificate are valid. + Signature() []byte //TODO: string or bytes? + + // CheckSignature will check that the certificate Signature() matches the + // computed signature. A true result means this certificate has not been tampered with. + CheckSignature(signingPublicKey []byte) bool + + // Sha256Sum returns the hex encoded sha256 sum of the certificate. + // This acts as a unique fingerprint and can be used to blocklist certificates. + Sha256Sum() (string, error) + + // Expired tests if the certificate is valid for the provided time. + Expired(t time.Time) bool + + // CheckRootConstraints tests if the certificate meets all constraints in the + // signing certificate, returning the first violated constraint or nil if the + // certificate conforms to all constraints. + //TODO: feels better to have this on the CAPool I think + CheckRootConstraints(signer Certificate) error + + //TODO + VerifyPrivateKey(curve Curve, privateKey []byte) error + + // Marshal will return the byte representation of this certificate + // This is primarily the format transmitted on the wire. + Marshal() ([]byte, error) + + // MarshalForHandshakes prepares the bytes needed to use directly in a handshake + MarshalForHandshakes() ([]byte, error) + + // MarshalToPEM will return a PEM encoded representation of this certificate + // This is primarily the format stored on disk + //TODO: MarshalPEM? + MarshalToPEM() ([]byte, error) + + // MarshalJSON will return the json representation of this certificate + MarshalJSON() ([]byte, error) + + // String will return a human-readable representation of this certificate + String() string + + //TODO + Copy() Certificate +} + +// CachedCertificate represents a verified certificate with some cached fields to improve +// performance. +type CachedCertificate struct { + Certificate Certificate + InvertedGroups map[string]struct{} + ShaSum string + signerShaSum string +} + +// TODO: +func UnmarshalCertificate(b []byte) (Certificate, error) { + c, err := unmarshalCertificateV1(b, true) + if err != nil { + return nil, err + } + return c, nil +} + +// TODO: +func UnmarshalCertificateFromHandshake(b []byte, publicKey []byte) (Certificate, error) { + c, err := unmarshalCertificateV1(b, false) + if err != nil { + return nil, err + } + c.details.PublicKey = publicKey + return c, nil +} diff --git a/cert/cert_test.go b/cert/cert_test.go index 26b5543c8..462d6e8c4 100644 --- a/cert/cert_test.go +++ b/cert/cert_test.go @@ -7,7 +7,7 @@ import ( "crypto/rand" "fmt" "io" - "net" + "net/netip" "testing" "time" @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/crypto/curve25519" "golang.org/x/crypto/ed25519" - "google.golang.org/protobuf/proto" ) func TestMarshalingNebulaCertificate(t *testing.T) { @@ -23,18 +22,20 @@ func TestMarshalingNebulaCertificate(t *testing.T) { after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ + nc := certificateV1{ + details: detailsV1{ Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + Ips: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + //TODO: netip cant represent this netmask + //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + Subnets: []netip.Prefix{ + //TODO: netip cant represent this netmask + //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), }, Groups: []string{"test-group1", "test-group2", "test-group3"}, NotBefore: before, @@ -43,35 +44,27 @@ func TestMarshalingNebulaCertificate(t *testing.T) { IsCA: false, Issuer: "1234567890abcedfghij1234567890ab", }, - Signature: []byte("1234567890abcedfghij1234567890ab"), + signature: []byte("1234567890abcedfghij1234567890ab"), } b, err := nc.Marshal() assert.Nil(t, err) //t.Log("Cert size:", len(b)) - nc2, err := UnmarshalNebulaCertificate(b) + nc2, err := unmarshalCertificateV1(b, true) assert.Nil(t, err) - assert.Equal(t, nc.Signature, nc2.Signature) - assert.Equal(t, nc.Details.Name, nc2.Details.Name) - assert.Equal(t, nc.Details.NotBefore, nc2.Details.NotBefore) - assert.Equal(t, nc.Details.NotAfter, nc2.Details.NotAfter) - assert.Equal(t, nc.Details.PublicKey, nc2.Details.PublicKey) - assert.Equal(t, nc.Details.IsCA, nc2.Details.IsCA) + assert.Equal(t, nc.signature, nc2.Signature()) + assert.Equal(t, nc.details.Name, nc2.Name()) + assert.Equal(t, nc.details.NotBefore, nc2.NotBefore()) + assert.Equal(t, nc.details.NotAfter, nc2.NotAfter()) + assert.Equal(t, nc.details.PublicKey, nc2.PublicKey()) + assert.Equal(t, nc.details.IsCA, nc2.IsCA()) - // IP byte arrays can be 4 or 16 in length so we have to go this route - assert.Equal(t, len(nc.Details.Ips), len(nc2.Details.Ips)) - for i, wIp := range nc.Details.Ips { - assert.Equal(t, wIp.String(), nc2.Details.Ips[i].String()) - } - - assert.Equal(t, len(nc.Details.Subnets), len(nc2.Details.Subnets)) - for i, wIp := range nc.Details.Subnets { - assert.Equal(t, wIp.String(), nc2.Details.Subnets[i].String()) - } + assert.Equal(t, nc.details.Ips, nc2.Networks()) + assert.Equal(t, nc.details.Subnets, nc2.UnsafeNetworks()) - assert.EqualValues(t, nc.Details.Groups, nc2.Details.Groups) + assert.Equal(t, nc.details.Groups, nc2.Groups()) } func TestNebulaCertificate_Sign(t *testing.T) { @@ -79,18 +72,20 @@ func TestNebulaCertificate_Sign(t *testing.T) { after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("1234567890abcedfghij1234567890ab") - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ + nc := certificateV1{ + details: detailsV1{ Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + Ips: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + //TODO: netip cant do it + //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + Subnets: []netip.Prefix{ + //TODO: netip cant do it + //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/24"), }, Groups: []string{"test-group1", "test-group2", "test-group3"}, NotBefore: before, @@ -117,18 +112,20 @@ func TestNebulaCertificate_SignP256(t *testing.T) { after := time.Now().Add(time.Second * 60).Round(time.Second) pubKey := []byte("01234567890abcedfghij1234567890ab1234567890abcedfghij1234567890ab") - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ + nc := certificateV1{ + details: detailsV1{ Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + Ips: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + //TODO: netip no can do + //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + Subnets: []netip.Prefix{ + //TODO: netip bad + //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), }, Groups: []string{"test-group1", "test-group2", "test-group3"}, NotBefore: before, @@ -155,8 +152,8 @@ func TestNebulaCertificate_SignP256(t *testing.T) { } func TestNebulaCertificate_Expired(t *testing.T) { - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ + nc := certificateV1{ + details: detailsV1{ NotBefore: time.Now().Add(time.Second * -60).Round(time.Second), NotAfter: time.Now().Add(time.Second * 60).Round(time.Second), }, @@ -171,18 +168,20 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) { time.Local = time.UTC pubKey := []byte("1234567890abcedfghij1234567890ab") - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ + nc := certificateV1{ + details: detailsV1{ Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + Ips: []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + //TODO: netip bad + //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, + Subnets: []netip.Prefix{ + //TODO: netip bad + //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), }, Groups: []string{"test-group1", "test-group2", "test-group3"}, NotBefore: time.Date(1, 0, 0, 1, 0, 0, 0, time.UTC), @@ -191,56 +190,49 @@ func TestNebulaCertificate_MarshalJSON(t *testing.T) { IsCA: false, Issuer: "1234567890abcedfghij1234567890ab", }, - Signature: []byte("1234567890abcedfghij1234567890ab"), + signature: []byte("1234567890abcedfghij1234567890ab"), } b, err := nc.MarshalJSON() assert.Nil(t, err) assert.Equal( t, - "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\",\"10.1.1.3/ff00ff00\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.1/ff00ff00\",\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"26cb1c30ad7872c804c166b5150fa372f437aa3856b04edb4334b4470ec728e4\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", + "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"ips\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"subnets\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\"}", string(b), ) } func TestNebulaCertificate_Verify(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) assert.Nil(t, err) - h, err := ca.Sha256Sum() + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) assert.Nil(t, err) caPool := NewCAPool() - caPool.CAs[h] = ca + assert.NoError(t, caPool.AddCA(ca)) f, err := c.Sha256Sum() assert.Nil(t, err) caPool.BlocklistFingerprint(f) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Minute*6), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now().Add(time.Minute*6), c) assert.EqualError(t, err, "certificate is expired") // Test group assertion - ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"}) + ca, _, caKey, err = newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) assert.Nil(t, err) caPem, err := ca.MarshalToPEM() @@ -251,57 +243,48 @@ func TestNebulaCertificate_Verify(t *testing.T) { assert.NoError(t, err) assert.Empty(t, b) - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestNebulaCertificate_VerifyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) - assert.Nil(t, err) - - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) assert.Nil(t, err) - h, err := ca.Sha256Sum() + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) assert.Nil(t, err) caPool := NewCAPool() - caPool.CAs[h] = ca + assert.NoError(t, caPool.AddCA(ca)) f, err := c.Sha256Sum() assert.Nil(t, err) caPool.BlocklistFingerprint(f) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Hour*1000), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) assert.EqualError(t, err, "root certificate is expired") - c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + c, _, _, err = newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) - v, err = c.Verify(time.Now().Add(time.Minute*6), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now().Add(time.Minute*6), c) assert.EqualError(t, err, "certificate is expired") // Test group assertion - ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "test2"}) + ca, _, caKey, err = newTestCaCertP256(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) assert.Nil(t, err) caPem, err := ca.MarshalToPEM() @@ -312,23 +295,21 @@ func TestNebulaCertificate_VerifyP256(t *testing.T) { assert.NoError(t, err) assert.Empty(t, b) - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1", "bad"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1", "bad"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained a group not present on the signing ca: bad") - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{"test1"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestNebulaCertificate_Verify_IPs(t *testing.T) { - _, caIp1, _ := net.ParseCIDR("10.0.0.0/16") - _, caIp2, _ := net.ParseCIDR("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"}) + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) assert.Nil(t, err) caPem, err := ca.MarshalToPEM() @@ -340,76 +321,68 @@ func TestNebulaCertificate_Verify_IPs(t *testing.T) { assert.Empty(t, b) // ip is outside the network - cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}} - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) assert.Nil(t, err) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24") // ip is outside the network reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.1.0.0/24") // ip is within the network but mask is outside - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15") // ip is within the network but mask is outside reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained an ip assignment outside the limitations of the signing ca: 10.0.1.0/15") // ip and mask are within the network - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{cIp1, cIp2}, []*net.IPNet{}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1, caIp2}, []*net.IPNet{}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp2, caIp1}, []*net.IPNet{}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{caIp1}, []*net.IPNet{}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestNebulaCertificate_Verify_Subnets(t *testing.T) { - _, caIp1, _ := net.ParseCIDR("10.0.0.0/16") - _, caIp2, _ := net.ParseCIDR("192.168.0.0/24") - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"}) + caIp1 := mustParsePrefixUnmapped("10.0.0.0/16") + caIp2 := mustParsePrefixUnmapped("192.168.0.0/24") + ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) assert.Nil(t, err) caPem, err := ca.MarshalToPEM() @@ -421,84 +394,76 @@ func TestNebulaCertificate_Verify_Subnets(t *testing.T) { assert.Empty(t, b) // ip is outside the network - cIp1 := &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - cIp2 := &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 0, 0}} - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) + cIp1 := mustParsePrefixUnmapped("10.1.0.0/24") + cIp2 := mustParsePrefixUnmapped("192.168.0.1/16") + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) assert.Nil(t, err) - v, err := c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24") // ip is outside the network reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.1.0.0"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.1.0.0/24") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.1.0.0/24") // ip is within the network but mask is outside - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("10.0.1.0/15") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/24") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15") // ip is within the network but mask is outside reversed order of above - cIp1 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 254, 0, 0}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("192.168.0.1/24") + cIp2 = mustParsePrefixUnmapped("10.0.1.0/15") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.False(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.EqualError(t, err, "certificate contained a subnet assignment outside the limitations of the signing ca: 10.0.1.0/15") // ip and mask are within the network - cIp1 = &net.IPNet{IP: net.ParseIP("10.0.1.0"), Mask: []byte{255, 255, 0, 0}} - cIp2 = &net.IPNet{IP: net.ParseIP("192.168.0.1"), Mask: []byte{255, 255, 255, 128}} - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{cIp1, cIp2}, []string{"test"}) + cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") + cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1, caIp2}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp2, caIp1}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) // Exact matches reversed with just 1 - c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{caIp1}, []string{"test"}) + c, _, _, err = newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) assert.Nil(t, err) - v, err = c.Verify(time.Now(), caPool) - assert.True(t, v) + _, err = caPool.VerifyCertificate(time.Now(), c) assert.Nil(t, err) } func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) assert.Nil(t, err) - _, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + _, _, caKey2, err := newTestCaCert(time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) assert.NotNil(t, err) - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) err = c.VerifyPrivateKey(Curve_CURVE25519, priv) assert.Nil(t, err) @@ -508,17 +473,17 @@ func TestNebulaCertificate_VerifyPrivateKey(t *testing.T) { } func TestNebulaCertificate_VerifyPrivateKeyP256(t *testing.T) { - ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey) assert.Nil(t, err) - _, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + _, _, caKey2, err := newTestCaCertP256(time.Time{}, time.Time{}, nil, nil, nil) assert.Nil(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) assert.NotNil(t, err) - c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, []*net.IPNet{}, []*net.IPNet{}, []string{}) + c, _, priv, err := newTestCert(ca, caKey, time.Time{}, time.Time{}, nil, nil, nil) err = c.VerifyPrivateKey(Curve_P256, priv) assert.Nil(t, err) @@ -537,50 +502,53 @@ func appendByteSlices(b ...[]byte) []byte { // Ensure that upgrading the protobuf library does not change how certificates // are marshalled, since this would break signature verification -func TestMarshalingNebulaCertificateConsistency(t *testing.T) { - before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) - after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC) - pubKey := []byte("1234567890abcedfghij1234567890ab") - - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("10.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - {IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - }, - Subnets: []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, - {IP: net.ParseIP("9.1.1.2"), Mask: net.IPMask(net.ParseIP("255.255.255.0"))}, - {IP: net.ParseIP("9.1.1.3"), Mask: net.IPMask(net.ParseIP("255.255.0.0"))}, - }, - Groups: []string{"test-group1", "test-group2", "test-group3"}, - NotBefore: before, - NotAfter: after, - PublicKey: pubKey, - IsCA: false, - Issuer: "1234567890abcedfghij1234567890ab", - }, - Signature: []byte("1234567890abcedfghij1234567890ab"), - } - - b, err := nc.Marshal() - assert.Nil(t, err) - //t.Log("Cert size:", len(b)) - assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) - - b, err = proto.Marshal(nc.getRawDetails()) - assert.Nil(t, err) - //t.Log("Raw cert size:", len(b)) - assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) -} +//TODO: since netip cant represent 255.0.255.0 netmask we can't verify the old certs are ok +//func TestMarshalingNebulaCertificateConsistency(t *testing.T) { +// before := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) +// after := time.Date(2017, time.January, 18, 28, 40, 0, 0, time.UTC) +// pubKey := []byte("1234567890abcedfghij1234567890ab") +// +// nc := certificateV1{ +// details: detailsV1{ +// Name: "testing", +// Ips: []netip.Prefix{ +// mustParsePrefixUnmapped("10.1.1.1/24"), +// mustParsePrefixUnmapped("10.1.1.2/16"), +// //TODO: netip bad +// //{IP: net.ParseIP("10.1.1.3"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, +// }, +// Subnets: []netip.Prefix{ +// //TODO: netip bad +// //{IP: net.ParseIP("9.1.1.1"), Mask: net.IPMask(net.ParseIP("255.0.255.0"))}, +// mustParsePrefixUnmapped("9.1.1.2/24"), +// mustParsePrefixUnmapped("9.1.1.3/16"), +// }, +// Groups: []string{"test-group1", "test-group2", "test-group3"}, +// NotBefore: before, +// NotAfter: after, +// PublicKey: pubKey, +// IsCA: false, +// Issuer: "1234567890abcedfghij1234567890ab", +// }, +// signature: []byte("1234567890abcedfghij1234567890ab"), +// } +// +// b, err := nc.Marshal() +// assert.Nil(t, err) +// //t.Log("Cert size:", len(b)) +// assert.Equal(t, "0aa2010a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) +// +// b, err = proto.Marshal(nc.getRawDetails()) +// assert.Nil(t, err) +// //t.Log("Raw cert size:", len(b)) +// assert.Equal(t, "0a0774657374696e67121b8182845080feffff0f828284508080fcff0f8382845080fe83f80f1a1b8182844880fe83f80f8282844880feffff0f838284488080fcff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328f0e0e7d70430a08681c4053a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) +//} func TestNebulaCertificate_Copy(t *testing.T) { - ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + ca, _, caKey, err := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, nil) assert.Nil(t, err) - c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{}) + c, _, _, err := newTestCert(ca, caKey, time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) assert.Nil(t, err) cc := c.Copy() @@ -590,11 +558,11 @@ func TestNebulaCertificate_Copy(t *testing.T) { func TestUnmarshalNebulaCertificate(t *testing.T) { // Test that we don't panic with an invalid certificate (#332) data := []byte("\x98\x00\x00") - _, err := UnmarshalNebulaCertificate(data) + _, err := unmarshalCertificateV1(data, true) assert.EqualError(t, err, "encoded Details was nil") } -func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { +func newTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*certificateV1, []byte, []byte, error) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) @@ -603,27 +571,26 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] after = time.Now().Add(time.Second * 60).Round(time.Second) } - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), + nc := &certificateV1{ + details: detailsV1{ + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: true, }, } if len(ips) > 0 { - nc.Details.Ips = ips + nc.details.Ips = ips } if len(subnets) > 0 { - nc.Details.Subnets = subnets + nc.details.Subnets = subnets } if len(groups) > 0 { - nc.Details.Groups = groups + nc.details.Groups = groups } err = nc.Sign(Curve_CURVE25519, priv) @@ -633,7 +600,7 @@ func newTestCaCert(before, after time.Time, ips, subnets []*net.IPNet, groups [] return nc, pub, priv, nil } -func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { +func newTestCaCertP256(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*certificateV1, []byte, []byte, error) { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) rawPriv := priv.D.FillBytes(make([]byte, 32)) @@ -645,28 +612,27 @@ func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, group after = time.Now().Add(time.Second * 60).Round(time.Second) } - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - Curve: Curve_P256, - InvertedGroups: make(map[string]struct{}), + nc := &certificateV1{ + details: detailsV1{ + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: true, + Curve: Curve_P256, }, } if len(ips) > 0 { - nc.Details.Ips = ips + nc.details.Ips = ips } if len(subnets) > 0 { - nc.Details.Subnets = subnets + nc.details.Subnets = subnets } if len(groups) > 0 { - nc.Details.Groups = groups + nc.details.Groups = groups } err = nc.Sign(Curve_P256, rawPriv) @@ -676,7 +642,7 @@ func newTestCaCertP256(before, after time.Time, ips, subnets []*net.IPNet, group return nc, pub, rawPriv, nil } -func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips, subnets []*net.IPNet, groups []string) (*NebulaCertificate, []byte, []byte, error) { +func newTestCert(ca *certificateV1, key []byte, before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*certificateV1, []byte, []byte, error) { issuer, err := ca.Sha256Sum() if err != nil { return nil, nil, nil, err @@ -694,49 +660,50 @@ func newTestCert(ca *NebulaCertificate, key []byte, before, after time.Time, ips } if len(ips) == 0 { - ips = []*net.IPNet{ - {IP: net.ParseIP("10.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, - {IP: net.ParseIP("10.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, - {IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, + ips = []netip.Prefix{ + mustParsePrefixUnmapped("10.1.1.1/24"), + mustParsePrefixUnmapped("10.1.1.2/16"), + //TODO: netip bad + //{IP: net.ParseIP("10.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, } } if len(subnets) == 0 { - subnets = []*net.IPNet{ - {IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, - {IP: net.ParseIP("9.1.1.2").To4(), Mask: net.IPMask(net.ParseIP("255.255.255.0").To4())}, - {IP: net.ParseIP("9.1.1.3").To4(), Mask: net.IPMask(net.ParseIP("255.255.0.0").To4())}, + subnets = []netip.Prefix{ + //TODO: netip bad + //{IP: net.ParseIP("9.1.1.1").To4(), Mask: net.IPMask(net.ParseIP("255.0.255.0").To4())}, + mustParsePrefixUnmapped("9.1.1.2/24"), + mustParsePrefixUnmapped("9.1.1.3/16"), } } var pub, rawPriv []byte - switch ca.Details.Curve { + switch ca.details.Curve { case Curve_CURVE25519: pub, rawPriv = x25519Keypair() case Curve_P256: pub, rawPriv = p256Keypair() default: - return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.Details.Curve) - } - - nc := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: "testing", - Ips: ips, - Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Curve: ca.Details.Curve, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), + return nil, nil, nil, fmt.Errorf("unknown curve: %v", ca.details.Curve) + } + + nc := &certificateV1{ + details: detailsV1{ + Name: "testing", + Ips: ips, + Subnets: subnets, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + Curve: ca.details.Curve, + Issuer: issuer, }, } - err = nc.Sign(ca.Details.Curve, key) + err = nc.Sign(ca.details.Curve, key) if err != nil { return nil, nil, nil, err } @@ -766,3 +733,8 @@ func p256Keypair() ([]byte, []byte) { pubkey := privkey.PublicKey() return pubkey.Bytes(), privkey.Bytes() } + +func mustParsePrefixUnmapped(s string) netip.Prefix { + prefix := netip.MustParsePrefix(s) + return netip.PrefixFrom(prefix.Addr().Unmap(), prefix.Bits()) +} diff --git a/cert/cert_v1.go b/cert/cert_v1.go index e31922923..3ae078e40 100644 --- a/cert/cert_v1.go +++ b/cert/cert_v1.go @@ -15,7 +15,8 @@ import ( "fmt" "math/big" "net" - "sync/atomic" + "net/netip" + "slices" "time" "golang.org/x/crypto/curve25519" @@ -24,23 +25,15 @@ import ( const publicKeyLen = 32 -type NebulaCertificate struct { - Details NebulaCertificateDetails - Signature []byte - - // the cached hex string of the calculated sha256sum - // for VerifyWithCache - sha256sum atomic.Pointer[string] - - // the cached public key bytes if they were verified as the signer - // for VerifyWithCache - signatureVerified atomic.Pointer[[]byte] +type certificateV1 struct { + details detailsV1 + signature []byte } -type NebulaCertificateDetails struct { +type detailsV1 struct { Name string - Ips []*net.IPNet - Subnets []*net.IPNet + Ips []netip.Prefix + Subnets []netip.Prefix Groups []string NotBefore time.Time NotAfter time.Time @@ -48,16 +41,72 @@ type NebulaCertificateDetails struct { IsCA bool Issuer string - // Map of groups for faster lookup - InvertedGroups map[string]struct{} - Curve Curve } type m map[string]interface{} -// UnmarshalNebulaCertificate will unmarshal a protobuf byte representation of a nebula cert -func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { +func (nc *certificateV1) Version() Version { + return Version1 +} + +func (nc *certificateV1) Curve() Curve { + return nc.details.Curve +} + +func (nc *certificateV1) Groups() []string { + return nc.details.Groups +} + +func (nc *certificateV1) IsCA() bool { + return nc.details.IsCA +} + +func (nc *certificateV1) Issuer() string { + return nc.details.Issuer +} + +func (nc *certificateV1) Name() string { + return nc.details.Name +} + +func (nc *certificateV1) Networks() []netip.Prefix { + return nc.details.Ips +} + +func (nc *certificateV1) NotAfter() time.Time { + return nc.details.NotAfter +} + +func (nc *certificateV1) NotBefore() time.Time { + return nc.details.NotBefore +} + +func (nc *certificateV1) PublicKey() []byte { + return nc.details.PublicKey +} + +func (nc *certificateV1) Signature() []byte { + return nc.signature +} + +func (nc *certificateV1) UnsafeNetworks() []netip.Prefix { + return nc.details.Subnets +} + +func (nc *certificateV1) MarshalForHandshakes() ([]byte, error) { + pubKey := nc.details.PublicKey + nc.details.PublicKey = nil + rawCertNoKey, err := nc.Marshal() + if err != nil { + return nil, err + } + nc.details.PublicKey = pubKey + return rawCertNoKey, nil +} + +// unmarshalCertificateV1 will unmarshal a protobuf byte representation of a nebula cert +func unmarshalCertificateV1(b []byte, assertPublicKey bool) (*certificateV1, error) { if len(b) == 0 { return nil, fmt.Errorf("nil byte array") } @@ -79,63 +128,70 @@ func UnmarshalNebulaCertificate(b []byte) (*NebulaCertificate, error) { return nil, fmt.Errorf("encoded Subnets should be in pairs, an odd number was found") } - nc := NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: rc.Details.Name, - Groups: make([]string, len(rc.Details.Groups)), - Ips: make([]*net.IPNet, len(rc.Details.Ips)/2), - Subnets: make([]*net.IPNet, len(rc.Details.Subnets)/2), - NotBefore: time.Unix(rc.Details.NotBefore, 0), - NotAfter: time.Unix(rc.Details.NotAfter, 0), - PublicKey: make([]byte, len(rc.Details.PublicKey)), - IsCA: rc.Details.IsCA, - InvertedGroups: make(map[string]struct{}), - Curve: rc.Details.Curve, + nc := certificateV1{ + details: detailsV1{ + Name: rc.Details.Name, + Groups: make([]string, len(rc.Details.Groups)), + Ips: make([]netip.Prefix, len(rc.Details.Ips)/2), + Subnets: make([]netip.Prefix, len(rc.Details.Subnets)/2), + NotBefore: time.Unix(rc.Details.NotBefore, 0), + NotAfter: time.Unix(rc.Details.NotAfter, 0), + PublicKey: make([]byte, len(rc.Details.PublicKey)), + IsCA: rc.Details.IsCA, + Curve: rc.Details.Curve, }, - Signature: make([]byte, len(rc.Signature)), + signature: make([]byte, len(rc.Signature)), } - copy(nc.Signature, rc.Signature) - copy(nc.Details.Groups, rc.Details.Groups) - nc.Details.Issuer = hex.EncodeToString(rc.Details.Issuer) + copy(nc.signature, rc.Signature) + copy(nc.details.Groups, rc.Details.Groups) + nc.details.Issuer = hex.EncodeToString(rc.Details.Issuer) - if len(rc.Details.PublicKey) < publicKeyLen { + if len(rc.Details.PublicKey) < publicKeyLen && assertPublicKey { return nil, fmt.Errorf("Public key was fewer than 32 bytes; %v", len(rc.Details.PublicKey)) } - copy(nc.Details.PublicKey, rc.Details.PublicKey) + copy(nc.details.PublicKey, rc.Details.PublicKey) + var ip netip.Addr for i, rawIp := range rc.Details.Ips { if i%2 == 0 { - nc.Details.Ips[i/2] = &net.IPNet{IP: int2ip(rawIp)} + ip = int2addr(rawIp) } else { - nc.Details.Ips[i/2].Mask = net.IPMask(int2ip(rawIp)) + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.Ips[i/2] = netip.PrefixFrom(ip, ones) } } for i, rawIp := range rc.Details.Subnets { if i%2 == 0 { - nc.Details.Subnets[i/2] = &net.IPNet{IP: int2ip(rawIp)} + ip = int2addr(rawIp) } else { - nc.Details.Subnets[i/2].Mask = net.IPMask(int2ip(rawIp)) + ones, _ := net.IPMask(int2ip(rawIp)).Size() + nc.details.Subnets[i/2] = netip.PrefixFrom(ip, ones) } } - for _, g := range rc.Details.Groups { - nc.Details.InvertedGroups[g] = struct{}{} - } - return &nc, nil } -// Sign signs a nebula cert with the provided private key -func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error { - if curve != nc.Details.Curve { - return fmt.Errorf("curve in cert and private key supplied don't match") +func signV1(t *TBSCertificate, curve Curve, key []byte) (*certificateV1, error) { + c := &certificateV1{ + details: detailsV1{ + Name: t.Name, + Ips: t.Networks, + Subnets: t.UnsafeNetworks, + Groups: t.Groups, + NotBefore: t.NotBefore, + NotAfter: t.NotAfter, + PublicKey: t.PublicKey, + IsCA: t.IsCA, + Curve: t.Curve, + Issuer: t.issuer, + }, } - - b, err := proto.Marshal(nc.getRawDetails()) + b, err := proto.Marshal(c.getRawDetails()) if err != nil { - return err + return nil, err } var sig []byte @@ -160,145 +216,96 @@ func (nc *NebulaCertificate) Sign(curve Curve, key []byte) error { hashed := sha256.Sum256(b) sig, err = ecdsa.SignASN1(rand.Reader, signer, hashed[:]) if err != nil { - return err + return nil, err } default: - return fmt.Errorf("invalid curve: %s", nc.Details.Curve) + return nil, fmt.Errorf("invalid curve: %s", c.details.Curve) } - nc.Signature = sig - return nil + c.signature = sig + return c, nil } // CheckSignature verifies the signature against the provided public key -func (nc *NebulaCertificate) CheckSignature(key []byte) bool { +func (nc *certificateV1) CheckSignature(key []byte) bool { b, err := proto.Marshal(nc.getRawDetails()) if err != nil { return false } - switch nc.Details.Curve { + switch nc.details.Curve { case Curve_CURVE25519: - return ed25519.Verify(ed25519.PublicKey(key), b, nc.Signature) + return ed25519.Verify(key, b, nc.signature) case Curve_P256: x, y := elliptic.Unmarshal(elliptic.P256(), key) pubKey := &ecdsa.PublicKey{Curve: elliptic.P256(), X: x, Y: y} hashed := sha256.Sum256(b) - return ecdsa.VerifyASN1(pubKey, hashed[:], nc.Signature) + return ecdsa.VerifyASN1(pubKey, hashed[:], nc.signature) default: return false } } -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) checkSignatureWithCache(key []byte, useCache bool) bool { - if !useCache { - return nc.CheckSignature(key) - } - - if v := nc.signatureVerified.Load(); v != nil { - return bytes.Equal(*v, key) - } - - verified := nc.CheckSignature(key) - if verified { - keyCopy := make([]byte, len(key)) - copy(keyCopy, key) - nc.signatureVerified.Store(&keyCopy) - } - - return verified -} - // Expired will return true if the nebula cert is too young or too old compared to the provided time, otherwise false -func (nc *NebulaCertificate) Expired(t time.Time) bool { - return nc.Details.NotBefore.After(t) || nc.Details.NotAfter.Before(t) -} - -// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -func (nc *NebulaCertificate) Verify(t time.Time, ncp *CAPool) (bool, error) { - return nc.verify(t, ncp, false) -} - -// VerifyWithCache will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -// -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) VerifyWithCache(t time.Time, ncp *CAPool) (bool, error) { - return nc.verify(t, ncp, true) +func (nc *certificateV1) Expired(t time.Time) bool { + return nc.details.NotBefore.After(t) || nc.details.NotAfter.Before(t) } -// ResetCache resets the cache used by VerifyWithCache. -func (nc *NebulaCertificate) ResetCache() { - nc.sha256sum.Store(nil) - nc.signatureVerified.Store(nil) -} - -// Verify will ensure a certificate is good in all respects (expiry, group membership, signature, cert blocklist, etc) -func (nc *NebulaCertificate) verify(t time.Time, ncp *CAPool, useCache bool) (bool, error) { - if ncp.isBlocklistedWithCache(nc, useCache) { - return false, ErrBlockListed - } - - signer, err := ncp.GetCAForCert(nc) - if err != nil { - return false, err - } - - if signer.Expired(t) { - return false, ErrRootExpired - } - - if nc.Expired(t) { - return false, ErrExpired - } - - if !nc.checkSignatureWithCache(signer.Details.PublicKey, useCache) { - return false, ErrSignatureMismatch - } - - if err := nc.CheckRootConstrains(signer); err != nil { - return false, err - } - - return true, nil -} - -// CheckRootConstrains returns an error if the certificate violates constraints set on the root (groups, ips, subnets) -func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) error { +// CheckRootConstraints returns an error if the certificate violates constraints set on the root (groups, ips, subnets) +// TODO: we could use cachedcert here to make it better, maybe who cares cuz its only once? Or move this entirely to caPool +func (nc *certificateV1) CheckRootConstraints(signer Certificate) error { // Make sure this cert wasn't valid before the root - if signer.Details.NotAfter.Before(nc.Details.NotAfter) { + if signer.NotAfter().Before(nc.details.NotAfter) { return fmt.Errorf("certificate expires after signing certificate") } // Make sure this cert isn't valid after the root - if signer.Details.NotBefore.After(nc.Details.NotBefore) { + if signer.NotBefore().After(nc.details.NotBefore) { return fmt.Errorf("certificate is valid before the signing certificate") } // If the signer has a limited set of groups make sure the cert only contains a subset - if len(signer.Details.InvertedGroups) > 0 { - for _, g := range nc.Details.Groups { - if _, ok := signer.Details.InvertedGroups[g]; !ok { + groups := signer.Groups() + if len(groups) > 0 { + for _, g := range nc.details.Groups { + if !slices.Contains(groups, g) { + //TODO: since we no longer pre-compute the inverted groups then this is kind of slow return fmt.Errorf("certificate contained a group not present on the signing ca: %s", g) } } } // If the signer has a limited set of ip ranges to issue from make sure the cert only contains a subset - if len(signer.Details.Ips) > 0 { - for _, ip := range nc.Details.Ips { - if !netMatch(ip, signer.Details.Ips) { - return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", ip.String()) + networks := signer.Networks() + if len(networks) > 0 { + for _, cNetwork := range nc.details.Ips { + found := false + for _, caNetwork := range networks { + if caNetwork.Contains(cNetwork.Addr()) && caNetwork.Bits() <= cNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained an ip assignment outside the limitations of the signing ca: %s", cNetwork.String()) } } } // If the signer has a limited set of subnet ranges to issue from make sure the cert only contains a subset - if len(signer.Details.Subnets) > 0 { - for _, subnet := range nc.Details.Subnets { - if !netMatch(subnet, signer.Details.Subnets) { - return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", subnet) + unsafeNetworks := signer.UnsafeNetworks() + if len(unsafeNetworks) > 0 { + for _, cNetwork := range nc.details.Subnets { + found := false + for _, caNetwork := range unsafeNetworks { + if caNetwork.Contains(cNetwork.Addr()) && caNetwork.Bits() <= cNetwork.Bits() { + found = true + break + } + } + + if !found { + return fmt.Errorf("certificate contained a subnet assignment outside the limitations of the signing ca: %s", cNetwork.String()) } } } @@ -307,11 +314,11 @@ func (nc *NebulaCertificate) CheckRootConstrains(signer *NebulaCertificate) erro } // VerifyPrivateKey checks that the public key in the Nebula certificate and a supplied private key match -func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { - if curve != nc.Details.Curve { +func (nc *certificateV1) VerifyPrivateKey(curve Curve, key []byte) error { + if curve != nc.details.Curve { return fmt.Errorf("curve in cert and private key supplied don't match") } - if nc.Details.IsCA { + if nc.details.IsCA { switch curve { case Curve_CURVE25519: // the call to PublicKey below will panic slice bounds out of range otherwise @@ -319,7 +326,7 @@ func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("key was not 64 bytes, is invalid ed25519 private key") } - if !ed25519.PublicKey(nc.Details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { + if !ed25519.PublicKey(nc.details.PublicKey).Equal(ed25519.PrivateKey(key).Public()) { return fmt.Errorf("public key in cert and private key supplied don't match") } case Curve_P256: @@ -328,7 +335,7 @@ func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { return fmt.Errorf("cannot parse private key as P256") } pub := privkey.PublicKey().Bytes() - if !bytes.Equal(pub, nc.Details.PublicKey) { + if !bytes.Equal(pub, nc.details.PublicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } default: @@ -354,7 +361,7 @@ func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { default: return fmt.Errorf("invalid curve: %s", curve) } - if !bytes.Equal(pub, nc.Details.PublicKey) { + if !bytes.Equal(pub, nc.details.PublicKey) { return fmt.Errorf("public key in cert and private key supplied don't match") } @@ -362,18 +369,18 @@ func (nc *NebulaCertificate) VerifyPrivateKey(curve Curve, key []byte) error { } // String will return a pretty printed representation of a nebula cert -func (nc *NebulaCertificate) String() string { +func (nc *certificateV1) String() string { if nc == nil { - return "NebulaCertificate {}\n" + return "Certificate {}\n" } s := "NebulaCertificate {\n" s += "\tDetails {\n" - s += fmt.Sprintf("\t\tName: %v\n", nc.Details.Name) + s += fmt.Sprintf("\t\tName: %v\n", nc.details.Name) - if len(nc.Details.Ips) > 0 { + if len(nc.details.Ips) > 0 { s += "\t\tIps: [\n" - for _, ip := range nc.Details.Ips { + for _, ip := range nc.details.Ips { s += fmt.Sprintf("\t\t\t%v\n", ip.String()) } s += "\t\t]\n" @@ -381,9 +388,9 @@ func (nc *NebulaCertificate) String() string { s += "\t\tIps: []\n" } - if len(nc.Details.Subnets) > 0 { + if len(nc.details.Subnets) > 0 { s += "\t\tSubnets: [\n" - for _, ip := range nc.Details.Subnets { + for _, ip := range nc.details.Subnets { s += fmt.Sprintf("\t\t\t%v\n", ip.String()) } s += "\t\t]\n" @@ -391,9 +398,9 @@ func (nc *NebulaCertificate) String() string { s += "\t\tSubnets: []\n" } - if len(nc.Details.Groups) > 0 { + if len(nc.details.Groups) > 0 { s += "\t\tGroups: [\n" - for _, g := range nc.Details.Groups { + for _, g := range nc.details.Groups { s += fmt.Sprintf("\t\t\t\"%v\"\n", g) } s += "\t\t]\n" @@ -401,63 +408,65 @@ func (nc *NebulaCertificate) String() string { s += "\t\tGroups: []\n" } - s += fmt.Sprintf("\t\tNot before: %v\n", nc.Details.NotBefore) - s += fmt.Sprintf("\t\tNot After: %v\n", nc.Details.NotAfter) - s += fmt.Sprintf("\t\tIs CA: %v\n", nc.Details.IsCA) - s += fmt.Sprintf("\t\tIssuer: %s\n", nc.Details.Issuer) - s += fmt.Sprintf("\t\tPublic key: %x\n", nc.Details.PublicKey) - s += fmt.Sprintf("\t\tCurve: %s\n", nc.Details.Curve) + s += fmt.Sprintf("\t\tNot before: %v\n", nc.details.NotBefore) + s += fmt.Sprintf("\t\tNot After: %v\n", nc.details.NotAfter) + s += fmt.Sprintf("\t\tIs CA: %v\n", nc.details.IsCA) + s += fmt.Sprintf("\t\tIssuer: %s\n", nc.details.Issuer) + s += fmt.Sprintf("\t\tPublic key: %x\n", nc.details.PublicKey) + s += fmt.Sprintf("\t\tCurve: %s\n", nc.details.Curve) s += "\t}\n" fp, err := nc.Sha256Sum() if err == nil { s += fmt.Sprintf("\tFingerprint: %s\n", fp) } - s += fmt.Sprintf("\tSignature: %x\n", nc.Signature) + s += fmt.Sprintf("\tSignature: %x\n", nc.Signature()) s += "}" return s } // getRawDetails marshals the raw details into protobuf ready struct -func (nc *NebulaCertificate) getRawDetails() *RawNebulaCertificateDetails { +func (nc *certificateV1) getRawDetails() *RawNebulaCertificateDetails { rd := &RawNebulaCertificateDetails{ - Name: nc.Details.Name, - Groups: nc.Details.Groups, - NotBefore: nc.Details.NotBefore.Unix(), - NotAfter: nc.Details.NotAfter.Unix(), - PublicKey: make([]byte, len(nc.Details.PublicKey)), - IsCA: nc.Details.IsCA, - Curve: nc.Details.Curve, + Name: nc.details.Name, + Groups: nc.details.Groups, + NotBefore: nc.details.NotBefore.Unix(), + NotAfter: nc.details.NotAfter.Unix(), + PublicKey: make([]byte, len(nc.details.PublicKey)), + IsCA: nc.details.IsCA, + Curve: nc.details.Curve, } - for _, ipNet := range nc.Details.Ips { - rd.Ips = append(rd.Ips, ip2int(ipNet.IP), ip2int(ipNet.Mask)) + for _, ipNet := range nc.details.Ips { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Ips = append(rd.Ips, addr2int(ipNet.Addr()), ip2int(mask)) } - for _, ipNet := range nc.Details.Subnets { - rd.Subnets = append(rd.Subnets, ip2int(ipNet.IP), ip2int(ipNet.Mask)) + for _, ipNet := range nc.details.Subnets { + mask := net.CIDRMask(ipNet.Bits(), ipNet.Addr().BitLen()) + rd.Subnets = append(rd.Subnets, addr2int(ipNet.Addr()), ip2int(mask)) } - copy(rd.PublicKey, nc.Details.PublicKey[:]) + copy(rd.PublicKey, nc.details.PublicKey[:]) // I know, this is terrible - rd.Issuer, _ = hex.DecodeString(nc.Details.Issuer) + rd.Issuer, _ = hex.DecodeString(nc.details.Issuer) return rd } // Marshal will marshal a nebula cert into a protobuf byte array -func (nc *NebulaCertificate) Marshal() ([]byte, error) { +func (nc *certificateV1) Marshal() ([]byte, error) { rc := RawNebulaCertificate{ Details: nc.getRawDetails(), - Signature: nc.Signature, + Signature: nc.signature, } return proto.Marshal(&rc) } // MarshalToPEM will marshal a nebula cert into a protobuf byte array and pem encode the result -func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) { +func (nc *certificateV1) MarshalToPEM() ([]byte, error) { b, err := nc.Marshal() if err != nil { return nil, err @@ -466,7 +475,7 @@ func (nc *NebulaCertificate) MarshalToPEM() ([]byte, error) { } // Sha256Sum calculates a sha-256 sum of the marshaled certificate -func (nc *NebulaCertificate) Sha256Sum() (string, error) { +func (nc *certificateV1) Sha256Sum() (string, error) { b, err := nc.Marshal() if err != nil { return "", err @@ -476,160 +485,86 @@ func (nc *NebulaCertificate) Sha256Sum() (string, error) { return hex.EncodeToString(sum[:]), nil } -// NOTE: This uses an internal cache that will not be invalidated automatically -// if you manually change any fields in the NebulaCertificate. -func (nc *NebulaCertificate) sha256SumWithCache(useCache bool) (string, error) { - if !useCache { - return nc.Sha256Sum() - } - - if s := nc.sha256sum.Load(); s != nil { - return *s, nil - } - s, err := nc.Sha256Sum() - if err != nil { - return s, err - } - - nc.sha256sum.Store(&s) - return s, nil -} - -func (nc *NebulaCertificate) MarshalJSON() ([]byte, error) { - toString := func(ips []*net.IPNet) []string { - s := []string{} - for _, ip := range ips { - s = append(s, ip.String()) - } - return s - } - +func (nc *certificateV1) MarshalJSON() ([]byte, error) { fp, _ := nc.Sha256Sum() jc := m{ "details": m{ - "name": nc.Details.Name, - "ips": toString(nc.Details.Ips), - "subnets": toString(nc.Details.Subnets), - "groups": nc.Details.Groups, - "notBefore": nc.Details.NotBefore, - "notAfter": nc.Details.NotAfter, - "publicKey": fmt.Sprintf("%x", nc.Details.PublicKey), - "isCa": nc.Details.IsCA, - "issuer": nc.Details.Issuer, - "curve": nc.Details.Curve.String(), + "name": nc.details.Name, + "ips": nc.details.Ips, + "subnets": nc.details.Subnets, + "groups": nc.details.Groups, + "notBefore": nc.details.NotBefore, + "notAfter": nc.details.NotAfter, + "publicKey": fmt.Sprintf("%x", nc.details.PublicKey), + "isCa": nc.details.IsCA, + "issuer": nc.details.Issuer, + "curve": nc.details.Curve.String(), }, "fingerprint": fp, - "signature": fmt.Sprintf("%x", nc.Signature), + "signature": fmt.Sprintf("%x", nc.Signature()), } return json.Marshal(jc) } -//func (nc *NebulaCertificate) Copy() *NebulaCertificate { -// r, err := nc.Marshal() -// if err != nil { -// //TODO -// return nil +// TODO: +func (nc *certificateV1) Copy() Certificate { + // r, err := nc.Marshal() + // if err != nil { + // //TODO + // return nil + // } + // + // c, err := UnmarshalNebulaCertificate(r) + // return c + return nc +} + +//func (nc *certificateV1) Copy() *certificateV1 { +// c := &certificateV1{ +// Details: detailsV1{ +// Name: nc.Details.Name, +// Groups: make([]string, len(nc.Details.Groups)), +// Ips: make([]*net.IPNet, len(nc.Details.Ips)), +// Subnets: make([]*net.IPNet, len(nc.Details.Subnets)), +// NotBefore: nc.Details.NotBefore, +// NotAfter: nc.Details.NotAfter, +// PublicKey: make([]byte, len(nc.Details.PublicKey)), +// IsCA: nc.Details.IsCA, +// Issuer: nc.Details.Issuer, +// InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)), +// }, +// Signature: make([]byte, len(nc.Signature)), +// } +// +// copy(c.Signature, nc.Signature) +// copy(c.Details.Groups, nc.Details.Groups) +// copy(c.Details.PublicKey, nc.Details.PublicKey) +// +// for i, p := range nc.Details.Ips { +// c.Details.Ips[i] = &net.IPNet{ +// IP: make(net.IP, len(p.IP)), +// Mask: make(net.IPMask, len(p.Mask)), +// } +// copy(c.Details.Ips[i].IP, p.IP) +// copy(c.Details.Ips[i].Mask, p.Mask) +// } +// +// for i, p := range nc.Details.Subnets { +// c.Details.Subnets[i] = &net.IPNet{ +// IP: make(net.IP, len(p.IP)), +// Mask: make(net.IPMask, len(p.Mask)), +// } +// copy(c.Details.Subnets[i].IP, p.IP) +// copy(c.Details.Subnets[i].Mask, p.Mask) +// } +// +// for g := range nc.Details.InvertedGroups { +// c.Details.InvertedGroups[g] = struct{}{} // } // -// c, err := UnmarshalNebulaCertificate(r) // return c //} -func (nc *NebulaCertificate) Copy() *NebulaCertificate { - c := &NebulaCertificate{ - Details: NebulaCertificateDetails{ - Name: nc.Details.Name, - Groups: make([]string, len(nc.Details.Groups)), - Ips: make([]*net.IPNet, len(nc.Details.Ips)), - Subnets: make([]*net.IPNet, len(nc.Details.Subnets)), - NotBefore: nc.Details.NotBefore, - NotAfter: nc.Details.NotAfter, - PublicKey: make([]byte, len(nc.Details.PublicKey)), - IsCA: nc.Details.IsCA, - Issuer: nc.Details.Issuer, - InvertedGroups: make(map[string]struct{}, len(nc.Details.InvertedGroups)), - }, - Signature: make([]byte, len(nc.Signature)), - } - - copy(c.Signature, nc.Signature) - copy(c.Details.Groups, nc.Details.Groups) - copy(c.Details.PublicKey, nc.Details.PublicKey) - - for i, p := range nc.Details.Ips { - c.Details.Ips[i] = &net.IPNet{ - IP: make(net.IP, len(p.IP)), - Mask: make(net.IPMask, len(p.Mask)), - } - copy(c.Details.Ips[i].IP, p.IP) - copy(c.Details.Ips[i].Mask, p.Mask) - } - - for i, p := range nc.Details.Subnets { - c.Details.Subnets[i] = &net.IPNet{ - IP: make(net.IP, len(p.IP)), - Mask: make(net.IPMask, len(p.Mask)), - } - copy(c.Details.Subnets[i].IP, p.IP) - copy(c.Details.Subnets[i].Mask, p.Mask) - } - - for g := range nc.Details.InvertedGroups { - c.Details.InvertedGroups[g] = struct{}{} - } - - return c -} - -func netMatch(certIp *net.IPNet, rootIps []*net.IPNet) bool { - for _, net := range rootIps { - if net.Contains(certIp.IP) && maskContains(net.Mask, certIp.Mask) { - return true - } - } - - return false -} - -func maskContains(caMask, certMask net.IPMask) bool { - caM := maskTo4(caMask) - cM := maskTo4(certMask) - // Make sure forcing to ipv4 didn't nuke us - if caM == nil || cM == nil { - return false - } - - // Make sure the cert mask is not greater than the ca mask - for i := 0; i < len(caMask); i++ { - if caM[i] > cM[i] { - return false - } - } - - return true -} - -func maskTo4(ip net.IPMask) net.IPMask { - if len(ip) == net.IPv4len { - return ip - } - - if len(ip) == net.IPv6len && isZeros(ip[0:10]) && ip[10] == 0xff && ip[11] == 0xff { - return ip[12:16] - } - - return nil -} - -func isZeros(b []byte) bool { - for i := 0; i < len(b); i++ { - if b[i] != 0 { - return false - } - } - return true -} - func ip2int(ip []byte) uint32 { if len(ip) == 16 { return binary.BigEndian.Uint32(ip[12:16]) @@ -642,3 +577,14 @@ func int2ip(nn uint32) net.IP { binary.BigEndian.PutUint32(ip, nn) return ip } + +func addr2int(addr netip.Addr) uint32 { + b := addr.Unmap().As4() + return binary.BigEndian.Uint32(b[:]) +} + +func int2addr(nn uint32) netip.Addr { + ip := [4]byte{} + binary.BigEndian.PutUint32(ip[:], nn) + return netip.AddrFrom4(ip).Unmap() +} diff --git a/cert/pem.go b/cert/pem.go index e442ef7d1..744ae2edf 100644 --- a/cert/pem.go +++ b/cert/pem.go @@ -9,6 +9,7 @@ import ( const ( CertificateBanner = "NEBULA CERTIFICATE" + CertificateV2Banner = "NEBULA CERTIFICATE V2" X25519PrivateKeyBanner = "NEBULA X25519 PRIVATE KEY" X25519PublicKeyBanner = "NEBULA X25519 PUBLIC KEY" EncryptedEd25519PrivateKeyBanner = "NEBULA ED25519 ENCRYPTED PRIVATE KEY" @@ -23,16 +24,25 @@ const ( // UnmarshalCertificateFromPEM will try to unmarshal the first pem block in a byte array, returning any non consumed // data or an error on failure -func UnmarshalCertificateFromPEM(b []byte) (*NebulaCertificate, []byte, error) { +func UnmarshalCertificateFromPEM(b []byte) (Certificate, []byte, error) { p, r := pem.Decode(b) if p == nil { return nil, r, ErrInvalidPEMBlock } - if p.Type != CertificateBanner { + + switch p.Type { + case CertificateBanner: + c, err := unmarshalCertificateV1(p.Bytes, true) + if err != nil { + return nil, nil, err + } + return c, r, nil + case CertificateV2Banner: + //TODO + panic("TODO") + default: return nil, r, ErrInvalidPEMCertificateBanner } - nc, err := UnmarshalNebulaCertificate(p.Bytes) - return nc, r, err } func MarshalPublicKeyToPEM(curve Curve, b []byte) []byte { diff --git a/cert/sign.go b/cert/sign.go new file mode 100644 index 000000000..430c944bf --- /dev/null +++ b/cert/sign.go @@ -0,0 +1,51 @@ +package cert + +import ( + "fmt" + "net/netip" + "time" +) + +type TBSCertificate struct { + Version Version + Name string + Networks []netip.Prefix + UnsafeNetworks []netip.Prefix + Groups []string + IsCA bool + NotBefore time.Time + NotAfter time.Time + PublicKey []byte + Curve Curve + issuer string +} + +// TODO: +func (t *TBSCertificate) Sign(signer Certificate, curve Curve, key []byte) (Certificate, error) { + if curve != t.Curve { + return nil, fmt.Errorf("curve in cert and private key supplied don't match") + } + + //TODO: signer should assert its constraints on the TBSCertificate, once you do nebula-cert sign needs to not double do it + if signer != nil { + if t.IsCA { + return nil, fmt.Errorf("can not sign a CA certificate with another") + } + issuer, err := signer.Sha256Sum() + if err != nil { + return nil, fmt.Errorf("error computing issuer: %v", err) + } + t.issuer = issuer + } else { + if !t.IsCA { + return nil, fmt.Errorf("self signed certificates must have IsCA set to true") + } + } + + switch t.Version { + case Version1: + return signV1(t, curve, key) + default: + return nil, fmt.Errorf("unknown cert version %d", t.Version) + } +} diff --git a/cmd/nebula-cert/ca.go b/cmd/nebula-cert/ca.go index 96f4eafcb..3f3785ab6 100644 --- a/cmd/nebula-cert/ca.go +++ b/cmd/nebula-cert/ca.go @@ -8,7 +8,7 @@ import ( "fmt" "io" "math" - "net" + "net/netip" "os" "strings" "time" @@ -106,38 +106,36 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error } } - var ips []*net.IPNet + var ips []netip.Prefix if *cf.ips != "" { for _, rs := range strings.Split(*cf.ips, ",") { rs := strings.Trim(rs, " ") if rs != "" { - ip, ipNet, err := net.ParseCIDR(rs) + n, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid ip definition: %s", err) } - if ip.To4() == nil { + if !n.Addr().Is4() { return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", rs) } - - ipNet.IP = ip - ips = append(ips, ipNet) + ips = append(ips, n) } } } - var subnets []*net.IPNet + var subnets []netip.Prefix if *cf.subnets != "" { for _, rs := range strings.Split(*cf.subnets, ",") { rs := strings.Trim(rs, " ") if rs != "" { - _, s, err := net.ParseCIDR(rs) + n, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid subnet definition: %s", err) } - if s.IP.To4() == nil { + if !n.Addr().Is4() { return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) } - subnets = append(subnets, s) + subnets = append(subnets, n) } } } @@ -191,18 +189,17 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error pub = eKey.PublicKey().Bytes() } - nc := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: *cf.name, - Groups: groups, - Ips: ips, - Subnets: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*cf.duration), - PublicKey: pub, - IsCA: true, - Curve: curve, - }, + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: *cf.name, + Groups: groups, + Networks: ips, + UnsafeNetworks: subnets, + NotBefore: time.Now(), + NotAfter: time.Now().Add(*cf.duration), + PublicKey: pub, + IsCA: true, + Curve: curve, } if _, err := os.Stat(*cf.outKeyPath); err == nil { @@ -213,7 +210,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("refusing to overwrite existing CA cert: %s", *cf.outCertPath) } - err = nc.Sign(curve, rawPriv) + c, err := t.Sign(nil, curve, rawPriv) if err != nil { return fmt.Errorf("error while signing: %s", err) } @@ -233,7 +230,7 @@ func ca(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) error return fmt.Errorf("error while writing out-key: %s", err) } - b, err = nc.MarshalToPEM() + b, err = c.MarshalToPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } diff --git a/cmd/nebula-cert/print.go b/cmd/nebula-cert/print.go index beb5718c9..77edc01a8 100644 --- a/cmd/nebula-cert/print.go +++ b/cmd/nebula-cert/print.go @@ -45,7 +45,7 @@ func printCert(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("unable to read cert; %s", err) } - var c *cert.NebulaCertificate + var c cert.Certificate var qrBytes []byte part := 0 diff --git a/cmd/nebula-cert/sign.go b/cmd/nebula-cert/sign.go index c15707a15..2c2047a3c 100644 --- a/cmd/nebula-cert/sign.go +++ b/cmd/nebula-cert/sign.go @@ -3,10 +3,11 @@ package main import ( "crypto/ecdh" "crypto/rand" + "errors" "flag" "fmt" "io" - "net" + "net/netip" "os" "strings" "time" @@ -82,14 +83,14 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) // naively attempt to decode the private key as though it is not encrypted caKey, _, curve, err = cert.UnmarshalSigningPrivateKeyFromPEM(rawCAKey) - if err == cert.ErrPrivateKeyEncrypted { + if errors.Is(err, cert.ErrPrivateKeyEncrypted) { // ask for a passphrase until we get one var passphrase []byte for i := 0; i < 5; i++ { out.Write([]byte("Enter passphrase: ")) passphrase, err = pr.ReadPassword() - if err == ErrNoTerminal { + if errors.Is(err, ErrNoTerminal) { return fmt.Errorf("ca-key is encrypted and must be decrypted interactively") } else if err != nil { return fmt.Errorf("error reading password: %s", err) @@ -125,30 +126,24 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to sign, root certificate does not match private key") } - issuer, err := caCert.Sha256Sum() - if err != nil { - return fmt.Errorf("error while getting -ca-crt fingerprint: %s", err) - } - if caCert.Expired(time.Now()) { return fmt.Errorf("ca certificate is expired") } // if no duration is given, expire one second before the root expires if *sf.duration <= 0 { - *sf.duration = time.Until(caCert.Details.NotAfter) - time.Second*1 + *sf.duration = time.Until(caCert.NotAfter()) - time.Second*1 } - ip, ipNet, err := net.ParseCIDR(*sf.ip) + network, err := netip.ParsePrefix(*sf.ip) if err != nil { return newHelpErrorf("invalid ip definition: %s", err) } - if ip.To4() == nil { + if !network.Addr().Is4() { return newHelpErrorf("invalid ip definition: can only be ipv4, have %s", *sf.ip) } - ipNet.IP = ip - groups := []string{} + var groups []string if *sf.groups != "" { for _, rg := range strings.Split(*sf.groups, ",") { g := strings.TrimSpace(rg) @@ -158,16 +153,16 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - subnets := []*net.IPNet{} + var subnets []netip.Prefix if *sf.subnets != "" { for _, rs := range strings.Split(*sf.subnets, ",") { rs := strings.Trim(rs, " ") if rs != "" { - _, s, err := net.ParseCIDR(rs) + s, err := netip.ParsePrefix(rs) if err != nil { return newHelpErrorf("invalid subnet definition: %s", err) } - if s.IP.To4() == nil { + if !s.Addr().Is4() { return newHelpErrorf("invalid subnet definition: can only be ipv4, have %s", rs) } subnets = append(subnets, s) @@ -193,24 +188,23 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) pub, rawPriv = newKeypair(curve) } - nc := cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: *sf.name, - Ips: []*net.IPNet{ipNet}, - Groups: groups, - Subnets: subnets, - NotBefore: time.Now(), - NotAfter: time.Now().Add(*sf.duration), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - Curve: curve, - }, + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: *sf.name, + Networks: []netip.Prefix{network}, + Groups: groups, + UnsafeNetworks: subnets, + NotBefore: time.Now(), + NotAfter: time.Now().Add(*sf.duration), + PublicKey: pub, + IsCA: false, + Curve: curve, } - if err := nc.CheckRootConstrains(caCert); err != nil { - return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err) - } + //TODO: + //if err := nc.CheckRootConstrains(caCert); err != nil { + // return fmt.Errorf("refusing to sign, root certificate constraints violated: %s", err) + //} if *sf.outKeyPath == "" { *sf.outKeyPath = *sf.name + ".key" @@ -224,7 +218,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) return fmt.Errorf("refusing to overwrite existing cert: %s", *sf.outCertPath) } - err = nc.Sign(curve, caKey) + c, err := t.Sign(caCert, curve, caKey) if err != nil { return fmt.Errorf("error while signing: %s", err) } @@ -240,7 +234,7 @@ func signCert(args []string, out io.Writer, errOut io.Writer, pr PasswordReader) } } - b, err := nc.MarshalToPEM() + b, err := c.MarshalToPEM() if err != nil { return fmt.Errorf("error while marshalling certificate: %s", err) } diff --git a/cmd/nebula-cert/verify.go b/cmd/nebula-cert/verify.go index b57a6b1dc..80cfef3c0 100644 --- a/cmd/nebula-cert/verify.go +++ b/cmd/nebula-cert/verify.go @@ -66,8 +66,8 @@ func verify(args []string, out io.Writer, errOut io.Writer) error { return fmt.Errorf("error while parsing crt: %s", err) } - good, err := c.Verify(time.Now(), caPool) - if !good { + _, err = caPool.VerifyCertificate(time.Now(), c) + if err != nil { return err } diff --git a/connection_manager.go b/connection_manager.go index d2e861647..a0de842db 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -415,7 +415,7 @@ func (n *connectionManager) shouldSwapPrimary(current, primary *HostInfo) bool { } certState := n.intf.pki.GetCertState() - return bytes.Equal(current.ConnectionState.myCert.Signature, certState.Certificate.Signature) + return bytes.Equal(current.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) } func (n *connectionManager) swapPrimary(current, primary *HostInfo) { @@ -436,8 +436,9 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - valid, err := remoteCert.VerifyWithCache(now, n.intf.pki.GetCAPool()) - if valid { + caPool := n.intf.pki.GetCAPool() + err := caPool.VerifyCachedCertificate(now, remoteCert) + if err == nil { return false } @@ -446,9 +447,8 @@ func (n *connectionManager) isInvalidCertificate(now time.Time, hostinfo *HostIn return false } - fingerprint, _ := remoteCert.Sha256Sum() hostinfo.logger(n.l).WithError(err). - WithField("fingerprint", fingerprint). + WithField("fingerprint", remoteCert.ShaSum). Info("Remote certificate is no longer valid, tearing down the tunnel") return true @@ -474,7 +474,7 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { func (n *connectionManager) tryRehandshake(hostinfo *HostInfo) { certState := n.intf.pki.GetCertState() - if bytes.Equal(hostinfo.ConnectionState.myCert.Signature, certState.Certificate.Signature) { + if bytes.Equal(hostinfo.ConnectionState.myCert.Signature(), certState.Certificate.Signature()) { return } diff --git a/connection_manager_test.go b/connection_manager_test.go index 3014df2b6..d377722b4 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -47,7 +47,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, + Certificate: &dummyCert{}, RawCertificateNoKey: []byte{}, } @@ -80,7 +80,7 @@ func Test_NewConnectionManagerTest(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &cert.NebulaCertificate{}, + myCert: &dummyCert{}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -130,7 +130,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, + Certificate: &dummyCert{}, RawCertificateNoKey: []byte{}, } @@ -163,7 +163,7 @@ func Test_NewConnectionManagerTest2(t *testing.T) { remoteIndexId: 9901, } hostinfo.ConnectionState = &ConnectionState{ - myCert: &cert.NebulaCertificate{}, + myCert: &dummyCert{}, H: &noise.HandshakeState{}, } nc.hostMap.unlockedAddHostInfo(hostinfo, ifce) @@ -253,7 +253,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { cs := &CertState{ RawCertificate: []byte{}, PrivateKey: []byte{}, - Certificate: &cert.NebulaCertificate{}, + Certificate: &dummyCert{}, RawCertificateNoKey: []byte{}, } @@ -282,7 +282,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { hostinfo := &HostInfo{ vpnIp: vpnIp, ConnectionState: &ConnectionState{ - myCert: &cert.NebulaCertificate{}, + myCert: &dummyCert{}, peerCert: &peerCert, H: &noise.HandshakeState{}, }, @@ -303,3 +303,103 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { invalid = nc.isInvalidCertificate(nextTick, hostinfo) assert.True(t, invalid) } + +type dummyCert struct { + isCa bool +} + +func (d *dummyCert) Version() cert.Version { + return cert.Version1 +} + +func (d *dummyCert) Curve() cert.Curve { + return cert.Curve_CURVE25519 +} + +func (d *dummyCert) Groups() []string { + return nil +} + +func (d *dummyCert) IsCA() bool { + return d.isCa +} + +func (d *dummyCert) Issuer() string { + return "" +} + +func (d *dummyCert) Name() string { + return "" +} + +func (d *dummyCert) Networks() []netip.Prefix { + return nil +} + +func (d *dummyCert) NotAfter() time.Time { + return time.Now().Add(time.Hour * -1) +} + +func (d *dummyCert) NotBefore() time.Time { + return time.Now().Add(time.Hour) +} + +func (d *dummyCert) PublicKey() []byte { + return nil +} + +func (d *dummyCert) Signature() []byte { + return nil +} + +func (d *dummyCert) UnsafeNetworks() []netip.Prefix { + return nil +} + +func (d *dummyCert) MarshalForHandshakes() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Sign(curve cert.Curve, key []byte) error { + return nil +} + +func (d *dummyCert) CheckSignature(key []byte) bool { + return true +} + +func (d *dummyCert) Expired(t time.Time) bool { + return false +} + +func (d *dummyCert) CheckRootConstraints(signer cert.Certificate) error { + return nil +} + +func (d *dummyCert) VerifyPrivateKey(curve cert.Curve, key []byte) error { + return nil +} + +func (d *dummyCert) String() string { + return "" +} + +func (d *dummyCert) Marshal() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) MarshalToPEM() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Sha256Sum() (string, error) { + return "", nil +} + +func (d *dummyCert) MarshalJSON() ([]byte, error) { + return nil, nil +} + +func (d *dummyCert) Copy() cert.Certificate { + return d +} diff --git a/connection_state.go b/connection_state.go index 1dd3c8cf5..e3963a19d 100644 --- a/connection_state.go +++ b/connection_state.go @@ -18,8 +18,8 @@ type ConnectionState struct { eKey *NebulaCipherState dKey *NebulaCipherState H *noise.HandshakeState - myCert *cert.NebulaCertificate - peerCert *cert.NebulaCertificate + myCert cert.Certificate + peerCert *cert.CachedCertificate initiator bool messageCounter atomic.Uint64 window *Bits @@ -28,13 +28,13 @@ type ConnectionState struct { func NewConnectionState(l *logrus.Logger, cipher string, certState *CertState, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState { var dhFunc noise.DHFunc - switch certState.Certificate.Details.Curve { + switch certState.Certificate.Curve() { case cert.Curve_CURVE25519: dhFunc = noise.DH25519 case cert.Curve_P256: dhFunc = noiseutil.DHP256 default: - l.Errorf("invalid curve: %s", certState.Certificate.Details.Curve) + l.Errorf("invalid curve: %s", certState.Certificate.Curve()) return nil } diff --git a/control.go b/control.go index 3468b3536..839c46f99 100644 --- a/control.go +++ b/control.go @@ -37,15 +37,15 @@ type Control struct { } type ControlHostInfo struct { - VpnIp netip.Addr `json:"vpnIp"` - LocalIndex uint32 `json:"localIndex"` - RemoteIndex uint32 `json:"remoteIndex"` - RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` - Cert *cert.NebulaCertificate `json:"cert"` - MessageCounter uint64 `json:"messageCounter"` - CurrentRemote netip.AddrPort `json:"currentRemote"` - CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` - CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` + VpnIp netip.Addr `json:"vpnIp"` + LocalIndex uint32 `json:"localIndex"` + RemoteIndex uint32 `json:"remoteIndex"` + RemoteAddrs []netip.AddrPort `json:"remoteAddrs"` + Cert cert.Certificate `json:"cert"` + MessageCounter uint64 `json:"messageCounter"` + CurrentRemote netip.AddrPort `json:"currentRemote"` + CurrentRelaysToMe []netip.Addr `json:"currentRelaysToMe"` + CurrentRelaysThroughMe []netip.Addr `json:"currentRelaysThroughMe"` } // Start actually runs nebula, this is a nonblocking call. To block use Control.ShutdownBlock() @@ -130,7 +130,8 @@ func (c *Control) ListHostmapIndexes(pendingMap bool) []ControlHostInfo { } // GetCertByVpnIp returns the authenticated certificate of the given vpn IP, or nil if not found -func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate { +// TODO: this should copy! +func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) cert.Certificate { if c.f.myVpnNet.Addr() == vpnIp { return c.f.pki.GetCertState().Certificate } @@ -138,7 +139,7 @@ func (c *Control) GetCertByVpnIp(vpnIp netip.Addr) *cert.NebulaCertificate { if hi == nil { return nil } - return hi.GetCert() + return hi.GetCert().Certificate } // CreateTunnel creates a new tunnel to the given vpn ip. @@ -290,7 +291,7 @@ func copyHostInfo(h *HostInfo, preferredRanges []netip.Prefix) ControlHostInfo { } if c := h.GetCert(); c != nil { - chi.Cert = c.Copy() + chi.Cert = c.Certificate.Copy() } return chi diff --git a/control_tester.go b/control_tester.go index d46540f04..fa87e5300 100644 --- a/control_tester.go +++ b/control_tester.go @@ -153,7 +153,7 @@ func (c *Control) GetHostmap() *HostMap { return c.f.hostMap } -func (c *Control) GetCert() *cert.NebulaCertificate { +func (c *Control) GetCert() cert.Certificate { return c.f.pki.GetCertState().Certificate } diff --git a/dns_server.go b/dns_server.go index 5fea65c47..750123122 100644 --- a/dns_server.go +++ b/dns_server.go @@ -57,9 +57,11 @@ func (d *dnsRecords) QueryCert(data string) string { return "" } - cert := q.Details - c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) - return c + b, err := q.Certificate.MarshalJSON() + if err != nil { + return "" + } + return string(b) } func (d *dnsRecords) Add(host, data string) { diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 3d42a560c..7c1ac4c4d 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -15,6 +15,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "golang.org/x/exp/slices" "gopkg.in/yaml.v2" ) @@ -538,7 +539,7 @@ func TestRehandshakingRelays(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -558,7 +559,7 @@ func TestRehandshakingRelays(t *testing.T) { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break @@ -571,7 +572,7 @@ func TestRehandshakingRelays(t *testing.T) { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break @@ -642,7 +643,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { // When I update the certificate for the relay, both me and them will have 2 host infos for the relay, // and the main host infos will not have any relay state to handle the me<->relay<->them tunnel. r.Log("Renew relay certificate and spin until me and them sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), relayVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "relay", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{relayVpnIpNet}, nil, []string{"new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -662,7 +663,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { r.Log("Assert the tunnel works between myVpnIpNet and relayVpnIpNet") assertTunnel(t, myVpnIpNet.Addr(), relayVpnIpNet.Addr(), myControl, relayControl, r) c := myControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between my and relay is updated!") break @@ -675,7 +676,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { r.Log("Assert the tunnel works between theirVpnIpNet and relayVpnIpNet") assertTunnel(t, theirVpnIpNet.Addr(), relayVpnIpNet.Addr(), theirControl, relayControl, r) c := theirControl.GetHostInfoByVpnIp(relayVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now r.Log("Certificate between their and relay is updated!") break @@ -737,7 +738,7 @@ func TestRehandshaking(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew my certificate and spin until their sees it") - _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), myVpnIpNet, nil, []string{"new group"}) + _, _, myNextPrivKey, myNextPEM := NewTestCert(ca, caKey, "me", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{myVpnIpNet}, nil, []string{"new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -756,7 +757,7 @@ func TestRehandshaking(t *testing.T) { for { assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) - if len(c.Cert.Details.Groups) != 0 { + if len(c.Cert.Groups()) != 0 { // We have a new certificate now break } @@ -764,6 +765,7 @@ func TestRehandshaking(t *testing.T) { time.Sleep(time.Second) } + r.Log("Got the new cert") // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(theirConfig.Settings) assert.NoError(t, err) @@ -794,7 +796,7 @@ func TestRehandshaking(t *testing.T) { // Make sure the correct tunnel won c := theirControl.GetHostInfoByVpnIp(myVpnIpNet.Addr(), false) - assert.Contains(t, c.Cert.Details.Groups, "new group") + assert.Contains(t, c.Cert.Groups(), "new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) @@ -837,7 +839,7 @@ func TestRehandshakingLoser(t *testing.T) { r.RenderHostmaps("Starting hostmaps", myControl, theirControl) r.Log("Renew their certificate and spin until mine sees it") - _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), theirVpnIpNet, nil, []string{"their new group"}) + _, _, theirNextPrivKey, theirNextPEM := NewTestCert(ca, caKey, "them", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{theirVpnIpNet}, nil, []string{"their new group"}) caB, err := ca.MarshalToPEM() if err != nil { @@ -857,8 +859,7 @@ func TestRehandshakingLoser(t *testing.T) { assertTunnel(t, myVpnIpNet.Addr(), theirVpnIpNet.Addr(), myControl, theirControl, r) theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) - _, theirNewGroup := theirCertInMe.Cert.Details.InvertedGroups["their new group"] - if theirNewGroup { + if slices.Contains(theirCertInMe.Cert.Groups(), "their new group") { break } @@ -895,7 +896,7 @@ func TestRehandshakingLoser(t *testing.T) { // Make sure the correct tunnel won theirCertInMe := myControl.GetHostInfoByVpnIp(theirVpnIpNet.Addr(), false) - assert.Contains(t, theirCertInMe.Cert.Details.Groups, "their new group") + assert.Contains(t, theirCertInMe.Cert.Groups(), "their new group") // We should only have a single tunnel now on both sides assert.Len(t, myFinalHostmapHosts, 1) diff --git a/e2e/helpers.go b/e2e/helpers.go index 19ba90be6..3c5d5ad17 100644 --- a/e2e/helpers.go +++ b/e2e/helpers.go @@ -3,7 +3,6 @@ package e2e import ( "crypto/rand" "io" - "net" "net/netip" "time" @@ -13,7 +12,7 @@ import ( ) // NewTestCaCert will generate a CA cert -func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { +func NewTestCaCert(before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { pub, priv, err := ed25519.GenerateKey(rand.Reader) if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) @@ -22,56 +21,34 @@ func NewTestCaCert(before, after time.Time, ips, subnets []netip.Prefix, groups after = time.Now().Add(time.Second * 60).Round(time.Second) } - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: "test ca", - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: true, - InvertedGroups: make(map[string]struct{}), - }, + t := &cert.TBSCertificate{ + Version: cert.Version1, + Name: "test ca", + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + IsCA: true, } - if len(ips) > 0 { - nc.Details.Ips = make([]*net.IPNet, len(ips)) - for i, ip := range ips { - nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} - } - } - - if len(subnets) > 0 { - nc.Details.Subnets = make([]*net.IPNet, len(subnets)) - for i, ip := range subnets { - nc.Details.Ips[i] = &net.IPNet{IP: ip.Addr().AsSlice(), Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())} - } - } - - if len(groups) > 0 { - nc.Details.Groups = groups - } - - err = nc.Sign(cert.Curve_CURVE25519, priv) + c, err := t.Sign(nil, cert.Curve_CURVE25519, priv) if err != nil { panic(err) } - pem, err := nc.MarshalToPEM() + pem, err := c.MarshalToPEM() if err != nil { panic(err) } - return nc, pub, priv, pem + return c, pub, priv, pem } // NewTestCert will generate a signed certificate with the provided details. // Expiry times are defaulted if you do not pass them in -func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, after time.Time, ip netip.Prefix, subnets []netip.Prefix, groups []string) (*cert.NebulaCertificate, []byte, []byte, []byte) { - issuer, err := ca.Sha256Sum() - if err != nil { - panic(err) - } - +func NewTestCert(ca cert.Certificate, key []byte, name string, before, after time.Time, networks, unsafeNetworks []netip.Prefix, groups []string) (cert.Certificate, []byte, []byte, []byte) { if before.IsZero() { before = time.Now().Add(time.Second * -60).Round(time.Second) } @@ -81,33 +58,29 @@ func NewTestCert(ca *cert.NebulaCertificate, key []byte, name string, before, af } pub, rawPriv := x25519Keypair() - ipb := ip.Addr().AsSlice() - nc := &cert.NebulaCertificate{ - Details: cert.NebulaCertificateDetails{ - Name: name, - Ips: []*net.IPNet{{IP: ipb[:], Mask: net.CIDRMask(ip.Bits(), ip.Addr().BitLen())}}, - //Subnets: subnets, - Groups: groups, - NotBefore: time.Unix(before.Unix(), 0), - NotAfter: time.Unix(after.Unix(), 0), - PublicKey: pub, - IsCA: false, - Issuer: issuer, - InvertedGroups: make(map[string]struct{}), - }, - } - - err = nc.Sign(ca.Details.Curve, key) + nc := &cert.TBSCertificate{ + Version: cert.Version1, + Name: name, + Networks: networks, + UnsafeNetworks: unsafeNetworks, + Groups: groups, + NotBefore: time.Unix(before.Unix(), 0), + NotAfter: time.Unix(after.Unix(), 0), + PublicKey: pub, + IsCA: false, + } + + c, err := nc.Sign(ca, ca.Curve(), key) if err != nil { panic(err) } - pem, err := nc.MarshalToPEM() + pem, err := c.MarshalToPEM() if err != nil { panic(err) } - return nc, pub, cert.MarshalPrivateKeyToPEM(cert.Curve_CURVE25519, rawPriv), pem + return c, pub, cert.MarshalPrivateKeyToPEM(cert.Curve_CURVE25519, rawPriv), pem } func x25519Keypair() ([]byte, []byte) { diff --git a/e2e/helpers_test.go b/e2e/helpers_test.go index 527f55bc7..885ba6a7c 100644 --- a/e2e/helpers_test.go +++ b/e2e/helpers_test.go @@ -26,7 +26,7 @@ import ( type m map[string]interface{} // newSimpleServer creates a nebula instance with many assumptions -func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { +func newSimpleServer(caCrt cert.Certificate, caKey []byte, name string, sVpnIpNet string, overrides m) (*nebula.Control, netip.Prefix, netip.AddrPort, *config.C) { l := NewTestLogger() vpnIpNet, err := netip.ParsePrefix(sVpnIpNet) @@ -44,7 +44,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, s budpIp[13] -= 128 udpAddr = netip.AddrPortFrom(netip.AddrFrom16(budpIp), 4242) } - _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{}) + _, _, myPrivKey, myPEM := NewTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{vpnIpNet}, nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { diff --git a/e2e/router/hostmap.go b/e2e/router/hostmap.go index c14ab2e77..29fa95991 100644 --- a/e2e/router/hostmap.go +++ b/e2e/router/hostmap.go @@ -58,8 +58,8 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { var lines []string var globalLines []*edge - clusterName := strings.Trim(c.GetCert().Details.Name, " ") - clusterVpnIp := c.GetCert().Details.Ips[0].IP + clusterName := strings.Trim(c.GetCert().Name(), " ") + clusterVpnIp := c.GetCert().Networks()[0].Addr() r := fmt.Sprintf("\tsubgraph %s[\"%s (%s)\"]\n", clusterName, clusterName, clusterVpnIp) hm := c.GetHostmap() @@ -102,7 +102,7 @@ func renderHostmap(c *nebula.Control) (string, []*edge) { hi, ok := hm.Indexes[idx] if ok { r += fmt.Sprintf("\t\t\t%v.%v[\"%v (%v)\"]\n", clusterName, idx, idx, hi.GetVpnIp()) - remoteClusterName := strings.Trim(hi.GetCert().Details.Name, " ") + remoteClusterName := strings.Trim(hi.GetCert().Certificate.Name(), " ") globalLines = append(globalLines, &edge{from: fmt.Sprintf("%v.%v", clusterName, idx), to: fmt.Sprintf("%v.%v", remoteClusterName, hi.GetRemoteIndex())}) _ = hi } diff --git a/firewall.go b/firewall.go index 255a889a8..80a828057 100644 --- a/firewall.go +++ b/firewall.go @@ -52,9 +52,9 @@ type Firewall struct { DefaultTimeout time.Duration //linux: 600s // Used to ensure we don't emit local packets for ips we don't own - localIps *bart.Table[struct{}] - assignedCIDR netip.Prefix - hasSubnets bool + localIps *bart.Table[struct{}] + assignedCIDR netip.Prefix + hasUnsafeNetworks bool rules string rulesVersion uint16 @@ -126,7 +126,7 @@ type firewallLocalCIDR struct { } // NewFirewall creates a new Firewall object. A TimerWheel is created for you from the provided timeouts. -func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c *cert.NebulaCertificate) *Firewall { +func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.Duration, c cert.Certificate) *Firewall { //TODO: error on 0 duration var min, max time.Duration @@ -147,11 +147,8 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D localIps := new(bart.Table[struct{}]) var assignedCIDR netip.Prefix var assignedSet bool - for _, ip := range c.Details.Ips { - //TODO: IPV6-WORK the unmap is a bit unfortunate - nip, _ := netip.AddrFromSlice(ip.IP) - nip = nip.Unmap() - nprefix := netip.PrefixFrom(nip, nip.BitLen()) + for _, network := range c.Networks() { + nprefix := netip.PrefixFrom(network.Addr(), network.Addr().BitLen()) localIps.Insert(nprefix, struct{}{}) if !assignedSet { @@ -161,11 +158,10 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } - for _, n := range c.Details.Subnets { - nip, _ := netip.AddrFromSlice(n.IP) - ones, _ := n.Mask.Size() - nip = nip.Unmap() - localIps.Insert(netip.PrefixFrom(nip, ones), struct{}{}) + hasUnsafeNetworks := false + for _, n := range c.UnsafeNetworks() { + localIps.Insert(n, struct{}{}) + hasUnsafeNetworks = true } return &Firewall{ @@ -173,15 +169,15 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D Conns: make(map[firewall.Packet]*conn), TimerWheel: NewTimerWheel[firewall.Packet](min, max), }, - InRules: newFirewallTable(), - OutRules: newFirewallTable(), - TCPTimeout: tcpTimeout, - UDPTimeout: UDPTimeout, - DefaultTimeout: defaultTimeout, - localIps: localIps, - assignedCIDR: assignedCIDR, - hasSubnets: len(c.Details.Subnets) > 0, - l: l, + InRules: newFirewallTable(), + OutRules: newFirewallTable(), + TCPTimeout: tcpTimeout, + UDPTimeout: UDPTimeout, + DefaultTimeout: defaultTimeout, + localIps: localIps, + assignedCIDR: assignedCIDR, + hasUnsafeNetworks: hasUnsafeNetworks, + l: l, incomingMetrics: firewallMetrics{ droppedLocalIP: metrics.GetOrRegisterCounter("firewall.incoming.dropped.local_ip", nil), @@ -196,7 +192,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D } } -func NewFirewallFromConfig(l *logrus.Logger, nc *cert.NebulaCertificate, c *config.C) (*Firewall, error) { +func NewFirewallFromConfig(l *logrus.Logger, nc cert.Certificate, c *config.C) (*Firewall, error) { fw := NewFirewall( l, c.GetDuration("firewall.conntrack.tcp_timeout", time.Minute*12), @@ -619,7 +615,7 @@ func (f *Firewall) evict(p firewall.Packet) { delete(conntrack.Conns, p) } -func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.CAPool) bool { +func (ft *FirewallTable) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if ft.AnyProto.match(p, incoming, c, caPool) { return true } @@ -663,7 +659,7 @@ func (fp firewallPort) addRule(f *Firewall, startPort int32, endPort int32, grou return nil } -func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.NebulaCertificate, caPool *cert.CAPool) bool { +func (fp firewallPort) match(p firewall.Packet, incoming bool, c *cert.CachedCertificate, caPool *cert.CAPool) bool { // We don't have any allowed ports, bail if fp == nil { return false @@ -726,7 +722,7 @@ func (fc *FirewallCA) addRule(f *Firewall, groups []string, host string, ip, loc return nil } -func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool *cert.CAPool) bool { +func (fc *FirewallCA) match(p firewall.Packet, c *cert.CachedCertificate, caPool *cert.CAPool) bool { if fc == nil { return false } @@ -735,18 +731,18 @@ func (fc *FirewallCA) match(p firewall.Packet, c *cert.NebulaCertificate, caPool return true } - if t, ok := fc.CAShas[c.Details.Issuer]; ok { + if t, ok := fc.CAShas[c.Certificate.Issuer()]; ok { if t.match(p, c) { return true } } - s, err := caPool.GetCAForCert(c) + s, err := caPool.GetCAForCert(c.Certificate) if err != nil { return false } - return fc.CANames[s.Details.Name].match(p, c) + return fc.CANames[s.Certificate.Name()].match(p, c) } func (fr *FirewallRule) addRule(f *Firewall, groups []string, host string, ip, localCIDR netip.Prefix) error { @@ -826,7 +822,7 @@ func (fr *FirewallRule) isAny(groups []string, host string, ip netip.Prefix) boo return false } -func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool { +func (fr *FirewallRule) match(p firewall.Packet, c *cert.CachedCertificate) bool { if fr == nil { return false } @@ -841,7 +837,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool found := false for _, g := range sg.Groups { - if _, ok := c.Details.InvertedGroups[g]; !ok { + if _, ok := c.InvertedGroups[g]; !ok { found = false break } @@ -855,7 +851,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool } if fr.Hosts != nil { - if flc, ok := fr.Hosts[c.Details.Name]; ok { + if flc, ok := fr.Hosts[c.Certificate.Name()]; ok { if flc.match(p, c) { return true } @@ -876,7 +872,7 @@ func (fr *FirewallRule) match(p firewall.Packet, c *cert.NebulaCertificate) bool func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { if !localIp.IsValid() { - if !f.hasSubnets || f.defaultLocalCIDRAny { + if !f.hasUnsafeNetworks || f.defaultLocalCIDRAny { flc.Any = true return nil } @@ -890,7 +886,7 @@ func (flc *firewallLocalCIDR) addRule(f *Firewall, localIp netip.Prefix) error { return nil } -func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.NebulaCertificate) bool { +func (flc *firewallLocalCIDR) match(p firewall.Packet, c *cert.CachedCertificate) bool { if flc == nil { return false } diff --git a/handshake_ix.go b/handshake_ix.go index 8cf534112..24c423d64 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -99,8 +99,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) - if !ok { + if len(remoteCert.Certificate.Networks()) == 0 { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) @@ -112,10 +111,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - vpnIp = vpnIp.Unmap() - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - issuer := remoteCert.Details.Issuer + vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + certName := remoteCert.Certificate.Name() + fingerprint := remoteCert.ShaSum + issuer := remoteCert.Certificate.Issuer() if vpnIp == f.myVpnNet.Addr() { f.l.WithField("vpnIp", vpnIp).WithField("udpAddr", addr). @@ -216,7 +215,7 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet hostinfo.remotes = f.lightHouse.QueryCache(vpnIp) hostinfo.SetRemote(addr) - hostinfo.CreateRemoteCIDR(remoteCert) + hostinfo.CreateRemoteCIDR(remoteCert.Certificate) existing, err := f.handshakeManager.CheckAndComplete(hostinfo, 0, f) if err != nil { @@ -402,8 +401,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - vpnIp, ok := netip.AddrFromSlice(remoteCert.Details.Ips[0].IP) - if !ok { + if len(remoteCert.Certificate.Networks()) == 0 { e := f.l.WithError(err).WithField("udpAddr", addr). WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) @@ -415,10 +413,10 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - vpnIp = vpnIp.Unmap() - certName := remoteCert.Details.Name - fingerprint, _ := remoteCert.Sha256Sum() - issuer := remoteCert.Details.Issuer + vpnIp := remoteCert.Certificate.Networks()[0].Addr().Unmap() + certName := remoteCert.Certificate.Name() + fingerprint := remoteCert.ShaSum + issuer := remoteCert.Certificate.Issuer() // Ensure the right host responded if vpnIp != hostinfo.vpnIp { @@ -486,7 +484,7 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha } // Build up the radix for the firewall if we have subnets in the cert - hostinfo.CreateRemoteCIDR(remoteCert) + hostinfo.CreateRemoteCIDR(remoteCert.Certificate) // Complete our handshake and update metrics, this will replace any existing tunnels for this vpnIp f.handshakeManager.Complete(hostinfo, f) diff --git a/hostmap.go b/hostmap.go index fb97b76d7..d83151eb3 100644 --- a/hostmap.go +++ b/hostmap.go @@ -491,7 +491,7 @@ func (hm *HostMap) queryVpnIp(vpnIp netip.Addr, promoteIfce *Interface) *HostInf func (hm *HostMap) unlockedAddHostInfo(hostinfo *HostInfo, f *Interface) { if f.serveDns { remoteCert := hostinfo.ConnectionState.peerCert - dnsR.Add(remoteCert.Details.Name+".", remoteCert.Details.Ips[0].IP.String()) + dnsR.Add(remoteCert.Certificate.Name()+".", remoteCert.Certificate.Networks()[0].Addr().String()) } existing := hm.Hosts[hostinfo.vpnIp] @@ -585,7 +585,7 @@ func (i *HostInfo) TryPromoteBest(preferredRanges []netip.Prefix, ifce *Interfac } } -func (i *HostInfo) GetCert() *cert.NebulaCertificate { +func (i *HostInfo) GetCert() *cert.CachedCertificate { if i.ConnectionState != nil { return i.ConnectionState.peerCert } @@ -647,27 +647,19 @@ func (i *HostInfo) RecvErrorExceeded() bool { return true } -func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) { - if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 { +func (i *HostInfo) CreateRemoteCIDR(c cert.Certificate) { + if len(c.Networks()) == 1 && len(c.UnsafeNetworks()) == 0 { // Simple case, no CIDRTree needed return } remoteCidr := new(bart.Table[struct{}]) - for _, ip := range c.Details.Ips { - //TODO: IPV6-WORK what to do when ip is invalid? - nip, _ := netip.AddrFromSlice(ip.IP) - nip = nip.Unmap() - bits, _ := ip.Mask.Size() - remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) + for _, network := range c.Networks() { + remoteCidr.Insert(network, struct{}{}) } - for _, n := range c.Details.Subnets { - //TODO: IPV6-WORK what to do when ip is invalid? - nip, _ := netip.AddrFromSlice(n.IP) - nip = nip.Unmap() - bits, _ := n.Mask.Size() - remoteCidr.Insert(netip.PrefixFrom(nip, bits), struct{}{}) + for _, network := range c.UnsafeNetworks() { + remoteCidr.Insert(network, struct{}{}) } i.remoteCidr = remoteCidr } @@ -683,7 +675,7 @@ func (i *HostInfo) logger(l *logrus.Logger) *logrus.Entry { if connState := i.ConnectionState; connState != nil { if peerCert := connState.peerCert; peerCert != nil { - li = li.WithField("certName", peerCert.Details.Name) + li = li.WithField("certName", peerCert.Certificate.Name()) } } diff --git a/interface.go b/interface.go index f2519076c..9308aae84 100644 --- a/interface.go +++ b/interface.go @@ -2,7 +2,6 @@ package nebula import ( "context" - "encoding/binary" "errors" "fmt" "io" @@ -157,26 +156,6 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { certificate := c.pki.GetCertState().Certificate - myVpnAddr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) - if !ok { - return nil, fmt.Errorf("invalid ip address in certificate: %s", certificate.Details.Ips[0].IP) - } - - myVpnMask, ok := netip.AddrFromSlice(certificate.Details.Ips[0].Mask) - if !ok { - return nil, fmt.Errorf("invalid ip mask in certificate: %s", certificate.Details.Ips[0].Mask) - } - - myVpnAddr = myVpnAddr.Unmap() - myVpnMask = myVpnMask.Unmap() - - if myVpnAddr.BitLen() != myVpnMask.BitLen() { - return nil, fmt.Errorf("ip address and mask are different lengths in certificate") - } - - ones, _ := certificate.Details.Ips[0].Mask.Size() - myVpnNet := netip.PrefixFrom(myVpnAddr, ones) - ifce := &Interface{ pki: c.pki, hostMap: c.HostMap, @@ -194,7 +173,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { version: c.version, writers: make([]udp.Conn, c.routines), readers: make([]io.ReadWriteCloser, c.routines), - myVpnNet: myVpnNet, + myVpnNet: certificate.Networks()[0], relayManager: c.relayManager, conntrackCacheTimeout: c.ConntrackCacheTimeout, @@ -209,10 +188,11 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) { l: c.l, } - if myVpnAddr.Is4() { - addr := myVpnNet.Masked().Addr().As4() - binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) - ifce.myBroadcastAddr = netip.AddrFrom4(addr) + if ifce.myVpnNet.Addr().Is4() { + //TODO: + //addr := myVpnNet.Masked().Addr().As4() + //binary.BigEndian.PutUint32(addr[:], binary.BigEndian.Uint32(addr[:])|^binary.BigEndian.Uint32(certificate.Details.Ips[0].Mask)) + //ifce.myBroadcastAddr = netip.AddrFrom4(addr) } ifce.tryPromoteEvery.Store(c.tryPromoteEvery) @@ -434,7 +414,7 @@ func (f *Interface) emitStats(ctx context.Context, i time.Duration) { f.firewall.EmitStats() f.handshakeManager.EmitStats() udpStats() - certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.Details.NotAfter.Sub(time.Now()) / time.Second)) + certExpirationGauge.Update(int64(f.pki.GetCertState().Certificate.NotAfter().Sub(time.Now()) / time.Second)) } } } diff --git a/main.go b/main.go index c6edc9133..8f4535951 100644 --- a/main.go +++ b/main.go @@ -68,17 +68,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg } l.WithField("firewallHashes", fw.GetRuleHashes()).Info("Firewall started") - ones, _ := certificate.Details.Ips[0].Mask.Size() - addr, ok := netip.AddrFromSlice(certificate.Details.Ips[0].IP) - if !ok { - err = util.NewContextualError( - "Invalid ip address in certificate", - m{"vpnIp": certificate.Details.Ips[0].IP}, - nil, - ) - return nil, err - } - tunCidr := netip.PrefixFrom(addr, ones) + tunCidr := certificate.Networks()[0] ssh, err := sshd.NewSSHServer(l.WithField("subsystem", "sshd")) if err != nil { diff --git a/outside.go b/outside.go index a0f7ad01d..6a71fe77f 100644 --- a/outside.go +++ b/outside.go @@ -14,7 +14,6 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "golang.org/x/net/ipv4" - "google.golang.org/protobuf/proto" ) const ( @@ -494,7 +493,7 @@ func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *N } */ -func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.NebulaCertificate, error) { +func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.CAPool) (*cert.CachedCertificate, error) { pk := h.PeerStatic() if pk == nil { @@ -505,31 +504,15 @@ func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPo return nil, errors.New("provided payload was empty") } - r := &cert.RawNebulaCertificate{} - err := proto.Unmarshal(rawCertBytes, r) + c, err := cert.UnmarshalCertificateFromHandshake(rawCertBytes, pk) if err != nil { - return nil, fmt.Errorf("error unmarshaling cert: %s", err) + return nil, fmt.Errorf("error unmarshaling cert: %w", err) } - // If the Details are nil, just exit to avoid crashing - if r.Details == nil { - return nil, fmt.Errorf("certificate did not contain any details") - } - - r.Details.PublicKey = pk - recombined, err := proto.Marshal(r) - if err != nil { - return nil, fmt.Errorf("error while recombining certificate: %s", err) - } - - c, _ := cert.UnmarshalNebulaCertificate(recombined) - isValid, err := c.Verify(time.Now(), caPool) + cc, err := caPool.VerifyCertificate(time.Now(), c) if err != nil { - return c, fmt.Errorf("certificate validation failed: %s", err) - } else if !isValid { - // This case should never happen but here's to defensive programming! - return c, errors.New("certificate validation failed but did not return an error") + return nil, fmt.Errorf("certificate validation failed: %w", err) } - return c, nil + return cc, nil } diff --git a/pki.go b/pki.go index 70ca944e7..8677f8a8d 100644 --- a/pki.go +++ b/pki.go @@ -21,7 +21,7 @@ type PKI struct { } type CertState struct { - Certificate *cert.NebulaCertificate + Certificate cert.Certificate RawCertificate []byte RawCertificateNoKey []byte PublicKey []byte @@ -84,8 +84,8 @@ func (p *PKI) reloadCert(c *config.C, initial bool) *util.ContextualError { // did IP in cert change? if so, don't set currentCert := p.cs.Load().Certificate - oldIPs := currentCert.Details.Ips - newIPs := cs.Certificate.Details.Ips + oldIPs := currentCert.Networks() + newIPs := cs.Certificate.Networks() if len(oldIPs) > 0 && len(newIPs) > 0 && oldIPs[0].String() != newIPs[0].String() { return util.NewContextualError( "IP in new cert was different from old", @@ -115,14 +115,14 @@ func (p *PKI) reloadCAPool(c *config.C) *util.ContextualError { return nil } -func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*CertState, error) { +func newCertState(certificate cert.Certificate, privateKey []byte) (*CertState, error) { // Marshal the certificate to ensure it is valid rawCertificate, err := certificate.Marshal() if err != nil { return nil, fmt.Errorf("invalid nebula certificate on interface: %s", err) } - publicKey := certificate.Details.PublicKey + publicKey := certificate.PublicKey() cs := &CertState{ RawCertificate: rawCertificate, Certificate: certificate, @@ -130,14 +130,12 @@ func newCertState(certificate *cert.NebulaCertificate, privateKey []byte) (*Cert PublicKey: publicKey, } - cs.Certificate.Details.PublicKey = nil - rawCertNoKey, err := cs.Certificate.Marshal() + rawCertNoKey, err := cs.Certificate.MarshalForHandshakes() if err != nil { return nil, fmt.Errorf("error marshalling certificate no key: %s", err) } cs.RawCertificateNoKey = rawCertNoKey - // put public key back - cs.Certificate.Details.PublicKey = cs.PublicKey + return cs, nil } @@ -193,7 +191,7 @@ func newCertStateFromConfig(c *config.C) (*CertState, error) { return nil, fmt.Errorf("nebula certificate for this host is expired") } - if len(nebulaCert.Details.Ips) == 0 { + if len(nebulaCert.Networks()) == 0 { return nil, fmt.Errorf("no IPs encoded in certificate") } @@ -227,7 +225,7 @@ func loadCAPoolFromConfig(l *logrus.Logger, c *config.C) (*cert.CAPool, error) { if errors.Is(err, cert.ErrExpired) { var expired int for _, crt := range caPool.CAs { - if crt.Expired(time.Now()) { + if crt.Certificate.Expired(time.Now()) { expired++ l.WithField("cert", crt).Warn("expired certificate present in CA pool") } diff --git a/ssh.go b/ssh.go index 2ff0954d6..93f162efc 100644 --- a/ssh.go +++ b/ssh.go @@ -801,7 +801,7 @@ func sshPrintCert(ifce *Interface, fs interface{}, a []string, w sshd.StringWrit return w.WriteLine(fmt.Sprintf("Could not find tunnel for vpn ip: %v", a[0])) } - cert = hostInfo.GetCert() + cert = hostInfo.GetCert().Certificate } if args.Json || args.Pretty {