util.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package util
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "encoding/hex"
  6. "fmt"
  7. "log"
  8. "net"
  9. "sync"
  10. "time"
  11. )
  12. func GenPSK(length int) ([]byte, error) {
  13. b := make([]byte, length)
  14. _, err := rand.Read(b)
  15. if err != nil {
  16. return nil, fmt.Errorf("random bytes generation failed: %w", err)
  17. }
  18. return b, nil
  19. }
  20. func GenPSKHex(length int) (string, error) {
  21. b, err := GenPSK(length)
  22. if err != nil {
  23. return "", fmt.Errorf("can't generate hex key: %w", err)
  24. }
  25. return hex.EncodeToString(b), nil
  26. }
  27. func PSKFromHex(input string) ([]byte, error) {
  28. return hex.DecodeString(input)
  29. }
  30. func isTimeout(err error) bool {
  31. if timeoutErr, ok := err.(interface {
  32. Timeout() bool
  33. }); ok {
  34. return timeoutErr.Timeout()
  35. }
  36. return false
  37. }
  38. func isTemporary(err error) bool {
  39. if timeoutErr, ok := err.(interface {
  40. Temporary() bool
  41. }); ok {
  42. return timeoutErr.Temporary()
  43. }
  44. return false
  45. }
  46. const (
  47. MaxPktBuf = 65536
  48. )
  49. func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) {
  50. var wg sync.WaitGroup
  51. tracker := newTracker(staleMode)
  52. copyDone := make(chan struct{})
  53. go func() {
  54. select {
  55. case <-ctx.Done():
  56. left.Close()
  57. right.Close()
  58. case <-copyDone:
  59. }
  60. }()
  61. defer close(copyDone)
  62. copier := func(dst, src net.Conn, label bool) {
  63. defer wg.Done()
  64. defer dst.Close()
  65. buf := make([]byte, MaxPktBuf)
  66. for {
  67. if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
  68. log.Printf("can't update deadline for connection: %v", err)
  69. break
  70. }
  71. n, err := src.Read(buf)
  72. if err != nil {
  73. if isTimeout(err) {
  74. // hit read deadline
  75. if tracker.handleTimeout(label) {
  76. // not stale conn
  77. continue
  78. } else {
  79. log.Printf("dropping stale connection %s <=> %s", src.LocalAddr(), src.RemoteAddr())
  80. }
  81. } else {
  82. // any other error
  83. if isTemporary(err) {
  84. log.Printf("ignoring temporary error during read from %s: %v", src.RemoteAddr(), err)
  85. continue
  86. }
  87. log.Printf("read from %s error: %v", src.RemoteAddr(), err)
  88. }
  89. break
  90. }
  91. tracker.notify(label)
  92. _, err = dst.Write(buf[:n])
  93. if err != nil {
  94. log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
  95. break
  96. }
  97. }
  98. }
  99. wg.Add(2)
  100. go copier(left, right, false)
  101. go copier(right, left, true)
  102. wg.Wait()
  103. }