浏览代码

Merge pull request #10 from Snawoot/session_hard_time_limit

Session hard time limit
Snawoot 1 年之前
父节点
当前提交
cf593cb567
共有 6 个文件被更改,包括 30 次插入5 次删除
  1. 3 1
      README.md
  2. 11 2
      client/client.go
  3. 1 0
      client/config.go
  4. 3 0
      cmd/dtlspipe/main.go
  5. 1 0
      server/config.go
  6. 11 2
      server/server.go

+ 3 - 1
README.md

@@ -80,7 +80,7 @@ Options:
   -identity string
   -identity string
     	client identity sent to server
     	client identity sent to server
   -idle-time duration
   -idle-time duration
-    	max idle time for UDP session (default 1m30s)
+    	max idle time for UDP session (default 30s)
   -key-length uint
   -key-length uint
     	generate key with specified length (default 16)
     	generate key with specified length (default 16)
   -mtu int
   -mtu int
@@ -91,6 +91,8 @@ Options:
     	(server only) skip hello verify request. Useful to workaround DPI
     	(server only) skip hello verify request. Useful to workaround DPI
   -stale-mode value
   -stale-mode value
     	which stale side of connection makes whole session stale (both, either, left, right) (default either)
     	which stale side of connection makes whole session stale (both, either, left, right) (default either)
+  -time-limit duration
+    	hard time limit for each session
   -timeout duration
   -timeout duration
     	network operation timeout (default 10s)
     	network operation timeout (default 10s)
 ```
 ```

+ 11 - 2
client/client.go

@@ -30,6 +30,7 @@ type Client struct {
 	cancelCtx   func()
 	cancelCtx   func()
 	staleMode   util.StaleMode
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
 	workerWG    sync.WaitGroup
+	timeLimit   time.Duration
 }
 }
 
 
 func New(cfg *Config) (*Client, error) {
 func New(cfg *Config) (*Client, error) {
@@ -45,6 +46,7 @@ func New(cfg *Config) (*Client, error) {
 		baseCtx:     baseCtx,
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
 		staleMode:   cfg.StaleMode,
+		timeLimit:   cfg.TimeLimit,
 	}
 	}
 
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -101,7 +103,14 @@ func (client *Client) serve(conn net.Conn) {
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer conn.Close()
 	defer conn.Close()
 
 
-	dialCtx, cancel := context.WithTimeout(client.baseCtx, client.timeout)
+	ctx := client.baseCtx
+	if client.timeLimit != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, client.timeLimit)
+		defer cancel()
+		ctx = newCtx
+	}
+
+	dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
 	defer cancel()
 	defer cancel()
 	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", client.rAddr)
 	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", client.rAddr)
 	if err != nil {
 	if err != nil {
@@ -116,7 +125,7 @@ func (client *Client) serve(conn net.Conn) {
 		return
 		return
 	}
 	}
 
 
-	util.PairConn(client.baseCtx, conn, remoteConn, client.idleTimeout, client.staleMode)
+	util.PairConn(ctx, conn, remoteConn, client.idleTimeout, client.staleMode)
 }
 }
 
 
 func (client *Client) contextMaker() (context.Context, func()) {
 func (client *Client) contextMaker() (context.Context, func()) {

+ 1 - 0
client/config.go

@@ -20,6 +20,7 @@ type Config struct {
 	CipherSuites   ciphers.CipherList
 	CipherSuites   ciphers.CipherList
 	EllipticCurves ciphers.CurveList
 	EllipticCurves ciphers.CurveList
 	StaleMode      util.StaleMode
 	StaleMode      util.StaleMode
+	TimeLimit      time.Duration
 }
 }
 
 
 func (cfg *Config) populateDefaults() *Config {
 func (cfg *Config) populateDefaults() *Config {

+ 3 - 0
cmd/dtlspipe/main.go

@@ -72,6 +72,7 @@ var (
 	ciphersuites    = cipherlistArg{}
 	ciphersuites    = cipherlistArg{}
 	curves          = curvelistArg{}
 	curves          = curvelistArg{}
 	staleMode       = util.EitherStale
 	staleMode       = util.EitherStale
+	timeLimit       = flag.Duration("time-limit", 0, "hard time limit for each session")
 )
 )
 
 
 func init() {
 func init() {
@@ -139,6 +140,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		CipherSuites:   ciphersuites.Value,
 		CipherSuites:   ciphersuites.Value,
 		EllipticCurves: curves.Value,
 		EllipticCurves: curves.Value,
 		StaleMode:      staleMode,
 		StaleMode:      staleMode,
+		TimeLimit:      *timeLimit,
 	}
 	}
 
 
 	clt, err := client.New(&cfg)
 	clt, err := client.New(&cfg)
@@ -176,6 +178,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		CipherSuites:    ciphersuites.Value,
 		CipherSuites:    ciphersuites.Value,
 		EllipticCurves:  curves.Value,
 		EllipticCurves:  curves.Value,
 		StaleMode:       staleMode,
 		StaleMode:       staleMode,
+		TimeLimit:       *timeLimit,
 	}
 	}
 
 
 	srv, err := server.New(&cfg)
 	srv, err := server.New(&cfg)

+ 1 - 0
server/config.go

@@ -20,6 +20,7 @@ type Config struct {
 	CipherSuites    ciphers.CipherList
 	CipherSuites    ciphers.CipherList
 	EllipticCurves  ciphers.CurveList
 	EllipticCurves  ciphers.CurveList
 	StaleMode       util.StaleMode
 	StaleMode       util.StaleMode
+	TimeLimit       time.Duration
 }
 }
 
 
 func (cfg *Config) populateDefaults() *Config {
 func (cfg *Config) populateDefaults() *Config {

+ 11 - 2
server/server.go

@@ -31,6 +31,7 @@ type Server struct {
 	cancelCtx   func()
 	cancelCtx   func()
 	staleMode   util.StaleMode
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
 	workerWG    sync.WaitGroup
+	timeLimit   time.Duration
 }
 }
 
 
 func New(cfg *Config) (*Server, error) {
 func New(cfg *Config) (*Server, error) {
@@ -46,6 +47,7 @@ func New(cfg *Config) (*Server, error) {
 		baseCtx:     baseCtx,
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
 		staleMode:   cfg.StaleMode,
+		timeLimit:   cfg.TimeLimit,
 	}
 	}
 
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -119,7 +121,14 @@ func (srv *Server) serve(conn net.Conn) {
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer conn.Close()
 	defer conn.Close()
 
 
-	dialCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
+	ctx := srv.baseCtx
+	if srv.timeLimit != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, srv.timeLimit)
+		defer cancel()
+		ctx = newCtx
+	}
+
+	dialCtx, cancel := context.WithTimeout(ctx, srv.timeout)
 	defer cancel()
 	defer cancel()
 	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
 	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
 	if err != nil {
 	if err != nil {
@@ -128,7 +137,7 @@ func (srv *Server) serve(conn net.Conn) {
 	}
 	}
 	defer remoteConn.Close()
 	defer remoteConn.Close()
 
 
-	util.PairConn(srv.baseCtx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
+	util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
 }
 }
 
 
 func (srv *Server) contextMaker() (context.Context, func()) {
 func (srv *Server) contextMaker() (context.Context, func()) {