diff --git a/pkg/tls/tls.go b/pkg/tls/tls.go index d939532..030635c 100644 --- a/pkg/tls/tls.go +++ b/pkg/tls/tls.go @@ -100,17 +100,25 @@ func GenerateAllCerts(outDir string, privateCACert *x509.Certificate, privateCAK } for _, node := range publicNodeNames { - _, _, err = GenerateCASignedCert(chiaCACert, chiaCAKey, path.Join(outDir, node, fmt.Sprintf("public_%s", node))) + cert, key, err := GenerateCASignedCert(chiaCACert, chiaCAKey) if err != nil { return fmt.Errorf("error generating public pair for %s: %w", node, err) } + _, _, err = WriteCertAndKey(cert, key, path.Join(outDir, node, fmt.Sprintf("public_%s", node))) + if err != nil { + return fmt.Errorf("error writing public pair for %s: %w", node, err) + } } for _, node := range privateNodeNames { - _, _, err = GenerateCASignedCert(privateCACert, privateCAKey, path.Join(outDir, node, fmt.Sprintf("private_%s", node))) + cert, key, err := GenerateCASignedCert(privateCACert, privateCAKey) if err != nil { return fmt.Errorf("error generating private pair for %s: %w", node, err) } + _, _, err = WriteCertAndKey(cert, key, path.Join(outDir, node, fmt.Sprintf("private_%s", node))) + if err != nil { + return fmt.Errorf("error writing private pair for %s: %w", node, err) + } } return nil @@ -178,26 +186,36 @@ func ParsePemKey(keyPem []byte) (*rsa.PrivateKey, error) { return caKey, nil } +// EncodeCertAndKeyToPEM encodes the cert and key to PEM +func EncodeCertAndKeyToPEM(certDER []byte, certKey *rsa.PrivateKey) ([]byte, []byte, error) { + certPemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + keyBytes, err := x509.MarshalPKCS8PrivateKey(certKey) + if err != nil { + return nil, nil, fmt.Errorf("error encoding private key to PKCS8: %w", err) + } + keyPemBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes}) + + return certPemBytes, keyPemBytes, nil +} + // WriteCertAndKey Returns the written cert bytes, key bytes, and error func WriteCertAndKey(certDER []byte, certKey *rsa.PrivateKey, certKeyBase string) ([]byte, []byte, error) { + certPemBytes, keyPemBytes, err := EncodeCertAndKeyToPEM(certDER, certKey) + if err != nil { + return nil, nil, fmt.Errorf("error encoding certificates: %w", err) + } + // Write the new certificate to file certOut := fmt.Sprintf("%s.crt", certKeyBase) - certPemBytes := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) if err := os.WriteFile(certOut, certPemBytes, 0600); err != nil { - return nil, nil, fmt.Errorf("failed to write cert PEM: %v", err) - } - - // Marshal private key to PKCS#8 - keyBytes, err := x509.MarshalPKCS8PrivateKey(certKey) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal private key to PKCS#8: %v", err) + return nil, nil, fmt.Errorf("failed to write cert PEM: %w", err) } // Write the new private key to file in PKCS#8 format keyOut := fmt.Sprintf("%s.key", certKeyBase) - keyPemBytes := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: keyBytes}) if err := os.WriteFile(keyOut, keyPemBytes, 0600); err != nil { - return nil, nil, fmt.Errorf("failed to write key PEM: %v", err) + return nil, nil, fmt.Errorf("failed to write key PEM: %w", err) } return certPemBytes, keyPemBytes, nil @@ -242,7 +260,7 @@ func GenerateNewCA(certKeyBase string) ([]byte, []byte, error) { } // GenerateCASignedCert generates a new key/cert signed by the given CA -func GenerateCASignedCert(caCert *x509.Certificate, caKey *rsa.PrivateKey, certKeyBase string) ([]byte, []byte, error) { +func GenerateCASignedCert(caCert *x509.Certificate, caKey *rsa.PrivateKey) ([]byte, *rsa.PrivateKey, error) { // Generate new private key certKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { @@ -274,5 +292,5 @@ func GenerateCASignedCert(caCert *x509.Certificate, caKey *rsa.PrivateKey, certK return nil, nil, fmt.Errorf("failed to create certificate: %v", err) } - return WriteCertAndKey(certDER, certKey, certKeyBase) + return certDER, certKey, nil }