client.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package client
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "net/netip"
  8. "sync"
  9. "time"
  10. "github.com/SenseUnit/dtlspipe/util"
  11. "github.com/pion/dtls/v3"
  12. "github.com/pion/transport/v3/udp"
  13. )
  14. const (
  15. MaxPktBuf = 65536
  16. Backlog = 1024
  17. )
  18. type Client struct {
  19. listener net.Listener
  20. dtlsConfig *dtls.Config
  21. remoteDialFn func(context.Context) (net.PacketConn, net.Addr, error)
  22. psk func([]byte) ([]byte, error)
  23. timeout time.Duration
  24. idleTimeout time.Duration
  25. baseCtx context.Context
  26. cancelCtx func()
  27. staleMode util.StaleMode
  28. workerWG sync.WaitGroup
  29. timeLimitFunc func() time.Duration
  30. allowFunc func(net.Addr) bool
  31. }
  32. func New(cfg *Config) (*Client, error) {
  33. cfg = cfg.populateDefaults()
  34. baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
  35. client := &Client{
  36. remoteDialFn: cfg.RemoteDialFunc,
  37. timeout: cfg.Timeout,
  38. psk: cfg.PSKCallback,
  39. idleTimeout: cfg.IdleTimeout,
  40. baseCtx: baseCtx,
  41. cancelCtx: cancelCtx,
  42. staleMode: cfg.StaleMode,
  43. timeLimitFunc: cfg.TimeLimitFunc,
  44. allowFunc: cfg.AllowFunc,
  45. }
  46. lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
  47. if err != nil {
  48. cancelCtx()
  49. return nil, fmt.Errorf("can't parse bind address: %w", err)
  50. }
  51. client.dtlsConfig = &dtls.Config{
  52. ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
  53. PSK: client.psk,
  54. PSKIdentityHint: []byte(cfg.PSKIdentity),
  55. MTU: cfg.MTU,
  56. CipherSuites: cfg.CipherSuites,
  57. EllipticCurves: cfg.EllipticCurves,
  58. }
  59. lc := udp.ListenConfig{
  60. Backlog: Backlog,
  61. }
  62. listener, err := lc.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort))
  63. if err != nil {
  64. cancelCtx()
  65. return nil, fmt.Errorf("client listen failed: %w", err)
  66. }
  67. client.listener = listener
  68. go client.listen()
  69. return client, nil
  70. }
  71. func (client *Client) listen() {
  72. defer client.Close()
  73. for client.baseCtx.Err() == nil {
  74. conn, err := client.listener.Accept()
  75. if err != nil {
  76. log.Printf("conn accept failed: %v", err)
  77. continue
  78. }
  79. if !client.allowFunc(conn.RemoteAddr()) {
  80. continue
  81. }
  82. client.workerWG.Add(1)
  83. go func(conn net.Conn) {
  84. defer client.workerWG.Done()
  85. defer conn.Close()
  86. client.serve(conn)
  87. }(conn)
  88. }
  89. }
  90. func (client *Client) serve(conn net.Conn) {
  91. log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  92. defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
  93. defer conn.Close()
  94. ctx := client.baseCtx
  95. tl := client.timeLimitFunc()
  96. if tl != 0 {
  97. newCtx, cancel := context.WithTimeout(ctx, tl)
  98. defer cancel()
  99. ctx = newCtx
  100. }
  101. dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
  102. defer cancel()
  103. remoteConn, remoteAddr, err := client.remoteDialFn(dialCtx)
  104. if err != nil {
  105. log.Printf("remote dial failed: %v", err)
  106. return
  107. }
  108. defer remoteConn.Close()
  109. dtlsConn, err := dtls.Client(remoteConn, remoteAddr, client.dtlsConfig)
  110. if err != nil {
  111. log.Printf("DTLS connection with remote server failed: %v", err)
  112. return
  113. }
  114. defer dtlsConn.Close()
  115. if err := dtlsConn.HandshakeContext(dialCtx); err != nil {
  116. log.Printf("DTLS handshake with remote server failed: %v", err)
  117. return
  118. }
  119. util.PairConn(ctx, conn, dtlsConn, client.idleTimeout, client.staleMode)
  120. }
  121. func (client *Client) Close() error {
  122. client.cancelCtx()
  123. err := client.listener.Close()
  124. client.workerWG.Wait()
  125. return err
  126. }