Skip to content

Commit

Permalink
feat: add SaltGenerator to Shadowsocks PacketListener (#370)
Browse files Browse the repository at this point in the history
  • Loading branch information
fortuna authored Feb 4, 2025
1 parent 994a1f3 commit 93ebcb0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 13 deletions.
12 changes: 9 additions & 3 deletions transport/shadowsocks/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@ var ErrShortPacket = errors.New("short packet")
// Assumes all ciphers have NonceSize() <= 12.
var zeroNonce [12]byte

// Pack encrypts a Shadowsocks-UDP packet and returns a slice containing the encrypted packet.
// PackSalt encrypts a Shadowsocks-UDP packet and returns a slice containing the encrypted packet.
// dst must be big enough to hold the encrypted packet.
// If plaintext and dst overlap but are not aligned for in-place encryption, this
// function will panic.
func Pack(dst, plaintext []byte, key *EncryptionKey) ([]byte, error) {
// It uses the given [SaltGenerator] to generate the salt.
func PackSalt(dst, plaintext []byte, key *EncryptionKey, sg SaltGenerator) ([]byte, error) {
saltSize := key.SaltSize()
if len(dst) < saltSize {
return nil, io.ErrShortBuffer
}
salt := dst[:saltSize]
if err := RandomSaltGenerator.GetSalt(salt); err != nil {
if err := sg.GetSalt(salt); err != nil {
return nil, err
}

Expand All @@ -50,6 +51,11 @@ func Pack(dst, plaintext []byte, key *EncryptionKey) ([]byte, error) {
return aead.Seal(salt, zeroNonce[:aead.NonceSize()], plaintext, nil), nil
}

// Pack calls PackSalt with the [RandomSaltGenerator].
func Pack(dst, plaintext []byte, key *EncryptionKey) ([]byte, error) {
return PackSalt(dst, plaintext, key, RandomSaltGenerator)
}

// Unpack decrypts a Shadowsocks-UDP packet in the format [salt][cipherText][AEAD tag] and returns a slice containing
// the decrypted payload or an error.
// If dst is present, it is used to store the plaintext, and must have enough capacity.
Expand Down
32 changes: 22 additions & 10 deletions transport/shadowsocks/packet_listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,33 +33,45 @@ const clientUDPBufferSize = 16 * 1024
var udpPool = slicepool.MakePool(clientUDPBufferSize)

type packetListener struct {
endpoint transport.PacketEndpoint
key *EncryptionKey
endpoint transport.PacketEndpoint
key *EncryptionKey
saltGenerator SaltGenerator
}

var _ transport.PacketListener = (*packetListener)(nil)

func NewPacketListener(endpoint transport.PacketEndpoint, key *EncryptionKey) (transport.PacketListener, error) {
type PacketListener = *packetListener

// NewPacketListener creates a new Shadowsocks PacketListener that connects to the proxy on the given endpoint
// and uses the given key for encryption.
func NewPacketListener(endpoint transport.PacketEndpoint, key *EncryptionKey) (PacketListener, error) {
if endpoint == nil {
return nil, errors.New("argument endpoint must not be nil")
}
if key == nil {
return nil, errors.New("argument key must not be nil")
}
return &packetListener{endpoint: endpoint, key: key}, nil
return &packetListener{endpoint: endpoint, key: key, saltGenerator: RandomSaltGenerator}, nil
}

// SetSaltGenerator sets the SaltGenerator to use for encryption. If not set, it used the [RandomSaltGenerator] by default.
func (pl *packetListener) SetSaltGenerator(sg SaltGenerator) {
pl.saltGenerator = sg
}

func (c *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, error) {
proxyConn, err := c.endpoint.ConnectPacket(ctx)
// ListenPacket creates a net.PackeConn to send packets from the remote endpoint.
func (pl *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, error) {
proxyConn, err := pl.endpoint.ConnectPacket(ctx)
if err != nil {
return nil, fmt.Errorf("could not connect to endpoint: %w", err)
}
return NewPacketConn(proxyConn, c.key), nil
return &packetConn{Conn: proxyConn, key: pl.key, saltGenerator: pl.saltGenerator}, nil
}

type packetConn struct {
net.Conn
key *EncryptionKey
key *EncryptionKey
saltGenerator SaltGenerator
}

var _ net.PacketConn = (*packetConn)(nil)
Expand All @@ -70,7 +82,7 @@ var _ net.PacketConn = (*packetConn)(nil)
//
// Closing the returned [net.PacketConn] will also close the underlying [net.Conn].
func NewPacketConn(conn net.Conn, key *EncryptionKey) net.PacketConn {
return &packetConn{Conn: conn, key: key}
return &packetConn{Conn: conn, key: key, saltGenerator: RandomSaltGenerator}
}

// WriteTo encrypts `b` and writes to `addr` through the proxy.
Expand All @@ -87,7 +99,7 @@ func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
// partially overlapping the plaintext and cipher slices since `Pack` skips the salt when calling
// `AEAD.Seal` (see https://golang.org/pkg/crypto/cipher/#AEAD).
plaintextBuf := append(append(cipherBuf[saltSize:saltSize], socksTargetAddr...), b...)
buf, err := Pack(cipherBuf, plaintextBuf, c.key)
buf, err := PackSalt(cipherBuf, plaintextBuf, c.key, c.saltGenerator)
if err != nil {
return 0, err
}
Expand Down
30 changes: 30 additions & 0 deletions transport/shadowsocks/packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package shadowsocks

import (
"io"
"testing"
"time"

Expand Down Expand Up @@ -43,3 +44,32 @@ func BenchmarkPack(b *testing.B) {
megabits := float64(8*len(plaintextBuf)*b.N) * 1e-6
b.ReportMetric(megabits/(elapsed.Seconds()), "mbps")
}

type fixedSaltGenerator struct {
Salt []byte
}

func (sg *fixedSaltGenerator) GetSalt(salt []byte) error {
n := copy(salt, sg.Salt)
if n < len(salt) {
return io.ErrUnexpectedEOF
}
return nil
}

func TestPack(t *testing.T) {
key := makeTestKey(t)
payload := makeTestPayload(100)
encrypted := make([]byte, len(payload)+key.SaltSize()+key.cipher.tagSize)
salt := makeTestPayload(key.SaltSize())
sg := &fixedSaltGenerator{salt}
encrypted, err := PackSalt(encrypted, payload, key, sg)
require.NoError(t, err)
// Ensure the selected salt is used.
require.Equal(t, salt, encrypted[:len(salt)])

// Ensure it decrypts correctly.
decrypted, err := Unpack(nil, encrypted, key)
require.NoError(t, err)
require.Equal(t, payload, decrypted)
}

0 comments on commit 93ebcb0

Please sign in to comment.