Prechádzať zdrojové kódy

client: proper dial context handling

Vladislav Yarmak 1 rok pred
rodič
commit
1746365295
2 zmenil súbory, kde vykonal 24 pridanie a 17 odobranie
  1. 23 16
      client/client.go
  2. 1 1
      server/server.go

+ 23 - 16
client/client.go

@@ -116,28 +116,35 @@ func (client *Client) serve(conn net.Conn) {
 		ctx = newCtx
 	}
 
-	dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
-	defer cancel()
-	remoteConn, remoteAddr, err := client.remoteDialFn(dialCtx)
+	remoteConn, err := func() (net.Conn, error) {
+		dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
+		defer cancel()
+		remoteConn, remoteAddr, err := client.remoteDialFn(dialCtx)
+		if err != nil {
+			return nil, fmt.Errorf("remote dial function failed: %w", err)
+		}
+
+		dtlsConn, err := dtls.Client(remoteConn, remoteAddr, client.dtlsConfig)
+		if err != nil {
+			remoteConn.Close()
+			return nil, fmt.Errorf("DTLS connection with remote server failed: %w", err)
+		}
+
+		if err := dtlsConn.HandshakeContext(dialCtx); err != nil {
+			dtlsConn.Close()
+			remoteConn.Close()
+			return nil, fmt.Errorf("DTLS handshake with remote server failed: %w", err)
+		}
+
+		return dtlsConn, nil
+	}()
 	if err != nil {
 		log.Printf("remote dial failed: %v", err)
 		return
 	}
 	defer remoteConn.Close()
 
-	dtlsConn, err := dtls.Client(remoteConn, remoteAddr, client.dtlsConfig)
-	if err != nil {
-		log.Printf("DTLS connection with remote server failed: %v", err)
-		return
-	}
-	defer dtlsConn.Close()
-
-	if err := dtlsConn.HandshakeContext(dialCtx); err != nil {
-		log.Printf("DTLS handshake with remote server failed: %v", err)
-		return
-	}
-
-	util.PairConn(ctx, conn, dtlsConn, client.idleTimeout, client.staleMode)
+	util.PairConn(ctx, conn, remoteConn, client.idleTimeout, client.staleMode)
 }
 
 func (client *Client) Close() error {

+ 1 - 1
server/server.go

@@ -64,7 +64,7 @@ func New(cfg *Config) (*Server, error) {
 		InsecureSkipVerifyHello: cfg.SkipHelloVerify,
 		CipherSuites:            cfg.CipherSuites,
 		EllipticCurves:          cfg.EllipticCurves,
-		OnConnectionAttempt:     func(a net.Addr) error {
+		OnConnectionAttempt: func(a net.Addr) error {
 			if !srv.allowFunc(a) {
 				return fmt.Errorf("address %s was not allowed by limiter", a.String())
 			}