1
0
Vladislav Yarmak 1 жил өмнө
parent
commit
5b95bd34e2
1 өөрчлөгдсөн 71 нэмэгдсэн , 2 устгасан
  1. 71 2
      server/server.go

+ 71 - 2
server/server.go

@@ -3,10 +3,11 @@ package server
 import (
 	"context"
 	"fmt"
-	"io"
 	"log"
 	"net"
 	"net/netip"
+	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/pion/dtls/v2"
@@ -15,6 +16,10 @@ import (
 	"github.com/pion/transport/v2/udp"
 )
 
+const (
+	MaxPktBuf = 4096
+)
+
 type Server struct {
 	listener    net.Listener
 	dtlsConfig  *dtls.Config
@@ -109,7 +114,62 @@ func (srv *Server) serve(conn net.Conn) {
 	log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer conn.Close()
-	io.Copy(conn, conn)
+
+	dialCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
+	defer cancel()
+	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
+	if err != nil {
+		log.Printf("remote dial failed: %v", err)
+		return
+	}
+	defer remoteConn.Close()
+
+	var lsn atomic.Int32
+	var wg sync.WaitGroup
+
+	copier := func(dst, src net.Conn) {
+		defer wg.Done()
+		defer dst.Close()
+		buf := make([]byte, MaxPktBuf)
+		for {
+			oldLSN := lsn.Load()
+
+			if err := src.SetReadDeadline(time.Now().Add(srv.idleTimeout)); err != nil {
+				log.Println("can't update deadline for connection: %v", err)
+				break
+			}
+
+			n, err := src.Read(buf)
+			if err != nil {
+				if isTimeout(err) {
+					// hit read deadline
+					if oldLSN != lsn.Load() {
+						// not stale conn
+						continue
+					} else {
+						log.Printf("dropping stale connection %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
+					}
+				} else {
+					// any other error
+					log.Printf("read from %s error: %v", src.RemoteAddr(), err)
+				}
+				break
+			}
+
+			lsn.Add(1)
+
+			_, err = dst.Write(buf[:n])
+			if err != nil {
+				log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
+				break
+			}
+		}
+	}
+
+	wg.Add(2)
+	go copier(conn, remoteConn)
+	go copier(remoteConn, conn)
+	wg.Wait()
 }
 
 func (srv *Server) contextMaker() (context.Context, func()) {
@@ -124,3 +184,12 @@ func (srv *Server) Close() error {
 	srv.cancelCtx()
 	return srv.listener.Close()
 }
+
+func isTimeout(err error) bool {
+	if timeoutErr, ok := err.(interface {
+		Timeout() bool
+	}); ok {
+		return timeoutErr.Timeout()
+	}
+	return false
+}