瀏覽代碼

Merge pull request #14 from Snawoot/random_timelimit

Random timelimit
Snawoot 1 年之前
父節點
當前提交
dc3492b067
共有 7 個文件被更改,包括 127 次插入54 次删除
  1. 1 1
      README.md
  2. 24 23
      client/client.go
  3. 4 1
      client/config.go
  4. 43 3
      cmd/dtlspipe/main.go
  5. 4 1
      server/config.go
  6. 24 23
      server/server.go
  7. 27 2
      util/util.go

+ 1 - 1
README.md

@@ -96,7 +96,7 @@ Options:
   -stale-mode value
   -stale-mode value
     	which stale side of connection makes whole session stale (both, either, left, right) (default either)
     	which stale side of connection makes whole session stale (both, either, left, right) (default either)
   -time-limit duration
   -time-limit duration
-    	hard time limit for each session
+    	limit for each session duration. Use single value X for fixed limit or range X-Y for randomized limit
   -timeout duration
   -timeout duration
     	network operation timeout (default 10s)
     	network operation timeout (default 10s)
 ```
 ```

+ 24 - 23
client/client.go

@@ -20,18 +20,18 @@ const (
 )
 )
 
 
 type Client struct {
 type Client struct {
-	listener    net.Listener
-	dtlsConfig  *dtls.Config
-	rAddr       string
-	psk         func([]byte) ([]byte, error)
-	timeout     time.Duration
-	idleTimeout time.Duration
-	baseCtx     context.Context
-	cancelCtx   func()
-	staleMode   util.StaleMode
-	workerWG    sync.WaitGroup
-	timeLimit   time.Duration
-	allowFunc   func(net.Addr, net.Addr) bool
+	listener      net.Listener
+	dtlsConfig    *dtls.Config
+	rAddr         string
+	psk           func([]byte) ([]byte, error)
+	timeout       time.Duration
+	idleTimeout   time.Duration
+	baseCtx       context.Context
+	cancelCtx     func()
+	staleMode     util.StaleMode
+	workerWG      sync.WaitGroup
+	timeLimitFunc func() time.Duration
+	allowFunc     func(net.Addr, net.Addr) bool
 }
 }
 
 
 func New(cfg *Config) (*Client, error) {
 func New(cfg *Config) (*Client, error) {
@@ -40,15 +40,15 @@ func New(cfg *Config) (*Client, error) {
 	baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
 	baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
 
 
 	client := &Client{
 	client := &Client{
-		rAddr:       cfg.RemoteAddress,
-		timeout:     cfg.Timeout,
-		psk:         cfg.PSKCallback,
-		idleTimeout: cfg.IdleTimeout,
-		baseCtx:     baseCtx,
-		cancelCtx:   cancelCtx,
-		staleMode:   cfg.StaleMode,
-		timeLimit:   cfg.TimeLimit,
-		allowFunc:   cfg.AllowFunc,
+		rAddr:         cfg.RemoteAddress,
+		timeout:       cfg.Timeout,
+		psk:           cfg.PSKCallback,
+		idleTimeout:   cfg.IdleTimeout,
+		baseCtx:       baseCtx,
+		cancelCtx:     cancelCtx,
+		staleMode:     cfg.StaleMode,
+		timeLimitFunc: cfg.TimeLimitFunc,
+		allowFunc:     cfg.AllowFunc,
 	}
 	}
 
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -110,8 +110,9 @@ func (client *Client) serve(conn net.Conn) {
 	defer conn.Close()
 	defer conn.Close()
 
 
 	ctx := client.baseCtx
 	ctx := client.baseCtx
-	if client.timeLimit != 0 {
-		newCtx, cancel := context.WithTimeout(ctx, client.timeLimit)
+	tl := client.timeLimitFunc()
+	if tl != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, tl)
 		defer cancel()
 		defer cancel()
 		ctx = newCtx
 		ctx = newCtx
 	}
 	}

+ 4 - 1
client/config.go

@@ -21,7 +21,7 @@ type Config struct {
 	CipherSuites   ciphers.CipherList
 	CipherSuites   ciphers.CipherList
 	EllipticCurves ciphers.CurveList
 	EllipticCurves ciphers.CurveList
 	StaleMode      util.StaleMode
 	StaleMode      util.StaleMode
-	TimeLimit      time.Duration
+	TimeLimitFunc  func() time.Duration
 	AllowFunc      func(localAddr, remoteAddr net.Addr) bool
 	AllowFunc      func(localAddr, remoteAddr net.Addr) bool
 }
 }
 
 
@@ -41,6 +41,9 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.EllipticCurves == nil {
 	if cfg.EllipticCurves == nil {
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 	}
 	}
+	if cfg.TimeLimitFunc == nil {
+		cfg.TimeLimitFunc = util.FixedTimeLimitFunc(0)
+	}
 	if cfg.AllowFunc == nil {
 	if cfg.AllowFunc == nil {
 		cfg.AllowFunc = util.AllowAllFunc
 		cfg.AllowFunc = util.AllowAllFunc
 	}
 	}

+ 43 - 3
cmd/dtlspipe/main.go

@@ -2,6 +2,7 @@ package main
 
 
 import (
 import (
 	"context"
 	"context"
+	"errors"
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"log"
 	"log"
@@ -60,6 +61,44 @@ func (l *curvelistArg) Set(s string) error {
 	return nil
 	return nil
 }
 }
 
 
+type timelimitArg struct {
+	low  time.Duration
+	high time.Duration
+}
+
+func (a *timelimitArg) String() string {
+	if a.low == a.high {
+		return a.low.String()
+	}
+	return fmt.Sprintf("%s-%s", a.low.String(), a.high.String())
+}
+
+func (a *timelimitArg) Set(s string) error {
+	parts := strings.SplitN(s, "-", 2)
+	switch len(parts) {
+	case 1:
+		dur, err := time.ParseDuration(s)
+		if err != nil {
+			return err
+		}
+		a.low, a.high = dur, dur
+		return nil
+	case 2:
+		durLow, err := time.ParseDuration(parts[0])
+		if err != nil {
+			return fmt.Errorf("first component parse failed: %w", err)
+		}
+		durHigh, err := time.ParseDuration(parts[1])
+		if err != nil {
+			return fmt.Errorf("second component parse failed: %w", err)
+		}
+		a.low, a.high = durLow, durHigh
+		return nil
+	default:
+		return errors.New("unexpected number of components")
+	}
+}
+
 type ratelimitArg struct {
 type ratelimitArg struct {
 	value rlzone.Ratelimiter[netip.Addr]
 	value rlzone.Ratelimiter[netip.Addr]
 }
 }
@@ -98,7 +137,7 @@ var (
 	ciphersuites    = cipherlistArg{}
 	ciphersuites    = cipherlistArg{}
 	curves          = curvelistArg{}
 	curves          = curvelistArg{}
 	staleMode       = util.EitherStale
 	staleMode       = util.EitherStale
-	timeLimit       = flag.Duration("time-limit", 0, "hard time limit for each session")
+	timeLimit       = timelimitArg{}
 	rateLimit       = ratelimitArg{rlzone.Must(rlzone.NewSmallest[netip.Addr](1*time.Minute, 20))}
 	rateLimit       = ratelimitArg{rlzone.Must(rlzone.NewSmallest[netip.Addr](1*time.Minute, 20))}
 )
 )
 
 
@@ -107,6 +146,7 @@ func init() {
 	flag.Var(&curves, "curves", "colon-separated list of curves to use")
 	flag.Var(&curves, "curves", "colon-separated list of curves to use")
 	flag.Var(&staleMode, "stale-mode", "which stale side of connection makes whole session stale (both, either, left, right)")
 	flag.Var(&staleMode, "stale-mode", "which stale side of connection makes whole session stale (both, either, left, right)")
 	flag.Var(&rateLimit, "rate-limit", "limit for incoming connections rate. Format: <limit>/<time duration> or empty string to disable")
 	flag.Var(&rateLimit, "rate-limit", "limit for incoming connections rate. Format: <limit>/<time duration> or empty string to disable")
+	flag.Var(&timeLimit, "time-limit", "limit for each session `duration`. Use single value X for fixed limit or range X-Y for randomized limit")
 }
 }
 
 
 func usage() {
 func usage() {
@@ -168,7 +208,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		CipherSuites:   ciphersuites.Value,
 		CipherSuites:   ciphersuites.Value,
 		EllipticCurves: curves.Value,
 		EllipticCurves: curves.Value,
 		StaleMode:      staleMode,
 		StaleMode:      staleMode,
-		TimeLimit:      *timeLimit,
+		TimeLimitFunc:  util.TimeLimitFunc(timeLimit.low, timeLimit.high),
 		AllowFunc:      util.AllowByRatelimit(rateLimit.value),
 		AllowFunc:      util.AllowByRatelimit(rateLimit.value),
 	}
 	}
 
 
