|
@@ -7,8 +7,11 @@ import (
|
|
|
"fmt"
|
|
|
"log"
|
|
|
"net"
|
|
|
+ "net/netip"
|
|
|
"sync"
|
|
|
"time"
|
|
|
+
|
|
|
+ "github.com/Snawoot/rlzone"
|
|
|
)
|
|
|
|
|
|
func GenPSK(length int) ([]byte, error) {
|
|
@@ -117,3 +120,28 @@ func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Durati
|
|
|
go copier(right, left, true)
|
|
|
wg.Wait()
|
|
|
}
|
|
|
+
|
|
|
+func NetAddrToNetipAddrPort(a net.Addr) netip.AddrPort {
|
|
|
+ switch v := a.(type) {
|
|
|
+ case *net.UDPAddr:
|
|
|
+ return v.AddrPort()
|
|
|
+ case *net.TCPAddr:
|
|
|
+ return v.AddrPort()
|
|
|
+ }
|
|
|
+ res, _ := netip.ParseAddrPort(a.String())
|
|
|
+ return res
|
|
|
+}
|
|
|
+
|
|
|
+func AllowAllFunc(_, _ net.Addr) bool {
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr, net.Addr) bool {
|
|
|
+ if z == nil {
|
|
|
+ return AllowAllFunc
|
|
|
+ }
|
|
|
+ return func(_, remoteAddr net.Addr) bool {
|
|
|
+ key := NetAddrToNetipAddrPort(remoteAddr).Addr()
|
|
|
+ return z.Allow(key)
|
|
|
+ }
|
|
|
+}
|