Browse Source

hook ratelimit to client and server configs

Vladislav Yarmak 1 year ago
parent
commit
5aa4d353c6
4 changed files with 40 additions and 0 deletions
  1. 5 0
      client/config.go
  2. 2 0
      cmd/dtlspipe/main.go
  3. 5 0
      server/config.go
  4. 28 0
      util/util.go

+ 5 - 0
client/config.go

@@ -2,6 +2,7 @@ package client
 
 import (
 	"context"
+	"net"
 	"time"
 
 	"github.com/Snawoot/dtlspipe/ciphers"
@@ -21,6 +22,7 @@ type Config struct {
 	EllipticCurves ciphers.CurveList
 	StaleMode      util.StaleMode
 	TimeLimit      time.Duration
+	AllowFunc      func(localAddr, remoteAddr net.Addr) bool
 }
 
 func (cfg *Config) populateDefaults() *Config {
@@ -39,5 +41,8 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.EllipticCurves == nil {
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 	}
+	if cfg.AllowFunc == nil {
+		cfg.AllowFunc = util.AllowAllFunc
+	}
 	return cfg
 }

+ 2 - 0
cmd/dtlspipe/main.go

@@ -169,6 +169,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		EllipticCurves: curves.Value,
 		StaleMode:      staleMode,
 		TimeLimit:      *timeLimit,
+		AllowFunc:      util.AllowByRatelimit(rateLimit.value),
 	}
 
 	clt, err := client.New(&cfg)
@@ -207,6 +208,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		EllipticCurves:  curves.Value,
 		StaleMode:       staleMode,
 		TimeLimit:       *timeLimit,
+		AllowFunc:       util.AllowByRatelimit(rateLimit.value),
 	}
 
 	srv, err := server.New(&cfg)

+ 5 - 0
server/config.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"context"
+	"net"
 	"time"
 
 	"github.com/Snawoot/dtlspipe/ciphers"
@@ -21,6 +22,7 @@ type Config struct {
 	EllipticCurves  ciphers.CurveList
 	StaleMode       util.StaleMode
 	TimeLimit       time.Duration
+	AllowFunc       func(localAddr, remoteAddr net.Addr) bool
 }
 
 func (cfg *Config) populateDefaults() *Config {
@@ -39,5 +41,8 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.EllipticCurves == nil {
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 	}
+	if cfg.AllowFunc == nil {
+		cfg.AllowFunc = util.AllowAllFunc
+	}
 	return cfg
 }

+ 28 - 0
util/util.go

@@ -7,8 +7,11 @@ import (
 	"fmt"
 	"log"
 	"net"
+	"net/netip"
 	"sync"
 	"time"
+
+	"github.com/Snawoot/rlzone"
 )
 
 func GenPSK(length int) ([]byte, error) {
@@ -117,3 +120,28 @@ func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Durati
 	go copier(right, left, true)
 	wg.Wait()
 }
+
+func NetAddrToNetipAddrPort(a net.Addr) netip.AddrPort {
+	switch v := a.(type) {
+	case *net.UDPAddr:
+		return v.AddrPort()
+	case *net.TCPAddr:
+		return v.AddrPort()
+	}
+	res, _ := netip.ParseAddrPort(a.String())
+	return res
+}
+
+func AllowAllFunc(_, _ net.Addr) bool {
+	return true
+}
+
+func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr, net.Addr) bool {
+	if z == nil {
+		return AllowAllFunc
+	}
+	return func(_, remoteAddr net.Addr) bool {
+		key := NetAddrToNetipAddrPort(remoteAddr).Addr()
+		return z.Allow(key)
+	}
+}