소스 검색

random time limit

Vladislav Yarmak 1 년 전
부모
커밋
c26d1ca9b5
6개의 변경된 파일126개의 추가작업 그리고 53개의 파일을 삭제
  1. 24 23
      client/client.go
  2. 4 1
      client/config.go
  3. 43 3
      cmd/dtlspipe/main.go
  4. 4 1
      server/config.go
  5. 24 23
      server/server.go
  6. 27 2
      util/util.go

+ 24 - 23
client/client.go

@@ -20,18 +20,18 @@ const (
 )
 
 type Client struct {
-	listener    net.Listener
-	dtlsConfig  *dtls.Config
-	rAddr       string
-	psk         func([]byte) ([]byte, error)
-	timeout     time.Duration
-	idleTimeout time.Duration
-	baseCtx     context.Context
-	cancelCtx   func()
-	staleMode   util.StaleMode
-	workerWG    sync.WaitGroup
-	timeLimit   time.Duration
-	allowFunc   func(net.Addr, net.Addr) bool
+	listener      net.Listener
+	dtlsConfig    *dtls.Config
+	rAddr         string
+	psk           func([]byte) ([]byte, error)
+	timeout       time.Duration
+	idleTimeout   time.Duration
+	baseCtx       context.Context
+	cancelCtx     func()
+	staleMode     util.StaleMode
+	workerWG      sync.WaitGroup
+	timeLimitFunc func() time.Duration
+	allowFunc     func(net.Addr, net.Addr) bool
 }
 
 func New(cfg *Config) (*Client, error) {
@@ -40,15 +40,15 @@ func New(cfg *Config) (*Client, error) {
 	baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
 
 	client := &Client{
-		rAddr:       cfg.RemoteAddress,
-		timeout:     cfg.Timeout,
-		psk:         cfg.PSKCallback,
-		idleTimeout: cfg.IdleTimeout,
-		baseCtx:     baseCtx,
-		cancelCtx:   cancelCtx,
-		staleMode:   cfg.StaleMode,
-		timeLimit:   cfg.TimeLimit,
-		allowFunc:   cfg.AllowFunc,
+		rAddr:         cfg.RemoteAddress,
+		timeout:       cfg.Timeout,
+		psk:           cfg.PSKCallback,
+		idleTimeout:   cfg.IdleTimeout,
+		baseCtx:       baseCtx,
+		cancelCtx:     cancelCtx,
+		staleMode:     cfg.StaleMode,
+		timeLimitFunc: cfg.TimeLimitFunc,
+		allowFunc:     cfg.AllowFunc,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -110,8 +110,9 @@ func (client *Client) serve(conn net.Conn) {
 	defer conn.Close()
 
 	ctx := client.baseCtx
-	if client.timeLimit != 0 {
-		newCtx, cancel := context.WithTimeout(ctx, client.timeLimit)
+	tl := client.timeLimitFunc()
+	if tl != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, tl)
 		defer cancel()
 		ctx = newCtx
 	}

+ 4 - 1
client/config.go

@@ -21,7 +21,7 @@ type Config struct {
 	CipherSuites   ciphers.CipherList
 	EllipticCurves ciphers.CurveList
 	StaleMode      util.StaleMode
-	TimeLimit      time.Duration
+	TimeLimitFunc  func() time.Duration
 	AllowFunc      func(localAddr, remoteAddr net.Addr) bool
 }
 
@@ -41,6 +41,9 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.EllipticCurves == nil {
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 	}
+	if cfg.TimeLimitFunc == nil {
+		cfg.TimeLimitFunc = util.FixedTimeLimitFunc(0)
+	}
 	if cfg.AllowFunc == nil {
 		cfg.AllowFunc = util.AllowAllFunc
 	}

+ 43 - 3
cmd/dtlspipe/main.go

@@ -2,6 +2,7 @@ package main
 
 import (
 	"context"
+	"errors"
 	"flag"
 	"fmt"
 	"log"
@@ -60,6 +61,44 @@ func (l *curvelistArg) Set(s string) error {
 	return nil
 }
 
+type timelimitArg struct {
+	low  time.Duration
+	high time.Duration
+}
+
+func (a *timelimitArg) String() string {
+	if a.low == a.high {
+		return a.low.String()
+	}
+	return fmt.Sprintf("%s-%s", a.low.String(), a.high.String())
+}
+
+func (a *timelimitArg) Set(s string) error {
+	parts := strings.SplitN(s, "-", 2)
+	switch len(parts) {
+	case 1:
+		dur, err := time.ParseDuration(s)
+		if err != nil {
+			return err
+		}
+		a.low, a.high = dur, dur
+		return nil
+	case 2:
+		durLow, err := time.ParseDuration(parts[0])
+		if err != nil {
+			return fmt.Errorf("first component parse failed: %w", err)
+		}
+		durHigh, err := time.ParseDuration(parts[1])
+		if err != nil {
+			return fmt.Errorf("second component parse failed: %w", err)
+		}
+		a.low, a.high = durLow, durHigh
+		return nil
+	default:
+		return errors.New("unexpected number of components")
+	}
+}
+
 type ratelimitArg struct {
 	value rlzone.Ratelimiter[netip.Addr]
 }
@@ -98,7 +137,7 @@ var (
 	ciphersuites    = cipherlistArg{}
 	curves          = curvelistArg{}
 	staleMode       = util.EitherStale
-	timeLimit       = flag.Duration("time-limit", 0, "hard time limit for each session")
+	timeLimit       = timelimitArg{}
 	rateLimit       = ratelimitArg{rlzone.Must(rlzone.NewSmallest[netip.Addr](1*time.Minute, 20))}
 )
 
@@ -107,6 +146,7 @@ func init() {
 	flag.Var(&curves, "curves", "colon-separated list of curves to use")
 	flag.Var(&staleMode, "stale-mode", "which stale side of connection makes whole session stale (both, either, left, right)")
 	flag.Var(&rateLimit, "rate-limit", "limit for incoming connections rate. Format: <limit>/<time duration> or empty string to disable")
+	flag.Var(&timeLimit, "time-limit", "limit for each session `duration`. Use single value X for fixed limit or range X-Y for randomized limit")
 }
 
 func usage() {
@@ -168,7 +208,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		CipherSuites:   ciphersuites.Value,
 		EllipticCurves: curves.Value,
 		StaleMode:      staleMode,
-		TimeLimit:      *timeLimit,
+		TimeLimitFunc:  util.TimeLimitFunc(timeLimit.low, timeLimit.high),
 		AllowFunc:      util.AllowByRatelimit(rateLimit.value),
 	}
 
@@ -207,7 +247,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		CipherSuites:    ciphersuites.Value,
 		EllipticCurves:  curves.Value,
 		StaleMode:       staleMode,
-		TimeLimit:       *timeLimit,
+		TimeLimitFunc:   util.TimeLimitFunc(timeLimit.low, timeLimit.high),
 		AllowFunc:       util.AllowByRatelimit(rateLimit.value),
 	}
 

+ 4 - 1
server/config.go

@@ -21,7 +21,7 @@ type Config struct {
 	CipherSuites    ciphers.CipherList
 	EllipticCurves  ciphers.CurveList
 	StaleMode       util.StaleMode
-	TimeLimit       time.Duration
+	TimeLimitFunc   func() time.Duration
 	AllowFunc       func(localAddr, remoteAddr net.Addr) bool
 }
 
@@ -41,6 +41,9 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.EllipticCurves == nil {
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 	}
+	if cfg.TimeLimitFunc == nil {
+		cfg.TimeLimitFunc = util.FixedTimeLimitFunc(0)
+	}
 	if cfg.AllowFunc == nil {
 		cfg.AllowFunc = util.AllowAllFunc
 	}

+ 24 - 23
server/server.go

@@ -21,18 +21,18 @@ const (
 )
 
 type Server struct {
-	listener    net.Listener
-	dtlsConfig  *dtls.Config
-	rAddr       string
-	psk         func([]byte) ([]byte, error)
-	timeout     time.Duration
-	idleTimeout time.Duration
-	baseCtx     context.Context
-	cancelCtx   func()
-	staleMode   util.StaleMode
-	workerWG    sync.WaitGroup
-	timeLimit   time.Duration
-	allowFunc   func(net.Addr, net.Addr) bool
+	listener      net.Listener
+	dtlsConfig    *dtls.Config
+	rAddr         string
+	psk           func([]byte) ([]byte, error)
+	timeout       time.Duration
+	idleTimeout   time.Duration
+	baseCtx       context.Context
+	cancelCtx     func()
+	staleMode     util.StaleMode
+	workerWG      sync.WaitGroup
+	timeLimitFunc func() time.Duration
+	allowFunc     func(net.Addr, net.Addr) bool
 }
 
 func New(cfg *Config) (*Server, error) {
@@ -41,15 +41,15 @@ func New(cfg *Config) (*Server, error) {
 	baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
 
 	srv := &Server{
-		rAddr:       cfg.RemoteAddress,
-		timeout:     cfg.Timeout,
-		psk:         cfg.PSKCallback,
-		idleTimeout: cfg.IdleTimeout,
-		baseCtx:     baseCtx,
-		cancelCtx:   cancelCtx,
-		staleMode:   cfg.StaleMode,
-		timeLimit:   cfg.TimeLimit,
-		allowFunc:   cfg.AllowFunc,
+		rAddr:         cfg.RemoteAddress,
+		timeout:       cfg.Timeout,
+		psk:           cfg.PSKCallback,
+		idleTimeout:   cfg.IdleTimeout,
+		baseCtx:       baseCtx,
+		cancelCtx:     cancelCtx,
+		staleMode:     cfg.StaleMode,
+		timeLimitFunc: cfg.TimeLimitFunc,
+		allowFunc:     cfg.AllowFunc,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -128,8 +128,9 @@ func (srv *Server) serve(conn net.Conn) {
 	defer conn.Close()
 
 	ctx := srv.baseCtx
-	if srv.timeLimit != 0 {
-		newCtx, cancel := context.WithTimeout(ctx, srv.timeLimit)
+	tl := srv.timeLimitFunc()
+	if tl != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, tl)
 		defer cancel()
 		ctx = newCtx
 	}

+ 27 - 2
util/util.go

@@ -2,10 +2,11 @@ package util
 
 import (
 	"context"
-	"crypto/rand"
+	crand "crypto/rand"
 	"encoding/hex"
 	"fmt"
 	"log"
+	"math/rand"
 	"net"
 	"net/netip"
 	"sync"
@@ -16,7 +17,7 @@ import (
 
 func GenPSK(length int) ([]byte, error) {
 	b := make([]byte, length)
-	_, err := rand.Read(b)
+	_, err := crand.Read(b)
 	if err != nil {
 		return nil, fmt.Errorf("random bytes generation failed: %w", err)
 	}
@@ -145,3 +146,27 @@ func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr, net.Addr)
 		return z.Allow(key)
 	}
 }
+
+func FixedTimeLimitFunc(d time.Duration) func() time.Duration {
+	return func() time.Duration {
+		return d
+	}
+}
+
+func TimeLimitFunc(low, high time.Duration) func() time.Duration {
+	if low > high {
+		return TimeLimitFunc(high, low)
+	}
+	if low == high {
+		return FixedTimeLimitFunc(low)
+	}
+
+	r := rand.New(rand.NewSource(time.Now().UnixNano()))
+	var mux sync.Mutex
+	delta := high - low
+	return func() time.Duration {
+		mux.Lock()
+		defer mux.Unlock()
+		return low + time.Duration(r.Int63n(int64(delta)))
+	}
+}