util.go 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package util
  2. import (
  3. "crypto/rand"
  4. "encoding/hex"
  5. "fmt"
  6. "log"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. )
  12. func GenPSK(length int) ([]byte, error) {
  13. b := make([]byte, length)
  14. _, err := rand.Read(b)
  15. if err != nil {
  16. return nil, fmt.Errorf("random bytes generation failed: %w", err)
  17. }
  18. return b, nil
  19. }
  20. func GenPSKHex(length int) (string, error) {
  21. b, err := GenPSK(length)
  22. if err != nil {
  23. return "", fmt.Errorf("can't generate hex key: %w", err)
  24. }
  25. return hex.EncodeToString(b), nil
  26. }
  27. func PSKFromHex(input string) ([]byte, error) {
  28. return hex.DecodeString(input)
  29. }
  30. func isTimeout(err error) bool {
  31. if timeoutErr, ok := err.(interface {
  32. Timeout() bool
  33. }); ok {
  34. return timeoutErr.Timeout()
  35. }
  36. return false
  37. }
  38. const (
  39. MaxPktBuf = 4096
  40. )
  41. func PairConn(left, right net.Conn, idleTimeout time.Duration) {
  42. var lsn atomic.Int32
  43. var wg sync.WaitGroup
  44. copier := func(dst, src net.Conn) {
  45. defer wg.Done()
  46. defer dst.Close()
  47. buf := make([]byte, MaxPktBuf)
  48. for {
  49. oldLSN := lsn.Load()
  50. if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
  51. log.Println("can't update deadline for connection: %v", err)
  52. break
  53. }
  54. n, err := src.Read(buf)
  55. if err != nil {
  56. if isTimeout(err) {
  57. // hit read deadline
  58. if oldLSN != lsn.Load() {
  59. // not stale conn
  60. continue
  61. } else {
  62. log.Printf("dropping stale connection %s <=> %s", src.LocalAddr(), src.RemoteAddr())
  63. }
  64. } else {
  65. // any other error
  66. log.Printf("read from %s error: %v", src.RemoteAddr(), err)
  67. }
  68. break
  69. }
  70. lsn.Add(1)
  71. _, err = dst.Write(buf[:n])
  72. if err != nil {
  73. log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
  74. break
  75. }
  76. }
  77. }
  78. wg.Add(2)
  79. go copier(left, right)
  80. go copier(right, left)
  81. wg.Wait()
  82. }