package util import ( "context" crand "crypto/rand" "encoding/hex" "fmt" "log" "math/rand" "net" "net/netip" "sync" "time" "github.com/Snawoot/rlzone" ) func GenPSK(length int) ([]byte, error) { b := make([]byte, length) _, err := crand.Read(b) if err != nil { return nil, fmt.Errorf("random bytes generation failed: %w", err) } return b, nil } func GenPSKHex(length int) (string, error) { b, err := GenPSK(length) if err != nil { return "", fmt.Errorf("can't generate hex key: %w", err) } return hex.EncodeToString(b), nil } func PSKFromHex(input string) ([]byte, error) { return hex.DecodeString(input) } func isTimeout(err error) bool { if timeoutErr, ok := err.(interface { Timeout() bool }); ok { return timeoutErr.Timeout() } return false } func isTemporary(err error) bool { if timeoutErr, ok := err.(interface { Temporary() bool }); ok { return timeoutErr.Temporary() } return false } const ( MaxPktBuf = 65536 ) func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) { var wg sync.WaitGroup tracker := newTracker(staleMode) copyDone := make(chan struct{}) go func() { select { case <-ctx.Done(): left.Close() right.Close() case <-copyDone: } }() defer close(copyDone) copier := func(dst, src net.Conn, label bool) { defer wg.Done() defer dst.Close() buf := make([]byte, MaxPktBuf) for { if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { log.Printf("can't update deadline for connection: %v", err) break } n, err := src.Read(buf) if err != nil { if isTimeout(err) { // hit read deadline if tracker.handleTimeout(label) { // not stale conn continue } else { log.Printf("dropping stale connection %s <=> %s", src.LocalAddr(), src.RemoteAddr()) } } else { // any other error if isTemporary(err) { log.Printf("ignoring temporary error during read from %s: %v", src.RemoteAddr(), err) continue } log.Printf("read from %s error: %v", src.RemoteAddr(), err) } break } tracker.notify(label) _, err = dst.Write(buf[:n]) if err != nil { log.Printf("write to %s error: %v", dst.RemoteAddr(), err) break } } } wg.Add(2) go copier(left, right, false) go copier(right, left, true) wg.Wait() } func NetAddrToNetipAddrPort(a net.Addr) netip.AddrPort { switch v := a.(type) { case *net.UDPAddr: return v.AddrPort() case *net.TCPAddr: return v.AddrPort() } res, _ := netip.ParseAddrPort(a.String()) return res } func AllowAllFunc(_ net.Addr) bool { return true } func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr) bool { if z == nil { return AllowAllFunc } return func(remoteAddr net.Addr) bool { key := NetAddrToNetipAddrPort(remoteAddr).Addr() return z.Allow(key) } } func FixedTimeLimitFunc(d time.Duration) func() time.Duration { return func() time.Duration { return d } } func TimeLimitFunc(low, high time.Duration) func() time.Duration { if low > high { return TimeLimitFunc(high, low) } if low == high { return FixedTimeLimitFunc(low) } r := rand.New(rand.NewSource(time.Now().UnixNano())) var mux sync.Mutex delta := high - low return func() time.Duration { mux.Lock() defer mux.Unlock() return low + time.Duration(r.Int63n(int64(delta))) } } type DynDialer struct { ep func() string resolver *net.Resolver } func NewDynDialer(ep func() string) DynDialer { return DynDialer{ resolver: new(net.Resolver), ep: ep, } } func (d DynDialer) DialContext(ctx context.Context) (net.PacketConn, net.Addr, error) { host, port, err := net.SplitHostPort(d.ep()) if err != nil { return nil, nil, fmt.Errorf("unable to split host and port: %w", err) } addrs, err := d.resolver.LookupIPAddr(ctx, host) if err != nil { return nil, nil, fmt.Errorf("address lookup failed: %w", err) } if len(addrs) == 0 { return nil, nil, fmt.Errorf("no addresses were resolved") } portNum, err := d.resolver.LookupPort(ctx, "udp", port) if err != nil { return nil, nil, fmt.Errorf("port lookup failed: %w", err) } pConn, err := net.ListenUDP("udp", nil) if err != nil { return nil, nil, fmt.Errorf("unable to open UDP socket: %w", err) } return pConn, &net.UDPAddr{ IP: addrs[0].IP, Port: portNum, Zone: addrs[0].Zone, }, nil }