server.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "net/netip"
  8. "sync"
  9. "time"
  10. "github.com/SenseUnit/dtlspipe/util"
  11. "github.com/pion/dtls/v3"
  12. )
  13. const (
  14. Backlog = 1024
  15. )
  16. type Server struct {
  17. listener net.Listener
  18. dialer *net.Dialer
  19. dtlsConfig *dtls.Config
  20. rAddr string
  21. psk func([]byte) ([]byte, error)
  22. timeout time.Duration
  23. idleTimeout time.Duration
  24. baseCtx context.Context
  25. cancelCtx func()
  26. staleMode util.StaleMode
  27. workerWG sync.WaitGroup
  28. timeLimitFunc func() time.Duration
  29. allowFunc func(net.Addr) bool
  30. }
  31. func New(cfg *Config) (*Server, error) {
  32. cfg = cfg.populateDefaults()
  33. baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
  34. srv := &Server{
  35. dialer: new(net.Dialer),
  36. rAddr: cfg.RemoteAddress,
  37. timeout: cfg.Timeout,
  38. psk: cfg.PSKCallback,
  39. idleTimeout: cfg.IdleTimeout,
  40. baseCtx: baseCtx,
  41. cancelCtx: cancelCtx,
  42. staleMode: cfg.StaleMode,
  43. timeLimitFunc: cfg.TimeLimitFunc,
  44. allowFunc: cfg.AllowFunc,
  45. }
  46. lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
  47. if err != nil {
  48. cancelCtx()
  49. return nil, fmt.Errorf("can't parse bind address: %w", err)
  50. }
  51. srv.dtlsConfig = &dtls.Config{
  52. ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
  53. PSK: srv.psk,
  54. MTU: cfg.MTU,
  55. InsecureSkipVerifyHello: cfg.SkipHelloVerify,
  56. CipherSuites: cfg.CipherSuites,
  57. EllipticCurves: cfg.EllipticCurves,
  58. OnConnectionAttempt: func(a net.Addr) error {
  59. if !srv.allowFunc(a) {
  60. return fmt.Errorf("address %s was not allowed by limiter", a.String())
  61. }
  62. return nil
  63. },
  64. }
  65. if cfg.EnableCID {
  66. srv.dtlsConfig.ConnectionIDGenerator = dtls.RandomCIDGenerator(8)
  67. }
  68. srv.listener, err = dtls.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort), srv.dtlsConfig)
  69. if err != nil {
  70. cancelCtx()
  71. return nil, fmt.Errorf("can't initialize DTLS listener: %w", err)
  72. }
  73. go srv.listen()
  74. return srv, nil
  75. }
  76. func (srv *Server) listen() {
  77. defer srv.Close()
  78. for srv.baseCtx.Err() == nil {
  79. conn, err := srv.listener.Accept()
  80. if err != nil {
  81. log.Printf("DTLS conn accept failed: %v", err)
  82. continue
  83. }
  84. srv.workerWG.Add(1)
  85. go func(conn net.Conn) {
  86. defer srv.workerWG.Done()
  87. defer conn.Close()
  88. srv.serve(conn)
  89. }(conn)
  90. }
  91. }
  92. func (srv *Server) serve(conn net.Conn) {
  93. log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  94. defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  95. defer conn.Close()
  96. if handshaker, ok := conn.(interface {
  97. HandshakeContext(context.Context) error
  98. }); ok {
  99. err := func() error {
  100. hsCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
  101. defer cancel()
  102. return handshaker.HandshakeContext(hsCtx)
  103. }()
  104. if err != nil {
  105. log.Printf("handshake %s <=> %s failed: %v", conn.LocalAddr(), conn.RemoteAddr(), err)
  106. return
  107. }
  108. }
  109. ctx := srv.baseCtx
  110. tl := srv.timeLimitFunc()
  111. if tl != 0 {
  112. newCtx, cancel := context.WithTimeout(ctx, tl)
  113. defer cancel()
  114. ctx = newCtx
  115. }
  116. remoteConn, err := func() (net.Conn, error) {
  117. dialCtx, cancel := context.WithTimeout(ctx, srv.timeout)
  118. defer cancel()
  119. return srv.dialer.DialContext(dialCtx, "udp", srv.rAddr)
  120. }()
  121. if err != nil {
  122. log.Printf("remote dial failed: %v", err)
  123. return
  124. }
  125. defer remoteConn.Close()
  126. util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
  127. }
  128. func (srv *Server) Close() error {
  129. srv.cancelCtx()
  130. err := srv.listener.Close()
  131. srv.workerWG.Wait()
  132. return err
  133. }