|
@@ -103,6 +103,20 @@ func (srv *Server) serve(conn net.Conn) {
|
|
|
defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
|
|
|
defer conn.Close()
|
|
|
|
|
|
+ if handshaker, ok := conn.(interface {
|
|
|
+ HandshakeContext(context.Context) error
|
|
|
+ }); ok {
|
|
|
+ err := func() error {
|
|
|
+ hsCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
|
|
|
+ defer cancel()
|
|
|
+ return handshaker.HandshakeContext(hsCtx)
|
|
|
+ }()
|
|
|
+ if err != nil {
|
|
|
+ log.Printf("handshake %s <=> %s failed: %v", conn.LocalAddr(), conn.RemoteAddr(), err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
ctx := srv.baseCtx
|
|
|
tl := srv.timeLimitFunc()
|
|
|
if tl != 0 {
|
|
@@ -123,10 +137,6 @@ func (srv *Server) serve(conn net.Conn) {
|
|
|
util.PairConn(ctx, conn, remoteConn, srv.idleTimeout, srv.staleMode)
|
|
|
}
|
|
|
|
|
|
-func (srv *Server) contextMaker() (context.Context, func()) {
|
|
|
- return context.WithTimeout(srv.baseCtx, srv.timeout)
|
|
|
-}
|
|
|
-
|
|
|
func (srv *Server) Close() error {
|
|
|
srv.cancelCtx()
|
|
|
err := srv.listener.Close()
|