server.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "net/netip"
  8. "sync"
  9. "time"
  10. "github.com/SenseUnit/dtlspipe/util"
  11. "github.com/pion/dtls/v3"
  12. )
  13. const (
  14. Backlog = 1024
  15. )
  16. type Server struct {
  17. listener net.Listener
  18. dtlsConfig *dtls.Config
  19. rAddr string
  20. psk func([]byte) ([]byte, error)
  21. timeout time.Duration
  22. idleTimeout time.Duration
  23. baseCtx context.Context
  24. cancelCtx func()
  25. staleMode util.StaleMode
  26. workerWG sync.WaitGroup
  27. timeLimitFunc func() time.Duration
  28. allowFunc func(net.Addr) bool
  29. }
  30. func New(cfg *Config) (*Server, error) {
  31. cfg = cfg.populateDefaults()
  32. baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
  33. srv := &Server{
  34. rAddr: cfg.RemoteAddress,
  35. timeout: cfg.Timeout,
  36. psk: cfg.PSKCallback,
  37. idleTimeout: cfg.IdleTimeout,
  38. baseCtx: baseCtx,
  39. cancelCtx: cancelCtx,
  40. staleMode: cfg.StaleMode,
  41. timeLimitFunc: cfg.TimeLimitFunc,
  42. allowFunc: cfg.AllowFunc,
  43. }
  44. lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
  45. if err != nil {
  46. cancelCtx()
  47. return nil, fmt.Errorf("can't parse bind address: %w", err)
  48. }
  49. srv.dtlsConfig = &dtls.Config{
  50. ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
  51. PSK: srv.psk,
  52. MTU: cfg.MTU,
  53. InsecureSkipVerifyHello: cfg.SkipHelloVerify,
  54. CipherSuites: cfg.CipherSuites,
  55. EllipticCurves: cfg.EllipticCurves,
  56. OnConnectionAttempt: func(a net.Addr) error {
  57. if !srv.allowFunc(a) {
  58. return fmt.Errorf("address %s was not allowed by limiter", a.String())
  59. }
  60. return nil
  61. },
  62. }
  63. srv.listener, err = dtls.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort), srv.dtlsConfig)
  64. if err != nil {
  65. cancelCtx()
  66. return nil, fmt.Errorf("can't initialize DTLS listener: %w", err)
  67. }
  68. go srv.listen()
  69. return srv, nil
  70. }
  71. func (srv *Server) listen() {
  72. defer srv.Close()
  73. for srv.baseCtx.Err() == nil {
  74. conn, err := srv.listener.Accept()
  75. if err != nil {
  76. log.Printf("DTLS conn accept failed: %v", err)
  77. continue
  78. }
  79. srv.workerWG.Add(1)
  80. go func(conn net.Conn) {
  81. defer srv.workerWG.Done()
  82. defer conn.Close()
  83. srv.serve(conn)
  84. }(conn)
  85. }
  86. }
  87. func (srv *Server) serve(conn net.Conn) {
  88. log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  89. defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  90. defer conn.Close()
  91. ctx := srv.baseCtx
  92. tl := srv.timeLimitFunc()
  93. if tl != 0 {
  94. newCtx, cancel := context.WithTimeout(ctx, tl)
  95. defer cancel()
  96. ctx = newCtx
  97. }
  98. dialCtx, cancel := context.WithTimeout(ctx, srv.timeout)
  99. defer cancel()
  100. remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
  101. if err != nil {
  102. log.Printf("remote dial failed: %v", err)
  103. return
  104. }
  105. defer remoteConn.Close()
  106. util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
  107. }
  108. func (srv *Server) contextMaker() (context.Context, func()) {
  109. return context.WithTimeout(srv.baseCtx, srv.timeout)
  110. }
  111. func (srv *Server) Close() error {
  112. srv.cancelCtx()
  113. err := srv.listener.Close()
  114. srv.workerWG.Wait()
  115. return err
  116. }