|
@@ -6,6 +6,7 @@ import (
|
|
|
"log"
|
|
|
"net"
|
|
|
"net/netip"
|
|
|
+ "sync"
|
|
|
"time"
|
|
|
|
|
|
"github.com/Snawoot/dtlspipe/util"
|
|
@@ -29,6 +30,7 @@ type Server struct {
|
|
|
baseCtx context.Context
|
|
|
cancelCtx func()
|
|
|
staleMode util.StaleMode
|
|
|
+ workerWG sync.WaitGroup
|
|
|
}
|
|
|
|
|
|
func New(cfg *Config) (*Server, error) {
|
|
@@ -97,7 +99,9 @@ func (srv *Server) listen() {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
+ srv.workerWG.Add(1)
|
|
|
go func(conn net.Conn) {
|
|
|
+ defer srv.workerWG.Done()
|
|
|
defer conn.Close()
|
|
|
conn, err := dtls.Server(conn, srv.dtlsConfig)
|
|
|
if err != nil {
|
|
@@ -124,7 +128,7 @@ func (srv *Server) serve(conn net.Conn) {
|
|
|
}
|
|
|
defer remoteConn.Close()
|
|
|
|
|
|
- util.PairConn(conn, remoteConn, srv.idleTimeout, srv.staleMode)
|
|
|
+ util.PairConn(srv.baseCtx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
|
|
|
}
|
|
|
|
|
|
func (srv *Server) contextMaker() (context.Context, func()) {
|
|
@@ -133,5 +137,7 @@ func (srv *Server) contextMaker() (context.Context, func()) {
|
|
|
|
|
|
func (srv *Server) Close() error {
|
|
|
srv.cancelCtx()
|
|
|
- return srv.listener.Close()
|
|
|
+ err := srv.listener.Close()
|
|
|
+ srv.workerWG.Wait()
|
|
|
+ return err
|
|
|
}
|