server.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "net/netip"
  8. "sync"
  9. "time"
  10. "github.com/SenseUnit/dtlspipe/util"
  11. "github.com/pion/dtls/v3"
  12. )
  13. const (
  14. Backlog = 1024
  15. )
  16. type Server struct {
  17. listener net.Listener
  18. dtlsConfig *dtls.Config
  19. rAddr string
  20. psk func([]byte) ([]byte, error)
  21. timeout time.Duration
  22. idleTimeout time.Duration
  23. baseCtx context.Context
  24. cancelCtx func()
  25. staleMode util.StaleMode
  26. workerWG sync.WaitGroup
  27. timeLimitFunc func() time.Duration
  28. allowFunc func(net.Addr) bool
  29. }
  30. func New(cfg *Config) (*Server, error) {
  31. cfg = cfg.populateDefaults()
  32. baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
  33. srv := &Server{
  34. rAddr: cfg.RemoteAddress,
  35. timeout: cfg.Timeout,
  36. psk: cfg.PSKCallback,
  37. idleTimeout: cfg.IdleTimeout,
  38. baseCtx: baseCtx,
  39. cancelCtx: cancelCtx,
  40. staleMode: cfg.StaleMode,
  41. timeLimitFunc: cfg.TimeLimitFunc,
  42. allowFunc: cfg.AllowFunc,
  43. }
  44. lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
  45. if err != nil {
  46. cancelCtx()
  47. return nil, fmt.Errorf("can't parse bind address: %w", err)
  48. }
  49. srv.dtlsConfig = &dtls.Config{
  50. ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
  51. PSK: srv.psk,
  52. MTU: cfg.MTU,
  53. InsecureSkipVerifyHello: cfg.SkipHelloVerify,
  54. CipherSuites: cfg.CipherSuites,
  55. EllipticCurves: cfg.EllipticCurves,
  56. OnConnectionAttempt: func(a net.Addr) error {
  57. if !srv.allowFunc(a) {
  58. return fmt.Errorf("address %s was not allowed by limiter", a.String())
  59. }
  60. return nil
  61. },
  62. }
  63. srv.listener, err = dtls.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort), srv.dtlsConfig)
  64. if err != nil {
  65. cancelCtx()
  66. return nil, fmt.Errorf("can't initialize DTLS listener: %w", err)
  67. }
  68. go srv.listen()
  69. return srv, nil
  70. }
  71. func (srv *Server) listen() {
  72. defer srv.Close()
  73. for srv.baseCtx.Err() == nil {
  74. conn, err := srv.listener.Accept()
  75. if err != nil {
  76. log.Printf("DTLS conn accept failed: %v", err)
  77. continue
  78. }
  79. srv.workerWG.Add(1)
  80. go func(conn net.Conn) {
  81. defer srv.workerWG.Done()
  82. defer conn.Close()
  83. srv.serve(conn)
  84. }(conn)
  85. }
  86. }
  87. func (srv *Server) serve(conn net.Conn) {
  88. log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  89. defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  90. defer conn.Close()
  91. if handshaker, ok := conn.(interface {
  92. HandshakeContext(context.Context) error
  93. }); ok {
  94. err := func() error {
  95. hsCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
  96. defer cancel()
  97. return handshaker.HandshakeContext(hsCtx)
  98. }()
  99. if err != nil {
  100. log.Printf("handshake %s <=> %s failed: %v", conn.LocalAddr(), conn.RemoteAddr(), err)
  101. return
  102. }
  103. }
  104. ctx := srv.baseCtx
  105. tl := srv.timeLimitFunc()
  106. if tl != 0 {
  107. newCtx, cancel := context.WithTimeout(ctx, tl)
  108. defer cancel()
  109. ctx = newCtx
  110. }
  111. dialCtx, cancel := context.WithTimeout(ctx, srv.timeout)
  112. defer cancel()
  113. remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
  114. if err != nil {
  115. log.Printf("remote dial failed: %v", err)
  116. return
  117. }
  118. defer remoteConn.Close()
  119. util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
  120. }
  121. func (srv *Server) Close() error {
  122. srv.cancelCtx()
  123. err := srv.listener.Close()
  124. srv.workerWG.Wait()
  125. return err
  126. }