diff --git a/base64_amd64.go b/base64_amd64.go index b45b6b1..3c64912 100644 --- a/base64_amd64.go +++ b/base64_amd64.go @@ -40,7 +40,7 @@ func decode(enc *Encoding, dst, src []byte) (int, error) { if remain < srcLen { // decoded by SIMD - remain = srcLen - remain + remain = srcLen - remain // remain is decoded length now src = src[remain:] dstStart := (remain / 4) * 3 dst = dst[dstStart:] diff --git a/base64_arm64.go b/base64_arm64.go index d00fadb..185e7a7 100644 --- a/base64_arm64.go +++ b/base64_arm64.go @@ -14,6 +14,17 @@ var dencodeStdLut = [128]byte{ 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, } +var dencodeUrlLut = [128]byte{ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, 255, + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255, + 0, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 255, 255, 255, 255, + 63, 255, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, +} + //go:noescape func encodeAsm(dst, src []byte, lut *[64]byte) int @@ -30,5 +41,27 @@ func encode(enc *Encoding, dst, src []byte) { } func decode(enc *Encoding, dst, src []byte) (int, error) { + srcLen := len(src) + if srcLen >= 64 { + remain := srcLen + if enc.lut == &encodeStdLut { + remain = decodeAsm(dst, src, &dencodeStdLut) + } else if enc.lut == &encodeURLLut { + remain = decodeAsm(dst, src, &dencodeUrlLut) + } + + if remain < srcLen { + // decoded by SIMD + remain = srcLen - remain // remain is decoded length now + src = src[remain:] + dstStart := (remain / 4) * 3 + dst = dst[dstStart:] + n, err := decodeGeneric(enc, dst, src) + if cerr, ok := err.(CorruptInputError); ok { + return n + dstStart, CorruptInputError(int(cerr) + remain) + } + return n + dstStart, err + } + } return decodeGeneric(enc, dst, src) } diff --git a/base64_arm64.s b/base64_arm64.s index d0ae56c..4c7bfed 100644 --- a/base64_arm64.s +++ b/base64_arm64.s @@ -107,8 +107,6 @@ loop: VORR V19.B16, V16.B16, V16.B16 // Check that all bits are zero: - // Do NOT use UMAXV first - // WORD $0x6e30aa05 // VUMAXV V16.B16, R5 VMOV V16.D[0], R5 CBNZ R5, done VMOV V16.D[1], R5 diff --git a/base64_arm64_test.go b/base64_arm64_test.go index d393f3f..aaa3514 100644 --- a/base64_arm64_test.go +++ b/base64_arm64_test.go @@ -31,7 +31,8 @@ func TestStdEncodeSIMD(t *testing.T) { func TestStdDecodeSIMD(t *testing.T) { pairs := []testpair{ - // {"abcdefghijklabcdefghijklabcdefghijklabcdefghijkl", "YWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamts"}, + {"abcdefghijklabcdefghijklabcdefghijklabcdefghijkl", "YWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamts"}, + {"abcdefghijklabcdefghijklabcdefghijklabcdefghijkl", "YWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamt="}, {"abcdefghijklabcdefghijklabcdefghijklabcdefghijklabcdefghijklabcdefghijklabcdefghijklabcdefghijkl", "YWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamts"}, } for _, p := range pairs { @@ -48,3 +49,45 @@ func TestStdDecodeSIMD(t *testing.T) { } } } + +func TestUrlEncodeSIMD(t *testing.T) { + pairs := []testpair{ + {"!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~", "IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-"}, + {"!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~", "IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-"}, + } + for _, p := range pairs { + src := []byte(p.decoded) + expected := []byte(p.encoded) + dst := make([]byte, len(expected)) + + ret := encodeAsm(dst, src, &StdEncoding.encode) + if ret != len(expected) { + t.Fatalf("should return %v", len(expected)) + } + if !bytes.Equal(dst, expected) { + t.Fatalf("got %v", string(dst)) + } + + } +} + +func TestUrlDecodeSIMD(t *testing.T) { + pairs := []testpair{ + {"!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~", "IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-"}, + {"abcdefghijklabcdefghijklabcdefghijklabcdefghijkl", "YWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamtsYWJjZGVmZ2hpamt="}, + {"!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~!?$*&()'-=@~", "IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-IT8kKiYoKSctPUB-"}, + } + for _, p := range pairs { + expected := []byte(p.decoded) + src := []byte(p.encoded) + dst := make([]byte, len(expected)) + + ret := decodeAsm(dst, src, &dencodeStdLut) + if ret == len(src) { + t.Fatal("should return decode") + } + if !bytes.Equal(dst, expected) { + t.Fatalf("got %x, expected %x", dst, expected) + } + } +}