1
0
Vladislav Yarmak 2 жил өмнө
parent
commit
f5c594166c

+ 3 - 1
client/client.go

@@ -27,6 +27,7 @@ type Client struct {
 	idleTimeout time.Duration
 	baseCtx     context.Context
 	cancelCtx   func()
+	staleMode   util.StaleMode
 }
 
 func New(cfg *Config) (*Client, error) {
@@ -41,6 +42,7 @@ func New(cfg *Config) (*Client, error) {
 		idleTimeout: cfg.IdleTimeout,
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
+		staleMode:   cfg.StaleMode,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -110,7 +112,7 @@ func (client *Client) serve(conn net.Conn) {
 		return
 	}
 
-	util.PairConn(conn, remoteConn, client.idleTimeout)
+	util.PairConn(conn, remoteConn, client.idleTimeout, client.staleMode)
 }
 
 func (client *Client) contextMaker() (context.Context, func()) {

+ 2 - 0
client/config.go

@@ -5,6 +5,7 @@ import (
 	"time"
 
 	"github.com/Snawoot/dtlspipe/ciphers"
+	"github.com/Snawoot/dtlspipe/util"
 )
 
 type Config struct {
@@ -18,6 +19,7 @@ type Config struct {
 	MTU            int
 	CipherSuites   ciphers.CipherList
 	EllipticCurves ciphers.CurveList
+	StaleMode      util.StaleMode
 }
 
 func (cfg *Config) populateDefaults() *Config {

+ 4 - 0
cmd/dtlspipe/main.go

@@ -71,11 +71,13 @@ var (
 	skipHelloVerify = flag.Bool("skip-hello-verify", false, "(server only) skip hello verify request. Useful to workaround DPI")
 	ciphersuites    = cipherlistArg{}
 	curves          = curvelistArg{}
+	staleMode       = util.EitherStale
 )
 
 func init() {
 	flag.Var(&ciphersuites, "ciphers", "colon-separated list of ciphers to use")
 	flag.Var(&curves, "curves", "colon-separated list of curves to use")
+	flag.Var(&staleMode, "stale-mode", "which stale side of connection makes whole session stale")
 }
 
 func usage() {
@@ -136,6 +138,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		MTU:            *mtu,
 		CipherSuites:   ciphersuites.Value,
 		EllipticCurves: curves.Value,
+		StaleMode:      staleMode,
 	}
 
 	clt, err := client.New(&cfg)
@@ -172,6 +175,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		SkipHelloVerify: *skipHelloVerify,
 		CipherSuites:    ciphersuites.Value,
 		EllipticCurves:  curves.Value,
+		StaleMode:       staleMode,
 	}
 
 	srv, err := server.New(&cfg)

+ 2 - 0
server/config.go

@@ -5,6 +5,7 @@ import (
 	"time"
 
 	"github.com/Snawoot/dtlspipe/ciphers"
+	"github.com/Snawoot/dtlspipe/util"
 )
 
 type Config struct {
@@ -18,6 +19,7 @@ type Config struct {
 	SkipHelloVerify bool
 	CipherSuites    ciphers.CipherList
 	EllipticCurves  ciphers.CurveList
+	StaleMode       util.StaleMode
 }
 
 func (cfg *Config) populateDefaults() *Config {

+ 3 - 1
server/server.go

@@ -28,6 +28,7 @@ type Server struct {
 	idleTimeout time.Duration
 	baseCtx     context.Context
 	cancelCtx   func()
+	staleMode   util.StaleMode
 }
 
 func New(cfg *Config) (*Server, error) {
@@ -42,6 +43,7 @@ func New(cfg *Config) (*Server, error) {
 		idleTimeout: cfg.IdleTimeout,
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
+		staleMode:   cfg.StaleMode,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -122,7 +124,7 @@ func (srv *Server) serve(conn net.Conn) {
 	}
 	defer remoteConn.Close()
 
-	util.PairConn(conn, remoteConn, srv.idleTimeout)
+	util.PairConn(conn, remoteConn, srv.idleTimeout, srv.staleMode)
 }
 
 func (srv *Server) contextMaker() (context.Context, func()) {

+ 44 - 8
util/tracker.go

@@ -1,35 +1,71 @@
 package util
 
-import "sync/atomic"
+import (
+	"errors"
+	"sync/atomic"
+)
 
 type StaleMode int
 
 const (
-	BothStale = iota
+	BothStale StaleMode = iota
 	EitherStale
 	LeftStale
 	RightStale
 )
 
+func (m *StaleMode) String() string {
+	if m == nil {
+		return "<nil>"
+	}
+	switch *m {
+	case BothStale:
+		return "both"
+	case EitherStale:
+		return "either"
+	case LeftStale:
+		return "left"
+	case RightStale:
+		return "right"
+	}
+	return "<unknown>"
+}
+
+func (m *StaleMode) Set(val string) error {
+	switch val {
+	case "both":
+		*m = BothStale
+	case "either":
+		*m = EitherStale
+	case "left":
+		*m = LeftStale
+	case "right":
+		*m = RightStale
+	default:
+		return errors.New("unknown stale mode")
+	}
+	return nil
+}
+
 type tracker struct {
 	leftCounter     atomic.Int32
 	rightCounter    atomic.Int32
 	leftTimedOutAt  atomic.Int32
 	rightTimedOutAt atomic.Int32
-	staleFun        func() bool
+	staleFunc       func() bool
 }
 
 func newTracker(staleMode StaleMode) *tracker {
 	t := &tracker{}
 	switch staleMode {
 	case BothStale:
-		t.staleFun = t.bothStale
+		t.staleFunc = t.bothStale
 	case EitherStale:
-		t.staleFun = t.eitherStale
+		t.staleFunc = t.eitherStale
 	case LeftStale:
-		t.staleFun = t.leftStale
+		t.staleFunc = t.leftStale
 	case RightStale:
-		t.staleFun = t.rightStale
+		t.staleFunc = t.rightStale
 	default:
 		panic("unsupported stale mode")
 	}
@@ -50,7 +86,7 @@ func (t *tracker) handleTimeout(isLeft bool) bool {
 	} else {
 		t.rightTimedOutAt.Store(t.rightCounter.Load())
 	}
-	return t.staleFun()
+	return !t.staleFunc()
 }
 
 func (t *tracker) leftStale() bool {

+ 7 - 10
util/util.go

@@ -7,7 +7,6 @@ import (
 	"log"
 	"net"
 	"sync"
-	"sync/atomic"
 	"time"
 )
 
@@ -56,17 +55,15 @@ const (
 	MaxPktBuf = 65536
 )
 
-func PairConn(left, right net.Conn, idleTimeout time.Duration) {
-	var lsn atomic.Int32
+func PairConn(left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) {
 	var wg sync.WaitGroup
+	tracker := newTracker(staleMode)
 
-	copier := func(dst, src net.Conn) {
+	copier := func(dst, src net.Conn, label bool) {
 		defer wg.Done()
 		defer dst.Close()
 		buf := make([]byte, MaxPktBuf)
 		for {
-			oldLSN := lsn.Load()
-
 			if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
 				log.Printf("can't update deadline for connection: %v", err)
 				break
@@ -76,7 +73,7 @@ func PairConn(left, right net.Conn, idleTimeout time.Duration) {
 			if err != nil {
 				if isTimeout(err) {
 					// hit read deadline
-					if oldLSN != lsn.Load() {
+					if tracker.handleTimeout(label) {
 						// not stale conn
 						continue
 					} else {
@@ -93,7 +90,7 @@ func PairConn(left, right net.Conn, idleTimeout time.Duration) {
 				break
 			}
 
-			lsn.Add(1)
+			tracker.notify(label)
 
 			_, err = dst.Write(buf[:n])
 			if err != nil {
@@ -104,7 +101,7 @@ func PairConn(left, right net.Conn, idleTimeout time.Duration) {
 	}
 
 	wg.Add(2)
-	go copier(left, right)
-	go copier(right, left)
+	go copier(left, right, false)
+	go copier(right, left, true)
 	wg.Wait()
 }