Browse Source

add time-limit option

Vladislav Yarmak 1 year ago
parent
commit
c13a52f014
5 changed files with 27 additions and 4 deletions
  1. 11 2
      client/client.go
  2. 1 0
      client/config.go
  3. 3 0
      cmd/dtlspipe/main.go
  4. 1 0
      server/config.go
  5. 11 2
      server/server.go

+ 11 - 2
client/client.go

@@ -30,6 +30,7 @@ type Client struct {
 	cancelCtx   func()
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
+	timeLimit   time.Duration
 }
 
 func New(cfg *Config) (*Client, error) {
@@ -45,6 +46,7 @@ func New(cfg *Config) (*Client, error) {
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
+		timeLimit:   cfg.TimeLimit,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -101,7 +103,14 @@ func (client *Client) serve(conn net.Conn) {
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer conn.Close()
 
-	dialCtx, cancel := context.WithTimeout(client.baseCtx, client.timeout)
+	ctx := client.baseCtx
+	if client.timeLimit != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, client.timeLimit)
+		defer cancel()
+		ctx = newCtx
+	}
+
+	dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
 	defer cancel()
 	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", client.rAddr)
 	if err != nil {
@@ -116,7 +125,7 @@ func (client *Client) serve(conn net.Conn) {
 		return
 	}
 
-	util.PairConn(client.baseCtx, conn, remoteConn, client.idleTimeout, client.staleMode)
+	util.PairConn(ctx, conn, remoteConn, client.idleTimeout, client.staleMode)
 }
 
 func (client *Client) contextMaker() (context.Context, func()) {

+ 1 - 0
client/config.go

@@ -20,6 +20,7 @@ type Config struct {
 	CipherSuites   ciphers.CipherList
 	EllipticCurves ciphers.CurveList
 	StaleMode      util.StaleMode
+	TimeLimit      time.Duration
 }
 
 func (cfg *Config) populateDefaults() *Config {

+ 3 - 0
cmd/dtlspipe/main.go

@@ -72,6 +72,7 @@ var (
 	ciphersuites    = cipherlistArg{}
 	curves          = curvelistArg{}
 	staleMode       = util.EitherStale
+	timeLimit       = flag.Duration("time-limit", 0, "hard time limit for each session")
 )
 
 func init() {
@@ -139,6 +140,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		CipherSuites:   ciphersuites.Value,
 		EllipticCurves: curves.Value,
 		StaleMode:      staleMode,
+		TimeLimit:      *timeLimit,
 	}
 
 	clt, err := client.New(&cfg)
@@ -176,6 +178,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		CipherSuites:    ciphersuites.Value,
 		EllipticCurves:  curves.Value,
 		StaleMode:       staleMode,
+		TimeLimit:       *timeLimit,
 	}
 
 	srv, err := server.New(&cfg)

+ 1 - 0
server/config.go

@@ -20,6 +20,7 @@ type Config struct {
 	CipherSuites    ciphers.CipherList
 	EllipticCurves  ciphers.CurveList
 	StaleMode       util.StaleMode
+	TimeLimit       time.Duration
 }
 
 func (cfg *Config) populateDefaults() *Config {

+ 11 - 2
server/server.go

@@ -31,6 +31,7 @@ type Server struct {
 	cancelCtx   func()
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
+	timeLimit   time.Duration
 }
 
 func New(cfg *Config) (*Server, error) {
@@ -46,6 +47,7 @@ func New(cfg *Config) (*Server, error) {
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
+		timeLimit:   cfg.TimeLimit,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -119,7 +121,14 @@ func (srv *Server) serve(conn net.Conn) {
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer conn.Close()
 
-	dialCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
+	ctx := srv.baseCtx
+	if srv.timeLimit != 0 {
+		newCtx, cancel := context.WithTimeout(ctx, srv.timeLimit)
+		defer cancel()
+		ctx = newCtx
+	}
+
+	dialCtx, cancel := context.WithTimeout(ctx, srv.timeout)
 	defer cancel()
 	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
 	if err != nil {
@@ -128,7 +137,7 @@ func (srv *Server) serve(conn net.Conn) {
 	}
 	defer remoteConn.Close()
 
-	util.PairConn(srv.baseCtx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
+	util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
 }
 
 func (srv *Server) contextMaker() (context.Context, func()) {