Skip to content

Commit

Permalink
autogenerate renamed functions
Browse files Browse the repository at this point in the history
  • Loading branch information
qmuntal committed Mar 5, 2025
1 parent 4975f61 commit a3c0336
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 154 deletions.
6 changes: 1 addition & 5 deletions cipher.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,7 @@ func newEVPCipher(key []byte, kind cipherKind) (*evpCipher, error) {
}
c := &evpCipher{key: make([]byte, len(key)), kind: kind}
copy(c.key, key)
if vMajor == 1 {
c.blockSize = int(go_openssl_EVP_CIPHER_block_size(cipher))
} else {
c.blockSize = int(go_openssl_EVP_CIPHER_get_block_size(cipher))
}
c.blockSize = int(go_openssl_EVP_CIPHER_get_block_size(cipher))
return c, nil
}

Expand Down
22 changes: 11 additions & 11 deletions cmd/checkheader/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,17 @@ func generate(header string) (string, error) {
case "EVP_PKEY_size", "EVP_PKEY_bits":
specialCond = "OPENSSL_VERSION_NUMBER >= 0x10101000L"
}
switch fn.Tag {
case "legacy_1":
tagCond = "OPENSSL_VERSION_NUMBER < 0x30000000L"
case "111":
tagCond = "OPENSSL_VERSION_NUMBER >= 0x10101000L"
case "3":
tagCond = "OPENSSL_VERSION_NUMBER >= 0x30000000L"
case "":
// No tag, the function is available in all versions.
default:
panic("unexpected tag: " + fn.Tag)
if len(fn.Tags) == 1 {
switch fn.Tags[0].Tag {
case "legacy_1":
tagCond = "OPENSSL_VERSION_NUMBER < 0x30000000L"
case "111":
tagCond = "OPENSSL_VERSION_NUMBER >= 0x10101000L"
case "3":
tagCond = "OPENSSL_VERSION_NUMBER >= 0x30000000L"
default:
panic("unexpected tag: " + fn.Tags[0].Tag)
}
}
if specialCond != "" {
fmt.Fprintf(w, "#if %s\n", specialCond)
Expand Down
28 changes: 22 additions & 6 deletions cmd/mkcgo/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,37 @@ func generateC(src *mkcgo.Source, w io.Writer) {
}
fmt.Fprintf(w, "\n")

fmt.Fprintf(w, "#define __mkcgo__dlsym(name) \\\n")
fmt.Fprintf(w, "\t_g_##name = (typeof(_g_##name))dlsym(handle, #name); \\\n")
fmt.Fprintf(w, "\tif (_g_##name == NULL) { \\\n")
fmt.Fprintf(w, "\t\tfprintf(stderr, \"Cannot get required symbol \" #name \"\\n\"); \\\n")
fmt.Fprintf(w, "#define __mkcgo__dlsym(name) __mkcgo__dlsym2(name, name)\n\n")

fmt.Fprintf(w, "#define __mkcgo__dlsym2(varname, funcname) \\\n")
fmt.Fprintf(w, "\t_g_##varname = (typeof(_g_##varname))dlsym(handle, #funcname); \\\n")
fmt.Fprintf(w, "\tif (_g_##varname == NULL) { \\\n")
fmt.Fprintf(w, "\t\tfprintf(stderr, \"Cannot get required symbol \" #funcname \"\\n\"); \\\n")
fmt.Fprintf(w, "\t\tabort(); \\\n")
fmt.Fprintf(w, "\t}\n\n")

// Loader functions for each tag.
for _, tag := range src.Tags() {
fmt.Fprintf(w, "void __mkcgoLoad_%s(void* handle) {\n", tag)
for _, fn := range src.Funcs {
if fn.VariadicInst || fn.Tag != tag {
if fn.VariadicInst {
continue
}
fmt.Fprintf(w, "\t__mkcgo__dlsym(%s)\n", fn.ImportName)
if len(fn.Tags) == 0 && tag == "" {
// Default tag.
fmt.Fprintf(w, "\t__mkcgo__dlsym(%s)\n", fn.ImportName)
} else {
for _, tagAttr := range fn.Tags {
if tagAttr.Tag == tag {
if tagAttr.Name != "" {
fmt.Fprintf(w, "\t__mkcgo__dlsym2(%s, %s)\n", fn.ImportName, tagAttr.Name)
} else {
fmt.Fprintf(w, "\t__mkcgo__dlsym(%s)\n", fn.ImportName)
}
break
}
}
}
}
fmt.Fprintf(w, "}\n\n")
}
Expand Down
7 changes: 1 addition & 6 deletions ecdh.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,7 @@ func GenerateKeyECDH(curve string) (*PrivateKeyECDH, []byte, error) {
// The fixed length is the order of the large prime subgroup of the curve,
// returned by EVP_PKEY_get_bits, which is generally the upper bound for
// generating a private ECDH key.
var bits int32
if vMajor == 1 {
bits = go_openssl_EVP_PKEY_bits(pkey)
} else {
bits = go_openssl_EVP_PKEY_get_bits(pkey)
}
bits := go_openssl_EVP_PKEY_get_bits(pkey)
bytes := make([]byte, (bits+7)/8)
if err := bnToBinPad(priv, bytes); err != nil {
return nil, nil, err
Expand Down
12 changes: 2 additions & 10 deletions evp.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,8 @@ func loadHash(ch crypto.Hash) *hashAlgorithm {
return nil
}
hash.ch = ch
if vMajor == 1 {
hash.size = int(go_openssl_EVP_MD_size(hash.md))
hash.blockSize = int(go_openssl_EVP_MD_block_size(hash.md))
} else {
hash.size = int(go_openssl_EVP_MD_get_size(hash.md))
hash.blockSize = int(go_openssl_EVP_MD_get_block_size(hash.md))
}
hash.size = int(go_openssl_EVP_MD_get_size(hash.md))
hash.blockSize = int(go_openssl_EVP_MD_get_block_size(hash.md))
if vMajor == 3 {
// On OpenSSL 3, directly operating on a EVP_MD object
// not created by EVP_MD_fetch has negative performance
Expand Down Expand Up @@ -355,9 +350,6 @@ func cryptEVP(withKey withKeyFunc, padding int32,
}
defer go_openssl_EVP_PKEY_CTX_free(ctx)
pkeySize := withKey(func(pkey _EVP_PKEY_PTR) int32 {
if vMajor == 1 {
return go_openssl_EVP_PKEY_size(pkey)
}
return go_openssl_EVP_PKEY_get_size(pkey)
})
outLen := int(pkeySize)
Expand Down
7 changes: 1 addition & 6 deletions hkdf.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,7 @@ func ExpandHKDF(h func() hash.Hash, pseudorandomKey, info []byte) (io.Reader, er
return nil, err
}

var size int
if vMajor == 1 {
size = int(go_openssl_EVP_MD_size(md))
} else {
size = int(go_openssl_EVP_MD_get_size(md))
}
size := int(go_openssl_EVP_MD_get_size(md))

switch vMajor {
case 1:
Expand Down
18 changes: 14 additions & 4 deletions internal/mkcgo/mkcgo.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type TypeDef struct {
Type string
}

// Enum describes an enum definition.
type Enum struct {
Name string
Value string
Expand All @@ -29,7 +30,7 @@ type Func struct {
GoName string
CName string
ImportName string
Tag string
Tags []TagAttr // if TagAttr.Name is set, it's the import name for the tag
Params []*Param
Ret *Return
VariadicInst bool // true if the function is a variadic instantiation
Expand All @@ -39,6 +40,12 @@ func (f *Func) Variadic() bool {
return len(f.Params) > 0 && f.Params[len(f.Params)-1].Variadic()
}

// TagAttr is an attribute of a tag with an optional name.
type TagAttr struct {
Tag string
Name string
}

// Param is a function parameter.
type Param struct {
Name string
Expand All @@ -56,10 +63,13 @@ type Return struct {
}

func (src *Source) Tags() []string {
var tags []string
tags := make([]string, 0, len(src.Funcs)+1)
tags = append(tags, "") // default tag
for _, fn := range src.Funcs {
if !slices.Contains(tags, fn.Tag) {
tags = append(tags, fn.Tag)
for _, tag := range fn.Tags {
if !slices.Contains(tags, tag.Tag) {
tags = append(tags, tag.Tag)
}
}
}
slices.Sort(tags)
Expand Down
112 changes: 76 additions & 36 deletions internal/mkcgo/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,35 @@ import (
)

type fnAttributes struct {
tag string
tags []TagAttr
variadic bool
importName string
}

type attribute struct {
name string
description string
hasParameter bool
handle func(string, *fnAttributes)
name string
description string
handle func(*fnAttributes, ...string)
}

var attributes = [...]attribute{
{
name: "tag",
description: "the function will be loaded together with other functions with the same tag.",
hasParameter: true,
handle: func(s string, opts *fnAttributes) {
opts.tag = s
name: "tag",
description: "The function will be loaded together with other functions with the same tag. It can contain an optional name, which is the import name for the tag.",
handle: func(opts *fnAttributes, s ...string) {
var name string
if len(s) > 1 {
name = s[1]
}
opts.tags = append(opts.tags, TagAttr{Tag: s[0], Name: name})
},
},
{
name: "variadic",
description: "the function has variadic arguments, and its name is a custom wrapper for the actual C name, defined in this attribute.",
hasParameter: true,
handle: func(s string, opts *fnAttributes) {
name: "variadic",
description: "The function has variadic arguments, and its name is a custom wrapper for the actual C name, defined in this attribute.",
handle: func(opts *fnAttributes, s ...string) {
opts.variadic = true
opts.importName = s
opts.importName = s[0]
},
},
}
Expand Down Expand Up @@ -180,7 +181,7 @@ func newFn(s string, opts fnAttributes) (*Func, error) {
fn := &Func{
Ret: &Return{},
VariadicInst: opts.variadic,
Tag: opts.tag,
Tags: opts.tags,
}
var err error
fn.Params, err = extractParams(body)
Expand Down Expand Up @@ -245,11 +246,28 @@ func extractSection(s string, start, end string) (prefix, body, suffix string, f
prefix = a[0]
body = a[1]
}
a := strings.SplitN(body, end, 2)
if len(a) != 2 {
return "", "", "", false
idxStart := strings.Index(body, start)
idxEnd := strings.Index(body, end)
needBalancing := idxStart != -1 && idxEnd != -1 && idxStart < idxEnd
if !needBalancing {
a := strings.SplitN(body, end, 2)
if len(a) != 2 {
return "", "", "", false
}
return prefix, a[0], a[1], true
}
depth := 1
for i := range len(body) {
if strings.HasPrefix(body[i:], start) {
depth++
} else if strings.HasPrefix(body[i:], end) {
depth--
if depth == 0 {
return prefix, body[:i], body[i+len(end):], true
}
}
}
return prefix, a[0], a[1], true
return "", "", s, false
}

// processComments removes comments from line and returns the result.
Expand Down Expand Up @@ -281,39 +299,61 @@ func processComments(line string, inBlockComment *bool) (comment, remmaining str

// extractFunctionAttributes extracts mkcgo attributes from string s.
// The attributes format follows the GCC __attribute__ syntax as
// described in https://gcc.gnu.org/onlinedocs/gcc/Function-Attributes.html.
// described in https://gcc.gnu.org/onlinedocs/gcc/Attribute-Syntax.html.
func extractFunctionAttributes(s string, fnAttrs *fnAttributes) (string, error) {
// There can be spaces between __attribute__ and the opening parenthesis.
prefix, after, found := strings.Cut(s, "__attribute__")
prefix, body, found := strings.Cut(s, "__attribute__")
if !found {
return s, nil
}
_, body, suffix, found := extractSection(after, "((", "));")
_, body, suffix, found := extractSection(body, "(", ")")
if !found {
return s, nil
}
for _, v := range strings.Split(body, ",") {
v = trim(v)
if !strings.HasPrefix(body, "(") || !strings.HasSuffix(body, ")") {
// Attributes are enclosed in double parentheses.
return s, nil
}
body = trim(body[1 : len(body)-1])
for {
if body == "" {
break
}
// Attributes are separated by commas. Get the next attribute.
// We can't just use strings.Split because the attribute argument
// can contain commas.
var name, args string
idxComma := strings.IndexByte(body, ',')
idxParen := strings.IndexByte(body, '(')
if idxComma != -1 && (idxParen == -1 || idxComma < idxParen) {
// The attribute has no arguments.
name = body[:idxComma]
body = body[idxComma+1:]
} else if idxParen != -1 && (idxComma == -1 || idxComma > idxParen) {
// The attribute has arguments, possibly with commas.
name = body[:idxParen]
_, args, body, found = extractSection(body[idxParen:], "(", ")")
if !found {
return "", errors.New("unbalanced parentheses in mkcgo attribute: " + s)
}
body = strings.TrimPrefix(body, ",")
}
name, args = trim(name), trim(args)
var handled bool
for _, attr := range attributes {
if (!attr.hasParameter && v != attr.name) ||
(attr.hasParameter && !strings.HasPrefix(v, attr.name+"(")) {
if name != attr.name {
continue
}
var arg string
if attr.hasParameter {
var ok bool
if _, arg, _, ok = extractSection(v, "(", ")"); !ok {
return "", errors.New("could not extract mkcgo attribute argument from \"" + v + "\"")
}
arg = strings.Trim(arg, `"`)
vargs := strings.Split(args, ",")
for i := range vargs {
vargs[i] = trim(strings.Trim(vargs[i], `"`))
}
attr.handle(arg, fnAttrs)
attr.handle(fnAttrs, vargs...)
handled = true
break
}
if !handled {
return "", errors.New("unknown mkcgo attribute: " + v)
return "", errors.New("unknown mkcgo attribute: " + s)
}
}
return trim(prefix + suffix), nil
Expand Down
7 changes: 1 addition & 6 deletions rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,12 +278,7 @@ func HashSignRSAPKCS1v15(priv *PrivateKeyRSA, h crypto.Hash, msg []byte) ([]byte

func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error {
if pub.withKey(func(pkey _EVP_PKEY_PTR) int32 {
var size int32
if vMajor == 1 {
size = go_openssl_EVP_PKEY_size(pkey)
} else {
size = go_openssl_EVP_PKEY_get_size(pkey)
}
size := go_openssl_EVP_PKEY_get_size(pkey)
if len(sig) < int(size) {
return 0
}
Expand Down
Loading

0 comments on commit a3c0336

Please sign in to comment.