Browse Source

use AllowFunc in client and server

Vladislav Yarmak 1 year ago
parent
commit
788ecfbaec
2 changed files with 12 additions and 0 deletions
  1. 6 0
      client/client.go
  2. 6 0
      server/server.go

+ 6 - 0
client/client.go

@@ -31,6 +31,7 @@ type Client struct {
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
 	timeLimit   time.Duration
+	allowFunc   func(net.Addr, net.Addr) bool
 }
 
 func New(cfg *Config) (*Client, error) {
@@ -47,6 +48,7 @@ func New(cfg *Config) (*Client, error) {
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
 		timeLimit:   cfg.TimeLimit,
+		allowFunc:   cfg.AllowFunc,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -89,6 +91,10 @@ func (client *Client) listen() {
 			continue
 		}
 
+		if !client.allowFunc(conn.LocalAddr(), conn.RemoteAddr()) {
+			continue
+		}
+
 		client.workerWG.Add(1)
 		go func(conn net.Conn) {
 			defer client.workerWG.Done()

+ 6 - 0
server/server.go

@@ -32,6 +32,7 @@ type Server struct {
 	staleMode   util.StaleMode
 	workerWG    sync.WaitGroup
 	timeLimit   time.Duration
+	allowFunc   func(net.Addr, net.Addr) bool
 }
 
 func New(cfg *Config) (*Server, error) {
@@ -48,6 +49,7 @@ func New(cfg *Config) (*Server, error) {
 		cancelCtx:   cancelCtx,
 		staleMode:   cfg.StaleMode,
 		timeLimit:   cfg.TimeLimit,
+		allowFunc:   cfg.AllowFunc,
 	}
 
 	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
@@ -101,6 +103,10 @@ func (srv *Server) listen() {
 			continue
 		}
 
+		if !srv.allowFunc(conn.LocalAddr(), conn.RemoteAddr()) {
+			continue
+		}
+
 		srv.workerWG.Add(1)
 		go func(conn net.Conn) {
 			defer srv.workerWG.Done()