server.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "net/netip"
  8. "sync"
  9. "time"
  10. "github.com/SenseUnit/dtlspipe/util"
  11. "github.com/pion/dtls/v3"
  12. )
  13. const (
  14. Backlog = 1024
  15. )
  16. type Server struct {
  17. listener net.Listener
  18. dialer *net.Dialer
  19. dtlsConfig *dtls.Config
  20. rAddr string
  21. psk func([]byte) ([]byte, error)
  22. timeout time.Duration
  23. idleTimeout time.Duration
  24. baseCtx context.Context
  25. cancelCtx func()
  26. staleMode util.StaleMode
  27. workerWG sync.WaitGroup
  28. timeLimitFunc func() time.Duration
  29. allowFunc func(net.Addr) bool
  30. }
  31. func New(cfg *Config) (*Server, error) {
  32. cfg = cfg.populateDefaults()
  33. baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
  34. srv := &Server{
  35. dialer: new(net.Dialer),
  36. rAddr: cfg.RemoteAddress,
  37. timeout: cfg.Timeout,
  38. psk: cfg.PSKCallback,
  39. idleTimeout: cfg.IdleTimeout,
  40. baseCtx: baseCtx,
  41. cancelCtx: cancelCtx,
  42. staleMode: cfg.StaleMode,
  43. timeLimitFunc: cfg.TimeLimitFunc,
  44. allowFunc: cfg.AllowFunc,
  45. }
  46. lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
  47. if err != nil {
  48. cancelCtx()
  49. return nil, fmt.Errorf("can't parse bind address: %w", err)
  50. }
  51. srv.dtlsConfig = &dtls.Config{
  52. ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
  53. PSK: srv.psk,
  54. MTU: cfg.MTU,
  55. InsecureSkipVerifyHello: cfg.SkipHelloVerify,
  56. CipherSuites: cfg.CipherSuites,
  57. EllipticCurves: cfg.EllipticCurves,
  58. OnConnectionAttempt: func(a net.Addr) error {
  59. if !srv.allowFunc(a) {
  60. return fmt.Errorf("address %s was not allowed by limiter", a.String())
  61. }
  62. return nil
  63. },
  64. }
  65. srv.listener, err = dtls.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort), srv.dtlsConfig)
  66. if err != nil {
  67. cancelCtx()
  68. return nil, fmt.Errorf("can't initialize DTLS listener: %w", err)
  69. }
  70. go srv.listen()
  71. return srv, nil
  72. }
  73. func (srv *Server) listen() {
  74. defer srv.Close()
  75. for srv.baseCtx.Err() == nil {
  76. conn, err := srv.listener.Accept()
  77. if err != nil {
  78. log.Printf("DTLS conn accept failed: %v", err)
  79. continue
  80. }
  81. srv.workerWG.Add(1)
  82. go func(conn net.Conn) {
  83. defer srv.workerWG.Done()
  84. defer conn.Close()
  85. srv.serve(conn)
  86. }(conn)
  87. }
  88. }
  89. func (srv *Server) serve(conn net.Conn) {
  90. log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  91. defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  92. defer conn.Close()
  93. if handshaker, ok := conn.(interface {
  94. HandshakeContext(context.Context) error
  95. }); ok {
  96. err := func() error {
  97. hsCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
  98. defer cancel()
  99. return handshaker.HandshakeContext(hsCtx)
  100. }()
  101. if err != nil {
  102. log.Printf("handshake %s <=> %s failed: %v", conn.LocalAddr(), conn.RemoteAddr(), err)
  103. return
  104. }
  105. }
  106. ctx := srv.baseCtx
  107. tl := srv.timeLimitFunc()
  108. if tl != 0 {
  109. newCtx, cancel := context.WithTimeout(ctx, tl)
  110. defer cancel()
  111. ctx = newCtx
  112. }
  113. remoteConn, err := func() (net.Conn, error) {
  114. dialCtx, cancel := context.WithTimeout(ctx, srv.timeout)
  115. defer cancel()
  116. return srv.dialer.DialContext(dialCtx, "udp", srv.rAddr)
  117. }()
  118. if err != nil {
  119. log.Printf("remote dial failed: %v", err)
  120. return
  121. }
  122. defer remoteConn.Close()
  123. util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
  124. }
  125. func (srv *Server) Close() error {
  126. srv.cancelCtx()
  127. err := srv.listener.Close()
  128. srv.workerWG.Wait()
  129. return err
  130. }