Vladislav Yarmak 8 месяцев назад
Родитель
Сommit
7ead1fcf7c
10 измененных файлов с 78 добавлено и 132 удалено
  1. 1 1
      ciphers/ciphers.go
  2. 1 1
      ciphers/curves.go
  3. 14 13
      client/client.go
  4. 2 2
      client/config.go
  5. 0 2
      cmd/dtlspipe/main.go
  6. 5 5
      go.mod
  7. 12 61
      go.sum
  8. 1 1
      server/config.go
  9. 9 33
      server/server.go
  10. 33 13
      util/util.go

+ 1 - 1
ciphers/ciphers.go

@@ -4,7 +4,7 @@ import (
 	"fmt"
 	"strings"
 
-	"github.com/pion/dtls/v2"
+	"github.com/pion/dtls/v3"
 )
 
 type CipherList = []dtls.CipherSuiteID

+ 1 - 1
ciphers/curves.go

@@ -4,7 +4,7 @@ import (
 	"fmt"
 	"strings"
 
-	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
+	"github.com/pion/dtls/v3/pkg/crypto/elliptic"
 )
 
 type CurveList = []elliptic.Curve

+ 14 - 13
client/client.go

@@ -10,8 +10,8 @@ import (
 	"time"
 
 	"github.com/SenseUnit/dtlspipe/util"
-	"github.com/pion/dtls/v2"
-	"github.com/pion/transport/v2/udp"
+	"github.com/pion/dtls/v3"
+	"github.com/pion/transport/v3/udp"
 )
 
 const (
@@ -22,7 +22,7 @@ const (
 type Client struct {
 	listener      net.Listener
 	dtlsConfig    *dtls.Config
-	remoteDialFn  func(context.Context, string) (net.Conn, error)
+	remoteDialFn  func(context.Context) (net.PacketConn, net.Addr, error)
 	psk           func([]byte) ([]byte, error)
 	timeout       time.Duration
 	idleTimeout   time.Duration
@@ -31,7 +31,7 @@ type Client struct {
 	staleMode     util.StaleMode
 	workerWG      sync.WaitGroup
 	timeLimitFunc func() time.Duration
-	allowFunc     func(net.Addr, net.Addr) bool
+	allowFunc     func(net.Addr) bool
 }
 
 func New(cfg *Config) (*Client, error) {
@@ -59,7 +59,6 @@ func New(cfg *Config) (*Client, error) {
 
 	client.dtlsConfig = &dtls.Config{
 		ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
-		ConnectContextMaker:  client.contextMaker,
 		PSK:                  client.psk,
 		PSKIdentityHint:      []byte(cfg.PSKIdentity),
 		MTU:                  cfg.MTU,
@@ -91,7 +90,7 @@ func (client *Client) listen() {
 			continue
 		}
 
-		if !client.allowFunc(conn.LocalAddr(), conn.RemoteAddr()) {
+		if !client.allowFunc(conn.RemoteAddr()) {
 			continue
 		}
 
@@ -119,24 +118,26 @@ func (client *Client) serve(conn net.Conn) {
 
 	dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
 	defer cancel()
-	remoteConn, err := client.remoteDialFn(dialCtx, "udp")
+	remoteConn, remoteAddr, err := client.remoteDialFn(dialCtx)
 	if err != nil {
 		log.Printf("remote dial failed: %v", err)
 		return
 	}
 	defer remoteConn.Close()
 
-	remoteConn, err = dtls.ClientWithContext(dialCtx, remoteConn, client.dtlsConfig)
+	dtlsConn, err := dtls.Client(remoteConn, remoteAddr, client.dtlsConfig)
 	if err != nil {
-		log.Printf("DTLS handshake with remote server failed: %v", err)
+		log.Printf("DTLS connection with remote server failed: %v", err)
 		return
 	}
+	defer dtlsConn.Close()
 
-	util.PairConn(ctx, conn, remoteConn, client.idleTimeout, client.staleMode)
-}
+	if err := dtlsConn.HandshakeContext(dialCtx); err != nil {
+		log.Printf("DTLS handshake with remote server failed: %v", err)
+		return
+	}
 
-func (client *Client) contextMaker() (context.Context, func()) {
-	return context.WithTimeout(client.baseCtx, client.timeout)
+	util.PairConn(ctx, conn, dtlsConn, client.idleTimeout, client.staleMode)
 }
 
 func (client *Client) Close() error {

+ 2 - 2
client/config.go

@@ -11,7 +11,7 @@ import (
 
 type Config struct {
 	BindAddress    string
-	RemoteDialFunc func(ctx context.Context, network string) (net.Conn, error)
+	RemoteDialFunc func(ctx context.Context) (net.PacketConn, net.Addr, error)
 	Timeout        time.Duration
 	IdleTimeout    time.Duration
 	BaseContext    context.Context
@@ -22,7 +22,7 @@ type Config struct {
 	EllipticCurves ciphers.CurveList
 	StaleMode      util.StaleMode
 	TimeLimitFunc  func() time.Duration
-	AllowFunc      func(localAddr, remoteAddr net.Addr) bool
+	AllowFunc      func(net.Addr) bool
 }
 
 func (cfg *Config) populateDefaults() *Config {

+ 0 - 2
cmd/dtlspipe/main.go

@@ -241,7 +241,6 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		BindAddress: bindAddress,
 		RemoteDialFunc: util.NewDynDialer(
 			addrgen.SingleEndpoint(remoteAddress).Endpoint,
-			nil,
 		).DialContext,
 		PSKCallback:    keystore.NewStaticKeystore(psk).PSKCallback,
 		PSKIdentity:    *identity,
@@ -295,7 +294,6 @@ func cmdHoppingClient(args []string) int {
 				log.Printf("selected new endpoint %s", ep)
 				return ep
 			},
-			nil,
 		).DialContext,
 		PSKCallback:    keystore.NewStaticKeystore(psk).PSKCallback,
 		PSKIdentity:    *identity,

+ 5 - 5
go.mod

@@ -4,13 +4,13 @@ go 1.21.1
 
 require (
 	github.com/Snawoot/rlzone v0.2.0
-	github.com/pion/dtls/v2 v2.2.10
-	github.com/pion/transport/v2 v2.2.4
+	github.com/pion/dtls/v3 v3.0.0
+	github.com/pion/transport/v3 v3.0.5
 )
 
 require (
 	github.com/pion/logging v0.2.2 // indirect
-	golang.org/x/crypto v0.18.0 // indirect
-	golang.org/x/net v0.20.0 // indirect
-	golang.org/x/sys v0.16.0 // indirect
+	golang.org/x/crypto v0.24.0 // indirect
+	golang.org/x/net v0.26.0 // indirect
+	golang.org/x/sys v0.22.0 // indirect
 )

+ 12 - 61
go.sum

@@ -1,71 +1,22 @@
 github.com/Snawoot/rlzone v0.2.0 h1:l/Gl8ncAdCjdalZlE7THD4xlwCnvn6jCF3hsiL4SmWQ=
 github.com/Snawoot/rlzone v0.2.0/go.mod h1:5yK8f9nJSOAPizq2LZ35arkortJhjFx1eO6ckOQCnwQ=
-github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
-github.com/pion/dtls/v2 v2.2.10 h1:u2Axk+FyIR1VFTPurktB+1zoEPGIW3bmyj3LEFrXjAA=
-github.com/pion/dtls/v2 v2.2.10/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE=
+github.com/pion/dtls/v3 v3.0.0 h1:m2hzwPkzqoBjVKXm5ymNuX01OAjht82TdFL6LoTzgi4=
+github.com/pion/dtls/v3 v3.0.0/go.mod h1:tiX7NaneB0wNoRaUpaMVP7igAlkMCTQkbpiY+OfeIi0=
 github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
 github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
-github.com/pion/transport/v2 v2.2.4 h1:41JJK6DZQYSeVLxILA2+F4ZkKb4Xd/tFJZRFZQ9QAlo=
-github.com/pion/transport/v2 v2.2.4/go.mod h1:q2U/tf9FEfnSBGSW6w5Qp5PFWRLRj3NjLhCCgpRK4p0=
+github.com/pion/transport/v3 v3.0.5 h1:ofVrcbPNqVPuKaTO5AMFnFuJ1ZX7ElYiWzC5PCf9YVQ=
+github.com/pion/transport/v3 v3.0.5/go.mod h1:HvJr2N/JwNJAfipsRleqwFoR3t/pWyHeZUs89v3+t5s=
 github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
-github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
-github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
-github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
-github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
-github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
-github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
-golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
-golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc=
-golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg=
-golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
-golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
-golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
-golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
-golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
-golang.org/x/net v0.14.0/go.mod h1:PpSgVXXLK0OxS0F31C1/tv6XNguvCrnXIDrFMspZIUI=
-golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
-golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
-golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
-golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
-golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
-golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
-golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU=
-golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
-golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
-golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
-golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
-golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
-golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
-golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
-golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
-gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
+github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI=
+golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM=
+golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
+golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
+golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
+golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 1 - 1
server/config.go

@@ -22,7 +22,7 @@ type Config struct {
 	EllipticCurves  ciphers.CurveList
 	StaleMode       util.StaleMode
 	TimeLimitFunc   func() time.Duration
-	AllowFunc       func(localAddr, remoteAddr net.Addr) bool
+	AllowFunc       func(net.Addr) bool
 }
 
 func (cfg *Config) populateDefaults() *Config {

+ 9 - 33
server/server.go

@@ -10,10 +10,7 @@ import (
 	"time"
 
 	"github.com/SenseUnit/dtlspipe/util"
-	"github.com/pion/dtls/v2"
-	"github.com/pion/dtls/v2/pkg/protocol"
-	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
-	"github.com/pion/transport/v2/udp"
+	"github.com/pion/dtls/v3"
 )
 
 const (
@@ -32,7 +29,7 @@ type Server struct {
 	staleMode     util.StaleMode
 	workerWG      sync.WaitGroup
 	timeLimitFunc func() time.Duration
-	allowFunc     func(net.Addr, net.Addr) bool
+	allowFunc     func(net.Addr) bool
 }
 
 func New(cfg *Config) (*Server, error) {
@@ -60,35 +57,24 @@ func New(cfg *Config) (*Server, error) {
 
 	srv.dtlsConfig = &dtls.Config{
 		ExtendedMasterSecret:    dtls.RequireExtendedMasterSecret,
-		ConnectContextMaker:     srv.contextMaker,
 		PSK:                     srv.psk,
 		MTU:                     cfg.MTU,
 		InsecureSkipVerifyHello: cfg.SkipHelloVerify,
 		CipherSuites:            cfg.CipherSuites,
 		EllipticCurves:          cfg.EllipticCurves,
-	}
-	lc := udp.ListenConfig{
-		AcceptFilter: func(packet []byte) bool {
-			pkts, err := recordlayer.UnpackDatagram(packet)
-			if err != nil || len(pkts) < 1 {
-				return false
-			}
-			h := &recordlayer.Header{}
-			if err := h.Unmarshal(pkts[0]); err != nil {
-				return false
+		OnConnectionAttempt:     func(a net.Addr) error {
+			if !srv.allowFunc(a) {
+				return fmt.Errorf("address %s was not allowed by limiter", a.String())
 			}
-			return h.ContentType == protocol.ContentTypeHandshake
+			return nil
 		},
-		Backlog: Backlog,
 	}
-	listener, err := lc.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort))
+	srv.listener, err = dtls.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort), srv.dtlsConfig)
 	if err != nil {
 		cancelCtx()
-		return nil, fmt.Errorf("server listen failed: %w", err)
+		return nil, fmt.Errorf("can't initialize DTLS listener: %w", err)
 	}
 
-	srv.listener = listener
-
 	go srv.listen()
 
 	return srv, nil
@@ -99,11 +85,7 @@ func (srv *Server) listen() {
 	for srv.baseCtx.Err() == nil {
 		conn, err := srv.listener.Accept()
 		if err != nil {
-			log.Printf("conn accept failed: %v", err)
-			continue
-		}
-
-		if !srv.allowFunc(conn.LocalAddr(), conn.RemoteAddr()) {
+			log.Printf("DTLS conn accept failed: %v", err)
 			continue
 		}
 
@@ -111,12 +93,6 @@ func (srv *Server) listen() {
 		go func(conn net.Conn) {
 			defer srv.workerWG.Done()
 			defer conn.Close()
-			conn, err := dtls.Server(conn, srv.dtlsConfig)
-			if err != nil {
-				log.Printf("DTLS accept error: %v", err)
-				return
-			}
-			defer conn.Close()
 			srv.serve(conn)
 		}(conn)
 	}

+ 33 - 13
util/util.go

@@ -133,15 +133,15 @@ func NetAddrToNetipAddrPort(a net.Addr) netip.AddrPort {
 	return res
 }
 
-func AllowAllFunc(_, _ net.Addr) bool {
+func AllowAllFunc(_ net.Addr) bool {
 	return true
 }
 
-func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr, net.Addr) bool {
+func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr) bool {
 	if z == nil {
 		return AllowAllFunc
 	}
-	return func(_, remoteAddr net.Addr) bool {
+	return func(remoteAddr net.Addr) bool {
 		key := NetAddrToNetipAddrPort(remoteAddr).Addr()
 		return z.Allow(key)
 	}
@@ -172,20 +172,40 @@ func TimeLimitFunc(low, high time.Duration) func() time.Duration {
 }
 
 type DynDialer struct {
-	dial func(context.Context, string, string) (net.Conn, error)
-	ep   func() string
+	ep       func() string
+	resolver *net.Resolver
 }
 
-func NewDynDialer(ep func() string, dial func(context.Context, string, string) (net.Conn, error)) DynDialer {
-	if dial == nil {
-		dial = (&net.Dialer{}).DialContext
-	}
+func NewDynDialer(ep func() string) DynDialer {
 	return DynDialer{
-		ep:   ep,
-		dial: dial,
+		resolver: new(net.Resolver),
+		ep:       ep,
 	}
 }
 
-func (d DynDialer) DialContext(ctx context.Context, network string) (net.Conn, error) {
-	return d.dial(ctx, network, d.ep())
+func (d DynDialer) DialContext(ctx context.Context) (net.PacketConn, net.Addr, error) {
+	host, port, err := net.SplitHostPort(d.ep())
+	if err != nil {
+		return nil, nil, fmt.Errorf("unable to split host and port: %w", err)
+	}
+	addrs, err := d.resolver.LookupIPAddr(ctx, host)
+	if err != nil {
+		return nil, nil, fmt.Errorf("address lookup failed: %w", err)
+	}
+	if len(addrs) == 0 {
+		return nil, nil, fmt.Errorf("no addresses were resolved")
+	}
+	portNum, err := d.resolver.LookupPort(ctx, "udp", port)
+	if err != nil {
+		return nil, nil, fmt.Errorf("port lookup failed: %w", err)
+	}
+	pConn, err := net.ListenUDP("udp", nil)
+	if err != nil {
+		return nil, nil, fmt.Errorf("unable to open UDP socket: %w", err)
+	}
+	return pConn, &net.UDPAddr{
+		IP:   addrs[0].IP,
+		Port: portNum,
+		Zone: addrs[0].Zone,
+	}, nil
 }