1
0
Эх сурвалжийг харах

Merge pull request #13 from Snawoot/dos_mitigation

DoS risks mitigation
Snawoot 1 жил өмнө
parent
commit
a994441391
9 өөрчлөгдсөн 96 нэмэгдсэн , 3 устгасан
  1. 2 0
      README.md
  2. 6 0
      client/client.go
  3. 5 0
      client/config.go
  4. 30 0
      cmd/dtlspipe/main.go
  5. 7 3
      go.mod
  6. 7 0
      go.sum
  7. 5 0
      server/config.go
  8. 6 0
      server/server.go
  9. 28 0
      util/util.go

+ 2 - 0
README.md

@@ -89,6 +89,8 @@ Options:
     	MTU used for DTLS fragments (default 1400)
   -psk string
     	hex-encoded pre-shared key. Can be generated with genpsk subcommand
+  -rate-limit value
+    	limit for incoming connections rate. Format: <limit>/<time duration> or empty string to disable (default 20/1m0s)
   -skip-hello-verify
     	(server only) skip hello verify request. Useful to workaround DPI (default true)
   -stale-mode value

+ 6 - 0
client/client.go

@@ -31,6 +31,7 @@ type Client struct {
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
 	timeLimit   time.Duration
+	allowFunc   func(net.Addr, net.Addr) bool
 }
 
 func New(cfg *Config) (*Client, error) {
@@ -47,6 +48,7 @@ func New(cfg *Config) (*Client, error) {
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
 		timeLimit:   cfg.TimeLimit,
+		allowFunc:   cfg.AllowFunc,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -89,6 +91,10 @@ func (client *Client) listen() {
 			continue
 		}
 
+		if !client.allowFunc(conn.LocalAddr(), conn.RemoteAddr()) {
+			continue
+		}
+
 		client.workerWG.Add(1)
 		go func(conn net.Conn) {
 			defer client.workerWG.Done()

+ 5 - 0
client/config.go

@@ -2,6 +2,7 @@ package client
 
 import (
 	"context"
+	"net"
 	"time"
 
 	"github.com/Snawoot/dtlspipe/ciphers"
@@ -21,6 +22,7 @@ type Config struct {
 	EllipticCurves ciphers.CurveList
 	StaleMode      util.StaleMode
 	TimeLimit      time.Duration
+	AllowFunc      func(localAddr, remoteAddr net.Addr) bool
 }
 
 func (cfg *Config) populateDefaults() *Config {
@@ -39,5 +41,8 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.EllipticCurves == nil {
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 	}
+	if cfg.AllowFunc == nil {
+		cfg.AllowFunc = util.AllowAllFunc
+	}
 	return cfg
 }

+ 30 - 0
cmd/dtlspipe/main.go

@@ -5,6 +5,7 @@ import (
 	"flag"
 	"fmt"
 	"log"
+	"net/netip"
 	"os"
 	"os/signal"
 	"runtime/pprof"
@@ -17,6 +18,7 @@ import (
 	"github.com/Snawoot/dtlspipe/keystore"
 	"github.com/Snawoot/dtlspipe/server"
 	"github.com/Snawoot/dtlspipe/util"
+	"github.com/Snawoot/rlzone"
 )
 
 const (
@@ -58,6 +60,30 @@ func (l *curvelistArg) Set(s string) error {
 	return nil
 }
 
+type ratelimitArg struct {
+	value rlzone.Ratelimiter[netip.Addr]
+}
+
+func (r *ratelimitArg) String() string {
+	if r == nil || r.value == nil {
+		return ""
+	}
+	return r.value.String()
+}
+
+func (r *ratelimitArg) Set(s string) error {
+	if s == "" {
+		r.value = nil
+		return nil
+	}
+	rl, err := rlzone.FromString[netip.Addr](s)
+	if err != nil {
+		return err
+	}
+	r.value = rl
+	return nil
+}
+
 var (
 	version = "undefined"
 
@@ -73,12 +99,14 @@ var (
 	curves          = curvelistArg{}
 	staleMode       = util.EitherStale
 	timeLimit       = flag.Duration("time-limit", 0, "hard time limit for each session")
+	rateLimit       = ratelimitArg{rlzone.Must(rlzone.NewSmallest[netip.Addr](1*time.Minute, 20))}
 )
 
 func init() {
 	flag.Var(&ciphersuites, "ciphers", "colon-separated list of ciphers to use")
 	flag.Var(&curves, "curves", "colon-separated list of curves to use")
 	flag.Var(&staleMode, "stale-mode", "which stale side of connection makes whole session stale (both, either, left, right)")
+	flag.Var(&rateLimit, "rate-limit", "limit for incoming connections rate. Format: <limit>/<time duration> or empty string to disable")
 }
 
 func usage() {
@@ -141,6 +169,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		EllipticCurves: curves.Value,
 		StaleMode:      staleMode,
 		TimeLimit:      *timeLimit,
+		AllowFunc:      util.AllowByRatelimit(rateLimit.value),
 	}
 
 	clt, err := client.New(&cfg)
@@ -179,6 +208,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		EllipticCurves:  curves.Value,
 		StaleMode:       staleMode,
 		TimeLimit:       *timeLimit,
+		AllowFunc:       util.AllowByRatelimit(rateLimit.value),
 	}
 
 	srv, err := server.New(&cfg)

+ 7 - 3
go.mod

@@ -1,10 +1,14 @@
 module github.com/Snawoot/dtlspipe
 
-go 1.21.0
+go 1.21.1
+
+require (
+	github.com/Snawoot/rlzone v0.2.0
+	github.com/pion/dtls/v2 v2.2.7
+	github.com/pion/transport/v2 v2.2.1
+)
 
 require (
-	github.com/pion/dtls/v2 v2.2.7 // indirect
 	github.com/pion/logging v0.2.2 // indirect
-	github.com/pion/transport/v2 v2.2.1 // indirect
 	golang.org/x/crypto v0.8.0 // indirect
 )

+ 7 - 0
go.sum

@@ -1,4 +1,7 @@
+github.com/Snawoot/rlzone v0.2.0 h1:l/Gl8ncAdCjdalZlE7THD4xlwCnvn6jCF3hsiL4SmWQ=
+github.com/Snawoot/rlzone v0.2.0/go.mod h1:5yK8f9nJSOAPizq2LZ35arkortJhjFx1eO6ckOQCnwQ=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/pion/dtls/v2 v2.2.7 h1:cSUBsETxepsCSFSxC3mc/aDo14qQLMSL+O6IjG28yV8=
 github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s=
@@ -6,12 +9,14 @@ github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY=
 github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms=
 github.com/pion/transport/v2 v2.2.1 h1:7qYnCBlpgSJNYMbLCKuSY9KbQdBFoETvPNETv0y4N7c=
 github.com/pion/transport/v2 v2.2.1/go.mod h1:cXXWavvCnFF6McHTft3DWS9iic2Mftcz1Aq29pGcU5g=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
+github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
 github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
@@ -24,6 +29,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL
 golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
 golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
 golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
+golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
 golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -51,4 +57,5 @@ golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 5 - 0
server/config.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"context"
+	"net"
 	"time"
 
 	"github.com/Snawoot/dtlspipe/ciphers"
@@ -21,6 +22,7 @@ type Config struct {
 	EllipticCurves  ciphers.CurveList
 	StaleMode       util.StaleMode
 	TimeLimit       time.Duration
+	AllowFunc       func(localAddr, remoteAddr net.Addr) bool
 }
 
 func (cfg *Config) populateDefaults() *Config {
@@ -39,5 +41,8 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.EllipticCurves == nil {
 		cfg.EllipticCurves = ciphers.DefaultCurveList
 	}
+	if cfg.AllowFunc == nil {
+		cfg.AllowFunc = util.AllowAllFunc
+	}
 	return cfg
 }

+ 6 - 0
server/server.go

@@ -32,6 +32,7 @@ type Server struct {
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
 	timeLimit   time.Duration
+	allowFunc   func(net.Addr, net.Addr) bool
 }
 
 func New(cfg *Config) (*Server, error) {
@@ -48,6 +49,7 @@ func New(cfg *Config) (*Server, error) {
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
 		timeLimit:   cfg.TimeLimit,
+		allowFunc:   cfg.AllowFunc,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -101,6 +103,10 @@ func (srv *Server) listen() {
 			continue
 		}
 
+		if !srv.allowFunc(conn.LocalAddr(), conn.RemoteAddr()) {
+			continue
+		}
+
 		srv.workerWG.Add(1)
 		go func(conn net.Conn) {
 			defer srv.workerWG.Done()

+ 28 - 0
util/util.go

@@ -7,8 +7,11 @@ import (
 	"fmt"
 	"log"
 	"net"
+	"net/netip"
 	"sync"
 	"time"
+
+	"github.com/Snawoot/rlzone"
 )
 
 func GenPSK(length int) ([]byte, error) {
@@ -117,3 +120,28 @@ func PairConn(ctx context.Context, left, right net.Conn, idleTimeout time.Durati
 	go copier(right, left, true)
 	wg.Wait()
 }
+
+func NetAddrToNetipAddrPort(a net.Addr) netip.AddrPort {
+	switch v := a.(type) {
+	case *net.UDPAddr:
+		return v.AddrPort()
+	case *net.TCPAddr:
+		return v.AddrPort()
+	}
+	res, _ := netip.ParseAddrPort(a.String())
+	return res
+}
+
+func AllowAllFunc(_, _ net.Addr) bool {
+	return true
+}
+
+func AllowByRatelimit(z rlzone.Ratelimiter[netip.Addr]) func(net.Addr, net.Addr) bool {
+	if z == nil {
+		return AllowAllFunc
+	}
+	return func(_, remoteAddr net.Addr) bool {
+		key := NetAddrToNetipAddrPort(remoteAddr).Addr()
+		return z.Allow(key)
+	}
+}