Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only implement encoding.BinaryMarshaler for hashers provided by built-in providers #161

Merged
merged 6 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ var cacheMD sync.Map
func hashToMD(h hash.Hash) C.GO_EVP_MD_PTR {
var ch crypto.Hash
switch h.(type) {
case *sha1Hash:
case *sha1Hash, *sha1Marshal:
ch = crypto.SHA1
case *sha224Hash:
case *sha224Hash, *sha224Marshal:
ch = crypto.SHA224
case *sha256Hash:
case *sha256Hash, *sha256Marshal:
ch = crypto.SHA256
case *sha384Hash:
case *sha384Hash, *sha384Marshal:
ch = crypto.SHA384
case *sha512Hash:
case *sha512Hash, *sha512Marshal:
ch = crypto.SHA512
case *sha3_224Hash:
ch = crypto.SHA3_224
Expand Down
116 changes: 92 additions & 24 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"hash"
"runtime"
"strconv"
"sync"
"unsafe"
)

Expand Down Expand Up @@ -110,6 +111,37 @@ func SHA3_512(p []byte) (sum [64]byte) {
return
}

var isMarshallableMap sync.Map

// isHashMarshallable returns true if its memory layout
// is known by this library, therefore it can be marshalled.
func isHashMarshallable(cb crypto.Hash) bool {
if vMajor == 1 {
return true
}
if v, ok := isMarshallableMap.Load(cb); ok {
return v.(bool)
}
md := cryptoHashToMD(cb)
if md == nil {
return false
}
prov := C.go_openssl_EVP_MD_get0_provider(md)
if prov == nil {
return false
}
cname := C.go_openssl_OSSL_PROVIDER_get0_name(prov)
if cname == nil {
return false
}
name := C.GoString(cname)
// We only known the memory layout of the built-in providers.
// See evpHash.hashState for more details.
marshallable := name == "default" || name == "fips"
isMarshallableMap.Store(cb, marshallable)
return marshallable
}

// evpHash implements generic hash methods.
type evpHash struct {
ctx C.GO_EVP_MD_CTX_PTR
Expand Down Expand Up @@ -244,9 +276,11 @@ func (h *md4Hash) Sum(in []byte) []byte {

// NewMD5 returns a new MD5 hash.
func NewMD5() hash.Hash {
return &md5Hash{
evpHash: newEvpHash(crypto.MD5, 16, 64),
h := newEvpHash(crypto.MD5, 16, 64)
if isHashMarshallable(crypto.MD5) {
return &md5Marshal{md5Hash{evpHash: h}}
}
return &md5Hash{evpHash: h}
}

// md5State layout is taken from
Expand All @@ -273,7 +307,11 @@ const (
md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8
)

func (h *md5Hash) MarshalBinary() ([]byte, error) {
type md5Marshal struct {
md5Hash
}

func (h *md5Marshal) MarshalBinary() ([]byte, error) {
d := (*md5State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/md5: can't retrieve hash state")
Expand All @@ -290,7 +328,7 @@ func (h *md5Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *md5Hash) UnmarshalBinary(b []byte) error {
func (h *md5Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(md5Magic) || string(b[:len(md5Magic)]) != md5Magic {
return errors.New("crypto/md5: invalid hash state identifier")
}
Expand All @@ -316,9 +354,11 @@ func (h *md5Hash) UnmarshalBinary(b []byte) error {

// NewSHA1 returns a new SHA1 hash.
func NewSHA1() hash.Hash {
return &sha1Hash{
evpHash: newEvpHash(crypto.SHA1, 20, 64),
h := newEvpHash(crypto.SHA1, 20, 64)
if isHashMarshallable(crypto.SHA1) {
return &sha1Marshal{sha1Hash{evpHash: h}}
}
return &sha1Hash{evpHash: h}
}

type sha1Hash struct {
Expand All @@ -345,7 +385,11 @@ const (
sha1MarshaledSize = len(sha1Magic) + 5*4 + 64 + 8
)

func (h *sha1Hash) MarshalBinary() ([]byte, error) {
type sha1Marshal struct {
sha1Hash
}

func (h *sha1Marshal) MarshalBinary() ([]byte, error) {
d := (*sha1State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha1: can't retrieve hash state")
Expand All @@ -363,7 +407,7 @@ func (h *sha1Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha1Hash) UnmarshalBinary(b []byte) error {
func (h *sha1Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(sha1Magic) || string(b[:len(sha1Magic)]) != sha1Magic {
return errors.New("crypto/sha1: invalid hash state identifier")
}
Expand All @@ -390,9 +434,11 @@ func (h *sha1Hash) UnmarshalBinary(b []byte) error {

// NewSHA224 returns a new SHA224 hash.
func NewSHA224() hash.Hash {
return &sha224Hash{
evpHash: newEvpHash(crypto.SHA224, 224/8, 64),
h := newEvpHash(crypto.SHA224, 224/8, 64)
if isHashMarshallable(crypto.SHA224) {
return &sha224Marshal{sha224Hash{evpHash: h}}
}
return &sha224Hash{evpHash: h}
}

type sha224Hash struct {
Expand All @@ -407,9 +453,11 @@ func (h *sha224Hash) Sum(in []byte) []byte {

// NewSHA256 returns a new SHA256 hash.
func NewSHA256() hash.Hash {
return &sha256Hash{
evpHash: newEvpHash(crypto.SHA256, 256/8, 64),
h := newEvpHash(crypto.SHA256, 256/8, 64)
if isHashMarshallable(crypto.SHA256) {
return &sha256Marshal{sha256Hash{evpHash: h}}
}
return &sha256Hash{evpHash: h}
}

type sha256Hash struct {
Expand Down Expand Up @@ -437,7 +485,15 @@ type sha256State struct {
nx uint32
}

func (h *sha224Hash) MarshalBinary() ([]byte, error) {
type sha224Marshal struct {
sha224Hash
}

type sha256Marshal struct {
sha256Hash
}

func (h *sha224Marshal) MarshalBinary() ([]byte, error) {
d := (*sha256State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha256: can't retrieve hash state")
Expand All @@ -458,7 +514,7 @@ func (h *sha224Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha256Hash) MarshalBinary() ([]byte, error) {
func (h *sha256Marshal) MarshalBinary() ([]byte, error) {
d := (*sha256State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha256: can't retrieve hash state")
Expand All @@ -479,7 +535,7 @@ func (h *sha256Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha224Hash) UnmarshalBinary(b []byte) error {
func (h *sha224Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(magic224) || string(b[:len(magic224)]) != magic224 {
return errors.New("crypto/sha256: invalid hash state identifier")
}
Expand Down Expand Up @@ -507,7 +563,7 @@ func (h *sha224Hash) UnmarshalBinary(b []byte) error {
return nil
}

func (h *sha256Hash) UnmarshalBinary(b []byte) error {
func (h *sha256Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(magic256) || string(b[:len(magic256)]) != magic256 {
return errors.New("crypto/sha256: invalid hash state identifier")
}
Expand Down Expand Up @@ -537,9 +593,11 @@ func (h *sha256Hash) UnmarshalBinary(b []byte) error {

// NewSHA384 returns a new SHA384 hash.
func NewSHA384() hash.Hash {
return &sha384Hash{
evpHash: newEvpHash(crypto.SHA384, 384/8, 128),
h := newEvpHash(crypto.SHA384, 384/8, 128)
if isHashMarshallable(crypto.SHA384) {
return &sha384Marshal{sha384Hash{evpHash: h}}
}
return &sha384Hash{evpHash: h}
}

type sha384Hash struct {
Expand All @@ -554,9 +612,11 @@ func (h *sha384Hash) Sum(in []byte) []byte {

// NewSHA512 returns a new SHA512 hash.
func NewSHA512() hash.Hash {
return &sha512Hash{
evpHash: newEvpHash(crypto.SHA512, 512/8, 128),
h := newEvpHash(crypto.SHA512, 512/8, 128)
if isHashMarshallable(crypto.SHA512) {
return &sha512Marshal{sha512Hash{evpHash: h}}
}
return &sha512Hash{evpHash: h}
}

type sha512Hash struct {
Expand Down Expand Up @@ -586,7 +646,15 @@ const (
marshaledSize512 = len(magic512) + 8*8 + 128 + 8
)

func (h *sha384Hash) MarshalBinary() ([]byte, error) {
type sha384Marshal struct {
sha384Hash
}

type sha512Marshal struct {
sha512Hash
}

func (h *sha384Marshal) MarshalBinary() ([]byte, error) {
d := (*sha512State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha512: can't retrieve hash state")
Expand All @@ -607,7 +675,7 @@ func (h *sha384Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha512Hash) MarshalBinary() ([]byte, error) {
func (h *sha512Marshal) MarshalBinary() ([]byte, error) {
d := (*sha512State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha512: can't retrieve hash state")
Expand All @@ -628,7 +696,7 @@ func (h *sha512Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha384Hash) UnmarshalBinary(b []byte) error {
func (h *sha384Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(magic512) {
return errors.New("crypto/sha512: invalid hash state identifier")
}
Expand Down Expand Up @@ -659,7 +727,7 @@ func (h *sha384Hash) UnmarshalBinary(b []byte) error {
return nil
}

func (h *sha512Hash) UnmarshalBinary(b []byte) error {
func (h *sha512Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(magic512) {
return errors.New("crypto/sha512: invalid hash state identifier")
}
Expand Down
41 changes: 19 additions & 22 deletions hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,27 @@ func cryptoToHash(h crypto.Hash) func() hash.Hash {

func TestHash(t *testing.T) {
msg := []byte("testing")
var tests = []struct {
h crypto.Hash
hasMarshaler bool
}{
{crypto.MD4, false},
{crypto.MD5, true},
{crypto.SHA1, true},
{crypto.SHA224, true},
{crypto.SHA256, true},
{crypto.SHA384, true},
{crypto.SHA512, true},
{crypto.SHA3_224, false},
{crypto.SHA3_256, false},
{crypto.SHA3_384, false},
{crypto.SHA3_512, false},
var tests = []crypto.Hash{
crypto.MD4,
crypto.MD5,
crypto.SHA1,
crypto.SHA224,
crypto.SHA256,
crypto.SHA384,
crypto.SHA512,
crypto.SHA3_224,
crypto.SHA3_256,
crypto.SHA3_384,
crypto.SHA3_512,
}
for _, tt := range tests {
tt := tt
t.Run(tt.h.String(), func(t *testing.T) {
for _, cb := range tests {
cb := cb
t.Run(cb.String(), func(t *testing.T) {
t.Parallel()
if !openssl.SupportsHash(tt.h) {
if !openssl.SupportsHash(cb) {
t.Skip("skipping: not supported")
}
h := cryptoToHash(tt.h)()
h := cryptoToHash(cb)()
initSum := h.Sum(nil)
n, err := h.Write(msg)
if err != nil {
Expand All @@ -80,12 +77,12 @@ func TestHash(t *testing.T) {
if bytes.Equal(sum, initSum) {
t.Error("Write didn't change internal hash state")
}
if tt.hasMarshaler {
if _, ok := h.(encoding.BinaryMarshaler); ok {
state, err := h.(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
t.Errorf("could not marshal: %v", err)
}
h2 := cryptoToHash(tt.h)()
h2 := cryptoToHash(cb)()
if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil {
t.Errorf("could not unmarshal: %v", err)
}
Expand Down
2 changes: 2 additions & 0 deletions shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,11 @@ DEFINEFUNC_3_0(int, EVP_default_properties_is_fips_enabled, (GO_OSSL_LIB_CTX_PTR
DEFINEFUNC_3_0(int, EVP_default_properties_enable_fips, (GO_OSSL_LIB_CTX_PTR libctx, int enable), (libctx, enable)) \
DEFINEFUNC_3_0(int, OSSL_PROVIDER_available, (GO_OSSL_LIB_CTX_PTR libctx, const char *name), (libctx, name)) \
DEFINEFUNC_3_0(GO_OSSL_PROVIDER_PTR, OSSL_PROVIDER_load, (GO_OSSL_LIB_CTX_PTR libctx, const char *name), (libctx, name)) \
DEFINEFUNC_3_0(const char *, OSSL_PROVIDER_get0_name, (const GO_OSSL_PROVIDER_PTR prov), (prov)) \
DEFINEFUNC_3_0(GO_EVP_MD_PTR, EVP_MD_fetch, (GO_OSSL_LIB_CTX_PTR ctx, const char *algorithm, const char *properties), (ctx, algorithm, properties)) \
DEFINEFUNC_3_0(void, EVP_MD_free, (GO_EVP_MD_PTR md), (md)) \
DEFINEFUNC_3_0(const char *, EVP_MD_get0_name, (const GO_EVP_MD_PTR md), (md)) \
DEFINEFUNC_3_0(const GO_OSSL_PROVIDER_PTR, EVP_MD_get0_provider, (const GO_EVP_MD_PTR md), (md)) \
DEFINEFUNC(int, RAND_bytes, (unsigned char *arg0, int arg1), (arg0, arg1)) \
DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (void), ()) \
DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \
Expand Down
Loading