diff --git a/README.md b/README.md index 82bcff4..254d3c6 100644 --- a/README.md +++ b/README.md @@ -46,4 +46,69 @@ func TestMyApp(t *testing.T) { } ``` -For a full list of possible functions, please [check the Go docs](https://pkg.go.dev/github.com/opentofu/tofutestutils). \ No newline at end of file +For a full list of possible functions, please [check the Go docs](https://pkg.go.dev/github.com/opentofu/tofutestutils). + +## Certificate authority + +When you need an x509 certificate for a server or a client, you can use the `tofutestutils.CA` function to obtain a `testca.CertificateAuthority` implementation using a pseudo-random number generator. You can use this to create a certificate for a socket server: + +```go +package your_test + +import ( + "crypto/tls" + "io" + "net" + "strconv" + "testing" + + "github.com/opentofu/tofutestutils" +) + +func TestMySocket(t *testing.T) { + ca := tofutestutils.CA(t) + + // Server side: + tlsListener, err := tls.Listen("tcp", "127.0.0.1:0", ca.CreateLocalhostServerCert().GetServerTLSConfig()) + if err != nil { + t.Fatalf("Failed to open server: %v", err) + } + defer func() { + if err = tlsListener.Close(); err != nil { + t.Fatalf("Failed to close server listener: %v", err) + } + }() + go func() { + conn, serverErr := tlsListener.Accept() + if serverErr != nil { + return + } + defer func() { + if err := conn.Close(); err != nil { + t.Logf("Failed to close connection: %v", err) + } + }() + if _, err = conn.Write([]byte("Hello world!")); err != nil { + t.Logf("Failed to write to client: %v", err) + } + }() + + // Client side: + port := tlsListener.Addr().(*net.TCPAddr).Port + client, err := tls.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)), ca.GetClientTLSConfig()) + if err != nil { + t.Fatalf("Failed to open connection to server: %v", err) + } + defer func() { + if err = client.Close(); err != nil { + t.Fatalf("Failed to close client: %v", err) + } + }() + + data, err := io.ReadAll(client) + if err != nil { + t.Fatal(err) + } + t.Logf("%s", data) +} +``` diff --git a/ca.go b/ca.go new file mode 100644 index 0000000..23716cf --- /dev/null +++ b/ca.go @@ -0,0 +1,18 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package tofutestutils + +import ( + "testing" + + "github.com/opentofu/tofutestutils/testca" +) + +// CA returns a certificate authority configured for the provided test. This implementation will configure the CA to use +// a pseudorandom source. You can call testca.New() for more configuration options. +func CA(t *testing.T) testca.CertificateAuthority { + return testca.New(t, RandomSource()) +} diff --git a/testca/README.md b/testca/README.md new file mode 100644 index 0000000..26ef0df --- /dev/null +++ b/testca/README.md @@ -0,0 +1,48 @@ +# Certificate authority + +This folder contains a basic x509 certificate authority implementation for testing purposes. You can use it whenever you need a certificate for servers or clients. + +```go +package your_test + +import ( + "crypto/tls" + "io" + "net" + "strconv" + "testing" + "time" + + "github.com/opentofu/tofutestutils" + "github.com/opentofu/tofutestutils/testca" + "github.com/opentofu/tofutestutils/testrandom" +) + +func TestMySocket(t *testing.T) { + // Configure a desired randomness and time source. You can use this to create deterministic behavior. + currentTimeSource := time.Now + ca := testca.New(t, testrandom.DeterministicSource(t), currentTimeSource) + + // Server side: + tlsListener := tofutestutils.Must2(tls.Listen("tcp", "127.0.0.1:0", ca.CreateLocalhostServerCert().GetServerTLSConfig())) + go func() { + conn, serverErr := tlsListener.Accept() + if serverErr != nil { + return + } + defer func() { + _ = conn.Close() + }() + _, _ = conn.Write([]byte("Hello world!")) + }() + + // Client side: + port := tlsListener.Addr().(*net.TCPAddr).Port + client := tofutestutils.Must2(tls.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)), ca.GetClientTLSConfig())) + defer func() { + _ = client.Close() + }() + + t.Logf("%s", tofutestutils.Must2(io.ReadAll(client))) +} +``` \ No newline at end of file diff --git a/testca/ca.go b/testca/ca.go new file mode 100644 index 0000000..4b27314 --- /dev/null +++ b/testca/ca.go @@ -0,0 +1,276 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package testca + +import ( + "bytes" + "crypto" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "io" + "math/big" + "net" + "sync" + "testing" + "time" +) + +const caKeySize = 2048 + +// startDate holds the date of the fork announcement. This is the starting date for the validity of the certificate by +// default. +var startDate = time.Date(2023, 9, 5, 0, 0, 0, 0, time.UTC) + +// expirationYears is the number of years the certificate is valid by default. +const expirationYears = 30 + +// New creates an x509 CA certificate that can produce certificates for testing purposes. Pass a desired deterministic +// randomSource to create a deterministic certificate. +func New(t *testing.T, randomSource io.Reader) CertificateAuthority { + caCert := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"OpenTofu a Series of LF Projects, LLC"}, + Country: []string{"US"}, + }, + NotBefore: startDate, + NotAfter: startDate.AddDate(expirationYears, 0, 0), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caPrivateKey, err := rsa.GenerateKey(randomSource, caKeySize) + if err != nil { + t.Skipf("Failed to create private key: %v", err) + } + caCertData, err := x509.CreateCertificate(randomSource, caCert, caCert, &caPrivateKey.PublicKey, caPrivateKey) + if err != nil { + t.Skipf("Failed to create CA certificate: %v", err) + } + caPEM := new(bytes.Buffer) + if err := pem.Encode(caPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: caCertData, + }); err != nil { + t.Skipf("Failed to encode CA cert: %v", err) + } + return &ca{ + t: t, + random: randomSource, + caCert: caCert, + caCertPEM: caPEM.Bytes(), + privateKey: caPrivateKey, + serial: big.NewInt(0), + lock: &sync.Mutex{}, + } +} + +// CertConfig is the configuration structure for creating specialized certificates using +// CertificateAuthority.CreateConfiguredServerCert. +type CertConfig struct { + // IPAddresses contains a list of IP addresses that should be added to the SubjectAltName field of the certificate. + IPAddresses []string + // Hosts contains a list of host names that should be added to the SubjectAltName field of the certificate. + Hosts []string + // Subject is the subject (CN, etc) setting for the certificate. Most commonly, you will want the CN field to match + // one of hour host names. + Subject pkix.Name + // ExtKeyUsage describes the extended key usage. Typically, this should be: + // + // []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} + ExtKeyUsage []x509.ExtKeyUsage + // StartTime indicates when the certificate should start to be valid. This defaults to startDate. + NotBefore *time.Time + // NotAfter indicates a time at which the certificate should stop being valid. This defaults to expirationYears + // after startDate + NotAfter *time.Time +} + +// KeyPair contains a certificate and private key in PEM format. +type KeyPair struct { + // Certificate contains an x509 certificate in PEM format. + Certificate []byte + // PrivateKey contains an RSA or other private key in PEM format. + PrivateKey []byte +} + +// GetPrivateKey returns a crypto.Signer for the private key. +func (k KeyPair) GetPrivateKey() crypto.PrivateKey { + block, _ := pem.Decode(k.PrivateKey) + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + panic(err) + } + return key +} + +// GetTLSCertificate returns the tls.Certificate based on this key pair. +func (k KeyPair) GetTLSCertificate() tls.Certificate { + cert, err := tls.X509KeyPair(k.Certificate, k.PrivateKey) + if err != nil { + panic(err) + } + return cert +} + +// GetServerTLSConfig returns a tls.Config suitable for a TLS server with this key pair. +func (k KeyPair) GetServerTLSConfig() *tls.Config { + return &tls.Config{ + Certificates: []tls.Certificate{ + k.GetTLSCertificate(), + }, + MinVersion: tls.VersionTLS12, + } +} + +// CertificateAuthority provides simple access to x509 CA functions for testing purposes only. +type CertificateAuthority interface { + // GetPEMCACert returns the CA certificate in PEM format. + GetPEMCACert() []byte + // GetCertPool returns an x509.CertPool configured for this CA. + GetCertPool() *x509.CertPool + // GetClientTLSConfig returns a *tls.Config with a valid cert pool configured for this CA. + GetClientTLSConfig() *tls.Config + // CreateLocalhostServerCert creates a server certificate pre-configured for "localhost", which is sufficient for + // most test cases. + CreateLocalhostServerCert() KeyPair + // CreateLocalhostClientCert creates a client certificate pre-configured for "localhost", which is sufficient for + // most test cases. + CreateLocalhostClientCert() KeyPair + // CreateConfiguredCert creates a certificate with a specialized configuration. + CreateConfiguredCert(config CertConfig) KeyPair +} + +type ca struct { + caCert *x509.Certificate + caCertPEM []byte + privateKey *rsa.PrivateKey + serial *big.Int + lock *sync.Mutex + t *testing.T + random io.Reader +} + +func (c *ca) GetClientTLSConfig() *tls.Config { + certPool := c.GetCertPool() + + return &tls.Config{ + RootCAs: certPool, + MinVersion: tls.VersionTLS12, + } +} + +func (c *ca) GetCertPool() *x509.CertPool { + certPool := x509.NewCertPool() + certPool.AppendCertsFromPEM(c.caCertPEM) + return certPool +} + +func (c *ca) GetPEMCACert() []byte { + return c.caCertPEM +} + +func (c *ca) CreateConfiguredCert(config CertConfig) KeyPair { + c.lock.Lock() + defer c.lock.Unlock() + c.serial.Add(c.serial, big.NewInt(1)) + + ipAddresses := make([]net.IP, len(config.IPAddresses)) + for i, ip := range config.IPAddresses { + ipAddresses[i] = net.ParseIP(ip) + } + + var notBefore time.Time + if config.NotBefore != nil { + notBefore = *config.NotBefore + } else { + notBefore = startDate + } + var notAfter time.Time + if config.NotAfter != nil { + notAfter = *config.NotAfter + } else { + notAfter = notBefore.Add(expirationYears * 365 * 24 * time.Hour) + } + + cert := &x509.Certificate{ + SerialNumber: c.serial, + Subject: config.Subject, + NotBefore: notBefore, + NotAfter: notAfter, + SubjectKeyId: []byte{1}, + ExtKeyUsage: config.ExtKeyUsage, + KeyUsage: x509.KeyUsageDigitalSignature, + DNSNames: config.Hosts, + IPAddresses: ipAddresses, + } + certPrivKey, err := rsa.GenerateKey(c.random, caKeySize) + if err != nil { + c.t.Skipf("Failed to generate private key: %v", err) + } + certBytes, err := x509.CreateCertificate( + c.random, + cert, + c.caCert, + &certPrivKey.PublicKey, + c.privateKey, + ) + if err != nil { + c.t.Skipf("Failed to create certificate: %v", err) + } + certPrivKeyPEM := new(bytes.Buffer) + if err := pem.Encode(certPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), + }); err != nil { + c.t.Skipf("Failed to encode private key: %v", err) + } + certPEM := new(bytes.Buffer) + if err := pem.Encode(certPEM, + &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}, + ); err != nil { + c.t.Skipf("Failed to encode certificate: %v", err) + } + return KeyPair{ + Certificate: certPEM.Bytes(), + PrivateKey: certPrivKeyPEM.Bytes(), + } +} + +func (c *ca) CreateLocalhostServerCert() KeyPair { + return c.CreateConfiguredCert(CertConfig{ + IPAddresses: []string{"127.0.0.1", "::1"}, + Subject: pkix.Name{ + Country: []string{"US"}, + Organization: []string{"OpenTofu a Series of LF Projects, LLC"}, + CommonName: "localhost", + }, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + Hosts: []string{ + "localhost", + }, + }) +} + +func (c *ca) CreateLocalhostClientCert() KeyPair { + return c.CreateConfiguredCert(CertConfig{ + IPAddresses: []string{"127.0.0.1", "::1"}, + Subject: pkix.Name{ + Country: []string{"US"}, + Organization: []string{"OpenTofu a Series of LF Projects, LLC"}, + CommonName: "localhost", + }, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + Hosts: []string{ + "localhost", + }, + }) +} diff --git a/testca/ca_test.go b/testca/ca_test.go new file mode 100644 index 0000000..780ee2b --- /dev/null +++ b/testca/ca_test.go @@ -0,0 +1,161 @@ +// Copyright (c) The OpenTofu Authors +// SPDX-License-Identifier: MPL-2.0 +// Copyright (c) 2023 HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package testca_test + +import ( + "bytes" + "context" + "crypto/tls" + "io" + "net" + "strconv" + "testing" + + "github.com/opentofu/tofutestutils/testca" + "github.com/opentofu/tofutestutils/testrandom" +) + +func TestCA(t *testing.T) { + t.Run("correct", testCACorrectCertificate) + t.Run("incorrect", testCAIncorrectCertificate) +} + +func testCAIncorrectCertificate(t *testing.T) { + ca1 := testca.New(t, testrandom.Source()) + ca2 := testca.New(t, testrandom.Source()) + + if bytes.Equal(ca1.GetPEMCACert(), ca2.GetPEMCACert()) { + t.Fatalf("The two CA's have the same CA PEM!") + } + + done := make(chan struct{}) + var serverErr error + t.Logf("🍦 Setting up TLS server...") + tlsListener, err := tls.Listen( + "tcp", + "127.0.0.1:0", + ca1.CreateLocalhostServerCert().GetServerTLSConfig(), + ) + if err != nil { + t.Fatalf("❌ Failed to set up listener: %v", err) + } + t.Cleanup(func() { + t.Logf("🍦 Server closing listener...") + if err := tlsListener.Close(); err != nil { + t.Logf("❌ Failed to close server listener (%v)", err) + } + }) + port := tlsListener.Addr().(*net.TCPAddr).Port + go func() { + defer close(done) + t.Logf("🍦 Server accepting connection...") + var conn net.Conn + conn, serverErr = tlsListener.Accept() + if serverErr != nil { + t.Logf("🍦 Server correctly received an error: %v", serverErr) + return + } + // Force a handshake even without read/write. The client automatically performs + // the handshake, but the server listener doesn't before reading. + serverErr = conn.(*tls.Conn).HandshakeContext(context.Background()) + if serverErr == nil { + t.Logf("❌ Server unexpectedly did not receive an error.") + } else { + t.Logf("🍦 Server correctly received an error: %v", serverErr) + } + if err := conn.Close(); err != nil { + t.Logf("❌ Could not close the connection on the server side: %v", err) + } + }() + t.Logf("🔌 Client connecting to server...") + conn, err := tls.Dial( + "tcp", + net.JoinHostPort("127.0.0.1", strconv.Itoa(port)), + ca2.GetClientTLSConfig(), + ) + if err == nil { + if err := conn.Close(); err != nil { + t.Logf("❌ Could not close the connection on the client side: %v", err) + } + t.Fatalf("❌ The TLS connection succeeded despite the incorrect CA certificate.") + } + t.Logf("🔌 Client correctly received an error: %v", err) + <-done + if serverErr == nil { + t.Fatalf("❌ The TLS server didn't error despite the incorrect CA certificate.") + } +} + +func testCACorrectCertificate(t *testing.T) { + ca := testca.New(t, testrandom.Source()) + const testGreeting = "Hello world!" + + var serverErr error + t.Cleanup(func() { + if serverErr != nil { + t.Fatalf("❌ TLS server failed: %v", serverErr) + } + }) + + done := make(chan struct{}) + + t.Logf("🍦 Setting up TLS server...") + tlsListener, err := tls.Listen("tcp", "127.0.0.1:0", ca.CreateLocalhostServerCert().GetServerTLSConfig()) + if err != nil { + t.Fatalf("❌ Failed to set up listener: %v", err) + } + t.Cleanup(func() { + t.Logf("🍦 Server closing listener...") + if err := tlsListener.Close(); err != nil { + t.Logf("❌ Could not close the server listener: %v", err) + } + }) + t.Logf("🍦 Starting TLS server...") + go func() { + defer close(done) + var conn net.Conn + t.Logf("🍦 Server accepting connection...") + conn, serverErr = tlsListener.Accept() + if serverErr != nil { + t.Errorf("❌ Server accept failed: %v", serverErr) + return + } + defer func() { + t.Logf("🍦 Server closing connection.") + if err := conn.Close(); err != nil { + t.Logf("❌ Could not close the server connection: %v", err) + } + }() + t.Logf("🍦 Server writing greeting...") + _, serverErr = conn.Write([]byte(testGreeting)) + if serverErr != nil { + t.Errorf("❌ Server write failed: %v", serverErr) + return + } + }() + t.Logf("🔌 Client connecting to server...") + port := tlsListener.Addr().(*net.TCPAddr).Port + client, err := tls.Dial("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port)), ca.GetClientTLSConfig()) + if err != nil { + t.Fatalf("❌ Failed to connect to server: %v", err) + } + defer func() { + t.Logf("🔌 Client closing connection...") + if err := client.Close(); err != nil { + t.Logf("❌ Could not close the client connection: %v", err) + } + }() + t.Logf("🔌 Client reading greeting...") + greeting, err := io.ReadAll(client) + if err != nil { + t.Fatalf("❌ Failed to read greeting: %v", err) + } + if string(greeting) != testGreeting { + t.Fatalf("❌ Client received incorrect greeting: %s", greeting) + } + t.Logf("🔌 Waiting for server to finish...") + <-done +}