diff --git a/.github/workflows/ci-lint.yaml b/.github/workflows/ci-lint.yaml new file mode 100644 index 0000000..1afeea9 --- /dev/null +++ b/.github/workflows/ci-lint.yaml @@ -0,0 +1,23 @@ +name: Lint +on: + push: + branches: + - master + - develop + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.62 diff --git a/.github/workflows/ci-test.yaml b/.github/workflows/ci-test.yaml new file mode 100644 index 0000000..061e9de --- /dev/null +++ b/.github/workflows/ci-test.yaml @@ -0,0 +1,31 @@ +name: Test + +on: + push: + branches: + - master + - develop + pull_request: + +jobs: + test: + strategy: + matrix: + containers: [ 1.22 ] + runs-on: ubuntu-latest + container: golang:${{ matrix.containers }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + - uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + /go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - name: Update go modules + run: go mod tidy + - name: Unit Tests + run: go test -v -race -count=1 ./... \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..625d4bc --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,71 @@ +linters-settings: + govet: + enable-all: true + revive: + confidence: 0.1 + rules: + - name: package-comments + disabled: true + goconst: + min-len: 2 + min-occurrences: 2 + misspell: + locale: US + lll: + line-length: 140 + gocritic: + enabled-tags: + - performance + - style + - experimental + disabled-checks: + - hugeParam + - commentedOutCode + gci: + sections: + - standard + - default + +linters: + enable: + - bodyclose + - revive + - govet + - unconvert + - gosec + - gocyclo + - dupl + - misspell + - unparam + - typecheck + - ineffassign + - stylecheck + - gochecknoinits + - gocritic + - nakedret + - gosimple + - prealloc + - gci + - errcheck + - gofmt + - goimports + - staticcheck + - unused + fast: false + disable-all: true + +issues: + exclude-rules: + - text: "at least one file in a package should have a package comment" + linters: + - stylecheck + - text: "should have a package comment, unless it's in another file for this package" + linters: + - revive + - text: "appendAssign: *" + linters: + - gocritic + - text: "fieldalignment: struct with *" + linters: + - govet + exclude-use-default: false diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4a6f6b6 --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +test: + go test -v -count=1 ./... + +lint: + golangci-lint run + +lint-fix: + golangci-lint run --fix diff --git a/README.md b/README.md index fe302f3..0d433da 100644 --- a/README.md +++ b/README.md @@ -1 +1,40 @@ # jose-primitives + +This library provides support for creating and parsing JWE (JSON Web Encryption) tokens using the `ECDH-1PU` key agreement protocol. The library is specifically designed to facilitate authenticated encryption (authcrypt) and supports generating a common key between participants using secure cryptographic methods. + +## Features +- **Key Agreement Protocol:** `ECDH-1PU` to derive a shared common key between participants. +- **Supported Curves:** `P-384` and `X25519` (as specified in the [DIDComm Messaging RFC](https://identity.foundation/didcomm-messaging/spec/)). +- **JWE Token Creation:** Supports `alg` and `enc` combinations such as: + - `ECDH-1PU+A256KW` for key agreement. + - `A256CBC-HS512` for content encryption. +- **JWE Token Parsing:** Parses JWE tokens in compressed format with the above `alg` and `enc` combinations. + +## Supported Algorithms + +### Key Agreement (`alg`) +| Algorithm | Description | +| ----------------- | ------------------------------------------ | +| `ECDH-1PU+A256KW` | Authenticated encryption with key wrapping | + +### Content Encryption (`enc`) +| Algorithm | Description | +| --------------- | ----------------------------- | +| `A256CBC-HS512` | AES-256 CBC with HMAC SHA-512 | + +## Supported Key Types +| Curve Name | Description | +| ------------ | ---------------------------- | +| `NIST P-384` | High-security elliptic curve | +| `X25519` | Modern, fast elliptic curve | + +## Limitations +- The library only supports JWE tokens created with: + - `alg`: `ECDH-1PU+A256KW` + - `enc`: `A256CBC-HS512` +- Parsing is restricted to JWE tokens in **compressed format**. +- Only `P-384` and `X25519` curves are supported. + +## License +This library is licensed under the MIT License. + diff --git a/a256cbc_hmac.go b/a256cbc_hmac.go new file mode 100644 index 0000000..cadfc0c --- /dev/null +++ b/a256cbc_hmac.go @@ -0,0 +1,283 @@ +package joseprimitives + +import ( + "crypto/aes" + "crypto/ecdh" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + + josecipher "github.com/go-jose/go-jose/v4/cipher" +) + +const ( + // KeyEncryptionAlgorithm is the key encryption algorithm. + KeyEncryptionAlgorithm = "ECDH-1PU+A256KW" + // ContentEncryptionAlgorithm is the content encryption algorithm. + ContentEncryptionAlgorithm = "A256CBC-HS512" +) + +const ( + HeaderKeyAlg = "alg" + HeaderKeyEnc = "enc" + HeaderKeyApu = "apu" + HeaderKeyApv = "apv" + HeaderKeyEpk = "epk" + HeaderKeySkid = "skid" + HeaderKeyKid = "kid" +) + +type encryptionOption func(*encryptionOptions) + +type encryptionOptions struct { + kid string + skid string +} + +// WithKid sets the 'kid' option. +func WithKid(kid string) encryptionOption { + return func(opts *encryptionOptions) { + opts.kid = kid + } +} + +// WithSkid sets the 'skid' option. +func WithSkid(skid string) encryptionOption { + return func(opts *encryptionOptions) { + opts.skid = skid + } +} + +// Encrypt encrypts a plaintext using the ECDH-1PU+A256KW and A256CBC-HS512 algorithms. +func Encrypt(recipient *ecdh.PublicKey, sender *ecdh.PrivateKey, plaintext []byte, opts ...encryptionOption) (string, error) { + if recipient.Curve() != sender.Curve() { + return "", + fmt.Errorf( + "curve mismatch: recipient's curve '%s', sender's curve '%s'", + recipient.Curve(), sender.Curve(), + ) + } + + o := &encryptionOptions{} + for _, opt := range opts { + opt(o) + } + + var ( + epk *ecdh.PrivateKey + err error + ) + + switch recipient.Curve() { + case ecdh.X25519(): + epk, err = ecdh.X25519().GenerateKey(rand.Reader) + if err != nil { + return "", fmt.Errorf("failed to generate ephemeral key: %w", err) + } + case ecdh.P384(): + epk, err = ecdh.P384().GenerateKey(rand.Reader) + if err != nil { + return "", fmt.Errorf("failed to generate ephemeral key: %w", err) + } + default: + return "", fmt.Errorf("unsupported curve: '%s'", recipient.Curve()) + } + + kek, err := NewECDHPU1Key( + ZxKeyPair{p: epk, pub: recipient}, + ZxKeyPair{p: sender, pub: recipient}, + ) + if err != nil { + return "", fmt.Errorf("failed to key agreement: %w", err) + } + + cek := make([]byte, 64) + _, err = rand.Read(cek) + if err != nil { + return "", fmt.Errorf("failed to generate cek: %w", err) + } + nonce := make([]byte, aes.BlockSize) + _, err = rand.Read(nonce) + if err != nil { + return "", fmt.Errorf("failed to generate nonce: %w", err) + } + + encrypter, err := josecipher.NewCBCHMAC(cek, aes.NewCipher) + if err != nil { + return "", fmt.Errorf("failed to create encrypter: %w", err) + } + add, err := getHeaders( + o.skid, o.kid, recipient, sender, epk) + if err != nil { + return "", fmt.Errorf("failed to create headers: %w", err) + } + headersBytes, err := json.Marshal(add) + if err != nil { + return "", fmt.Errorf("failed to marshal headers: %w", err) + } + + ciphertext := encrypter.Seal(nil, nonce, plaintext, headersBytes) + if len(ciphertext) == 0 { + return "", errors.New("failed to encrypt plaintext") + } + + encryptedCek, err := kek.Wrap(cek) + if err != nil { + return "", fmt.Errorf("failed to wrap cek: %w", err) + } + + noAuthCiphertext, authTag, err := extractAuthTag(ciphertext, len(plaintext), aes.BlockSize, len(cek)/2) + if err != nil { + return "", fmt.Errorf("failed to extract auth tag: %w", err) + } + + compactToken := fmt.Sprintf( + "%s.%s.%s.%s.%s", + base64.URLEncoding.EncodeToString(headersBytes), + base64.URLEncoding.EncodeToString(encryptedCek), + base64.URLEncoding.EncodeToString(nonce), + base64.URLEncoding.EncodeToString(noAuthCiphertext), + base64.URLEncoding.EncodeToString(authTag), + ) + + return compactToken, nil +} + +// Decrypt decrypts a compact token. +func Decrypt(recipient *ecdh.PrivateKey, sender *ecdh.PublicKey, compactToken string) ([]byte, error) { + headersBytes, encryptedCek, nonce, ciphertext, authTag, err := parseCompactToken(compactToken) + if err != nil { + return nil, fmt.Errorf("failed to parse compact token: %w", err) + } + + headers := map[string]string{} + if err = json.Unmarshal(headersBytes, &headers); err != nil { + return nil, fmt.Errorf("failed to decode headers: %w", err) + } + + e, ok := headers[HeaderKeyEpk] + if !ok { + return nil, errors.New("epk not found in headers") + } + epkjwk := &JWK{} + if err = json.Unmarshal([]byte(e), epkjwk); err != nil { + return nil, fmt.Errorf("failed to unmarshal epk: %w", err) + } + + ephemeral, err := Export(epkjwk) + if err != nil { + return nil, fmt.Errorf("failed to export epk: %w", err) + } + + kek, err := NewECDHPU1Key(ZxKeyPair{p: recipient, pub: ephemeral}, ZxKeyPair{p: recipient, pub: sender}) + if err != nil { + return nil, fmt.Errorf("failed to key agreement: %w", err) + } + + cek, err := kek.Unwrap(encryptedCek) + if err != nil { + return nil, fmt.Errorf("failed to unwrap cek: %w", err) + } + + decrypter, err := josecipher.NewCBCHMAC(cek, aes.NewCipher) + if err != nil { + return nil, fmt.Errorf("failed to create decrypter: %w", err) + } + + ciphertext = append(ciphertext, authTag...) + plaintext, err := decrypter.Open(nil, nonce, ciphertext, headersBytes) + if err != nil { + return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err) + } + + return plaintext, nil +} + +//nolint:gocritic // it's okay for the function to have many return statements +func parseCompactToken(compactToken string) (headers, encryptedCek, nonce, ciphertext, authTag []byte, err error) { + parts := strings.Split(compactToken, ".") + if len(parts) != 5 { + return nil, nil, nil, nil, nil, errors.New("invalid compact token") + } + + headers, err = base64.URLEncoding.DecodeString(parts[0]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode headers: %w", err) + } + + encryptedCek, err = base64.URLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode encrypted cek: %w", err) + } + + nonce, err = base64.URLEncoding.DecodeString(parts[2]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode nonce: %w", err) + } + + ciphertext, err = base64.URLEncoding.DecodeString(parts[3]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode ciphertext: %w", err) + } + + authTag, err = base64.URLEncoding.DecodeString(parts[4]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode auth tag: %w", err) + } + + return headers, encryptedCek, nonce, ciphertext, authTag, nil +} + +func getHeaders( + skid, kid string, + recipient *ecdh.PublicKey, + sender *ecdh.PrivateKey, + epk *ecdh.PrivateKey, +) (map[string]string, error) { + epkjwk, err := Import(epk.PublicKey()) + if err != nil { + return nil, fmt.Errorf("failed to import epk to jwt: %w", err) + } + epkstr, err := json.Marshal(epkjwk) + if err != nil { + return nil, fmt.Errorf("failed to encode epk: %w", err) + } + + apuBytes := append(epk.PublicKey().Bytes(), sender.PublicKey().Bytes()...) + apuHash := sha256.Sum256(apuBytes) + apvHash := sha256.Sum256(recipient.Bytes()) + + headers := map[string]string{} + headers[HeaderKeyAlg] = KeyEncryptionAlgorithm + headers[HeaderKeyEnc] = ContentEncryptionAlgorithm + headers[HeaderKeyApu] = base64.URLEncoding.EncodeToString(apuHash[:]) + headers[HeaderKeyApv] = base64.URLEncoding.EncodeToString(apvHash[:]) + headers[HeaderKeyEpk] = string(epkstr) + + if skid != "" { + headers[HeaderKeySkid] = skid + } + if kid != "" { + headers[HeaderKeyKid] = kid + } + + return headers, nil +} + +func extractAuthTag(ciphertextWithAuthTag []byte, plaintextLength, blockSize, authTagLength int) ( + ciphertext []byte, authTag []byte, err error) { + paddedLength := (plaintextLength + blockSize - 1) / blockSize * blockSize + + if len(ciphertextWithAuthTag) < paddedLength+authTagLength { + return nil, nil, errors.New("invalid ciphertext length") + } + + ciphertext = ciphertextWithAuthTag[:paddedLength] + authTag = ciphertextWithAuthTag[paddedLength : paddedLength+authTagLength] + + return ciphertext, authTag, nil +} diff --git a/a256cbc_hmac_test.go b/a256cbc_hmac_test.go new file mode 100644 index 0000000..240ec36 --- /dev/null +++ b/a256cbc_hmac_test.go @@ -0,0 +1,92 @@ +package joseprimitives + +import ( + "crypto/ecdh" + "crypto/rand" + "encoding/base64" + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func mustGenerateKey(t *testing.T, c ecdh.Curve) *ecdh.PrivateKey { + var ( + priv *ecdh.PrivateKey + err error + ) + + switch c { + case ecdh.P256(): + priv, err = ecdh.P256().GenerateKey(rand.Reader) + case ecdh.P384(): + priv, err = ecdh.P384().GenerateKey(rand.Reader) + case ecdh.P521(): + priv, err = ecdh.P521().GenerateKey(rand.Reader) + case ecdh.X25519(): + priv, err = ecdh.X25519().GenerateKey(rand.Reader) + default: + require.Fail(t, "unsupported curve") + } + require.NoError(t, err) + return priv +} + +func TestEncryptDecryptPxx(t *testing.T) { + tests := []struct { + name string + recipient *ecdh.PrivateKey + sender *ecdh.PrivateKey + plaintext string + encriptionOptions []encryptionOption + expectedHeaders map[string]interface{} + }{ + { + name: "Valid encryption and decryption: P-384", + recipient: mustGenerateKey(t, ecdh.P384()), + sender: mustGenerateKey(t, ecdh.P384()), + plaintext: "plaintext", + encriptionOptions: []encryptionOption{ + WithKid("kid"), + WithSkid("skid"), + }, + expectedHeaders: map[string]interface{}{ + HeaderKeyKid: "kid", + HeaderKeySkid: "skid", + }, + }, + { + name: "Valid encryption and decryption: x25519", + recipient: mustGenerateKey(t, ecdh.X25519()), + sender: mustGenerateKey(t, ecdh.X25519()), + plaintext: "plaintext", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jweToken, err := Encrypt( + tt.recipient.PublicKey(), tt.sender, []byte(tt.plaintext), tt.encriptionOptions...) + require.NoError(t, err) + + h := decodeHeaders(t, jweToken) + require.Equal(t, tt.expectedHeaders[HeaderKeyKid], h[HeaderKeyKid]) + require.Equal(t, tt.expectedHeaders[HeaderKeySkid], h[HeaderKeySkid]) + + raw, err := Decrypt(tt.recipient, tt.sender.PublicKey(), jweToken) + require.NoError(t, err) + require.Equal(t, tt.plaintext, string(raw)) + }) + } +} + +func decodeHeaders(t *testing.T, token string) map[string]interface{} { + var h map[string]interface{} + parts := strings.Split(token, ".") + require.Len(t, parts, 5) + headersBytes, err := base64.URLEncoding.DecodeString(parts[0]) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(headersBytes, &h)) + return h +} diff --git a/ecdhpu1.go b/ecdhpu1.go new file mode 100644 index 0000000..a8147dd --- /dev/null +++ b/ecdhpu1.go @@ -0,0 +1,75 @@ +package joseprimitives + +import ( + "crypto" + "crypto/aes" + "crypto/ecdh" + "fmt" + + josecipher "github.com/go-jose/go-jose/v4/cipher" +) + +// ZxKeyPair is a pair of ECDH private and public keys. +type ZxKeyPair struct { + p *ecdh.PrivateKey + pub *ecdh.PublicKey +} + +// NewZxKeyPair creates a new ZxKeyPair. +func NewZxKeyPair(p *ecdh.PrivateKey, pub *ecdh.PublicKey) ZxKeyPair { + return ZxKeyPair{p: p, pub: pub} +} + +// ECDH generates a shared secret using the ECDH algorithm. +func (z ZxKeyPair) ECDH() ([]byte, error) { + zx, err := z.p.ECDH(z.pub) + if err != nil { + return nil, fmt.Errorf("failed to generate shared secret: %w", err) + } + return zx, nil +} + +// ECDHPU1Key is a key for the ECDH-1PU+A256KW algorithm. +type ECDHPU1Key struct { + kek []byte +} + +// Wrap wraps a content encryption key (CEK) using the key encryption key (KEK). +func (k *ECDHPU1Key) Wrap(cek []byte) ([]byte, error) { + b, err := aes.NewCipher(k.kek) + if err != nil { + return nil, fmt.Errorf("failed to create new cipher: %w", err) + } + return josecipher.KeyWrap(b, cek) +} + +// Unwrap unwraps a content encryption key (CEK) using the key encryption key (KEK). +func (k *ECDHPU1Key) Unwrap(cek []byte) ([]byte, error) { + b, err := aes.NewCipher(k.kek) + if err != nil { + return nil, fmt.Errorf("failed to create new cipher: %w", err) + } + return josecipher.KeyUnwrap(b, cek) +} + +// NewECDHPU1Key creates a new ECDHPU1Key. +func NewECDHPU1Key(zeKeyPair, zsKeyPair ZxKeyPair) (*ECDHPU1Key, error) { + ze, err := zeKeyPair.ECDH() + if err != nil { + return nil, fmt.Errorf("failed to generate shared ze secret: %w", err) + } + zs, err := zsKeyPair.ECDH() + if err != nil { + return nil, fmt.Errorf("failed to generate shared zs secret: %w", err) + } + z := append(ze, zs...) + + empty := make([]byte, 0) + r := josecipher.NewConcatKDF(crypto.SHA256, z, empty, empty, empty, empty, empty) + kek := make([]byte, 32) + _, err = r.Read(kek) + if err != nil { + return nil, fmt.Errorf("failed to generate kek: %w", err) + } + return &ECDHPU1Key{kek}, nil +} diff --git a/ecdhpu1_test.go b/ecdhpu1_test.go new file mode 100644 index 0000000..417c697 --- /dev/null +++ b/ecdhpu1_test.go @@ -0,0 +1,92 @@ +package joseprimitives + +import ( + "crypto/ecdh" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewECDHPU1Key(t *testing.T) { + type testCase struct { + name string + senderSide [2]ZxKeyPair + recipientSide [2]ZxKeyPair + cek []byte + } + + senderStaticKeyNist := mustGenerateKey(t, ecdh.P384()) + recipientStaticKeyNist := mustGenerateKey(t, ecdh.P384()) + ephemeralKeyNist := mustGenerateKey(t, ecdh.P384()) + + senderStaticKeyX25519 := mustGenerateKey(t, ecdh.X25519()) + recipientStaticKeyX25519 := mustGenerateKey(t, ecdh.X25519()) + ephemeralKeyX25519 := mustGenerateKey(t, ecdh.X25519()) + + testCases := []testCase{ + { + name: "Valid keys, should encrypt and decrypt correctly. Nist curves", + senderSide: [2]ZxKeyPair{ + { + p: ephemeralKeyNist, + pub: recipientStaticKeyNist.PublicKey(), + }, + { + p: senderStaticKeyNist, + pub: recipientStaticKeyNist.PublicKey(), + }, + }, + recipientSide: [2]ZxKeyPair{ + { + p: recipientStaticKeyNist, + pub: ephemeralKeyNist.PublicKey(), + }, + { + p: recipientStaticKeyNist, + pub: senderStaticKeyNist.PublicKey(), + }, + }, + cek: []byte("1234567890123456"), + }, + { + name: "Valid keys, should encrypt and decrypt correctly. X25519 curves", + senderSide: [2]ZxKeyPair{ + { + p: ephemeralKeyX25519, + pub: recipientStaticKeyX25519.PublicKey(), + }, + { + p: senderStaticKeyX25519, + pub: recipientStaticKeyX25519.PublicKey(), + }, + }, + recipientSide: [2]ZxKeyPair{ + { + p: recipientStaticKeyX25519, + pub: ephemeralKeyX25519.PublicKey(), + }, + { + p: recipientStaticKeyX25519, + pub: senderStaticKeyX25519.PublicKey(), + }, + }, + cek: []byte("1234567890123456"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + senderSideKek, err := NewECDHPU1Key(tc.senderSide[0], tc.senderSide[1]) + require.NoError(t, err) + encryptedCek, err := senderSideKek.Wrap(tc.cek) + require.NoError(t, err) + + userSideKek, err := NewECDHPU1Key(tc.recipientSide[0], tc.recipientSide[1]) + require.NoError(t, err) + decryptedCek, err := userSideKek.Unwrap(encryptedCek) + require.NoError(t, err) + + require.Equal(t, tc.cek, decryptedCek) + }) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..5765bb2 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module github.com/iden3/jose-primitives + +go 1.22.6 + +toolchain go1.22.10 + +require ( + github.com/go-jose/go-jose/v4 v4.0.4 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d0aa611 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E= +github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/jwk.go b/jwk.go new file mode 100644 index 0000000..17e0bfb --- /dev/null +++ b/jwk.go @@ -0,0 +1,129 @@ +// Description: This file contains the JWK struct and the Import function. +// We need this custom package because the go-jose and lestrrat-go/jwx/v3/jwk packages don't support ecdh.PublicKey +// in proper way. + +package joseprimitives + +import ( + "crypto/ecdh" + "crypto/elliptic" + "encoding/base64" + "errors" + "fmt" + "math/big" +) + +// JWK represents a JSON Web Key. +type JWK struct { + Kty string `json:"kty"` + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y,omitempty"` +} + +// Import converts an ecdh.PublicKey to a JWK. +func Import(key *ecdh.PublicKey) (*JWK, error) { + switch key.Curve() { + case ecdh.X25519(): + return &JWK{ + Kty: "OKP", + Crv: fmt.Sprintf("%s", ecdh.X25519()), + X: base64.RawURLEncoding.EncodeToString(key.Bytes()), + }, nil + case ecdh.P256(), ecdh.P384(), ecdh.P521(): + c, err := convertCurve(key.Curve()) + if err != nil { + return nil, fmt.Errorf("failed to convert curve: %w", err) + } + //nolint:staticcheck // there is no another way to extract x and y from ecdh.PublicKey + x, y := elliptic.Unmarshal(c, key.Bytes()) + if x == nil || y == nil { + return nil, errors.New("invalid public key") + } + return &JWK{ + Kty: "EC", + Crv: c.Params().Name, + X: base64.RawURLEncoding.EncodeToString(x.Bytes()), + Y: base64.RawURLEncoding.EncodeToString(y.Bytes()), + }, nil + default: + return nil, fmt.Errorf("unsupported curve: '%s'", key.Curve()) + } +} + +// Export converts a JWK to an ecdh.PublicKey. +func Export(jwk *JWK) (*ecdh.PublicKey, error) { + switch jwk.Kty { + case "OKP": + switch jwk.Crv { + case "X25519": + x, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, fmt.Errorf("failed to decode X25519: %w", err) + } + key, err := ecdh.X25519().NewPublicKey(x) + if err != nil { + return nil, fmt.Errorf("failed to parse X25519 public key: %w", err) + } + return key, nil + default: + return nil, fmt.Errorf("unsupported OKP curve: '%s'", jwk.Crv) + } + case "EC": + switch jwk.Crv { + case "P-256": + pubBytes, err := convertNistJWK(jwk.X, jwk.Y, ecdh.P256()) + if err != nil { + return nil, fmt.Errorf("failed convert JWK with NIST P256: %w", err) + } + return ecdh.P256().NewPublicKey(pubBytes) + case "P-384": + pubBytes, err := convertNistJWK(jwk.X, jwk.Y, ecdh.P384()) + if err != nil { + return nil, fmt.Errorf("failed convert JWK with NIST P384: %w", err) + } + return ecdh.P384().NewPublicKey(pubBytes) + case "P-521": + pubBytes, err := convertNistJWK(jwk.X, jwk.Y, ecdh.P521()) + if err != nil { + return nil, fmt.Errorf("failed convert JWK with NIST P521: %w", err) + } + return ecdh.P521().NewPublicKey(pubBytes) + default: + return nil, fmt.Errorf("unsupported EC curve: '%s'", jwk.Crv) + } + default: + return nil, fmt.Errorf("unsupported kty: '%s'", jwk.Kty) + } + +} + +func convertNistJWK(xBase64, yBase64 string, curve ecdh.Curve) ([]byte, error) { + x, err := base64.RawURLEncoding.DecodeString(xBase64) + if err != nil { + return nil, fmt.Errorf("failed to decode64 X: %w", err) + } + y, err := base64.RawURLEncoding.DecodeString(yBase64) + if err != nil { + return nil, fmt.Errorf("failed to decode64 Y: %w", err) + } + c, err := convertCurve(curve) + if err != nil { + return nil, fmt.Errorf("failed to convert curve: %w", err) + } + //nolint:staticcheck // there is no another way to build ecdh.PublicKey from x and y + pubBytes := elliptic.Marshal(c, big.NewInt(0).SetBytes(x), big.NewInt(0).SetBytes(y)) + return pubBytes, nil +} + +func convertCurve(c ecdh.Curve) (elliptic.Curve, error) { + switch c { + case ecdh.P256(): + return elliptic.P256(), nil + case ecdh.P384(): + return elliptic.P384(), nil + case ecdh.P521(): + return elliptic.P521(), nil + } + return nil, fmt.Errorf("unsupported curve: '%s'", c) +} diff --git a/jwk_test.go b/jwk_test.go new file mode 100644 index 0000000..2dda3f2 --- /dev/null +++ b/jwk_test.go @@ -0,0 +1,28 @@ +package joseprimitives + +import ( + "crypto/ecdh" + "crypto/rand" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestImportExport(t *testing.T) { + p, err := ecdh.P384().GenerateKey(rand.Reader) + require.NoError(t, err) + jwk, err := Import(p.PublicKey()) + require.NoError(t, err) + require.NotNil(t, jwk) + + jsonJWK, err := json.Marshal(jwk) + require.NoError(t, err) + t.Logf("JWK: %s", jsonJWK) + + exported, err := Export(jwk) + require.NoError(t, err) + require.NotNil(t, exported) + + require.True(t, p.PublicKey().Equal(exported)) +}