server.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "net/netip"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. "github.com/pion/dtls/v2"
  12. "github.com/pion/dtls/v2/pkg/protocol"
  13. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  14. "github.com/pion/transport/v2/udp"
  15. )
  16. const (
  17. MaxPktBuf = 4096
  18. )
  19. type Server struct {
  20. listener net.Listener
  21. dtlsConfig *dtls.Config
  22. rAddr string
  23. psk func([]byte) ([]byte, error)
  24. timeout time.Duration
  25. idleTimeout time.Duration
  26. baseCtx context.Context
  27. cancelCtx func()
  28. }
  29. func New(cfg *Config) (*Server, error) {
  30. cfg = cfg.populateDefaults()
  31. baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
  32. srv := &Server{
  33. rAddr: cfg.RemoteAddress,
  34. timeout: cfg.Timeout,
  35. psk: cfg.PSKCallback,
  36. idleTimeout: cfg.IdleTimeout,
  37. baseCtx: baseCtx,
  38. cancelCtx: cancelCtx,
  39. }
  40. lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
  41. if err != nil {
  42. cancelCtx()
  43. return nil, fmt.Errorf("can't parse bind address: %w", err)
  44. }
  45. srv.dtlsConfig = &dtls.Config{
  46. CipherSuites: []dtls.CipherSuiteID{
  47. dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
  48. dtls.TLS_PSK_WITH_AES_128_CCM,
  49. dtls.TLS_PSK_WITH_AES_128_CCM_8,
  50. dtls.TLS_PSK_WITH_AES_256_CCM_8,
  51. dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
  52. dtls.TLS_PSK_WITH_AES_128_CBC_SHA256,
  53. },
  54. ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
  55. ConnectContextMaker: srv.contextMaker,
  56. PSK: srv.psk,
  57. }
  58. lc := udp.ListenConfig{
  59. AcceptFilter: func(packet []byte) bool {
  60. pkts, err := recordlayer.UnpackDatagram(packet)
  61. if err != nil || len(pkts) < 1 {
  62. return false
  63. }
  64. h := &recordlayer.Header{}
  65. if err := h.Unmarshal(pkts[0]); err != nil {
  66. return false
  67. }
  68. return h.ContentType == protocol.ContentTypeHandshake
  69. },
  70. }
  71. listener, err := lc.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort))
  72. if err != nil {
  73. cancelCtx()
  74. return nil, fmt.Errorf("server listen failed: %w", err)
  75. }
  76. srv.listener = listener
  77. go srv.listen()
  78. return srv, nil
  79. }
  80. func (srv *Server) listen() {
  81. defer srv.Close()
  82. for srv.baseCtx.Err() == nil {
  83. conn, err := srv.listener.Accept()
  84. if err != nil {
  85. log.Printf("conn accept failed: %v", err)
  86. continue
  87. }
  88. go func(conn net.Conn) {
  89. conn, err := dtls.Server(conn, srv.dtlsConfig)
  90. if err != nil {
  91. log.Printf("DTLS accept error: %v", err)
  92. return
  93. }
  94. srv.serve(conn)
  95. }(conn)
  96. }
  97. }
  98. func (srv *Server) serve(conn net.Conn) {
  99. log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  100. defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  101. defer conn.Close()
  102. dialCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
  103. defer cancel()
  104. remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
  105. if err != nil {
  106. log.Printf("remote dial failed: %v", err)
  107. return
  108. }
  109. defer remoteConn.Close()
  110. var lsn atomic.Int32
  111. var wg sync.WaitGroup
  112. copier := func(dst, src net.Conn) {
  113. defer wg.Done()
  114. defer dst.Close()
  115. buf := make([]byte, MaxPktBuf)
  116. for {
  117. oldLSN := lsn.Load()
  118. if err := src.SetReadDeadline(time.Now().Add(srv.idleTimeout)); err != nil {
  119. log.Println("can't update deadline for connection: %v", err)
  120. break
  121. }
  122. n, err := src.Read(buf)
  123. if err != nil {
  124. if isTimeout(err) {
  125. // hit read deadline
  126. if oldLSN != lsn.Load() {
  127. // not stale conn
  128. continue
  129. } else {
  130. log.Printf("dropping stale connection %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  131. }
  132. } else {
  133. // any other error
  134. log.Printf("read from %s error: %v", src.RemoteAddr(), err)
  135. }
  136. break
  137. }
  138. lsn.Add(1)
  139. _, err = dst.Write(buf[:n])
  140. if err != nil {
  141. log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
  142. break
  143. }
  144. }
  145. }
  146. wg.Add(2)
  147. go copier(conn, remoteConn)
  148. go copier(remoteConn, conn)
  149. wg.Wait()
  150. }
  151. func (srv *Server) contextMaker() (context.Context, func()) {
  152. return context.WithTimeout(srv.baseCtx, srv.timeout)
  153. }
  154. func (srv *Server) Close() error {
  155. srv.cancelCtx()
  156. return srv.listener.Close()
  157. }
  158. func isTimeout(err error) bool {
  159. if timeoutErr, ok := err.(interface {
  160. Timeout() bool
  161. }); ok {
  162. return timeoutErr.Timeout()
  163. }
  164. return false
  165. }