123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- package util
- import (
- "context"
- crand "crypto/rand"
- "encoding/hex"
- "fmt"
- "log"
- "math/rand"
- "net"
- "net/netip"
- "sync"
- "time"
- "github.com/Snawoot/rlzone"
- )
- func GenPSK(length int) ([]byte, error) {
- b := make([]byte, length)
- _, err := crand.Read(b)
- if err != nil {
- return nil, fmt.Errorf("random bytes generation failed: %w", err)
- }
- return b, nil
- }
- func GenPSKHex(length int) (string, error) {
- b, err := GenPSK(length)
- if err != nil {
- return "", fmt.Errorf("can't generate hex key: %w", err)
- }
- return hex.EncodeToString(b), nil
- }
- func PSKFromHex(input string) ([]byte, error) {
- return hex.DecodeString(input)
- }
- func isTimeout(err error) bool {
- if timeoutErr, ok := err.(interface {
- Timeout() bool
- }); ok {
- return timeoutErr.Timeout()
- }
- return false
- }
- func isTemporary(err error) bool {
- if timeoutErr, ok := err.(interface {
- Temporary() bool
- }); ok {
- return timeoutErr.Temporary()
- }
- return false
- }
- const (
- MaxPktBuf = 65536
- )
- func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Duration, staleMode StaleMode) {
- var wg sync.WaitGroup
- tracker := newTracker(staleMode)
- copyDone := make(chan struct{})
- go func() {
- select {
- case <-ctx.Done():
- left.Close()
- right.Close()
- case <-copyDone:
- }
- }()
- defer close(copyDone)
- copier := func(dst, src net.Conn, label bool) {
- defer wg.Done()
- defer dst.Close()
- buf := make([]byte, MaxPktBuf)
- for {
- if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
- log.Printf("can't update deadline for connection: %v", err)
- break
- }
- n, err := src.Read(buf)
- if err != nil {
- if isTimeout(err) {
- // hit read deadline
- if tracker.handleTimeout(label) {
- // not stale conn
- continue
- } else {
- log.Printf("dropping stale connection %s <=> %s", src.LocalAddr(), src.RemoteAddr())
- }
- } else {
- // any other error
- if isTemporary(err) {
- log.Printf("ignoring temporary error during read from %s: %v", src.RemoteAddr(), err)
- continue
- }
- log.Printf("read from %s error: %v", src.RemoteAddr(), err)
- }
- break
- }
- tracker.notify(label)
- _, err = dst.Write(buf[:n])
- if err != nil {
- log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
- break
- }
- }
- }
- wg.Add(2)
- go copier(left, right, false)
- go copier(right, left, true)
- wg.Wait()
- }
- func NetAddrToNetipAddrPort(a net.Addr) netip.AddrPort {
- switch v := a.(type) {
- case *net.UDPAddr:
- return v.AddrPort()
- case *net.TCPAddr:
- return v.AddrPort()
- }
- res, _ := netip.ParseAddrPort(a.String())
- return res
- }
- func AllowAllFunc(_ net.Addr) bool {
- return true
- }
- func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr) bool {
- if z == nil {
- return AllowAllFunc
- }
- return func(remoteAddr net.Addr) bool {
- key := NetAddrToNetipAddrPort(remoteAddr).Addr()
- return z.Allow(key)
- }
- }
- func FixedTimeLimitFunc(d time.Duration) func() time.Duration {
- return func() time.Duration {
- return d
- }
- }
- func TimeLimitFunc(low, high time.Duration) func() time.Duration {
- if low > high {
- return TimeLimitFunc(high, low)
- }
- if low == high {
- return FixedTimeLimitFunc(low)
- }
- r := rand.New(rand.NewSource(time.Now().UnixNano()))
- var mux sync.Mutex
- delta := high - low
- return func() time.Duration {
- mux.Lock()
- defer mux.Unlock()
- return low + time.Duration(r.Int63n(int64(delta)))
- }
- }
- type DynDialer struct {
- ep func() string
- resolver *net.Resolver
- }
- func NewDynDialer(ep func() string) DynDialer {
- return DynDialer{
- resolver: new(net.Resolver),
- ep: ep,
- }
- }
- func (d DynDialer) DialContext(ctx context.Context) (net.PacketConn, net.Addr, error) {
- host, port, err := net.SplitHostPort(d.ep())
- if err != nil {
- return nil, nil, fmt.Errorf("unable to split host and port: %w", err)
- }
- addrs, err := d.resolver.LookupIPAddr(ctx, host)
- if err != nil {
- return nil, nil, fmt.Errorf("address lookup failed: %w", err)
- }
- if len(addrs) == 0 {
- return nil, nil, fmt.Errorf("no addresses were resolved")
- }
- portNum, err := d.resolver.LookupPort(ctx, "udp", port)
- if err != nil {
- return nil, nil, fmt.Errorf("port lookup failed: %w", err)
- }
- pConn, err := net.ListenUDP("udp", nil)
- if err != nil {
- return nil, nil, fmt.Errorf("unable to open UDP socket: %w", err)
- }
- return pConn, &net.UDPAddr{
- IP: addrs[0].IP,
- Port: portNum,
- Zone: addrs[0].Zone,
- }, nil
- }
|