Skip to content

Commit

Permalink
only implement encoding.BinaryMarshaler for hashers provided by built…
Browse files Browse the repository at this point in the history
…-in providers
  • Loading branch information
qmuntal committed Sep 3, 2024
1 parent f8eea43 commit 5abc52a
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 46 deletions.
116 changes: 92 additions & 24 deletions hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"hash"
"runtime"
"strconv"
"sync"
"unsafe"
)

Expand Down Expand Up @@ -110,6 +111,37 @@ func SHA3_512(p []byte) (sum [64]byte) {
return
}

var isMarshallableMap sync.Map

// isHashMarshallable returns true if its memory layout
// is known by this library, therefore it can be marshalled.
func isHashMarshallable(cb crypto.Hash) bool {
if vMajor == 1 {
return true
}
if v, ok := isMarshallableMap.Load(cb); ok {
return v.(bool)
}
md := cryptoHashToMD(cb)
if md == nil {
return false
}
prov := C.go_openssl_EVP_MD_get0_provider(md)
if prov == nil {
return false
}
cname := C.go_openssl_OSSL_PROVIDER_get0_name(prov)
if cname == nil {
return false
}
name := C.GoString(cname)
// We only known the memory layout of the built-in providers.
// See evpHash.hashState for more details.
marshallable := name == "default" || name == "fips"
isMarshallableMap.Store(cb, marshallable)
return marshallable
}

