1
0

client.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. package client
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "net/netip"
  8. "sync"
  9. "time"
  10. "github.com/SenseUnit/dtlspipe/util"
  11. "github.com/pion/dtls/v3"
  12. "github.com/pion/transport/v3/udp"
  13. )
  14. const (
  15. MaxPktBuf = 65536
  16. Backlog = 1024
  17. )
  18. type Client struct {
  19. listener net.Listener
  20. dtlsConfig *dtls.Config
  21. remoteDialFn func(context.Context) (net.PacketConn, net.Addr, error)
  22. psk func([]byte) ([]byte, error)
  23. timeout time.Duration
  24. idleTimeout time.Duration
  25. baseCtx context.Context
  26. cancelCtx func()
  27. staleMode util.StaleMode
  28. workerWG sync.WaitGroup
  29. timeLimitFunc func() time.Duration
  30. allowFunc func(net.Addr) bool
  31. }
  32. func New(cfg *Config) (*Client, error) {
  33. cfg = cfg.populateDefaults()
  34. baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
  35. client := &Client{
  36. remoteDialFn: cfg.RemoteDialFunc,
  37. timeout: cfg.Timeout,
  38. psk: cfg.PSKCallback,
  39. idleTimeout: cfg.IdleTimeout,
  40. baseCtx: baseCtx,
  41. cancelCtx: cancelCtx,
  42. staleMode: cfg.StaleMode,
  43. timeLimitFunc: cfg.TimeLimitFunc,
  44. allowFunc: cfg.AllowFunc,
  45. }
  46. lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
  47. if err != nil {
  48. cancelCtx()
  49. return nil, fmt.Errorf("can't parse bind address: %w", err)
  50. }
  51. client.dtlsConfig = &dtls.Config{
  52. ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
  53. PSK: client.psk,
  54. PSKIdentityHint: []byte(cfg.PSKIdentity),
  55. MTU: cfg.MTU,
  56. CipherSuites: cfg.CipherSuites,
  57. EllipticCurves: cfg.EllipticCurves,
  58. }
  59. if cfg.EnableCID {
  60. client.dtlsConfig.ConnectionIDGenerator = dtls.OnlySendCIDGenerator()
  61. }
  62. lc := udp.ListenConfig{
  63. Backlog: Backlog,
  64. }
  65. listener, err := lc.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort))
  66. if err != nil {
  67. cancelCtx()
  68. return nil, fmt.Errorf("client listen failed: %w", err)
  69. }
  70. client.listener = listener
  71. go client.listen()
  72. return client, nil
  73. }
  74. func (client *Client) listen() {
  75. defer client.Close()
  76. for client.baseCtx.Err() == nil {
  77. conn, err := client.listener.Accept()
  78. if err != nil {
  79. log.Printf("conn accept failed: %v", err)
  80. continue
  81. }
  82. if !client.allowFunc(conn.RemoteAddr()) {
  83. continue
  84. }
  85. client.workerWG.Add(1)
  86. go func(conn net.Conn) {
  87. defer client.workerWG.Done()
  88. defer conn.Close()
  89. client.serve(conn)
  90. }(conn)
  91. }
  92. }
  93. func (client *Client) serve(conn net.Conn) {
  94. log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  95. defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  96. defer conn.Close()
  97. ctx := client.baseCtx
  98. tl := client.timeLimitFunc()
  99. if tl != 0 {
  100. newCtx, cancel := context.WithTimeout(ctx, tl)
  101. defer cancel()
  102. ctx = newCtx
  103. }
  104. remoteConn, err := func() (net.Conn, error) {
  105. dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
  106. defer cancel()
  107. remoteConn, remoteAddr, err := client.remoteDialFn(dialCtx)
  108. if err != nil {
  109. return nil, fmt.Errorf("remote dial function failed: %w", err)
  110. }
  111. dtlsConn, err := dtls.Client(remoteConn, remoteAddr, client.dtlsConfig)
  112. if err != nil {
  113. remoteConn.Close()
  114. return nil, fmt.Errorf("DTLS connection with remote server failed: %w", err)
  115. }
  116. if err := dtlsConn.HandshakeContext(dialCtx); err != nil {
  117. dtlsConn.Close()
  118. remoteConn.Close()
  119. return nil, fmt.Errorf("DTLS handshake with remote server failed: %w", err)
  120. }
  121. return dtlsConn, nil
  122. }()
  123. if err != nil {
  124. log.Printf("remote dial failed: %v", err)
  125. return
  126. }
  127. defer remoteConn.Close()
  128. util.PairConn(ctx, conn, remoteConn, client.idleTimeout, client.staleMode)
  129. }
  130. func (client *Client) Close() error {
  131. client.cancelCtx()
  132. err := client.listener.Close()
  133. client.workerWG.Wait()
  134. return err
  135. }