diff --git a/transport/shadowsocks/packet.go b/transport/shadowsocks/packet.go index cf5219bb..c4d5feb4 100644 --- a/transport/shadowsocks/packet.go +++ b/transport/shadowsocks/packet.go @@ -25,17 +25,18 @@ var ErrShortPacket = errors.New("short packet") // Assumes all ciphers have NonceSize() <= 12. var zeroNonce [12]byte -// Pack encrypts a Shadowsocks-UDP packet and returns a slice containing the encrypted packet. +// PackSalt encrypts a Shadowsocks-UDP packet and returns a slice containing the encrypted packet. // dst must be big enough to hold the encrypted packet. // If plaintext and dst overlap but are not aligned for in-place encryption, this // function will panic. -func Pack(dst, plaintext []byte, key *EncryptionKey) ([]byte, error) { +// It uses the given [SaltGenerator] to generate the salt. +func PackSalt(dst, plaintext []byte, key *EncryptionKey, sg SaltGenerator) ([]byte, error) { saltSize := key.SaltSize() if len(dst) < saltSize { return nil, io.ErrShortBuffer } salt := dst[:saltSize] - if err := RandomSaltGenerator.GetSalt(salt); err != nil { + if err := sg.GetSalt(salt); err != nil { return nil, err } @@ -50,6 +51,11 @@ func Pack(dst, plaintext []byte, key *EncryptionKey) ([]byte, error) { return aead.Seal(salt, zeroNonce[:aead.NonceSize()], plaintext, nil), nil } +// Pack calls PackSalt with the [RandomSaltGenerator]. +func Pack(dst, plaintext []byte, key *EncryptionKey) ([]byte, error) { + return PackSalt(dst, plaintext, key, RandomSaltGenerator) +} + // Unpack decrypts a Shadowsocks-UDP packet in the format [salt][cipherText][AEAD tag] and returns a slice containing // the decrypted payload or an error. // If dst is present, it is used to store the plaintext, and must have enough capacity. diff --git a/transport/shadowsocks/packet_listener.go b/transport/shadowsocks/packet_listener.go index e1ece175..4915db9a 100644 --- a/transport/shadowsocks/packet_listener.go +++ b/transport/shadowsocks/packet_listener.go @@ -33,33 +33,45 @@ const clientUDPBufferSize = 16 * 1024 var udpPool = slicepool.MakePool(clientUDPBufferSize) type packetListener struct { - endpoint transport.PacketEndpoint - key *EncryptionKey + endpoint transport.PacketEndpoint + key *EncryptionKey + saltGenerator SaltGenerator } var _ transport.PacketListener = (*packetListener)(nil) -func NewPacketListener(endpoint transport.PacketEndpoint, key *EncryptionKey) (transport.PacketListener, error) { +type PacketListener = *packetListener + +// NewPacketListener creates a new Shadowsocks PacketListener that connects to the proxy on the given endpoint +// and uses the given key for encryption. +func NewPacketListener(endpoint transport.PacketEndpoint, key *EncryptionKey) (PacketListener, error) { if endpoint == nil { return nil, errors.New("argument endpoint must not be nil") } if key == nil { return nil, errors.New("argument key must not be nil") } - return &packetListener{endpoint: endpoint, key: key}, nil + return &packetListener{endpoint: endpoint, key: key, saltGenerator: RandomSaltGenerator}, nil +} + +// SetSaltGenerator sets the SaltGenerator to use for encryption. If not set, it used the [RandomSaltGenerator] by default. +func (pl *packetListener) SetSaltGenerator(sg SaltGenerator) { + pl.saltGenerator = sg } -func (c *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { - proxyConn, err := c.endpoint.ConnectPacket(ctx) +// ListenPacket creates a net.PackeConn to send packets from the remote endpoint. +func (pl *packetListener) ListenPacket(ctx context.Context) (net.PacketConn, error) { + proxyConn, err := pl.endpoint.ConnectPacket(ctx) if err != nil { return nil, fmt.Errorf("could not connect to endpoint: %w", err) } - return NewPacketConn(proxyConn, c.key), nil + return &packetConn{Conn: proxyConn, key: pl.key, saltGenerator: pl.saltGenerator}, nil } type packetConn struct { net.Conn - key *EncryptionKey + key *EncryptionKey + saltGenerator SaltGenerator } var _ net.PacketConn = (*packetConn)(nil) @@ -70,7 +82,7 @@ var _ net.PacketConn = (*packetConn)(nil) // // Closing the returned [net.PacketConn] will also close the underlying [net.Conn]. func NewPacketConn(conn net.Conn, key *EncryptionKey) net.PacketConn { - return &packetConn{Conn: conn, key: key} + return &packetConn{Conn: conn, key: key, saltGenerator: RandomSaltGenerator} } // WriteTo encrypts `b` and writes to `addr` through the proxy. @@ -87,7 +99,7 @@ func (c *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) { // partially overlapping the plaintext and cipher slices since `Pack` skips the salt when calling // `AEAD.Seal` (see https://golang.org/pkg/crypto/cipher/#AEAD). plaintextBuf := append(append(cipherBuf[saltSize:saltSize], socksTargetAddr...), b...) - buf, err := Pack(cipherBuf, plaintextBuf, c.key) + buf, err := PackSalt(cipherBuf, plaintextBuf, c.key, c.saltGenerator) if err != nil { return 0, err } diff --git a/transport/shadowsocks/packet_test.go b/transport/shadowsocks/packet_test.go index 40408be1..81e59fc0 100644 --- a/transport/shadowsocks/packet_test.go +++ b/transport/shadowsocks/packet_test.go @@ -15,6 +15,7 @@ package shadowsocks import ( + "io" "testing" "time" @@ -43,3 +44,32 @@ func BenchmarkPack(b *testing.B) { megabits := float64(8*len(plaintextBuf)*b.N) * 1e-6 b.ReportMetric(megabits/(elapsed.Seconds()), "mbps") } + +type fixedSaltGenerator struct { + Salt []byte +} + +func (sg *fixedSaltGenerator) GetSalt(salt []byte) error { + n := copy(salt, sg.Salt) + if n < len(salt) { + return io.ErrUnexpectedEOF + } + return nil +} + +func TestPack(t *testing.T) { + key := makeTestKey(t) + payload := makeTestPayload(100) + encrypted := make([]byte, len(payload)+key.SaltSize()+key.cipher.tagSize) + salt := makeTestPayload(key.SaltSize()) + sg := &fixedSaltGenerator{salt} + encrypted, err := PackSalt(encrypted, payload, key, sg) + require.NoError(t, err) + // Ensure the selected salt is used. + require.Equal(t, salt, encrypted[:len(salt)]) + + // Ensure it decrypts correctly. + decrypted, err := Unpack(nil, encrypted, key) + require.NoError(t, err) + require.Equal(t, payload, decrypted) +}