|
@@ -10,10 +10,7 @@ import (
|
|
|
"time"
|
|
|
|
|
|
"github.com/SenseUnit/dtlspipe/util"
|
|
|
- "github.com/pion/dtls/v2"
|
|
|
- "github.com/pion/dtls/v2/pkg/protocol"
|
|
|
- "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
|
|
- "github.com/pion/transport/v2/udp"
|
|
|
+ "github.com/pion/dtls/v3"
|
|
|
)
|
|
|
|
|
|
const (
|
|
@@ -22,6 +19,7 @@ const (
|
|
|
|
|
|
type Server struct {
|
|
|
listener net.Listener
|
|
|
+ dialer *net.Dialer
|
|
|
dtlsConfig *dtls.Config
|
|
|
rAddr string
|
|
|
psk func([]byte) ([]byte, error)
|
|
@@ -32,7 +30,7 @@ type Server struct {
|
|
|
staleMode util.StaleMode
|
|
|
workerWG sync.WaitGroup
|
|
|
timeLimitFunc func() time.Duration
|
|
|
- allowFunc func(net.Addr, net.Addr) bool
|
|
|
+ allowFunc func(net.Addr) bool
|
|
|
}
|
|
|
|
|
|
func New(cfg *Config) (*Server, error) {
|
|
@@ -41,6 +39,7 @@ func New(cfg *Config) (*Server, error) {
|
|
|
baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
|
|
|
|
|
|
srv := &Server{
|
|
|
+ dialer: new(net.Dialer),
|
|
|
rAddr: cfg.RemoteAddress,
|
|
|
timeout: cfg.Timeout,
|
|
|
psk: cfg.PSKCallback,
|
|
@@ -60,35 +59,24 @@ func New(cfg *Config) (*Server, error) {
|
|
|
|
|
|
srv.dtlsConfig = &dtls.Config{
|
|
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
|
|
- ConnectContextMaker: srv.contextMaker,
|
|
|
PSK: srv.psk,
|
|
|
MTU: cfg.MTU,
|
|
|
InsecureSkipVerifyHello: cfg.SkipHelloVerify,
|
|
|
CipherSuites: cfg.CipherSuites,
|
|
|
EllipticCurves: cfg.EllipticCurves,
|
|
|
- }
|
|
|
- lc := udp.ListenConfig{
|
|
|
- AcceptFilter: func(packet []byte) bool {
|
|
|
- pkts, err := recordlayer.UnpackDatagram(packet)
|
|
|
- if err != nil || len(pkts) < 1 {
|
|
|
- return false
|
|
|
- }
|
|
|
- h := &recordlayer.Header{}
|
|
|
- if err := h.Unmarshal(pkts[0]); err != nil {
|
|
|
- return false
|
|
|
+ OnConnectionAttempt: func(a net.Addr) error {
|
|
|
+ if !srv.allowFunc(a) {
|
|
|
+ return fmt.Errorf("address %s was not allowed by limiter", a.String())
|
|
|
}
|
|
|
- return h.ContentType == protocol.ContentTypeHandshake
|
|
|
+ return nil
|
|
|
},
|
|
|
- Backlog: Backlog,
|
|
|
}
|
|
|
- listener, err := lc.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort))
|
|
|
+ srv.listener, err = dtls.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort), srv.dtlsConfig)
|
|
|
if err != nil {
|
|
|
cancelCtx()
|
|
|
- return nil, fmt.Errorf("server listen failed: %w", err)
|
|
|
+ return nil, fmt.Errorf("can't initialize DTLS listener: %w", err)
|
|
|
}
|
|
|
|
|
|
- srv.listener = listener
|
|
|
-
|
|
|
go srv.listen()
|
|
|
|
|
|
return srv, nil
|
|
@@ -99,11 +87,7 @@ func (srv *Server) listen() {
|
|
|
for srv.baseCtx.Err() == nil {
|
|
|
conn, err := srv.listener.Accept()
|
|
|
if err != nil {
|
|
|
- log.Printf("conn accept failed: %v", err)
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- if !srv.allowFunc(conn.LocalAddr(), conn.RemoteAddr()) {
|
|
|
+ log.Printf("DTLS conn accept failed: %v", err)
|
|
|
continue
|
|
|
}
|
|
|
|
|
@@ -111,12 +95,6 @@ func (srv *Server) listen() {
|
|
|
go func(conn net.Conn) {
|
|
|
defer srv.workerWG.Done()
|
|
|
defer conn.Close()
|
|
|
- conn, err := dtls.Server(conn, srv.dtlsConfig)
|
|
|
- if err != nil {
|
|
|
- log.Printf("DTLS accept error: %v", err)
|
|
|
- return
|
|
|
- }
|
|
|
- defer conn.Close()
|
|
|
srv.serve(conn)
|
|
|
}(conn)
|
|
|
}
|
|
@@ -127,6 +105,20 @@ func (srv *Server) serve(conn net.Conn) {
|
|
|
defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
|
|
|
defer conn.Close()
|
|
|
|
|
|
+ if handshaker, ok := conn.(interface {
|
|
|
+ HandshakeContext(context.Context) error
|
|
|
+ }); ok {
|
|
|
+ err := func() error {
|
|
|
+ hsCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
|
|
|
+ defer cancel()
|
|
|
+ return handshaker.HandshakeContext(hsCtx)
|
|
|
+ }()
|
|
|
+ if err != nil {
|
|
|
+ log.Printf("handshake %s <=> %s failed: %v", conn.LocalAddr(), conn.RemoteAddr(), err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
ctx := srv.baseCtx
|
|
|
tl := srv.timeLimitFunc()
|
|
|
if tl != 0 {
|
|
@@ -135,9 +127,11 @@ func (srv *Server) serve(conn net.Conn) {
|
|
|
ctx = newCtx
|
|
|
}
|
|
|
|
|
|
- dialCtx, cancel := context.WithTimeout(ctx, srv.timeout)
|
|
|
- defer cancel()
|
|
|
- remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
|
|
|
+ remoteConn, err := func() (net.Conn, error) {
|
|
|
+ dialCtx, cancel := context.WithTimeout(ctx, srv.timeout)
|
|
|
+ defer cancel()
|
|
|
+ return srv.dialer.DialContext(dialCtx, "udp", srv.rAddr)
|
|
|
+ }()
|
|
|
if err != nil {
|
|
|
log.Printf("remote dial failed: %v", err)
|
|
|
return
|
|
@@ -147,10 +141,6 @@ func (srv *Server) serve(conn net.Conn) {
|
|
|
util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
|
|
|
}
|
|
|
|
|
|
-func (srv *Server) contextMaker() (context.Context, func()) {
|
|
|
- return context.WithTimeout(srv.baseCtx, srv.timeout)
|
|
|
-}
|
|
|
-
|
|
|
func (srv *Server) Close() error {
|
|
|
srv.cancelCtx()
|
|
|
err := srv.listener.Close()
|