Quellcode durchsuchen

Merge pull request #10 from Snawoot/session_hard_time_limit

Session hard time limit
Snawoot vor 2 Jahren
Ursprung
Commit
57285a25fc
6 geänderte Dateien mit 30 neuen und 5 gelöschten Zeilen
  1. 3 1
      README.md
  2. 11 2
      client/client.go
  3. 1 0
      client/config.go
  4. 3 0
      cmd/dtlspipe/main.go
  5. 1 0
      server/config.go
  6. 11 2
      server/server.go

+ 3 - 1
README.md

@@ -80,7 +80,7 @@ Options:
   -identity string
     	client identity sent to server
   -idle-time duration
-    	max idle time for UDP session (default 1m30s)
+    	max idle time for UDP session (default 30s)
   -key-length uint
     	generate key with specified length (default 16)
   -mtu int
@@ -91,6 +91,8 @@ Options:
     	(server only) skip hello verify request. Useful to workaround DPI
   -stale-mode value
     	which stale side of connection makes whole session stale (both, either, left, right) (default either)
+  -time-limit duration
+    	hard time limit for each session
   -timeout duration
     	network operation timeout (default 10s)
 ```

+ 11 - 2
client/client.go

@@ -30,6 +30,7 @@ type Client struct {
 	cancelCtx   func()
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
+	timeLimit   time.Duration
 }
 
 func New(cfg *Config) (*Client, error) {
@@ -45,6 +46,7 @@ func New(cfg *Config) (*Client, error) {
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
+		timeLimit:   cfg.TimeLimit,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -101,7 +103,14 @@ func (client *Client) serve(conn net.Conn) {
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer conn.Close()
 
-	dialCtx, cancel := context.WithTimeout(client.baseCtx, client.timeout)
+	ctx := client.baseCtx
+	if client.timeLimit != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, client.timeLimit)
+		defer cancel()
+		ctx = newCtx
+	}
+
+	dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
 	defer cancel()
 	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", client.rAddr)
 	if err != nil {
@@ -116,7 +125,7 @@ func (client *Client) serve(conn net.Conn) {
 		return
 	}
 
-	util.PairConn(client.baseCtx, conn, remoteConn, client.idleTimeout, client.staleMode)
+	util.PairConn(ctx, conn, remoteConn, client.idleTimeout, client.staleMode)
 }
 
 func (client *Client) contextMaker() (context.Context, func()) {

+ 1 - 0
client/config.go

@@ -20,6 +20,7 @@ type Config struct {
 	CipherSuites   ciphers.CipherList
 	EllipticCurves ciphers.CurveList
 	StaleMode      util.StaleMode
+	TimeLimit      time.Duration
 }
 
 func (cfg *Config) populateDefaults() *Config {

+ 3 - 0
cmd/dtlspipe/main.go

@@ -72,6 +72,7 @@ var (
 	ciphersuites    = cipherlistArg{}
 	curves          = curvelistArg{}
 	staleMode       = util.EitherStale
+	timeLimit       = flag.Duration("time-limit", 0, "hard time limit for each session")
 )
 
 func init() {
@@ -139,6 +140,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		CipherSuites:   ciphersuites.Value,
 		EllipticCurves: curves.Value,
 		StaleMode:      staleMode,
+		TimeLimit:      *timeLimit,
 	}
 
 	clt, err := client.New(&cfg)
@@ -176,6 +178,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		CipherSuites:    ciphersuites.Value,
 		EllipticCurves:  curves.Value,
 		StaleMode:       staleMode,
+		TimeLimit:       *timeLimit,
 	}
 
 	srv, err := server.New(&cfg)

+ 1 - 0
server/config.go

@@ -20,6 +20,7 @@ type Config struct {
 	CipherSuites    ciphers.CipherList
 	EllipticCurves  ciphers.CurveList
 	StaleMode       util.StaleMode
+	TimeLimit       time.Duration
 }
 
 func (cfg *Config) populateDefaults() *Config {

+ 11 - 2
server/server.go

@@ -31,6 +31,7 @@ type Server struct {
 	cancelCtx   func()
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
+	timeLimit   time.Duration
 }
 
 func New(cfg *Config) (*Server, error) {
@@ -46,6 +47,7 @@ func New(cfg *Config) (*Server, error) {
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
+		timeLimit:   cfg.TimeLimit,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -119,7 +121,14 @@ func (srv *Server) serve(conn net.Conn) {
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer conn.Close()
 
-	dialCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
+	ctx := srv.baseCtx
+	if srv.timeLimit != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, srv.timeLimit)
+		defer cancel()
+		ctx = newCtx
+	}
+
+	dialCtx, cancel := context.WithTimeout(ctx, srv.timeout)
 	defer cancel()
 	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
 	if err != nil {
@@ -128,7 +137,7 @@ func (srv *Server) serve(conn net.Conn) {
 	}
 	defer remoteConn.Close()
 
-	util.PairConn(srv.baseCtx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
+	util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
 }
 
 func (srv *Server) contextMaker() (context.Context, func()) {