|
@@ -3,10 +3,11 @@ package server
|
|
|
import (
|
|
|
"context"
|
|
|
"fmt"
|
|
|
- "io"
|
|
|
"log"
|
|
|
"net"
|
|
|
"net/netip"
|
|
|
+ "sync"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
"github.com/pion/dtls/v2"
|
|
@@ -15,6 +16,10 @@ import (
|
|
|
"github.com/pion/transport/v2/udp"
|
|
|
)
|
|
|
|
|
|
+const (
|
|
|
+ MaxPktBuf = 4096
|
|
|
+)
|
|
|
+
|
|
|
type Server struct {
|
|
|
listener net.Listener
|
|
|
dtlsConfig *dtls.Config
|
|
@@ -109,7 +114,62 @@ func (srv *Server) serve(conn net.Conn) {
|
|
|
log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
|
|
|
defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
|
|
|
defer conn.Close()
|
|
|
- io.Copy(conn, conn)
|
|
|
+
|
|
|
+ dialCtx, cancel := context.WithTimeout(srv.baseCtx, srv.timeout)
|
|
|
+ defer cancel()
|
|
|
+ remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", srv.rAddr)
|
|
|
+ if err != nil {
|
|
|
+ log.Printf("remote dial failed: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ defer remoteConn.Close()
|
|
|
+
|
|
|
+ var lsn atomic.Int32
|
|
|
+ var wg sync.WaitGroup
|
|
|
+
|
|
|
+ copier := func(dst, src net.Conn) {
|
|
|
+ defer wg.Done()
|
|
|
+ defer dst.Close()
|
|
|
+ buf := make([]byte, MaxPktBuf)
|
|
|
+ for {
|
|
|
+ oldLSN := lsn.Load()
|
|
|
+
|
|
|
+ if err := src.SetReadDeadline(time.Now().Add(srv.idleTimeout)); err != nil {
|
|
|
+ log.Println("can't update deadline for connection: %v", err)
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ n, err := src.Read(buf)
|
|
|
+ if err != nil {
|
|
|
+ if isTimeout(err) {
|
|
|
+ // hit read deadline
|
|
|
+ if oldLSN != lsn.Load() {
|
|
|
+ // not stale conn
|
|
|
+ continue
|
|
|
+ } else {
|
|
|
+ log.Printf("dropping stale connection %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ // any other error
|
|
|
+ log.Printf("read from %s error: %v", src.RemoteAddr(), err)
|
|
|
+ }
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ lsn.Add(1)
|
|
|
+
|
|
|
+ _, err = dst.Write(buf[:n])
|
|
|
+ if err != nil {
|
|
|
+ log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ wg.Add(2)
|
|
|
+ go copier(conn, remoteConn)
|
|
|
+ go copier(remoteConn, conn)
|
|
|
+ wg.Wait()
|
|
|
}
|
|
|
|
|
|
func (srv *Server) contextMaker() (context.Context, func()) {
|
|
@@ -124,3 +184,12 @@ func (srv *Server) Close() error {
|
|
|
srv.cancelCtx()
|
|
|
return srv.listener.Close()
|
|
|
}
|
|
|
+
|
|
|
+func isTimeout(err error) bool {
|
|
|
+ if timeoutErr, ok := err.(interface {
|
|
|
+ Timeout() bool
|
|
|
+ }); ok {
|
|
|
+ return timeoutErr.Timeout()
|
|
|
+ }
|
|
|
+ return false
|
|
|
+}
|