From bbc488cce554e2c014b9ebb0716aef965cfe7ad1 Mon Sep 17 00:00:00 2001 From: ilya-korotya Date: Fri, 13 Dec 2024 18:16:21 +0100 Subject: [PATCH 1/6] draft: support alg ECDH-1PU+A256KW for jwe tokens --- .github/workflows/ci-lint.yaml | 23 +++ .github/workflows/ci-test.yaml | 31 ++++ .golangci.yml | 68 ++++++++ Makefile | 8 + README.md | 39 +++++ a256cbc_hmac.go | 309 +++++++++++++++++++++++++++++++++ a256cbc_hmac_test.go | 101 +++++++++++ ecdhpu1.go | 68 ++++++++ ecdhpu1_test.go | 92 ++++++++++ go.mod | 16 ++ go.sum | 12 ++ jwk.go | 128 ++++++++++++++ jwk_test.go | 28 +++ 13 files changed, 923 insertions(+) create mode 100644 .github/workflows/ci-lint.yaml create mode 100644 .github/workflows/ci-test.yaml create mode 100644 .golangci.yml create mode 100644 Makefile create mode 100644 a256cbc_hmac.go create mode 100644 a256cbc_hmac_test.go create mode 100644 ecdhpu1.go create mode 100644 ecdhpu1_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 jwk.go create mode 100644 jwk_test.go diff --git a/.github/workflows/ci-lint.yaml b/.github/workflows/ci-lint.yaml new file mode 100644 index 0000000..1afeea9 --- /dev/null +++ b/.github/workflows/ci-lint.yaml @@ -0,0 +1,23 @@ +name: Lint +on: + push: + branches: + - master + - develop + pull_request: + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.62 diff --git a/.github/workflows/ci-test.yaml b/.github/workflows/ci-test.yaml new file mode 100644 index 0000000..9644212 --- /dev/null +++ b/.github/workflows/ci-test.yaml @@ -0,0 +1,31 @@ +name: Test + +on: + push: + branches: + - master + - develop + pull_request: + +jobs: + test: + strategy: + matrix: + containers: [ 1.22, 1.21, 1.20 ] + runs-on: ubuntu-latest + container: golang:${{ matrix.containers }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + - uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + /go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - name: Update go modules + run: go mod tidy + - name: Unit Tests + run: go test -v -race -count=1 ./... \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..50d3ed8 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,68 @@ +linters-settings: + govet: + enable-all: true + revive: + confidence: 0.1 + rules: + - name: package-comments + disabled: true + goconst: + min-len: 2 + min-occurrences: 2 + misspell: + locale: US + lll: + line-length: 140 + gocritic: + enabled-tags: + - performance + - style + - experimental + disabled-checks: + - hugeParam + - commentedOutCode + gci: + sections: + - standard + - default + +linters: + enable: + - bodyclose + - revive + - govet + - unconvert + - gosec + - gocyclo + - dupl + - misspell + - unparam + - typecheck + - ineffassign + - stylecheck + - gochecknoinits + - gocritic + - nakedret + - gosimple + - prealloc + - gci + - errcheck + - gofmt + - goimports + - staticcheck + - unused + fast: false + disable-all: true + +issues: + exclude-rules: + - text: "at least one file in a package should have a package comment" + linters: + - stylecheck + - text: "should have a package comment, unless it's in another file for this package" + linters: + - revive + - text: "appendAssign: *" + linters: + - gocritic + exclude-use-default: false diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4a6f6b6 --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +test: + go test -v -count=1 ./... + +lint: + golangci-lint run + +lint-fix: + golangci-lint run --fix diff --git a/README.md b/README.md index fe302f3..0d433da 100644 --- a/README.md +++ b/README.md @@ -1 +1,40 @@ # jose-primitives + +This library provides support for creating and parsing JWE (JSON Web Encryption) tokens using the `ECDH-1PU` key agreement protocol. The library is specifically designed to facilitate authenticated encryption (authcrypt) and supports generating a common key between participants using secure cryptographic methods. + +## Features +- **Key Agreement Protocol:** `ECDH-1PU` to derive a shared common key between participants. +- **Supported Curves:** `P-384` and `X25519` (as specified in the [DIDComm Messaging RFC](https://identity.foundation/didcomm-messaging/spec/)). +- **JWE Token Creation:** Supports `alg` and `enc` combinations such as: + - `ECDH-1PU+A256KW` for key agreement. + - `A256CBC-HS512` for content encryption. +- **JWE Token Parsing:** Parses JWE tokens in compressed format with the above `alg` and `enc` combinations. + +## Supported Algorithms + +### Key Agreement (`alg`) +| Algorithm | Description | +| ----------------- | ------------------------------------------ | +| `ECDH-1PU+A256KW` | Authenticated encryption with key wrapping | + +### Content Encryption (`enc`) +| Algorithm | Description | +| --------------- | ----------------------------- | +| `A256CBC-HS512` | AES-256 CBC with HMAC SHA-512 | + +## Supported Key Types +| Curve Name | Description | +| ------------ | ---------------------------- | +| `NIST P-384` | High-security elliptic curve | +| `X25519` | Modern, fast elliptic curve | + +## Limitations +- The library only supports JWE tokens created with: + - `alg`: `ECDH-1PU+A256KW` + - `enc`: `A256CBC-HS512` +- Parsing is restricted to JWE tokens in **compressed format**. +- Only `P-384` and `X25519` curves are supported. + +## License +This library is licensed under the MIT License. + diff --git a/a256cbc_hmac.go b/a256cbc_hmac.go new file mode 100644 index 0000000..d7417d6 --- /dev/null +++ b/a256cbc_hmac.go @@ -0,0 +1,309 @@ +package joseprimitives + +import ( + "crypto/aes" + "crypto/ecdh" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + + josecipher "github.com/go-jose/go-jose/v4/cipher" +) + +type PrivateKeyResolver func(kid string) (*ecdh.PrivateKey, error) +type PublicKeyResolver func(kid string) (*ecdh.PublicKey, error) + +type Encrypter struct { + recipientResolver PublicKeyResolver + senderResolver PrivateKeyResolver +} + +func NewEncrypter( + recipientResolver PublicKeyResolver, + senderResolver PrivateKeyResolver, +) *Encrypter { + return &Encrypter{ + recipientResolver: recipientResolver, + senderResolver: senderResolver, + } +} + +func (e *Encrypter) Encrypt(recipientKid, senderKid string, plaintext []byte) (string, error) { + recipient, err := e.recipientResolver(recipientKid) + if err != nil { + return "", fmt.Errorf("failed to resolve recipient key: %w", err) + } + sender, err := e.senderResolver(senderKid) + if err != nil { + return "", fmt.Errorf("failed to resolve sender key: %w", err) + } + + if recipient.Curve() != sender.Curve() { + return "", + fmt.Errorf( + "curve mismatch: recipient's curve '%s', sender's curve '%s'", + recipient.Curve(), sender.Curve(), + ) + } + + var epk *ecdh.PrivateKey + switch recipient.Curve() { + case ecdh.X25519(): + epk, err = ecdh.X25519().GenerateKey(rand.Reader) + if err != nil { + return "", fmt.Errorf("failed to generate ephemeral key: %w", err) + } + case ecdh.P384(): + epk, err = ecdh.P384().GenerateKey(rand.Reader) + if err != nil { + return "", fmt.Errorf("failed to generate ephemeral key: %w", err) + } + default: + return "", fmt.Errorf("unsupported curve: '%s'", recipient.Curve()) + } + + kek, err := NewECDHPU1Key( + ZxKeyPair{p: epk, pub: recipient}, + ZxKeyPair{p: sender, pub: recipient}, + ) + if err != nil { + return "", fmt.Errorf("failed to key agreement: %w", err) + } + + cek := make([]byte, 64) + _, err = rand.Read(cek) + if err != nil { + return "", fmt.Errorf("failed to generate cek: %w", err) + } + nonce := make([]byte, aes.BlockSize) + _, err = rand.Read(nonce) + if err != nil { + return "", fmt.Errorf("failed to generate nonce: %w", err) + } + + encrypter, err := josecipher.NewCBCHMAC(cek, aes.NewCipher) + if err != nil { + return "", fmt.Errorf("failed to create encrypter: %w", err) + } + add, err := getHeaders( + senderKid, recipientKid, recipient, sender, epk) + if err != nil { + return "", fmt.Errorf("failed to create headers: %w", err) + } + headersBytes, err := json.Marshal(add) + if err != nil { + return "", fmt.Errorf("failed to marshal headers: %w", err) + } + + ciphertext := encrypter.Seal(nil, nonce, plaintext, headersBytes) + if len(ciphertext) == 0 { + return "", errors.New("failed to encrypt plaintext") + } + + encryptedCek, err := kek.Wrap(cek) + if err != nil { + return "", fmt.Errorf("failed to wrap cek: %w", err) + } + + noAuthCiphertext, authTag, err := extractAuthTag(ciphertext, len(plaintext), aes.BlockSize, len(cek)/2) + if err != nil { + return "", fmt.Errorf("failed to extract auth tag: %w", err) + } + + compactToken := fmt.Sprintf( + "%s.%s.%s.%s.%s", + base64.URLEncoding.EncodeToString(headersBytes), + base64.URLEncoding.EncodeToString(encryptedCek), + base64.URLEncoding.EncodeToString(nonce), + base64.URLEncoding.EncodeToString(noAuthCiphertext), + base64.URLEncoding.EncodeToString(authTag), + ) + + return compactToken, nil +} + +type decryptionOption func(*decrypterOptions) + +type decrypterOptions struct { + kid string + skid string +} + +// WithKid sets the 'kid' option. +func WithKid(kid string) decryptionOption { + return func(opts *decrypterOptions) { + opts.kid = kid + } +} + +// WithSkid sets the 'skid' option. +func WithSkid(skid string) decryptionOption { + return func(opts *decrypterOptions) { + opts.skid = skid + } +} + +type Decrypter struct { + recipientResolver PrivateKeyResolver + senderResolver PublicKeyResolver +} + +func NewDecrypter( + recipientResolver PrivateKeyResolver, + senderResolver PublicKeyResolver, +) *Decrypter { + return &Decrypter{ + recipientResolver: recipientResolver, + senderResolver: senderResolver, + } +} + +// Decrypt decrypts a compact token. +func (d *Decrypter) Decrypt(compactToken string, opts ...decryptionOption) ([]byte, error) { + headersBytes, encryptedCek, nonce, ciphertext, authTag, err := parseCompactToken(compactToken) + if err != nil { + return nil, fmt.Errorf("failed to parse compact token: %w", err) + } + + headers := map[string]string{} + if err = json.Unmarshal(headersBytes, &headers); err != nil { + return nil, fmt.Errorf("failed to decode headers: %w", err) + } + + o := &decrypterOptions{ + kid: headers["kid"], + skid: headers["skid"], + } + for _, opt := range opts { + opt(o) + } + + e, ok := headers["epk"] + if !ok { + return nil, errors.New("epk not found in headers") + } + epkjwk := &JWK{} + if err = json.Unmarshal([]byte(e), epkjwk); err != nil { + return nil, fmt.Errorf("failed to unmarshal epk: %w", err) + } + + ephemeral, err := Export(epkjwk) + if err != nil { + return nil, fmt.Errorf("failed to export epk: %w", err) + } + recipient, err := d.recipientResolver(o.kid) + if err != nil { + return nil, fmt.Errorf("failed to resolve recipient: %w", err) + } + sender, err := d.senderResolver(o.skid) + if err != nil { + return nil, fmt.Errorf("failed to resolve sender: %w", err) + } + + kek, err := NewECDHPU1Key(ZxKeyPair{p: recipient, pub: ephemeral}, ZxKeyPair{p: recipient, pub: sender}) + if err != nil { + return nil, fmt.Errorf("failed to key agreement: %w", err) + } + + cek, err := kek.Unwrap(encryptedCek) + if err != nil { + return nil, fmt.Errorf("failed to unwrap cek: %w", err) + } + + decrypter, err := josecipher.NewCBCHMAC(cek, aes.NewCipher) + if err != nil { + return nil, fmt.Errorf("failed to create decrypter: %w", err) + } + + ciphertext = append(ciphertext, authTag...) + plaintext, err := decrypter.Open(nil, nonce, ciphertext, headersBytes) + if err != nil { + return nil, fmt.Errorf("failed to decrypt ciphertext: %w", err) + } + + return plaintext, nil +} + +//nolint:gocritic // it's okay for the function to have many return statements +func parseCompactToken(compactToken string) (headers, encryptedCek, nonce, ciphertext, authTag []byte, err error) { + parts := strings.Split(compactToken, ".") + if len(parts) != 5 { + return nil, nil, nil, nil, nil, errors.New("invalid compact token") + } + + headers, err = base64.URLEncoding.DecodeString(parts[0]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode headers: %w", err) + } + + encryptedCek, err = base64.URLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode encrypted cek: %w", err) + } + + nonce, err = base64.URLEncoding.DecodeString(parts[2]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode nonce: %w", err) + } + + ciphertext, err = base64.URLEncoding.DecodeString(parts[3]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode ciphertext: %w", err) + } + + authTag, err = base64.URLEncoding.DecodeString(parts[4]) + if err != nil { + return nil, nil, nil, nil, nil, fmt.Errorf("failed to decode auth tag: %w", err) + } + + return headers, encryptedCek, nonce, ciphertext, authTag, nil +} + +func getHeaders( + skid, kid string, + recipient *ecdh.PublicKey, + sender *ecdh.PrivateKey, + epk *ecdh.PrivateKey, +) (map[string]string, error) { + epkjwk, err := Import(epk.PublicKey()) + if err != nil { + return nil, fmt.Errorf("failed to import epk to jwt: %w", err) + } + epkstr, err := json.Marshal(epkjwk) + if err != nil { + return nil, fmt.Errorf("failed to encode epk: %w", err) + } + + apuBytes := append(epk.PublicKey().Bytes(), sender.PublicKey().Bytes()...) + apuHash := sha256.Sum256(apuBytes) + apvHash := sha256.Sum256(recipient.Bytes()) + + headers := map[string]string{} + headers["alg"] = "ECDH-1PU+A256KW" + headers["enc"] = "A256CBC-HS512" + headers["apu"] = base64.URLEncoding.EncodeToString(apuHash[:]) + headers["apv"] = base64.URLEncoding.EncodeToString(apvHash[:]) + headers["epk"] = string(epkstr) + headers["skid"] = skid + headers["kid"] = kid + + return headers, nil +} + +func extractAuthTag(ciphertextWithAuthTag []byte, plaintextLength, blockSize, authTagLength int) ( + ciphertext []byte, authTag []byte, err error) { + paddedLength := (plaintextLength + blockSize - 1) / blockSize * blockSize + + if len(ciphertextWithAuthTag) < paddedLength+authTagLength { + return nil, nil, errors.New("invalid ciphertext length") + } + + ciphertext = ciphertextWithAuthTag[:paddedLength] + authTag = ciphertextWithAuthTag[paddedLength : paddedLength+authTagLength] + + return ciphertext, authTag, nil +} diff --git a/a256cbc_hmac_test.go b/a256cbc_hmac_test.go new file mode 100644 index 0000000..016c8ee --- /dev/null +++ b/a256cbc_hmac_test.go @@ -0,0 +1,101 @@ +package joseprimitives + +import ( + "crypto/ecdh" + "crypto/rand" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func privateKeyresolver(priv *ecdh.PrivateKey, kidclosure string) func(string) (*ecdh.PrivateKey, error) { + return func(kid string) (*ecdh.PrivateKey, error) { + if kid != kidclosure { + return nil, fmt.Errorf("kid '%s' not found", kid) + } + return priv, nil + } +} + +func publicKeyResolver(pub *ecdh.PublicKey, kidclosure string) func(string) (*ecdh.PublicKey, error) { + return func(kid string) (*ecdh.PublicKey, error) { + if kid != kidclosure { + return nil, fmt.Errorf("kid '%s' not found", kid) + } + return pub, nil + } +} + +func mustGenerateKey(t *testing.T, c ecdh.Curve) *ecdh.PrivateKey { + var ( + priv *ecdh.PrivateKey + err error + ) + + switch c { + case ecdh.P256(): + priv, err = ecdh.P256().GenerateKey(rand.Reader) + case ecdh.P384(): + priv, err = ecdh.P384().GenerateKey(rand.Reader) + case ecdh.P521(): + priv, err = ecdh.P521().GenerateKey(rand.Reader) + case ecdh.X25519(): + priv, err = ecdh.X25519().GenerateKey(rand.Reader) + default: + require.Fail(t, "unsupported curve") + } + require.NoError(t, err) + return priv +} + +func TestEncryptDecryptPxx(t *testing.T) { + tests := []struct { + name string + recipient *ecdh.PrivateKey + sender *ecdh.PrivateKey + plaintext string + encryptExpectErr bool + decryptExpectErr bool + }{ + { + name: "Valid encryption and decryption: P-384", + recipient: mustGenerateKey(t, ecdh.P384()), + sender: mustGenerateKey(t, ecdh.P384()), + plaintext: "plaintext", + encryptExpectErr: false, + decryptExpectErr: false, + }, + { + name: "Valid encryption and decryption: x25519", + recipient: mustGenerateKey(t, ecdh.X25519()), + sender: mustGenerateKey(t, ecdh.X25519()), + plaintext: "plaintext", + encryptExpectErr: false, + decryptExpectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encrypter := NewEncrypter( + publicKeyResolver(tt.recipient.PublicKey(), "did:recipient"), + privateKeyresolver(tt.sender, "did:sender"), + ) + jweToken, err := encrypter.Encrypt("did:recipient", "did:sender", []byte(tt.plaintext)) + if tt.encryptExpectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + decrypter := NewDecrypter( + privateKeyresolver(tt.recipient, "did:recipient"), + publicKeyResolver(tt.sender.PublicKey(), "did:sender"), + ) + raw, err := decrypter.Decrypt(jweToken) + require.NoError(t, err) + require.Equal(t, tt.plaintext, string(raw)) + }) + } +} diff --git a/ecdhpu1.go b/ecdhpu1.go new file mode 100644 index 0000000..0397989 --- /dev/null +++ b/ecdhpu1.go @@ -0,0 +1,68 @@ +package joseprimitives + +import ( + "crypto" + "crypto/aes" + "crypto/ecdh" + "fmt" + + josecipher "github.com/go-jose/go-jose/v4/cipher" +) + +type ZxKeyPair struct { + p *ecdh.PrivateKey + pub *ecdh.PublicKey +} + +func NewZxKeyPair(p *ecdh.PrivateKey, pub *ecdh.PublicKey) ZxKeyPair { + return ZxKeyPair{p: p, pub: pub} +} + +func (z ZxKeyPair) ECDH() ([]byte, error) { + zx, err := z.p.ECDH(z.pub) + if err != nil { + return nil, fmt.Errorf("failed to generate shared secret: %w", err) + } + return zx, nil +} + +type ECDHPU1Key struct { + kek []byte +} + +func (k *ECDHPU1Key) Wrap(cek []byte) ([]byte, error) { + b, err := aes.NewCipher(k.kek) + if err != nil { + return nil, fmt.Errorf("failed to create new cipher: %w", err) + } + return josecipher.KeyWrap(b, cek) +} + +func (k *ECDHPU1Key) Unwrap(cek []byte) ([]byte, error) { + b, err := aes.NewCipher(k.kek) + if err != nil { + return nil, fmt.Errorf("failed to create new cipher: %w", err) + } + return josecipher.KeyUnwrap(b, cek) +} + +func NewECDHPU1Key(zeKeyPair, zsKeyPair ZxKeyPair) (*ECDHPU1Key, error) { + ze, err := zeKeyPair.ECDH() + if err != nil { + return nil, fmt.Errorf("failed to generate shared ze secret: %w", err) + } + zs, err := zsKeyPair.ECDH() + if err != nil { + return nil, fmt.Errorf("failed to generate shared zs secret: %w", err) + } + z := append(ze, zs...) + + empty := make([]byte, 0) + r := josecipher.NewConcatKDF(crypto.SHA256, z, empty, empty, empty, empty, empty) + kek := make([]byte, 32) + _, err = r.Read(kek) + if err != nil { + return nil, fmt.Errorf("failed to generate kek: %w", err) + } + return &ECDHPU1Key{kek}, nil +} diff --git a/ecdhpu1_test.go b/ecdhpu1_test.go new file mode 100644 index 0000000..417c697 --- /dev/null +++ b/ecdhpu1_test.go @@ -0,0 +1,92 @@ +package joseprimitives + +import ( + "crypto/ecdh" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewECDHPU1Key(t *testing.T) { + type testCase struct { + name string + senderSide [2]ZxKeyPair + recipientSide [2]ZxKeyPair + cek []byte + } + + senderStaticKeyNist := mustGenerateKey(t, ecdh.P384()) + recipientStaticKeyNist := mustGenerateKey(t, ecdh.P384()) + ephemeralKeyNist := mustGenerateKey(t, ecdh.P384()) + + senderStaticKeyX25519 := mustGenerateKey(t, ecdh.X25519()) + recipientStaticKeyX25519 := mustGenerateKey(t, ecdh.X25519()) + ephemeralKeyX25519 := mustGenerateKey(t, ecdh.X25519()) + + testCases := []testCase{ + { + name: "Valid keys, should encrypt and decrypt correctly. Nist curves", + senderSide: [2]ZxKeyPair{ + { + p: ephemeralKeyNist, + pub: recipientStaticKeyNist.PublicKey(), + }, + { + p: senderStaticKeyNist, + pub: recipientStaticKeyNist.PublicKey(), + }, + }, + recipientSide: [2]ZxKeyPair{ + { + p: recipientStaticKeyNist, + pub: ephemeralKeyNist.PublicKey(), + }, + { + p: recipientStaticKeyNist, + pub: senderStaticKeyNist.PublicKey(), + }, + }, + cek: []byte("1234567890123456"), + }, + { + name: "Valid keys, should encrypt and decrypt correctly. X25519 curves", + senderSide: [2]ZxKeyPair{ + { + p: ephemeralKeyX25519, + pub: recipientStaticKeyX25519.PublicKey(), + }, + { + p: senderStaticKeyX25519, + pub: recipientStaticKeyX25519.PublicKey(), + }, + }, + recipientSide: [2]ZxKeyPair{ + { + p: recipientStaticKeyX25519, + pub: ephemeralKeyX25519.PublicKey(), + }, + { + p: recipientStaticKeyX25519, + pub: senderStaticKeyX25519.PublicKey(), + }, + }, + cek: []byte("1234567890123456"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + senderSideKek, err := NewECDHPU1Key(tc.senderSide[0], tc.senderSide[1]) + require.NoError(t, err) + encryptedCek, err := senderSideKek.Wrap(tc.cek) + require.NoError(t, err) + + userSideKek, err := NewECDHPU1Key(tc.recipientSide[0], tc.recipientSide[1]) + require.NoError(t, err) + decryptedCek, err := userSideKek.Unwrap(encryptedCek) + require.NoError(t, err) + + require.Equal(t, tc.cek, decryptedCek) + }) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..5765bb2 --- /dev/null +++ b/go.mod @@ -0,0 +1,16 @@ +module github.com/iden3/jose-primitives + +go 1.22.6 + +toolchain go1.22.10 + +require ( + github.com/go-jose/go-jose/v4 v4.0.4 + github.com/stretchr/testify v1.9.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d0aa611 --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-jose/go-jose/v4 v4.0.4 h1:VsjPI33J0SB9vQM6PLmNjoHqMQNGPiZ0rHL7Ni7Q6/E= +github.com/go-jose/go-jose/v4 v4.0.4/go.mod h1:NKb5HO1EZccyMpiZNbdUw/14tiXNyUJh188dfnMCAfc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/jwk.go b/jwk.go new file mode 100644 index 0000000..84a08a4 --- /dev/null +++ b/jwk.go @@ -0,0 +1,128 @@ +// Description: This file contains the JWK struct and the Import function. +// We need this custom package because the go-jose and lestrrat-go/jwx/v3/jwk packages don't support ecdh.PublicKey +// in proper way. + +package joseprimitives + +import ( + "crypto/ecdh" + "crypto/elliptic" + "encoding/base64" + "errors" + "fmt" + "math/big" +) + +// JWK represents a JSON Web Key. +type JWK struct { + Kty string `json:"kty"` + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y,omitempty"` +} + +// Import converts an ecdh.PublicKey to a JWK. +func Import(key *ecdh.PublicKey) (*JWK, error) { + switch key.Curve() { + case ecdh.X25519(): + return &JWK{ + Kty: "OKP", + Crv: fmt.Sprintf("%s", ecdh.X25519()), + X: base64.RawURLEncoding.EncodeToString(key.Bytes()), + }, nil + case ecdh.P256(), ecdh.P384(), ecdh.P521(): + c, err := convertCurve(key.Curve()) + if err != nil { + return nil, fmt.Errorf("failed to convert curve: %w", err) + } + //nolint:staticcheck // there is no another way to extract x and y from ecdh.PublicKey + x, y := elliptic.Unmarshal(c, key.Bytes()) + if x == nil || y == nil { + return nil, errors.New("invalid public key") + } + return &JWK{ + Kty: "EC", + Crv: c.Params().Name, + X: base64.RawURLEncoding.EncodeToString(x.Bytes()), + Y: base64.RawURLEncoding.EncodeToString(y.Bytes()), + }, nil + default: + return nil, fmt.Errorf("unsupported curve: '%s'", key.Curve()) + } +} + +func Export(jwk *JWK) (*ecdh.PublicKey, error) { + switch jwk.Kty { + case "OKP": + switch jwk.Crv { + case "X25519": + x, err := base64.RawURLEncoding.DecodeString(jwk.X) + if err != nil { + return nil, fmt.Errorf("failed to decode X25519: %w", err) + } + key, err := ecdh.X25519().NewPublicKey(x) + if err != nil { + return nil, fmt.Errorf("failed to parse X25519 public key: %w", err) + } + return key, nil + default: + return nil, fmt.Errorf("unsupported OKP curve: '%s'", jwk.Crv) + } + case "EC": + switch jwk.Crv { + case "P-256": + pubBytes, err := convertNistJWK(jwk.X, jwk.Y, ecdh.P256()) + if err != nil { + return nil, fmt.Errorf("failed convert JWK with NIST P256: %w", err) + } + return ecdh.P256().NewPublicKey(pubBytes) + case "P-384": + pubBytes, err := convertNistJWK(jwk.X, jwk.Y, ecdh.P384()) + if err != nil { + return nil, fmt.Errorf("failed convert JWK with NIST P384: %w", err) + } + return ecdh.P384().NewPublicKey(pubBytes) + case "P-521": + pubBytes, err := convertNistJWK(jwk.X, jwk.Y, ecdh.P521()) + if err != nil { + return nil, fmt.Errorf("failed convert JWK with NIST P521: %w", err) + } + return ecdh.P521().NewPublicKey(pubBytes) + default: + return nil, fmt.Errorf("unsupported EC curve: '%s'", jwk.Crv) + } + default: + return nil, fmt.Errorf("unsupported kty: '%s'", jwk.Kty) + } + +} + +func convertNistJWK(xBase64, yBase64 string, curve ecdh.Curve) ([]byte, error) { + x, err := base64.RawURLEncoding.DecodeString(xBase64) + if err != nil { + return nil, fmt.Errorf("failed to decode64 X: %w", err) + } + y, err := base64.RawURLEncoding.DecodeString(yBase64) + if err != nil { + return nil, fmt.Errorf("failed to decode64 Y: %w", err) + } + c, err := convertCurve(curve) + if err != nil { + return nil, fmt.Errorf("failed to convert curve: %w", err) + } + //nolint:staticcheck // there is no another way to build ecdh.PublicKey from x and y + pubBytes := elliptic.Marshal(c, big.NewInt(0).SetBytes(x), big.NewInt(0).SetBytes(y)) + return pubBytes, nil +} + +func convertCurve(c ecdh.Curve) (elliptic.Curve, error) { + switch c { + case ecdh.P256(): + return elliptic.P256(), nil + case ecdh.P384(): + return elliptic.P384(), nil + case ecdh.P521(): + return elliptic.P521(), nil + } + return nil, fmt.Errorf("unsupported curve: '%s'", c) +} diff --git a/jwk_test.go b/jwk_test.go new file mode 100644 index 0000000..2dda3f2 --- /dev/null +++ b/jwk_test.go @@ -0,0 +1,28 @@ +package joseprimitives + +import ( + "crypto/ecdh" + "crypto/rand" + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestImportExport(t *testing.T) { + p, err := ecdh.P384().GenerateKey(rand.Reader) + require.NoError(t, err) + jwk, err := Import(p.PublicKey()) + require.NoError(t, err) + require.NotNil(t, jwk) + + jsonJWK, err := json.Marshal(jwk) + require.NoError(t, err) + t.Logf("JWK: %s", jsonJWK) + + exported, err := Export(jwk) + require.NoError(t, err) + require.NotNil(t, exported) + + require.True(t, p.PublicKey().Equal(exported)) +} From 995efe6525e5aa13e2547e2e93a676262a73810e Mon Sep 17 00:00:00 2001 From: ilya-korotya Date: Fri, 13 Dec 2024 18:29:26 +0100 Subject: [PATCH 2/6] support only latest golang version --- .github/workflows/ci-test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-test.yaml b/.github/workflows/ci-test.yaml index 9644212..061e9de 100644 --- a/.github/workflows/ci-test.yaml +++ b/.github/workflows/ci-test.yaml @@ -11,7 +11,7 @@ jobs: test: strategy: matrix: - containers: [ 1.22, 1.21, 1.20 ] + containers: [ 1.22 ] runs-on: ubuntu-latest container: golang:${{ matrix.containers }} steps: From 1b781d19dc966fbb8422c173b355d6f3208966f1 Mon Sep 17 00:00:00 2001 From: ilya-korotya Date: Wed, 8 Jan 2025 12:59:06 +0100 Subject: [PATCH 3/6] fix comments --- a256cbc_hmac.go | 47 +++++++++++++++++++++++++++++++++++++---------- ecdhpu1.go | 7 +++++++ jwk.go | 1 + 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/a256cbc_hmac.go b/a256cbc_hmac.go index d7417d6..8dcb220 100644 --- a/a256cbc_hmac.go +++ b/a256cbc_hmac.go @@ -14,14 +14,37 @@ import ( josecipher "github.com/go-jose/go-jose/v4/cipher" ) +const ( + // KeyEncryptionAlgorithm is the key encryption algorithm. + KeyEncryptionAlgorithm = "ECDH-1PU+A256KW" + // ContentEncryptionAlgorithm is the content encryption algorithm. + ContentEncryptionAlgorithm = "A256CBC-HS512" +) + +const ( + HeaderKeyAlg = "alg" + HeaderKeyEnc = "enc" + HeaderKeyApu = "apu" + HeaderKeyApv = "apv" + HeaderKeyEpk = "epk" + HeaderKeySkid = "skid" + HeaderKeyKid = "kid" +) + +// PrivateKeyResolver resolves a private key by its key ID. type PrivateKeyResolver func(kid string) (*ecdh.PrivateKey, error) + +// PublicKeyResolver resolves a public key by its key ID. type PublicKeyResolver func(kid string) (*ecdh.PublicKey, error) +// Encrypter encrypts plaintext using the ECDH-1PU+A256KW and A256CBC-HS512 algorithms. +// Supported curves are X25519 and P384. type Encrypter struct { recipientResolver PublicKeyResolver senderResolver PrivateKeyResolver } +// NewEncrypter creates a new Encrypter. func NewEncrypter( recipientResolver PublicKeyResolver, senderResolver PrivateKeyResolver, @@ -32,6 +55,7 @@ func NewEncrypter( } } +// Encrypt encrypts a plaintext using the ECDH-1PU+A256KW and A256CBC-HS512 algorithms. func (e *Encrypter) Encrypt(recipientKid, senderKid string, plaintext []byte) (string, error) { recipient, err := e.recipientResolver(recipientKid) if err != nil { @@ -147,11 +171,14 @@ func WithSkid(skid string) decryptionOption { } } +// Decrypter decrypts a compact token. +// Supported curves are X25519 and P384. type Decrypter struct { recipientResolver PrivateKeyResolver senderResolver PublicKeyResolver } +// NewDecrypter creates a new Decrypter. func NewDecrypter( recipientResolver PrivateKeyResolver, senderResolver PublicKeyResolver, @@ -175,14 +202,14 @@ func (d *Decrypter) Decrypt(compactToken string, opts ...decryptionOption) ([]by } o := &decrypterOptions{ - kid: headers["kid"], - skid: headers["skid"], + kid: headers[HeaderKeyKid], + skid: headers[HeaderKeySkid], } for _, opt := range opts { opt(o) } - e, ok := headers["epk"] + e, ok := headers[HeaderKeyEpk] if !ok { return nil, errors.New("epk not found in headers") } @@ -283,13 +310,13 @@ func getHeaders( apvHash := sha256.Sum256(recipient.Bytes()) headers := map[string]string{} - headers["alg"] = "ECDH-1PU+A256KW" - headers["enc"] = "A256CBC-HS512" - headers["apu"] = base64.URLEncoding.EncodeToString(apuHash[:]) - headers["apv"] = base64.URLEncoding.EncodeToString(apvHash[:]) - headers["epk"] = string(epkstr) - headers["skid"] = skid - headers["kid"] = kid + headers[HeaderKeyAlg] = KeyEncryptionAlgorithm + headers[HeaderKeyEnc] = ContentEncryptionAlgorithm + headers[HeaderKeyApu] = base64.URLEncoding.EncodeToString(apuHash[:]) + headers[HeaderKeyApv] = base64.URLEncoding.EncodeToString(apvHash[:]) + headers[HeaderKeyEpk] = string(epkstr) + headers[HeaderKeySkid] = skid + headers[HeaderKeyKid] = kid return headers, nil } diff --git a/ecdhpu1.go b/ecdhpu1.go index 0397989..a8147dd 100644 --- a/ecdhpu1.go +++ b/ecdhpu1.go @@ -9,15 +9,18 @@ import ( josecipher "github.com/go-jose/go-jose/v4/cipher" ) +// ZxKeyPair is a pair of ECDH private and public keys. type ZxKeyPair struct { p *ecdh.PrivateKey pub *ecdh.PublicKey } +// NewZxKeyPair creates a new ZxKeyPair. func NewZxKeyPair(p *ecdh.PrivateKey, pub *ecdh.PublicKey) ZxKeyPair { return ZxKeyPair{p: p, pub: pub} } +// ECDH generates a shared secret using the ECDH algorithm. func (z ZxKeyPair) ECDH() ([]byte, error) { zx, err := z.p.ECDH(z.pub) if err != nil { @@ -26,10 +29,12 @@ func (z ZxKeyPair) ECDH() ([]byte, error) { return zx, nil } +// ECDHPU1Key is a key for the ECDH-1PU+A256KW algorithm. type ECDHPU1Key struct { kek []byte } +// Wrap wraps a content encryption key (CEK) using the key encryption key (KEK). func (k *ECDHPU1Key) Wrap(cek []byte) ([]byte, error) { b, err := aes.NewCipher(k.kek) if err != nil { @@ -38,6 +43,7 @@ func (k *ECDHPU1Key) Wrap(cek []byte) ([]byte, error) { return josecipher.KeyWrap(b, cek) } +// Unwrap unwraps a content encryption key (CEK) using the key encryption key (KEK). func (k *ECDHPU1Key) Unwrap(cek []byte) ([]byte, error) { b, err := aes.NewCipher(k.kek) if err != nil { @@ -46,6 +52,7 @@ func (k *ECDHPU1Key) Unwrap(cek []byte) ([]byte, error) { return josecipher.KeyUnwrap(b, cek) } +// NewECDHPU1Key creates a new ECDHPU1Key. func NewECDHPU1Key(zeKeyPair, zsKeyPair ZxKeyPair) (*ECDHPU1Key, error) { ze, err := zeKeyPair.ECDH() if err != nil { diff --git a/jwk.go b/jwk.go index 84a08a4..17e0bfb 100644 --- a/jwk.go +++ b/jwk.go @@ -51,6 +51,7 @@ func Import(key *ecdh.PublicKey) (*JWK, error) { } } +// Export converts a JWK to an ecdh.PublicKey. func Export(jwk *JWK) (*ecdh.PublicKey, error) { switch jwk.Kty { case "OKP": From 4b9f172df11f93432309b76c05b702dda3c71ab6 Mon Sep 17 00:00:00 2001 From: ilya-korotya Date: Wed, 8 Jan 2025 13:12:15 +0100 Subject: [PATCH 4/6] fix linter --- a256cbc_hmac.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/a256cbc_hmac.go b/a256cbc_hmac.go index 8dcb220..af79b9b 100644 --- a/a256cbc_hmac.go +++ b/a256cbc_hmac.go @@ -171,8 +171,7 @@ func WithSkid(skid string) decryptionOption { } } -// Decrypter decrypts a compact token. -// Supported curves are X25519 and P384. +// Decrypter decrypts a compact token. Supported curves are X25519 and P384. type Decrypter struct { recipientResolver PrivateKeyResolver senderResolver PublicKeyResolver From 8dfb6f06e64ea0369386204d4925b42bdc7801ab Mon Sep 17 00:00:00 2001 From: ilya-korotya Date: Wed, 8 Jan 2025 15:44:08 +0100 Subject: [PATCH 5/6] use keys directly --- a256cbc_hmac.go | 120 ++++++++++++------------------------------- a256cbc_hmac_test.go | 93 +++++++++++++++------------------ 2 files changed, 76 insertions(+), 137 deletions(-) diff --git a/a256cbc_hmac.go b/a256cbc_hmac.go index af79b9b..cadfc0c 100644 --- a/a256cbc_hmac.go +++ b/a256cbc_hmac.go @@ -31,41 +31,29 @@ const ( HeaderKeyKid = "kid" ) -// PrivateKeyResolver resolves a private key by its key ID. -type PrivateKeyResolver func(kid string) (*ecdh.PrivateKey, error) +type encryptionOption func(*encryptionOptions) -// PublicKeyResolver resolves a public key by its key ID. -type PublicKeyResolver func(kid string) (*ecdh.PublicKey, error) - -// Encrypter encrypts plaintext using the ECDH-1PU+A256KW and A256CBC-HS512 algorithms. -// Supported curves are X25519 and P384. -type Encrypter struct { - recipientResolver PublicKeyResolver - senderResolver PrivateKeyResolver +type encryptionOptions struct { + kid string + skid string } -// NewEncrypter creates a new Encrypter. -func NewEncrypter( - recipientResolver PublicKeyResolver, - senderResolver PrivateKeyResolver, -) *Encrypter { - return &Encrypter{ - recipientResolver: recipientResolver, - senderResolver: senderResolver, +// WithKid sets the 'kid' option. +func WithKid(kid string) encryptionOption { + return func(opts *encryptionOptions) { + opts.kid = kid } } -// Encrypt encrypts a plaintext using the ECDH-1PU+A256KW and A256CBC-HS512 algorithms. -func (e *Encrypter) Encrypt(recipientKid, senderKid string, plaintext []byte) (string, error) { - recipient, err := e.recipientResolver(recipientKid) - if err != nil { - return "", fmt.Errorf("failed to resolve recipient key: %w", err) - } - sender, err := e.senderResolver(senderKid) - if err != nil { - return "", fmt.Errorf("failed to resolve sender key: %w", err) +// WithSkid sets the 'skid' option. +func WithSkid(skid string) encryptionOption { + return func(opts *encryptionOptions) { + opts.skid = skid } +} +// Encrypt encrypts a plaintext using the ECDH-1PU+A256KW and A256CBC-HS512 algorithms. +func Encrypt(recipient *ecdh.PublicKey, sender *ecdh.PrivateKey, plaintext []byte, opts ...encryptionOption) (string, error) { if recipient.Curve() != sender.Curve() { return "", fmt.Errorf( @@ -74,7 +62,16 @@ func (e *Encrypter) Encrypt(recipientKid, senderKid string, plaintext []byte) (s ) } - var epk *ecdh.PrivateKey + o := &encryptionOptions{} + for _, opt := range opts { + opt(o) + } + + var ( + epk *ecdh.PrivateKey + err error + ) + switch recipient.Curve() { case ecdh.X25519(): epk, err = ecdh.X25519().GenerateKey(rand.Reader) @@ -114,7 +111,7 @@ func (e *Encrypter) Encrypt(recipientKid, senderKid string, plaintext []byte) (s return "", fmt.Errorf("failed to create encrypter: %w", err) } add, err := getHeaders( - senderKid, recipientKid, recipient, sender, epk) + o.skid, o.kid, recipient, sender, epk) if err != nil { return "", fmt.Errorf("failed to create headers: %w", err) } @@ -150,46 +147,8 @@ func (e *Encrypter) Encrypt(recipientKid, senderKid string, plaintext []byte) (s return compactToken, nil } -type decryptionOption func(*decrypterOptions) - -type decrypterOptions struct { - kid string - skid string -} - -// WithKid sets the 'kid' option. -func WithKid(kid string) decryptionOption { - return func(opts *decrypterOptions) { - opts.kid = kid - } -} - -// WithSkid sets the 'skid' option. -func WithSkid(skid string) decryptionOption { - return func(opts *decrypterOptions) { - opts.skid = skid - } -} - -// Decrypter decrypts a compact token. Supported curves are X25519 and P384. -type Decrypter struct { - recipientResolver PrivateKeyResolver - senderResolver PublicKeyResolver -} - -// NewDecrypter creates a new Decrypter. -func NewDecrypter( - recipientResolver PrivateKeyResolver, - senderResolver PublicKeyResolver, -) *Decrypter { - return &Decrypter{ - recipientResolver: recipientResolver, - senderResolver: senderResolver, - } -} - // Decrypt decrypts a compact token. -func (d *Decrypter) Decrypt(compactToken string, opts ...decryptionOption) ([]byte, error) { +func Decrypt(recipient *ecdh.PrivateKey, sender *ecdh.PublicKey, compactToken string) ([]byte, error) { headersBytes, encryptedCek, nonce, ciphertext, authTag, err := parseCompactToken(compactToken) if err != nil { return nil, fmt.Errorf("failed to parse compact token: %w", err) @@ -200,14 +159,6 @@ func (d *Decrypter) Decrypt(compactToken string, opts ...decryptionOption) ([]by return nil, fmt.Errorf("failed to decode headers: %w", err) } - o := &decrypterOptions{ - kid: headers[HeaderKeyKid], - skid: headers[HeaderKeySkid], - } - for _, opt := range opts { - opt(o) - } - e, ok := headers[HeaderKeyEpk] if !ok { return nil, errors.New("epk not found in headers") @@ -221,14 +172,6 @@ func (d *Decrypter) Decrypt(compactToken string, opts ...decryptionOption) ([]by if err != nil { return nil, fmt.Errorf("failed to export epk: %w", err) } - recipient, err := d.recipientResolver(o.kid) - if err != nil { - return nil, fmt.Errorf("failed to resolve recipient: %w", err) - } - sender, err := d.senderResolver(o.skid) - if err != nil { - return nil, fmt.Errorf("failed to resolve sender: %w", err) - } kek, err := NewECDHPU1Key(ZxKeyPair{p: recipient, pub: ephemeral}, ZxKeyPair{p: recipient, pub: sender}) if err != nil { @@ -314,8 +257,13 @@ func getHeaders( headers[HeaderKeyApu] = base64.URLEncoding.EncodeToString(apuHash[:]) headers[HeaderKeyApv] = base64.URLEncoding.EncodeToString(apvHash[:]) headers[HeaderKeyEpk] = string(epkstr) - headers[HeaderKeySkid] = skid - headers[HeaderKeyKid] = kid + + if skid != "" { + headers[HeaderKeySkid] = skid + } + if kid != "" { + headers[HeaderKeyKid] = kid + } return headers, nil } diff --git a/a256cbc_hmac_test.go b/a256cbc_hmac_test.go index 016c8ee..240ec36 100644 --- a/a256cbc_hmac_test.go +++ b/a256cbc_hmac_test.go @@ -3,30 +3,14 @@ package joseprimitives import ( "crypto/ecdh" "crypto/rand" - "fmt" + "encoding/base64" + "encoding/json" + "strings" "testing" "github.com/stretchr/testify/require" ) -func privateKeyresolver(priv *ecdh.PrivateKey, kidclosure string) func(string) (*ecdh.PrivateKey, error) { - return func(kid string) (*ecdh.PrivateKey, error) { - if kid != kidclosure { - return nil, fmt.Errorf("kid '%s' not found", kid) - } - return priv, nil - } -} - -func publicKeyResolver(pub *ecdh.PublicKey, kidclosure string) func(string) (*ecdh.PublicKey, error) { - return func(kid string) (*ecdh.PublicKey, error) { - if kid != kidclosure { - return nil, fmt.Errorf("kid '%s' not found", kid) - } - return pub, nil - } -} - func mustGenerateKey(t *testing.T, c ecdh.Curve) *ecdh.PrivateKey { var ( priv *ecdh.PrivateKey @@ -51,51 +35,58 @@ func mustGenerateKey(t *testing.T, c ecdh.Curve) *ecdh.PrivateKey { func TestEncryptDecryptPxx(t *testing.T) { tests := []struct { - name string - recipient *ecdh.PrivateKey - sender *ecdh.PrivateKey - plaintext string - encryptExpectErr bool - decryptExpectErr bool + name string + recipient *ecdh.PrivateKey + sender *ecdh.PrivateKey + plaintext string + encriptionOptions []encryptionOption + expectedHeaders map[string]interface{} }{ { - name: "Valid encryption and decryption: P-384", - recipient: mustGenerateKey(t, ecdh.P384()), - sender: mustGenerateKey(t, ecdh.P384()), - plaintext: "plaintext", - encryptExpectErr: false, - decryptExpectErr: false, + name: "Valid encryption and decryption: P-384", + recipient: mustGenerateKey(t, ecdh.P384()), + sender: mustGenerateKey(t, ecdh.P384()), + plaintext: "plaintext", + encriptionOptions: []encryptionOption{ + WithKid("kid"), + WithSkid("skid"), + }, + expectedHeaders: map[string]interface{}{ + HeaderKeyKid: "kid", + HeaderKeySkid: "skid", + }, }, { - name: "Valid encryption and decryption: x25519", - recipient: mustGenerateKey(t, ecdh.X25519()), - sender: mustGenerateKey(t, ecdh.X25519()), - plaintext: "plaintext", - encryptExpectErr: false, - decryptExpectErr: false, + name: "Valid encryption and decryption: x25519", + recipient: mustGenerateKey(t, ecdh.X25519()), + sender: mustGenerateKey(t, ecdh.X25519()), + plaintext: "plaintext", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - encrypter := NewEncrypter( - publicKeyResolver(tt.recipient.PublicKey(), "did:recipient"), - privateKeyresolver(tt.sender, "did:sender"), - ) - jweToken, err := encrypter.Encrypt("did:recipient", "did:sender", []byte(tt.plaintext)) - if tt.encryptExpectErr { - require.Error(t, err) - return - } + jweToken, err := Encrypt( + tt.recipient.PublicKey(), tt.sender, []byte(tt.plaintext), tt.encriptionOptions...) require.NoError(t, err) - decrypter := NewDecrypter( - privateKeyresolver(tt.recipient, "did:recipient"), - publicKeyResolver(tt.sender.PublicKey(), "did:sender"), - ) - raw, err := decrypter.Decrypt(jweToken) + h := decodeHeaders(t, jweToken) + require.Equal(t, tt.expectedHeaders[HeaderKeyKid], h[HeaderKeyKid]) + require.Equal(t, tt.expectedHeaders[HeaderKeySkid], h[HeaderKeySkid]) + + raw, err := Decrypt(tt.recipient, tt.sender.PublicKey(), jweToken) require.NoError(t, err) require.Equal(t, tt.plaintext, string(raw)) }) } } + +func decodeHeaders(t *testing.T, token string) map[string]interface{} { + var h map[string]interface{} + parts := strings.Split(token, ".") + require.Len(t, parts, 5) + headersBytes, err := base64.URLEncoding.DecodeString(parts[0]) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(headersBytes, &h)) + return h +} From b908bd2acbe6cf3d2a3b6c2a45069772f848314b Mon Sep 17 00:00:00 2001 From: ilya-korotya Date: Wed, 8 Jan 2025 15:52:13 +0100 Subject: [PATCH 6/6] fix linter rules --- .golangci.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.golangci.yml b/.golangci.yml index 50d3ed8..625d4bc 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -65,4 +65,7 @@ issues: - text: "appendAssign: *" linters: - gocritic + - text: "fieldalignment: struct with *" + linters: + - govet exclude-use-default: false