// evpHash implements generic hash methods.
type evpHash struct {
ctx C.GO_EVP_MD_CTX_PTR
Expand Down Expand Up @@ -244,9 +276,11 @@ func (h *md4Hash) Sum(in []byte) []byte {

// NewMD5 returns a new MD5 hash.
func NewMD5() hash.Hash {
return &md5Hash{
evpHash: newEvpHash(crypto.MD5, 16, 64),
h := newEvpHash(crypto.MD5, 16, 64)
if isHashMarshallable(crypto.MD5) {
return &md5Marshal{md5Hash{evpHash: h}}
}
return &md5Hash{evpHash: h}
}

// md5State layout is taken from
Expand All @@ -273,7 +307,11 @@ const (
md5MarshaledSize = len(md5Magic) + 4*4 + 64 + 8
)

func (h *md5Hash) MarshalBinary() ([]byte, error) {
type md5Marshal struct {
md5Hash
}

func (h *md5Marshal) MarshalBinary() ([]byte, error) {
d := (*md5State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/md5: can't retrieve hash state")
Expand All @@ -290,7 +328,7 @@ func (h *md5Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *md5Hash) UnmarshalBinary(b []byte) error {
func (h *md5Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(md5Magic) || string(b[:len(md5Magic)]) != md5Magic {
return errors.New("crypto/md5: invalid hash state identifier")
}
Expand All @@ -316,9 +354,11 @@ func (h *md5Hash) UnmarshalBinary(b []byte) error {

// NewSHA1 returns a new SHA1 hash.
func NewSHA1() hash.Hash {
return &sha1Hash{
evpHash: newEvpHash(crypto.SHA1, 20, 64),
h := newEvpHash(crypto.SHA1, 20, 64)
if isHashMarshallable(crypto.SHA1) {
return &sha1Marshal{sha1Hash{evpHash: h}}
}
return &sha1Hash{evpHash: h}
}

type sha1Hash struct {
Expand All @@ -345,7 +385,11 @@ const (
sha1MarshaledSize = len(sha1Magic) + 5*4 + 64 + 8
)

func (h *sha1Hash) MarshalBinary() ([]byte, error) {
type sha1Marshal struct {
sha1Hash
}

func (h *sha1Marshal) MarshalBinary() ([]byte, error) {
d := (*sha1State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha1: can't retrieve hash state")
Expand All @@ -363,7 +407,7 @@ func (h *sha1Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha1Hash) UnmarshalBinary(b []byte) error {
func (h *sha1Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(sha1Magic) || string(b[:len(sha1Magic)]) != sha1Magic {
return errors.New("crypto/sha1: invalid hash state identifier")
}
Expand All @@ -390,9 +434,11 @@ func (h *sha1Hash) UnmarshalBinary(b []byte) error {

// NewSHA224 returns a new SHA224 hash.
func NewSHA224() hash.Hash {
return &sha224Hash{
evpHash: newEvpHash(crypto.SHA224, 224/8, 64),
h := newEvpHash(crypto.SHA224, 224/8, 64)
if isHashMarshallable(crypto.SHA224) {
return &sha224Marshal{sha224Hash{evpHash: h}}
}
return &sha224Hash{evpHash: h}
}

type sha224Hash struct {
Expand All @@ -407,9 +453,11 @@ func (h *sha224Hash) Sum(in []byte) []byte {

// NewSHA256 returns a new SHA256 hash.
func NewSHA256() hash.Hash {
return &sha256Hash{
evpHash: newEvpHash(crypto.SHA256, 256/8, 64),
h := newEvpHash(crypto.SHA256, 256/8, 64)
if isHashMarshallable(crypto.SHA256) {
return &sha256Marshal{sha256Hash{evpHash: h}}
}
return &sha256Hash{evpHash: h}
}

type sha256Hash struct {
Expand Down Expand Up @@ -437,7 +485,15 @@ type sha256State struct {
nx uint32
}

func (h *sha224Hash) MarshalBinary() ([]byte, error) {
type sha224Marshal struct {
sha224Hash
}

type sha256Marshal struct {
sha256Hash
}

func (h *sha224Marshal) MarshalBinary() ([]byte, error) {
d := (*sha256State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha256: can't retrieve hash state")
Expand All @@ -458,7 +514,7 @@ func (h *sha224Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha256Hash) MarshalBinary() ([]byte, error) {
func (h *sha256Marshal) MarshalBinary() ([]byte, error) {
d := (*sha256State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha256: can't retrieve hash state")
Expand All @@ -479,7 +535,7 @@ func (h *sha256Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha224Hash) UnmarshalBinary(b []byte) error {
func (h *sha224Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(magic224) || string(b[:len(magic224)]) != magic224 {
return errors.New("crypto/sha256: invalid hash state identifier")
}
Expand Down Expand Up @@ -507,7 +563,7 @@ func (h *sha224Hash) UnmarshalBinary(b []byte) error {
return nil
}

func (h *sha256Hash) UnmarshalBinary(b []byte) error {
func (h *sha256Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(magic256) || string(b[:len(magic256)]) != magic256 {
return errors.New("crypto/sha256: invalid hash state identifier")
}
Expand Down Expand Up @@ -537,9 +593,11 @@ func (h *sha256Hash) UnmarshalBinary(b []byte) error {

// NewSHA384 returns a new SHA384 hash.
func NewSHA384() hash.Hash {
return &sha384Hash{
evpHash: newEvpHash(crypto.SHA384, 384/8, 128),
h := newEvpHash(crypto.SHA384, 384/8, 128)
if isHashMarshallable(crypto.SHA384) {
return &sha384Marshal{sha384Hash{evpHash: h}}
}
return &sha384Hash{evpHash: h}
}

type sha384Hash struct {
Expand All @@ -554,9 +612,11 @@ func (h *sha384Hash) Sum(in []byte) []byte {

// NewSHA512 returns a new SHA512 hash.
func NewSHA512() hash.Hash {
return &sha512Hash{
evpHash: newEvpHash(crypto.SHA512, 512/8, 128),
h := newEvpHash(crypto.SHA512, 512/8, 128)
if isHashMarshallable(crypto.SHA512) {
return &sha512Marshal{sha512Hash{evpHash: h}}
}
return &sha512Hash{evpHash: h}
}

type sha512Hash struct {
Expand Down Expand Up @@ -586,7 +646,15 @@ const (
marshaledSize512 = len(magic512) + 8*8 + 128 + 8
)

func (h *sha384Hash) MarshalBinary() ([]byte, error) {
type sha384Marshal struct {
sha384Hash
}

type sha512Marshal struct {
sha512Hash
}

func (h *sha384Marshal) MarshalBinary() ([]byte, error) {
d := (*sha512State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha512: can't retrieve hash state")
Expand All @@ -607,7 +675,7 @@ func (h *sha384Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha512Hash) MarshalBinary() ([]byte, error) {
func (h *sha512Marshal) MarshalBinary() ([]byte, error) {
d := (*sha512State)(h.hashState())
if d == nil {
return nil, errors.New("crypto/sha512: can't retrieve hash state")
Expand All @@ -628,7 +696,7 @@ func (h *sha512Hash) MarshalBinary() ([]byte, error) {
return b, nil
}

func (h *sha384Hash) UnmarshalBinary(b []byte) error {
func (h *sha384Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(magic512) {
return errors.New("crypto/sha512: invalid hash state identifier")
}
Expand Down Expand Up @@ -659,7 +727,7 @@ func (h *sha384Hash) UnmarshalBinary(b []byte) error {
return nil
}

func (h *sha512Hash) UnmarshalBinary(b []byte) error {
func (h *sha512Marshal) UnmarshalBinary(b []byte) error {
if len(b) < len(magic512) {
return errors.New("crypto/sha512: invalid hash state identifier")
}
Expand Down
41 changes: 19 additions & 22 deletions hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,27 @@ func cryptoToHash(h crypto.Hash) func() hash.Hash {

func TestHash(t *testing.T) {
msg := []byte("testing")
var tests = []struct {
h crypto.Hash
hasMarshaler bool
}{
{crypto.MD4, false},
{crypto.MD5, true},
{crypto.SHA1, true},
{crypto.SHA224, true},
{crypto.SHA256, true},
{crypto.SHA384, true},
{crypto.SHA512, true},
{crypto.SHA3_224, false},
{crypto.SHA3_256, false},
{crypto.SHA3_384, false},
{crypto.SHA3_512, false},
var tests = []crypto.Hash{
crypto.MD4,
crypto.MD5,
crypto.SHA1,
crypto.SHA224,
crypto.SHA256,
crypto.SHA384,
crypto.SHA512,
crypto.SHA3_224,
crypto.SHA3_256,
crypto.SHA3_384,
crypto.SHA3_512,
}
for _, tt := range tests {
tt := tt
t.Run(tt.h.String(), func(t *testing.T) {
for _, cb := range tests {
cb := cb
t.Run(cb.String(), func(t *testing.T) {
t.Parallel()
if !openssl.SupportsHash(tt.h) {
if !openssl.SupportsHash(cb) {
t.Skip("skipping: not supported")
}
h := cryptoToHash(tt.h)()
h := cryptoToHash(cb)()
initSum := h.Sum(nil)
n, err := h.Write(msg)
if err != nil {
Expand All @@ -80,12 +77,12 @@ func TestHash(t *testing.T) {
if bytes.Equal(sum, initSum) {
t.Error("Write didn't change internal hash state")
}
if tt.hasMarshaler {
if _, ok := h.(encoding.BinaryMarshaler); ok {
state, err := h.(encoding.BinaryMarshaler).MarshalBinary()
if err != nil {
t.Errorf("could not marshal: %v", err)
}
h2 := cryptoToHash(tt.h)()
h2 := cryptoToHash(cb)()
if err := h2.(encoding.BinaryUnmarshaler).UnmarshalBinary(state); err != nil {
t.Errorf("could not unmarshal: %v", err)
}
Expand Down
2 changes: 2 additions & 0 deletions shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,11 @@ DEFINEFUNC_3_0(int, EVP_default_properties_is_fips_enabled, (GO_OSSL_LIB_CTX_PTR
DEFINEFUNC_3_0(int, EVP_default_properties_enable_fips, (GO_OSSL_LIB_CTX_PTR libctx, int enable), (libctx, enable)) \
DEFINEFUNC_3_0(int, OSSL_PROVIDER_available, (GO_OSSL_LIB_CTX_PTR libctx, const char *name), (libctx, name)) \
DEFINEFUNC_3_0(GO_OSSL_PROVIDER_PTR, OSSL_PROVIDER_load, (GO_OSSL_LIB_CTX_PTR libctx, const char *name), (libctx, name)) \
DEFINEFUNC_3_0(const char *, OSSL_PROVIDER_get0_name, (const GO_OSSL_PROVIDER_PTR prov), (prov)) \
DEFINEFUNC_3_0(GO_EVP_MD_PTR, EVP_MD_fetch, (GO_OSSL_LIB_CTX_PTR ctx, const char *algorithm, const char *properties), (ctx, algorithm, properties)) \
DEFINEFUNC_3_0(void, EVP_MD_free, (GO_EVP_MD_PTR md), (md)) \
DEFINEFUNC_3_0(const char *, EVP_MD_get0_name, (const GO_EVP_MD_PTR md), (md)) \
DEFINEFUNC_3_0(const GO_OSSL_PROVIDER_PTR, EVP_MD_get0_provider, (const GO_EVP_MD_PTR md), (md)) \
DEFINEFUNC(int, RAND_bytes, (unsigned char *arg0, int arg1), (arg0, arg1)) \
DEFINEFUNC_RENAMED_1_1(GO_EVP_MD_CTX_PTR, EVP_MD_CTX_new, EVP_MD_CTX_create, (void), ()) \
DEFINEFUNC_RENAMED_1_1(void, EVP_MD_CTX_free, EVP_MD_CTX_destroy, (GO_EVP_MD_CTX_PTR ctx), (ctx)) \
Expand Down

0 comments on commit 5abc52a

Please sign in to comment.