diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go index 4683e5d..025f823 100644 --- a/pkg/tls/tls.go +++ b/pkg/tls/tls.go @@ -8,6 +8,7 @@ import ( // Need to embed the default config into the library _ "embed" "encoding/pem" + "errors" "fmt" "math/big" "os" @@ -43,7 +44,8 @@ var ( ) // GenerateAllCerts generates the full set of required certs for chia blockchain -func GenerateAllCerts(outDir string) error { +// If privateCACert and privateCAKey are both nil, a new private CA will be generated +func GenerateAllCerts(outDir string, privateCACert *x509.Certificate, privateCAKey *rsa.PrivateKey) error { // First, ensure that all output directories exist allNodes := append(privateNodeNames, publicNodeNames...) for _, subdir := range append(allNodes, "ca") { @@ -73,19 +75,31 @@ func GenerateAllCerts(outDir string) error { return fmt.Errorf("error parsing chia_ca.key") } - privateCACertBytes, privateCAKeyBytes, err := GenerateNewCA(path.Join(outDir, "ca", "private_ca")) - if err != nil { - return fmt.Errorf("error creating private ca pair: %w", err) - } - privateCACert, err := ParsePemCertificate(privateCACertBytes) - if err != nil { - return fmt.Errorf("error parsing generated private_ca.crt: %w", err) - } - privateCAKey, err := ParsePemKey(privateCAKeyBytes) - if err != nil { - return fmt.Errorf("error parsing generated private_ca.key: %w", err) + if privateCACert == nil && privateCAKey == nil { + // If privateCACert and privateCAKey are both nil, we will generate a new one + privateCACertBytes, privateCAKeyBytes, err := GenerateNewCA(path.Join(outDir, "ca", "private_ca")) + if err != nil { + return fmt.Errorf("error creating private ca pair: %w", err) + } + privateCACert, err = ParsePemCertificate(privateCACertBytes) + if err != nil { + return fmt.Errorf("error parsing generated private_ca.crt: %w", err) + } + privateCAKey, err = ParsePemKey(privateCAKeyBytes) + if err != nil { + return fmt.Errorf("error parsing generated private_ca.key: %w", err) + } + } else if privateCACert == nil || privateCAKey == nil { + // If only one of them is nil, we can't continue + return errors.New("you must provide the CA cert and key if providing a CA, or set both to nil and a new CA will be generated") + } else { + // Must have non-nil values for both, so ensure the cert and key match + if ! CertMatchesPrivateKey(privateCACert, privateCAKey) { + return errors.New("provided private CA Cert and Key do not match") + } } + for _, node := range publicNodeNames { _, _, err = GenerateCASignedCert(chiaCACert, chiaCAKey, path.Join(outDir, node, fmt.Sprintf("public_%s", node))) if err != nil { @@ -103,6 +117,22 @@ func GenerateAllCerts(outDir string) error { return nil } +// CertMatchesPrivateKey tests to make the sure cert and private key match +func CertMatchesPrivateKey(cert *x509.Certificate, privateKey *rsa.PrivateKey) bool { + publicKey := &privateKey.PublicKey + + certPublicKey, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + fmt.Println("Certificate public key is not of type RSA") + return false + } + + if publicKey.N.Cmp(certPublicKey.N) == 0 && publicKey.E == certPublicKey.E { + return true + } + return false +} + // ParsePemCertificate parses a certificate func ParsePemCertificate(certPem []byte) (*x509.Certificate, error) { // Load CA certificate