util.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package util
  2. import (
  3. "context"
  4. crand "crypto/rand"
  5. "encoding/hex"
  6. "fmt"
  7. "log"
  8. "math/rand"
  9. "net"
  10. "net/netip"
  11. "sync"
  12. "time"
  13. "github.com/Snawoot/rlzone"
  14. )
  15. func GenPSK(length int) ([]byte, error) {
  16. b := make([]byte, length)
  17. _, err := crand.Read(b)
  18. if err != nil {
  19. return nil, fmt.Errorf("random bytes generation failed: %w", err)
  20. }
  21. return b, nil
  22. }
  23. func GenPSKHex(length int) (string, error) {
  24. b, err := GenPSK(length)
  25. if err != nil {
  26. return "", fmt.Errorf("can't generate hex key: %w", err)
  27. }
  28. return hex.EncodeToString(b), nil
  29. }
  30. func PSKFromHex(input string) ([]byte, error) {
  31. return hex.DecodeString(input)
  32. }
  33. func isTimeout(err error) bool {
  34. if timeoutErr, ok := err.(interface {
  35. Timeout() bool
  36. }); ok {
  37. return timeoutErr.Timeout()
  38. }
  39. return false
  40. }
  41. func isTemporary(err error) bool {
  42. if timeoutErr, ok := err.(interface {
  43. Temporary() bool
  44. }); ok {
  45. return timeoutErr.Temporary()
  46. }
  47. return false
  48. }
  49. const (
  50. MaxPktBuf = 65536
  51. )
  52. func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) {
  53. var wg sync.WaitGroup
  54. tracker := newTracker(staleMode)
  55. copyDone := make(chan struct{})
  56. go func() {
  57. select {
  58. case <-ctx.Done():
  59. left.Close()
  60. right.Close()
  61. case <-copyDone:
  62. }
  63. }()
  64. defer close(copyDone)
  65. copier := func(dst, src net.Conn, label bool) {
  66. defer wg.Done()
  67. defer dst.Close()
  68. buf := make([]byte, MaxPktBuf)
  69. for {
  70. if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
  71. log.Printf("can't update deadline for connection: %v", err)
  72. break
  73. }
  74. n, err := src.Read(buf)
  75. if err != nil {
  76. if isTimeout(err) {
  77. // hit read deadline
  78. if tracker.handleTimeout(label) {
  79. // not stale conn
  80. continue
  81. } else {
  82. log.Printf("dropping stale connection %s <=> %s", src.LocalAddr(), src.RemoteAddr())
  83. }
  84. } else {
  85. // any other error
  86. if isTemporary(err) {
  87. log.Printf("ignoring temporary error during read from %s: %v", src.RemoteAddr(), err)
  88. continue
  89. }
  90. log.Printf("read from %s error: %v", src.RemoteAddr(), err)
  91. }
  92. break
  93. }
  94. tracker.notify(label)
  95. _, err = dst.Write(buf[:n])
  96. if err != nil {
  97. log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
  98. break
  99. }
  100. }
  101. }
  102. wg.Add(2)
  103. go copier(left, right, false)
  104. go copier(right, left, true)
  105. wg.Wait()
  106. }
  107. func NetAddrToNetipAddrPort(a net.Addr) netip.AddrPort {
  108. switch v := a.(type) {
  109. case *net.UDPAddr:
  110. return v.AddrPort()
  111. case *net.TCPAddr:
  112. return v.AddrPort()
  113. }
  114. res, _ := netip.ParseAddrPort(a.String())
  115. return res
  116. }
  117. func AllowAllFunc(_, _ net.Addr) bool {
  118. return true
  119. }
  120. func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr, net.Addr) bool {
  121. if z == nil {
  122. return AllowAllFunc
  123. }
  124. return func(_, remoteAddr net.Addr) bool {
  125. key := NetAddrToNetipAddrPort(remoteAddr).Addr()
  126. return z.Allow(key)
  127. }
  128. }
  129. func FixedTimeLimitFunc(d time.Duration) func() time.Duration {
  130. return func() time.Duration {
  131. return d
  132. }
  133. }
  134. func TimeLimitFunc(low, high time.Duration) func() time.Duration {
  135. if low > high {
  136. return TimeLimitFunc(high, low)
  137. }
  138. if low == high {
  139. return FixedTimeLimitFunc(low)
  140. }
  141. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  142. var mux sync.Mutex
  143. delta := high - low
  144. return func() time.Duration {
  145. mux.Lock()
  146. defer mux.Unlock()
  147. return low + time.Duration(r.Int63n(int64(delta)))
  148. }
  149. }
  150. type DynDialer struct {
  151. dial func(context.Context, string, string) (net.Conn, error)
  152. ep func() string
  153. }
  154. func NewDynDialer(ep func() string, dial func(context.Context, string, string) (net.Conn, error)) DynDialer {
  155. if dial == nil {
  156. dial = (&net.Dialer{}).DialContext
  157. }
  158. return DynDialer{
  159. ep: ep,
  160. dial: dial,
  161. }
  162. }
  163. func (d DynDialer) DialContext(ctx context.Context, network string) (net.Conn, error) {
  164. return d.dial(ctx, network, d.ep())
  165. }