@@ -207,7 +247,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		CipherSuites:    ciphersuites.Value,
 		CipherSuites:    ciphersuites.Value,
 		EllipticCurves:  curves.Value,
 		EllipticCurves:  curves.Value,
 		StaleMode:       staleMode,
 		StaleMode:       staleMode,
-		TimeLimit:       *timeLimit,
+		TimeLimitFunc:   util.TimeLimitFunc(timeLimit.low, timeLimit.high),
 		AllowFunc:       util.AllowByRatelimit(rateLimit.value),
 		AllowFunc:       util.AllowByRatelimit(rateLimit.value),
 	}
 	}
 
 

+ 4 - 1
server/config.go

@@ -21,7 +21,7 @@ type Config struct {
 	CipherSuites    ciphers.CipherList
 	CipherSuites    ciphers.CipherList
 	EllipticCurves  ciphers.CurveList
 	EllipticCurves  ciphers.CurveList
 	StaleMode       util.StaleMode
 	StaleMode       util.StaleMode
-	TimeLimit       time.Duration
+	TimeLimitFunc   func() time.Duration
 	AllowFunc       func(localAddr, remoteAddr net.Addr) bool
 	AllowFunc       func(localAddr, remoteAddr net.Addr) bool
 }
 }
 
 
@@ -41,6 +41,9 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.EllipticCurves == nil {
 	if cfg.EllipticCurves == nil {
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 	}
 	}
