Skip to content

Commit

Permalink
Merge pull request #8 from Snawoot/stale_modes
Browse files Browse the repository at this point in the history
Stale modes
  • Loading branch information
Snawoot authored Oct 1, 2023
2 parents 7b8e4f2 + fef237b commit c8af4c2
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 13 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ Options:
hex-encoded pre-shared key. Can be generated with genpsk subcommand
-skip-hello-verify
(server only) skip hello verify request. Useful to workaround DPI
-stale-mode value
which stale side of connection makes whole session stale (both, either, left, right) (default either)
-timeout duration
network operation timeout (default 10s)
```
Expand Down
4 changes: 3 additions & 1 deletion client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Client struct {
idleTimeout time.Duration
baseCtx context.Context
cancelCtx func()
staleMode util.StaleMode
}

func New(cfg *Config) (*Client, error) {
Expand All @@ -41,6 +42,7 @@ func New(cfg *Config) (*Client, error) {
idleTimeout: cfg.IdleTimeout,
baseCtx: baseCtx,
cancelCtx: cancelCtx,
staleMode: cfg.StaleMode,
}

lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
Expand Down Expand Up @@ -110,7 +112,7 @@ func (client *Client) serve(conn net.Conn) {
return
}

util.PairConn(conn, remoteConn, client.idleTimeout)
util.PairConn(conn, remoteConn, client.idleTimeout, client.staleMode)
}

func (client *Client) contextMaker() (context.Context, func()) {
Expand Down
2 changes: 2 additions & 0 deletions client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/Snawoot/dtlspipe/ciphers"
"github.com/Snawoot/dtlspipe/util"
)

type Config struct {
Expand All @@ -18,6 +19,7 @@ type Config struct {
MTU int
CipherSuites ciphers.CipherList
EllipticCurves ciphers.CurveList
StaleMode util.StaleMode
}

func (cfg *Config) populateDefaults() *Config {
Expand Down
6 changes: 5 additions & 1 deletion cmd/dtlspipe/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ var (
version = "undefined"

timeout = flag.Duration("timeout", 10*time.Second, "network operation timeout")
idleTime = flag.Duration("idle-time", 90*time.Second, "max idle time for UDP session")
idleTime = flag.Duration("idle-time", 30*time.Second, "max idle time for UDP session")
pskHexOpt = flag.String("psk", "", "hex-encoded pre-shared key. Can be generated with genpsk subcommand")
keyLength = flag.Uint("key-length", 16, "generate key with specified length")
identity = flag.String("identity", "", "client identity sent to server")
Expand All @@ -71,11 +71,13 @@ var (
skipHelloVerify = flag.Bool("skip-hello-verify", false, "(server only) skip hello verify request. Useful to workaround DPI")
ciphersuites = cipherlistArg{}
curves = curvelistArg{}
staleMode = util.EitherStale
)

func init() {
flag.Var(&ciphersuites, "ciphers", "colon-separated list of ciphers to use")
flag.Var(&curves, "curves", "colon-separated list of curves to use")
flag.Var(&staleMode, "stale-mode", "which stale side of connection makes whole session stale (both, either, left, right)")
}

func usage() {
Expand Down Expand Up @@ -136,6 +138,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
MTU: *mtu,
CipherSuites: ciphersuites.Value,
EllipticCurves: curves.Value,
StaleMode: staleMode,
}

clt, err := client.New(&cfg)
Expand Down Expand Up @@ -172,6 +175,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
SkipHelloVerify: *skipHelloVerify,
CipherSuites: ciphersuites.Value,
EllipticCurves: curves.Value,
StaleMode: staleMode,
}

srv, err := server.New(&cfg)
Expand Down
2 changes: 2 additions & 0 deletions server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"time"

"github.com/Snawoot/dtlspipe/ciphers"
"github.com/Snawoot/dtlspipe/util"
)

type Config struct {
Expand All @@ -18,6 +19,7 @@ type Config struct {
SkipHelloVerify bool
CipherSuites ciphers.CipherList
EllipticCurves ciphers.CurveList
StaleMode util.StaleMode
}

func (cfg *Config) populateDefaults() *Config {
Expand Down
4 changes: 3 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Server struct {
idleTimeout time.Duration
baseCtx context.Context
cancelCtx func()
staleMode util.StaleMode
}

func New(cfg *Config) (*Server, error) {
Expand All @@ -42,6 +43,7 @@ func New(cfg *Config) (*Server, error) {
idleTimeout: cfg.IdleTimeout,
baseCtx: baseCtx,
cancelCtx: cancelCtx,
staleMode: cfg.StaleMode,
}

lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
Expand Down Expand Up @@ -122,7 +124,7 @@ func (srv *Server) serve(conn net.Conn) {
}
defer remoteConn.Close()

util.PairConn(conn, remoteConn, srv.idleTimeout)
util.PairConn(conn, remoteConn, srv.idleTimeout, srv.staleMode)
}

func (srv *Server) contextMaker() (context.Context, func()) {
Expand Down
106 changes: 106 additions & 0 deletions util/tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package util

import (
"errors"
"sync/atomic"
)

type StaleMode int

const (
BothStale StaleMode = iota
EitherStale
LeftStale
RightStale
)

func (m *StaleMode) String() string {
if m == nil {
return "<nil>"
}
switch *m {
case BothStale:
return "both"
case EitherStale:
return "either"
case LeftStale:
return "left"
case RightStale:
return "right"
}
return "<unknown>"
}

func (m *StaleMode) Set(val string) error {
switch val {
case "both":
*m = BothStale
case "either":
*m = EitherStale
case "left":
*m = LeftStale
case "right":
*m = RightStale
default:
return errors.New("unknown stale mode")
}
return nil
}

type tracker struct {
leftCounter atomic.Int32
rightCounter atomic.Int32
leftTimedOutAt atomic.Int32
rightTimedOutAt atomic.Int32
staleFunc func() bool
}

func newTracker(staleMode StaleMode) *tracker {
t := &tracker{}
switch staleMode {
case BothStale:
t.staleFunc = t.bothStale
case EitherStale:
t.staleFunc = t.eitherStale
case LeftStale:
t.staleFunc = t.leftStale
case RightStale:
t.staleFunc = t.rightStale
default:
panic("unsupported stale mode")
}
return t
}

func (t *tracker) notify(isLeft bool) {
if isLeft {
t.leftCounter.Add(1)
} else {
t.rightCounter.Add(1)
}
}

func (t *tracker) handleTimeout(isLeft bool) bool {
if isLeft {
t.leftTimedOutAt.Store(t.leftCounter.Load())
} else {
t.rightTimedOutAt.Store(t.rightCounter.Load())
}
return !t.staleFunc()
}

func (t *tracker) leftStale() bool {
return t.leftCounter.Load() == t.leftTimedOutAt.Load()
}

func (t *tracker) rightStale() bool {
return t.rightCounter.Load() == t.rightTimedOutAt.Load()
}

func (t *tracker) bothStale() bool {
return t.leftStale() && t.rightStale()
}

func (t *tracker) eitherStale() bool {
return t.leftStale() || t.rightStale()
}
17 changes: 7 additions & 10 deletions util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"log"
"net"
"sync"
"sync/atomic"
"time"
)

Expand Down Expand Up @@ -56,17 +55,15 @@ const (
MaxPktBuf = 65536
)

func PairConn(left, right net.Conn, idleTimeout time.Duration) {
var lsn atomic.Int32
func PairConn(left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) {
var wg sync.WaitGroup
tracker := newTracker(staleMode)

copier := func(dst, src net.Conn) {
copier := func(dst, src net.Conn, label bool) {
defer wg.Done()
defer dst.Close()
buf := make([]byte, MaxPktBuf)
for {
oldLSN := lsn.Load()

if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
log.Printf("can't update deadline for connection: %v", err)
break
Expand All @@ -76,7 +73,7 @@ func PairConn(left, right net.Conn, idleTimeout time.Duration) {
if err != nil {
if isTimeout(err) {
// hit read deadline
if oldLSN != lsn.Load() {
if tracker.handleTimeout(label) {
// not stale conn
continue
} else {
Expand All @@ -93,7 +90,7 @@ func PairConn(left, right net.Conn, idleTimeout time.Duration) {
break
}

lsn.Add(1)
tracker.notify(label)

_, err = dst.Write(buf[:n])
if err != nil {
Expand All @@ -104,7 +101,7 @@ func PairConn(left, right net.Conn, idleTimeout time.Duration) {
}

wg.Add(2)
go copier(left, right)
go copier(right, left)
go copier(left, right, false)
go copier(right, left, true)
wg.Wait()
}

0 comments on commit c8af4c2

Please sign in to comment.