|
@@ -116,28 +116,35 @@ func (client *Client) serve(conn net.Conn) {
|
|
|
ctx = newCtx
|
|
|
}
|
|
|
|
|
|
- dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
|
|
|
- defer cancel()
|
|
|
- remoteConn, remoteAddr, err := client.remoteDialFn(dialCtx)
|
|
|
+ remoteConn, err := func() (net.Conn, error) {
|
|
|
+ dialCtx, cancel := context.WithTimeout(ctx, client.timeout)
|
|
|
+ defer cancel()
|
|
|
+ remoteConn, remoteAddr, err := client.remoteDialFn(dialCtx)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("remote dial function failed: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ dtlsConn, err := dtls.Client(remoteConn, remoteAddr, client.dtlsConfig)
|
|
|
+ if err != nil {
|
|
|
+ remoteConn.Close()
|
|
|
+ return nil, fmt.Errorf("DTLS connection with remote server failed: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := dtlsConn.HandshakeContext(dialCtx); err != nil {
|
|
|
+ dtlsConn.Close()
|
|
|
+ remoteConn.Close()
|
|
|
+ return nil, fmt.Errorf("DTLS handshake with remote server failed: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return dtlsConn, nil
|
|
|
+ }()
|
|
|
if err != nil {
|
|
|
log.Printf("remote dial failed: %v", err)
|
|
|
return
|
|
|
}
|
|
|
defer remoteConn.Close()
|
|
|
|
|
|
- dtlsConn, err := dtls.Client(remoteConn, remoteAddr, client.dtlsConfig)
|
|
|
- if err != nil {
|
|
|
- log.Printf("DTLS connection with remote server failed: %v", err)
|
|
|
- return
|
|
|
- }
|
|
|
- defer dtlsConn.Close()
|
|
|
-
|
|
|
- if err := dtlsConn.HandshakeContext(dialCtx); err != nil {
|
|
|
- log.Printf("DTLS handshake with remote server failed: %v", err)
|
|
|
- return
|
|
|
- }
|
|
|
-
|
|
|
- util.PairConn(ctx, conn, dtlsConn, client.idleTimeout, client.staleMode)
|
|
|
+ util.PairConn(ctx, conn, remoteConn, client.idleTimeout, client.staleMode)
|
|
|
}
|
|
|
|
|
|
func (client *Client) Close() error {
|