+	if cfg.TimeLimitFunc == nil {
+		cfg.TimeLimitFunc = util.FixedTimeLimitFunc(0)
+	}
 	if cfg.AllowFunc == nil {
 	if cfg.AllowFunc == nil {
 		cfg.AllowFunc = util.AllowAllFunc
 		cfg.AllowFunc = util.AllowAllFunc
 	}
 	}

+ 24 - 23
server/server.go

@@ -21,18 +21,18 @@ const (
 )
 )
 
 
 type Server struct {
 type Server struct {
-	listener    net.Listener
-	dtlsConfig  *dtls.Config
-	rAddr       string
-	psk         func([]byte) ([]byte, error)
-	timeout     time.Duration
-	idleTimeout time.Duration
-	baseCtx     context.Context
-	cancelCtx   func()
-	staleMode   util.StaleMode
-	workerWG    sync.WaitGroup
-	timeLimit   time.Duration
-	allowFunc   func(net.Addr, net.Addr) bool
+	listener      net.Listener
+	dtlsConfig    *dtls.Config
+	rAddr         string
+	psk           func([]byte) ([]byte, error)
+	timeout       time.Duration
+	idleTimeout   time.Duration
+	baseCtx       context.Context
+	cancelCtx     func()
+	staleMode     util.StaleMode
+	workerWG      sync.WaitGroup
+	timeLimitFunc func() time.Duration
+	allowFunc     func(net.Addr, net.Addr) bool
 }
 }
 
 
 func New(cfg *Config) (*Server, error) {
 func New(cfg *Config) (*Server, error) {
@@ -41,15 +41,15 @@ func New(cfg *Config) (*Server, error) {
 	baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
 	baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
 
 
 	srv := &Server{
 	srv := &Server{
-		rAddr:       cfg.RemoteAddress,
-		timeout:     cfg.Timeout,
-		psk:         cfg.PSKCallback,
-		idleTimeout: cfg.IdleTimeout,
-		baseCtx:     baseCtx,
-		cancelCtx:   cancelCtx,
-		staleMode:   cfg.StaleMode,
-		timeLimit:   cfg.TimeLimit,
-		allowFunc:   cfg.AllowFunc,
+		rAddr:         cfg.RemoteAddress,
+		timeout:       cfg.Timeout,
+		psk:           cfg.PSKCallback,
+		idleTimeout:   cfg.IdleTimeout,
+		baseCtx:       baseCtx,
+		cancelCtx:     cancelCtx,
+		staleMode:     cfg.StaleMode,
+		timeLimitFunc: cfg.TimeLimitFunc,
+		allowFunc:     cfg.AllowFunc,
 	}
 	}
 
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -128,8 +128,9 @@ func (srv *Server) serve(conn net.Conn) {
 	defer conn.Close()
 	defer conn.Close()
 
 
 	ctx := srv.baseCtx
 	ctx := srv.baseCtx
-	if srv.timeLimit != 0 {
-		newCtx, cancel := context.WithTimeout(ctx, srv.timeLimit)
+	tl := srv.timeLimitFunc()
+	if tl != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, tl)
 		defer cancel()
 		defer cancel()
 		ctx = newCtx
 		ctx = newCtx
 	}
 	}

+ 27 - 2
util/util.go

@@ -2,10 +2,11 @@ package util
 
 
 import (
 import (
 	"context"
 	"context"
-	"crypto/rand"
+	crand "crypto/rand"
 	"encoding/hex"
 	"encoding/hex"
 	"fmt"
 	"fmt"
 	"log"
 	"log"
+	"math/rand"
 	"net"
 	"net"
 	"net/netip"
 	"net/netip"
 	"sync"
 	"sync"
@@ -16,7 +17,7 @@ import (
 
 
 func GenPSK(length int) ([]byte, error) {
 func GenPSK(length int) ([]byte, error) {
 	b := make([]byte, length)
 	b := make([]byte, length)
-	_, err := rand.Read(b)
+	_, err := crand.Read(b)
 	if err != nil {
 	if err != nil {
 		return nil, fmt.Errorf("random bytes generation failed: %w", err)
 		return nil, fmt.Errorf("random bytes generation failed: %w", err)
 	}
 	}
@@ -145,3 +146,27 @@ func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr, net.Addr)
 		return z.Allow(key)
 		return z.Allow(key)
 	}
 	}
 }
 }
+
+func FixedTimeLimitFunc(d time.Duration) func() time.Duration {
+	return func() time.Duration {
+		return d
+	}
+}
+
+func TimeLimitFunc(low, high time.Duration) func() time.Duration {
+	if low > high {
+		return TimeLimitFunc(high, low)
+	}
+	if low == high {
+		return FixedTimeLimitFunc(low)
+	}
+
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
+	var mux sync.Mutex
+	delta := high - low
+	return func() time.Duration {
+		mux.Lock()
+		defer mux.Unlock()
+		return low + time.Duration(r.Int63n(int64(delta)))
+	}
+}