util.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package util
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "encoding/hex"
  6. "fmt"
  7. "log"
  8. "net"
  9. "net/netip"
  10. "sync"
  11. "time"
  12. "github.com/Snawoot/rlzone"
  13. )
  14. func GenPSK(length int) ([]byte, error) {
  15. b := make([]byte, length)
  16. _, err := rand.Read(b)
  17. if err != nil {
  18. return nil, fmt.Errorf("random bytes generation failed: %w", err)
  19. }
  20. return b, nil
  21. }
  22. func GenPSKHex(length int) (string, error) {
  23. b, err := GenPSK(length)
  24. if err != nil {
  25. return "", fmt.Errorf("can't generate hex key: %w", err)
  26. }
  27. return hex.EncodeToString(b), nil
  28. }
  29. func PSKFromHex(input string) ([]byte, error) {
  30. return hex.DecodeString(input)
  31. }
  32. func isTimeout(err error) bool {
  33. if timeoutErr, ok := err.(interface {
  34. Timeout() bool
  35. }); ok {
  36. return timeoutErr.Timeout()
  37. }
  38. return false
  39. }
  40. func isTemporary(err error) bool {
  41. if timeoutErr, ok := err.(interface {
  42. Temporary() bool
  43. }); ok {
  44. return timeoutErr.Temporary()
  45. }
  46. return false
  47. }
  48. const (
  49. MaxPktBuf = 65536
  50. )
  51. func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) {
  52. var wg sync.WaitGroup
  53. tracker := newTracker(staleMode)
  54. copyDone := make(chan struct{})
  55. go func() {
  56. select {
  57. case <-ctx.Done():
  58. left.Close()
  59. right.Close()
  60. case <-copyDone:
  61. }
  62. }()
  63. defer close(copyDone)
  64. copier := func(dst, src net.Conn, label bool) {
  65. defer wg.Done()
  66. defer dst.Close()
  67. buf := make([]byte, MaxPktBuf)
  68. for {
  69. if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
  70. log.Printf("can't update deadline for connection: %v", err)
  71. break
  72. }
  73. n, err := src.Read(buf)
  74. if err != nil {
  75. if isTimeout(err) {
  76. // hit read deadline
  77. if tracker.handleTimeout(label) {
  78. // not stale conn
  79. continue
  80. } else {
  81. log.Printf("dropping stale connection %s <=> %s", src.LocalAddr(), src.RemoteAddr())
  82. }
  83. } else {
  84. // any other error
  85. if isTemporary(err) {
  86. log.Printf("ignoring temporary error during read from %s: %v", src.RemoteAddr(), err)
  87. continue
  88. }
  89. log.Printf("read from %s error: %v", src.RemoteAddr(), err)
  90. }
  91. break
  92. }
  93. tracker.notify(label)
  94. _, err = dst.Write(buf[:n])
  95. if err != nil {
  96. log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
  97. break
  98. }
  99. }
  100. }
  101. wg.Add(2)
  102. go copier(left, right, false)
  103. go copier(right, left, true)
  104. wg.Wait()
  105. }
  106. func NetAddrToNetipAddrPort(a net.Addr) netip.AddrPort {
  107. switch v := a.(type) {
  108. case *net.UDPAddr:
  109. return v.AddrPort()
  110. case *net.TCPAddr:
  111. return v.AddrPort()
  112. }
  113. res, _ := netip.ParseAddrPort(a.String())
  114. return res
  115. }
  116. func AllowAllFunc(_, _ net.Addr) bool {
  117. return true
  118. }
  119. func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr, net.Addr) bool {
  120. if z == nil {
  121. return AllowAllFunc
  122. }
  123. return func(_, remoteAddr net.Addr) bool {
  124. key := NetAddrToNetipAddrPort(remoteAddr).Addr()
  125. return z.Allow(key)
  126. }
  127. }