Browse Source

restore server hs timeout

Vladislav Yarmak 8 months ago
parent
commit
8c3d514fca
1 changed files with 14 additions and 4 deletions
  1. 14 4
      server/server.go

+ 14 - 4
server/server.go

@@ -103,6 +103,20 @@ func (srv *Server) serve(conn net.Conn) {
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer conn.Close()
 
+	if handshaker, ok := conn.(interface {
+		HandshakeContext(context.Context) error
+	}); ok {
+		err := func() error {
+			hsCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
+			defer cancel()
+			return handshaker.HandshakeContext(hsCtx)
+		}()
+		if err != nil {
+			log.Printf("handshake %s <=> %s failed: %v", conn.LocalAddr(), conn.RemoteAddr(), err)
+			return
+		}
+	}
+
 	ctx := srv.baseCtx
 	tl := srv.timeLimitFunc()
 	if tl != 0 {
@@ -123,10 +137,6 @@ func (srv *Server) serve(conn net.Conn) {
 	util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
 }
 
-func (srv *Server) contextMaker() (context.Context, func()) {
-	return context.WithTimeout(srv.baseCtx, srv.timeout)
-}
-
 func (srv *Server) Close() error {
 	srv.cancelCtx()
 	err := srv.listener.Close()