Browse Source

Merge pull request #9 from Snawoot/graceful_shutdown

More graceful shutdown
Snawoot 1 year ago
parent
commit
5ce3863e1f
3 changed files with 29 additions and 5 deletions
  1. 8 2
      client/client.go
  2. 8 2
      server/server.go
  3. 13 1
      util/util.go

+ 8 - 2
client/client.go

@@ -6,6 +6,7 @@ import (
 	"log"
 	"net"
 	"net/netip"
+	"sync"
 	"time"
 
 	"github.com/Snawoot/dtlspipe/util"
@@ -28,6 +29,7 @@ type Client struct {
 	baseCtx     context.Context
 	cancelCtx   func()
 	staleMode   util.StaleMode
+	workerWG    sync.WaitGroup
 }
 
 func New(cfg *Config) (*Client, error) {
@@ -85,7 +87,9 @@ func (client *Client) listen() {
 			continue
 		}
 
+		client.workerWG.Add(1)
 		go func(conn net.Conn) {
+			defer client.workerWG.Done()
 			defer conn.Close()
 			client.serve(conn)
 		}(conn)
@@ -112,7 +116,7 @@ func (client *Client) serve(conn net.Conn) {
 		return
 	}
 
-	util.PairConn(conn, remoteConn, client.idleTimeout, client.staleMode)
+	util.PairConn(client.baseCtx, conn, remoteConn, client.idleTimeout, client.staleMode)
 }
 
 func (client *Client) contextMaker() (context.Context, func()) {
@@ -121,5 +125,7 @@ func (client *Client) contextMaker() (context.Context, func()) {
 
 func (client *Client) Close() error {
 	client.cancelCtx()
-	return client.listener.Close()
+	err := client.listener.Close()
+	client.workerWG.Wait()
+	return err
 }

+ 8 - 2
server/server.go

@@ -6,6 +6,7 @@ import (
 	"log"
 	"net"
 	"net/netip"
+	"sync"
 	"time"
 
 	"github.com/Snawoot/dtlspipe/util"
@@ -29,6 +30,7 @@ type Server struct {
 	baseCtx     context.Context
 	cancelCtx   func()
 	staleMode   util.StaleMode
+	workerWG    sync.WaitGroup
 }
 
 func New(cfg *Config) (*Server, error) {
@@ -97,7 +99,9 @@ func (srv *Server) listen() {
 			continue
 		}
 
+		srv.workerWG.Add(1)
 		go func(conn net.Conn) {
+			defer srv.workerWG.Done()
 			defer conn.Close()
 			conn, err := dtls.Server(conn, srv.dtlsConfig)
 			if err != nil {
@@ -124,7 +128,7 @@ func (srv *Server) serve(conn net.Conn) {
 	}
 	defer remoteConn.Close()
 
-	util.PairConn(conn, remoteConn, srv.idleTimeout, srv.staleMode)
+	util.PairConn(srv.baseCtx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
 }
 
 func (srv *Server) contextMaker() (context.Context, func()) {
@@ -133,5 +137,7 @@ func (srv *Server) contextMaker() (context.Context, func()) {
 
 func (srv *Server) Close() error {
 	srv.cancelCtx()
-	return srv.listener.Close()
+	err := srv.listener.Close()
+	srv.workerWG.Wait()
+	return err
 }

+ 13 - 1
util/util.go

@@ -1,6 +1,7 @@
 package util
 
 import (
+	"context"
 	"crypto/rand"
 	"encoding/hex"
 	"fmt"
@@ -55,10 +56,21 @@ const (
 	MaxPktBuf = 65536
 )
 
-func PairConn(left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) {
+func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) {
 	var wg sync.WaitGroup
 	tracker := newTracker(staleMode)
 
+	copyDone := make(chan struct{})
+	go func() {
+		select {
+		case <-ctx.Done():
+			left.Close()
+			right.Close()
+		case <-copyDone:
+		}
+	}()
+	defer close(copyDone)
+
 	copier := func(dst, src net.Conn, label bool) {
 		defer wg.Done()
 		defer dst.Close()