util.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. package util
  2. import (
  3. "crypto/rand"
  4. "encoding/hex"
  5. "fmt"
  6. "log"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. )
  12. func GenPSK(length int) ([]byte, error) {
  13. b := make([]byte, length)
  14. _, err := rand.Read(b)
  15. if err != nil {
  16. return nil, fmt.Errorf("random bytes generation failed: %w", err)
  17. }
  18. return b, nil
  19. }
  20. func GenPSKHex(length int) (string, error) {
  21. b, err := GenPSK(length)
  22. if err != nil {
  23. return "", fmt.Errorf("can't generate hex key: %w", err)
  24. }
  25. return hex.EncodeToString(b), nil
  26. }
  27. func PSKFromHex(input string) ([]byte, error) {
  28. return hex.DecodeString(input)
  29. }
  30. func isTimeout(err error) bool {
  31. if timeoutErr, ok := err.(interface {
  32. Timeout() bool
  33. }); ok {
  34. return timeoutErr.Timeout()
  35. }
  36. return false
  37. }
  38. func isTemporary(err error) bool {
  39. if timeoutErr, ok := err.(interface {
  40. Temporary() bool
  41. }); ok {
  42. return timeoutErr.Temporary()
  43. }
  44. return false
  45. }
  46. const (
  47. MaxPktBuf = 65536
  48. )
  49. func PairConn(left, right net.Conn, idleTimeout time.Duration) {
  50. var lsn atomic.Int32
  51. var wg sync.WaitGroup
  52. copier := func(dst, src net.Conn) {
  53. defer wg.Done()
  54. defer dst.Close()
  55. buf := make([]byte, MaxPktBuf)
  56. for {
  57. oldLSN := lsn.Load()
  58. if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
  59. log.Printf("can't update deadline for connection: %v", err)
  60. break
  61. }
  62. n, err := src.Read(buf)
  63. if err != nil {
  64. if isTimeout(err) {
  65. // hit read deadline
  66. if oldLSN != lsn.Load() {
  67. // not stale conn
  68. continue
  69. } else {
  70. log.Printf("dropping stale connection %s <=> %s", src.LocalAddr(), src.RemoteAddr())
  71. }
  72. } else {
  73. // any other error
  74. if isTemporary(err) {
  75. log.Printf("ignoring temporary error during read from %s: %v", src.RemoteAddr(), err)
  76. continue
  77. }
  78. log.Printf("read from %s error: %v", src.RemoteAddr(), err)
  79. }
  80. break
  81. }
  82. lsn.Add(1)
  83. _, err = dst.Write(buf[:n])
  84. if err != nil {
  85. log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
  86. break
  87. }
  88. }
  89. }
  90. wg.Add(2)
  91. go copier(left, right)
  92. go copier(right, left)
  93. wg.Wait()
  94. }