util.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package util
  2. import (
  3. "context"
  4. crand "crypto/rand"
  5. "encoding/hex"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net"
  10. "net/netip"
  11. "sync"
  12. "time"
  13. "github.com/Snawoot/rlzone"
  14. )
  15. func GenPSK(length int) ([]byte, error) {
  16. b := make([]byte, length)
  17. _, err := crand.Read(b)
  18. if err != nil {
  19. return nil, fmt.Errorf("random bytes generation failed: %w", err)
  20. }
  21. return b, nil
  22. }
  23. func GenPSKHex(length int) (string, error) {
  24. b, err := GenPSK(length)
  25. if err != nil {
  26. return "", fmt.Errorf("can't generate hex key: %w", err)
  27. }
  28. return hex.EncodeToString(b), nil
  29. }
  30. func PSKFromHex(input string) ([]byte, error) {
  31. return hex.DecodeString(input)
  32. }
  33. func isTimeout(err error) bool {
  34. if timeoutErr, ok := err.(interface {
  35. Timeout() bool
  36. }); ok {
  37. return timeoutErr.Timeout()
  38. }
  39. return false
  40. }
  41. func isTemporary(err error) bool {
  42. if timeoutErr, ok := err.(interface {
  43. Temporary() bool
  44. }); ok {
  45. return timeoutErr.Temporary()
  46. }
  47. return false
  48. }
  49. const (
  50. MaxPktBuf = 65536
  51. )
  52. func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) {
  53. var wg sync.WaitGroup
  54. tracker := newTracker(staleMode)
  55. copyDone := make(chan struct{})
  56. go func() {
  57. select {
  58. case <-ctx.Done():
  59. left.Close()
  60. right.Close()
  61. case <-copyDone:
  62. }
  63. }()
  64. defer close(copyDone)
  65. copier := func(dst, src net.Conn, label bool) {
  66. defer wg.Done()
  67. defer dst.Close()
  68. buf := make([]byte, MaxPktBuf)
  69. for {
  70. if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
  71. log.Printf("can't update deadline for connection: %v", err)
  72. break
  73. }
  74. n, err := src.Read(buf)
  75. if err != nil {
  76. if isTimeout(err) {
  77. // hit read deadline
  78. if tracker.handleTimeout(label) {
  79. // not stale conn
  80. continue
  81. } else {
  82. log.Printf("dropping stale connection %s <=> %s", src.LocalAddr(), src.RemoteAddr())
  83. }
  84. } else {
  85. // any other error
  86. if isTemporary(err) {
  87. log.Printf("ignoring temporary error during read from %s: %v", src.RemoteAddr(), err)
  88. continue
  89. }
  90. log.Printf("read from %s error: %v", src.RemoteAddr(), err)
  91. }
  92. break
  93. }
  94. tracker.notify(label)
  95. _, err = dst.Write(buf[:n])
  96. if err != nil {
  97. log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
  98. break
  99. }
  100. }
  101. }
  102. wg.Add(2)
  103. go copier(left, right, false)
  104. go copier(right, left, true)
  105. wg.Wait()
  106. }
  107. func NetAddrToNetipAddrPort(a net.Addr) netip.AddrPort {
  108. switch v := a.(type) {
  109. case *net.UDPAddr:
  110. return v.AddrPort()
  111. case *net.TCPAddr:
  112. return v.AddrPort()
  113. }
  114. res, _ := netip.ParseAddrPort(a.String())
  115. return res
  116. }
  117. func AllowAllFunc(_ net.Addr) bool {
  118. return true
  119. }
  120. func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr) bool {
  121. if z == nil {
  122. return AllowAllFunc
  123. }
  124. return func(remoteAddr net.Addr) bool {
  125. key := NetAddrToNetipAddrPort(remoteAddr).Addr()
  126. return z.Allow(key)
  127. }
  128. }
  129. func FixedTimeLimitFunc(d time.Duration) func() time.Duration {
  130. return func() time.Duration {
  131. return d
  132. }
  133. }
  134. func TimeLimitFunc(low, high time.Duration) func() time.Duration {
  135. if low > high {
  136. return TimeLimitFunc(high, low)
  137. }
  138. if low == high {
  139. return FixedTimeLimitFunc(low)
  140. }
  141. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  142. var mux sync.Mutex
  143. delta := high - low
  144. return func() time.Duration {
  145. mux.Lock()
  146. defer mux.Unlock()
  147. return low + time.Duration(r.Int63n(int64(delta)))
  148. }
  149. }
  150. type DynDialer struct {
  151. ep func() string
  152. resolver *net.Resolver
  153. }
  154. func NewDynDialer(ep func() string) DynDialer {
  155. return DynDialer{
  156. resolver: new(net.Resolver),
  157. ep: ep,
  158. }
  159. }
  160. func (d DynDialer) DialContext(ctx context.Context) (net.PacketConn, net.Addr, error) {
  161. host, port, err := net.SplitHostPort(d.ep())
  162. if err != nil {
  163. return nil, nil, fmt.Errorf("unable to split host and port: %w", err)
  164. }
  165. addrs, err := d.resolver.LookupIPAddr(ctx, host)
  166. if err != nil {
  167. return nil, nil, fmt.Errorf("address lookup failed: %w", err)
  168. }
  169. if len(addrs) == 0 {
  170. return nil, nil, fmt.Errorf("no addresses were resolved")
  171. }
  172. portNum, err := d.resolver.LookupPort(ctx, "udp", port)
  173. if err != nil {
  174. return nil, nil, fmt.Errorf("port lookup failed: %w", err)
  175. }
  176. pConn, err := net.ListenUDP("udp", nil)
  177. if err != nil {
  178. return nil, nil, fmt.Errorf("unable to open UDP socket: %w", err)
  179. }
  180. return pConn, &net.UDPAddr{
  181. IP: addrs[0].IP,
  182. Port: portNum,
  183. Zone: addrs[0].Zone,
  184. }, nil
  185. }