diff --git a/evp.go b/evp.go index a9237a6a..fa557d86 100644 --- a/evp.go +++ b/evp.go @@ -20,15 +20,15 @@ var cacheMD sync.Map func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR { var ch crypto.Hash switch h.(type) { - case *sha1Hash: + case *sha1Hash, *sha1Marshal: ch = crypto.SHA1 - case *sha224Hash: + case *sha224Hash, *sha224Marshal: ch = crypto.SHA224 - case *sha256Hash: + case *sha256Hash, *sha256Marshal: ch = crypto.SHA256 - case *sha384Hash: + case *sha384Hash, *sha384Marshal: ch = crypto.SHA384 - case *sha512Hash: + case *sha512Hash, *sha512Marshal: ch = crypto.SHA512 case *sha3_224Hash: ch = crypto.SHA3_224 diff --git a/hash.go b/hash.go index 646b4ce2..bdc54b59 100644 --- a/hash.go +++ b/hash.go @@ -10,6 +10,7 @@ import ( "hash" "runtime" "strconv" + "sync" "unsafe" ) @@ -110,18 +111,50 @@ func SHA3_512(p []byte) (sum [64]byte) { return } +var isMarshallableCache sync.Map + +// isHashMarshallable returns true if the memory layout of cb +// is known by this library and can therefore be marshalled. +func isHashMarshallable(ch crypto.Hash) bool { + if vMajor == 1 { + return true + } + if v, ok := isMarshallableCache.Load(ch); ok { + return v.(bool) + } + md := cryptoHashToMD(ch) + if md == nil { + return false + } + prov := C.go_openssl_EVP_MD_get0_provider(md) + if prov == nil { + return false + } + cname := C.go_openssl_OSSL_PROVIDER_get0_name(prov) + if cname == nil { + return false + } + name := C.GoString(cname) + // We only know the memory layout of the built-in providers. + // See evpHash.hashState for more details. + marshallable := name == "default" || name == "fips" + isMarshallableCache.Store(ch, marshallable) + return marshallable +} + // evpHash implements generic hash methods. type evpHash struct { ctx C.GO_EVP_MD_CTX_PTR // ctx2 is used in evpHash.sum to avoid changing // the state of ctx. Having it here allows reusing the // same allocated object multiple times. - ctx2 C.GO_EVP_MD_CTX_PTR - size int - blockSize int + ctx2 C.GO_EVP_MD_CTX_PTR + size int + blockSize int + marshallable bool } -func newEvpHash(ch crypto.Hash, size, blockSize int) *evpHash { +func newEvpHash(ch crypto.Hash) *evpHash { md := cryptoHashToMD(ch) if md == nil { panic("openssl: unsupported hash function: " + strconv.Itoa(int(ch))) @@ -132,11 +165,13 @@ func newEvpHash(ch crypto.Hash, size, blockSize int) *evpHash { panic(newOpenSSLError("EVP_DigestInit_ex")) } ctx2 := C.go_openssl_EVP_MD_CTX_new() + blockSize := int(C.go_openssl_EVP_MD_get_block_size(md)) h := &evpHash{ - ctx: ctx, - ctx2: ctx2, - size: size, - blockSize: blockSize, + ctx: ctx, + ctx2: ctx2, + size: ch.Size(), + blockSize: blockSize, + marshallable: isHashMarshallable(ch), } runtime.SetFinalizer(h, (*evpHash).finalize) return h @@ -200,6 +235,9 @@ func (h *evpHash) sum(out []byte) { // The EVP_MD_CTX memory layout has changed in OpenSSL 3 // and the property holding the internal structure is no longer md_data but algctx. func (h *evpHash) hashState() unsafe.Pointer { + if !h.marshallable { + panic("openssl: hash state is not marshallable") + } switch vMajor { case 1: // https://github.com/openssl/openssl/blob/0418e993c717a6863f206feaa40673a261de7395/crypto/evp/evp_local.h#L12. @@ -228,7 +266,7 @@ func (h *evpHash) hashState() unsafe.Pointer { // encoding.BinaryUnmarshaler. func NewMD4() hash.Hash { return &md4Hash{ - evpHash: newEvpHash(crypto.MD4, 16, 64), + evpHash: newEvpHash(crypto.MD4), } } @@ -244,9 +282,11 @@ func (h *md4Hash) Sum(in []byte) []byte { // NewMD5 returns a new MD5 hash. func NewMD5() hash.Hash { - return &md5Hash{ - evpHash: newEvpHash(crypto.MD5, 16, 64), + h := md5Hash{evpHash: newEvpHash(crypto.MD5)} + if h.marshallable { + return &md5Marshal{h} } + return &h } // md5State layout is taken from @@ -273,7 +313,11 @@ const ( md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8 ) -func (h *md5Hash) MarshalBinary() ([]byte, error) { +type md5Marshal struct { + md5Hash +} + +func (h *md5Marshal) MarshalBinary() ([]byte, error) { d := (*md5State)(h.hashState()) if d == nil { return nil, errors.New("crypto/md5: can't retrieve hash state") @@ -290,7 +334,7 @@ func (h *md5Hash) MarshalBinary() ([]byte, error) { return b, nil } -func (h *md5Hash) UnmarshalBinary(b []byte) error { +func (h *md5Marshal) UnmarshalBinary(b []byte) error { if len(b) < len(md5Magic) || string(b[:len(md5Magic)]) != md5Magic { return errors.New("crypto/md5: invalid hash state identifier") } @@ -316,9 +360,11 @@ func (h *md5Hash) UnmarshalBinary(b []byte) error { // NewSHA1 returns a new SHA1 hash. func NewSHA1() hash.Hash { - return &sha1Hash{ - evpHash: newEvpHash(crypto.SHA1, 20, 64), + h := sha1Hash{evpHash: newEvpHash(crypto.SHA1)} + if h.marshallable { + return &sha1Marshal{h} } + return &h } type sha1Hash struct { @@ -345,7 +391,11 @@ const ( sha1MarshaledSize = len(sha1Magic) + 5*4 + 64 + 8 ) -func (h *sha1Hash) MarshalBinary() ([]byte, error) { +type sha1Marshal struct { + sha1Hash +} + +func (h *sha1Marshal) MarshalBinary() ([]byte, error) { d := (*sha1State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha1: can't retrieve hash state") @@ -363,7 +413,7 @@ func (h *sha1Hash) MarshalBinary() ([]byte, error) { return b, nil } -func (h *sha1Hash) UnmarshalBinary(b []byte) error { +func (h *sha1Marshal) UnmarshalBinary(b []byte) error { if len(b) < len(sha1Magic) || string(b[:len(sha1Magic)]) != sha1Magic { return errors.New("crypto/sha1: invalid hash state identifier") } @@ -390,9 +440,11 @@ func (h *sha1Hash) UnmarshalBinary(b []byte) error { // NewSHA224 returns a new SHA224 hash. func NewSHA224() hash.Hash { - return &sha224Hash{ - evpHash: newEvpHash(crypto.SHA224, 224/8, 64), + h := sha224Hash{evpHash: newEvpHash(crypto.SHA224)} + if h.marshallable { + return &sha224Marshal{h} } + return &h } type sha224Hash struct { @@ -407,9 +459,11 @@ func (h *sha224Hash) Sum(in []byte) []byte { // NewSHA256 returns a new SHA256 hash. func NewSHA256() hash.Hash { - return &sha256Hash{ - evpHash: newEvpHash(crypto.SHA256, 256/8, 64), + h := sha256Hash{evpHash: newEvpHash(crypto.SHA256)} + if h.marshallable { + return &sha256Marshal{h} } + return &h } type sha256Hash struct { @@ -437,7 +491,15 @@ type sha256State struct { nx uint32 } -func (h *sha224Hash) MarshalBinary() ([]byte, error) { +type sha224Marshal struct { + sha224Hash +} + +type sha256Marshal struct { + sha256Hash +} + +func (h *sha224Marshal) MarshalBinary() ([]byte, error) { d := (*sha256State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha256: can't retrieve hash state") @@ -458,7 +520,7 @@ func (h *sha224Hash) MarshalBinary() ([]byte, error) { return b, nil } -func (h *sha256Hash) MarshalBinary() ([]byte, error) { +func (h *sha256Marshal) MarshalBinary() ([]byte, error) { d := (*sha256State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha256: can't retrieve hash state") @@ -479,7 +541,7 @@ func (h *sha256Hash) MarshalBinary() ([]byte, error) { return b, nil } -func (h *sha224Hash) UnmarshalBinary(b []byte) error { +func (h *sha224Marshal) UnmarshalBinary(b []byte) error { if len(b) < len(magic224) || string(b[:len(magic224)]) != magic224 { return errors.New("crypto/sha256: invalid hash state identifier") } @@ -507,7 +569,7 @@ func (h *sha224Hash) UnmarshalBinary(b []byte) error { return nil } -func (h *sha256Hash) UnmarshalBinary(b []byte) error { +func (h *sha256Marshal) UnmarshalBinary(b []byte) error { if len(b) < len(magic256) || string(b[:len(magic256)]) != magic256 { return errors.New("crypto/sha256: invalid hash state identifier") } @@ -537,9 +599,11 @@ func (h *sha256Hash) UnmarshalBinary(b []byte) error { // NewSHA384 returns a new SHA384 hash. func NewSHA384() hash.Hash { - return &sha384Hash{ - evpHash: newEvpHash(crypto.SHA384, 384/8, 128), + h := sha384Hash{evpHash: newEvpHash(crypto.SHA384)} + if h.marshallable { + return &sha384Marshal{h} } + return &h } type sha384Hash struct { @@ -554,9 +618,11 @@ func (h *sha384Hash) Sum(in []byte) []byte { // NewSHA512 returns a new SHA512 hash. func NewSHA512() hash.Hash { - return &sha512Hash{ - evpHash: newEvpHash(crypto.SHA512, 512/8, 128), + h := sha512Hash{evpHash: newEvpHash(crypto.SHA512)} + if h.marshallable { + return &sha512Marshal{h} } + return &h } type sha512Hash struct { @@ -586,7 +652,15 @@ const ( marshaledSize512 = len(magic512) + 8*8 + 128 + 8 ) -func (h *sha384Hash) MarshalBinary() ([]byte, error) { +type sha384Marshal struct { + sha384Hash +} + +type sha512Marshal struct { + sha512Hash +} + +func (h *sha384Marshal) MarshalBinary() ([]byte, error) { d := (*sha512State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha512: can't retrieve hash state") @@ -607,7 +681,7 @@ func (h *sha384Hash) MarshalBinary() ([]byte, error) { return b, nil } -func (h *sha512Hash) MarshalBinary() ([]byte, error) { +func (h *sha512Marshal) MarshalBinary() ([]byte, error) { d := (*sha512State)(h.hashState()) if d == nil { return nil, errors.New("crypto/sha512: can't retrieve hash state") @@ -628,7 +702,7 @@ func (h *sha512Hash) MarshalBinary() ([]byte, error) { return b, nil } -func (h *sha384Hash) UnmarshalBinary(b []byte) error { +func (h *sha384Marshal) UnmarshalBinary(b []byte) error { if len(b) < len(magic512) { return errors.New("crypto/sha512: invalid hash state identifier") } @@ -659,7 +733,7 @@ func (h *sha384Hash) UnmarshalBinary(b []byte) error { return nil } -func (h *sha512Hash) UnmarshalBinary(b []byte) error { +func (h *sha512Marshal) UnmarshalBinary(b []byte) error { if len(b) < len(magic512) { return errors.New("crypto/sha512: invalid hash state identifier") } @@ -693,7 +767,7 @@ func (h *sha512Hash) UnmarshalBinary(b []byte) error { // NewSHA3_224 returns a new SHA3-224 hash. func NewSHA3_224() hash.Hash { return &sha3_224Hash{ - evpHash: newEvpHash(crypto.SHA3_224, 224/8, 64), + evpHash: newEvpHash(crypto.SHA3_224), } } @@ -710,7 +784,7 @@ func (h *sha3_224Hash) Sum(in []byte) []byte { // NewSHA3_256 returns a new SHA3-256 hash. func NewSHA3_256() hash.Hash { return &sha3_256Hash{ - evpHash: newEvpHash(crypto.SHA3_256, 256/8, 64), + evpHash: newEvpHash(crypto.SHA3_256), } } @@ -727,7 +801,7 @@ func (h *sha3_256Hash) Sum(in []byte) []byte { // NewSHA3_384 returns a new SHA3-384 hash. func NewSHA3_384() hash.Hash { return &sha3_384Hash{ - evpHash: newEvpHash(crypto.SHA3_384, 384/8, 128), + evpHash: newEvpHash(crypto.SHA3_384), } } @@ -744,7 +818,7 @@ func (h *sha3_384Hash) Sum(in []byte) []byte { // NewSHA3_512 returns a new SHA3-512 hash. func NewSHA3_512() hash.Hash { return &sha3_512Hash{ - evpHash: newEvpHash(crypto.SHA3_512, 512/8, 128), + evpHash: newEvpHash(crypto.SHA3_512), } } diff --git a/hash_test.go b/hash_test.go index 7244038a..d8f7e78a 100644 --- a/hash_test.go +++ b/hash_test.go @@ -41,30 +41,27 @@ func cryptoToHash(h crypto.Hash) func() hash.Hash { func TestHash(t *testing.T) { msg := []byte("testing") - var tests = []struct { - h crypto.Hash - hasMarshaler bool - }{ - {crypto.MD4, false}, - {crypto.MD5, true}, - {crypto.SHA1, true}, - {crypto.SHA224, true}, - {crypto.SHA256, true}, - {crypto.SHA384, true}, - {crypto.SHA512, true}, - {crypto.SHA3_224, false}, - {crypto.SHA3_256, false}, - {crypto.SHA3_384, false}, - {crypto.SHA3_512, false}, + var tests = []crypto.Hash{ + crypto.MD4, + crypto.MD5, + crypto.SHA1, + crypto.SHA224, + crypto.SHA256, + crypto.SHA384, + crypto.SHA512, + crypto.SHA3_224, + crypto.SHA3_256, + crypto.SHA3_384, + crypto.SHA3_512, } - for _, tt := range tests { - tt := tt - t.Run(tt.h.String(), func(t *testing.T) { + for _, ch := range tests { + ch := ch + t.Run(ch.String(), func(t *testing.T) { t.Parallel() - if !openssl.SupportsHash(tt.h) { + if !openssl.SupportsHash(ch) { t.Skip("skipping: not supported") } - h := cryptoToHash(tt.h)() + h := cryptoToHash(ch)() initSum := h.Sum(nil) n, err := h.Write(msg) if err != nil { @@ -80,12 +77,12 @@ func TestHash(t *testing.T) { if bytes.Equal(sum, initSum) { t.Error("Write didn't change internal hash state") } - if tt.hasMarshaler { + if _, ok := h.(encoding.BinaryMarshaler); ok { state, err := h.(encoding.BinaryMarshaler).MarshalBinary() if err != nil { t.Errorf("could not marshal: %v", err) } - h2 := cryptoToHash(tt.h)() + h2 := cryptoToHash(ch)() if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil { t.Errorf("could not unmarshal: %v", err) } diff --git a/shims.h b/shims.h index 99656f0c..f3a623d1 100644 --- a/shims.h +++ b/shims.h @@ -190,9 +190,12 @@ DEFINEFUNC_3_0(int, EVP_default_properties_is_fips_enabled, (GO_OSSL_LIB_CTX_PTR DEFINEFUNC_3_0(int, EVP_default_properties_enable_fips, (GO_OSSL_LIB_CTX_PTR libctx, int enable), (libctx, enable)) \ DEFINEFUNC_3_0(int, OSSL_PROVIDER_available, (GO_OSSL_LIB_CTX_PTR libctx, const char *name), (libctx, name)) \ DEFINEFUNC_3_0(GO_OSSL_PROVIDER_PTR, OSSL_PROVIDER_load, (GO_OSSL_LIB_CTX_PTR libctx, const char *name), (libctx, name)) \ +DEFINEFUNC_3_0(const char *, OSSL_PROVIDER_get0_name, (const GO_OSSL_PROVIDER_PTR prov), (prov)) \ DEFINEFUNC_3_0(GO_EVP_MD_PTR, EVP_MD_fetch, (GO_OSSL_LIB_CTX_PTR ctx, const char *algorithm, const char *properties), (ctx, algorithm, properties)) \ DEFINEFUNC_3_0(void, EVP_MD_free, (GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC_3_0(const char *, EVP_MD_get0_name, (const GO_EVP_MD_PTR md), (md)) \ +DEFINEFUNC_3_0(const GO_OSSL_PROVIDER_PTR, EVP_MD_get0_provider, (const GO_EVP_MD_PTR md), (md)) \ +DEFINEFUNC_RENAMED_3_0(int, EVP_MD_get_block_size, EVP_MD_block_size, (const GO_EVP_MD_PTR md), (md)) \ DEFINEFUNC(int, RAND_bytes, (unsigned char *arg0, int arg1), (arg0, arg1)) \ DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (void), ()) \ DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \