Browse Source

client impl.

Vladislav Yarmak 1 year ago
parent
commit
04d34c0121
5 changed files with 239 additions and 62 deletions
  1. 122 0
      client/client.go
  2. 29 0
      client/config.go
  3. 19 1
      cmd/dtlspipe/main.go
  4. 2 61
      server/server.go
  5. 67 0
      util/util.go

+ 122 - 0
client/client.go

@@ -0,0 +1,122 @@
+package client
+
+import (
+	"context"
+	"fmt"
+	"log"
+	"net"
+	"net/netip"
+	"time"
+
+	"github.com/Snawoot/dtlspipe/util"
+	"github.com/pion/dtls/v2"
+	"github.com/pion/transport/v2/udp"
+)
+
+const (
+	MaxPktBuf = 4096
+)
+
+type Client struct {
+	listener    net.Listener
+	dtlsConfig  *dtls.Config
+	rAddr       string
+	psk         func([]byte) ([]byte, error)
+	timeout     time.Duration
+	idleTimeout time.Duration
+	baseCtx     context.Context
+	cancelCtx   func()
+}
+
+func New(cfg *Config) (*Client, error) {
+	cfg = cfg.populateDefaults()
+
+	baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
+
+	client := &Client{
+		rAddr:       cfg.RemoteAddress,
+		timeout:     cfg.Timeout,
+		psk:         cfg.PSKCallback,
+		idleTimeout: cfg.IdleTimeout,
+		baseCtx:     baseCtx,
+		cancelCtx:   cancelCtx,
+	}
+
+	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
+	if err != nil {
+		cancelCtx()
+		return nil, fmt.Errorf("can't parse bind address: %w", err)
+	}
+
+	client.dtlsConfig = &dtls.Config{
+		CipherSuites: []dtls.CipherSuiteID{
+			dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
+			dtls.TLS_PSK_WITH_AES_128_CCM,
+			dtls.TLS_PSK_WITH_AES_128_CCM_8,
+			dtls.TLS_PSK_WITH_AES_256_CCM_8,
+			dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
+			dtls.TLS_PSK_WITH_AES_128_CBC_SHA256,
+		},
+		ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
+		ConnectContextMaker:  client.contextMaker,
+		PSK:                  client.psk,
+		PSKIdentityHint:      []byte(cfg.PSKIdentity),
+	}
+	lc := udp.ListenConfig{}
+	listener, err := lc.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort))
+	if err != nil {
+		cancelCtx()
+		return nil, fmt.Errorf("client listen failed: %w", err)
+	}
+
+	client.listener = listener
+
+	go client.listen()
+
+	return client, nil
+}
+
+func (client *Client) listen() {
+	defer client.Close()
+	for client.baseCtx.Err() == nil {
+		conn, err := client.listener.Accept()
+		if err != nil {
+			log.Printf("conn accept failed: %v", err)
+			continue
+		}
+
+		go client.serve(conn)
+	}
+}
+
+func (client *Client) serve(conn net.Conn) {
+	log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
+	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
+	defer conn.Close()
+
+	dialCtx, cancel := context.WithTimeout(client.baseCtx, client.timeout)
+	defer cancel()
+	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", client.rAddr)
+	if err != nil {
+		log.Printf("remote dial failed: %v", err)
+		return
+	}
+	defer remoteConn.Close()
+
+	remoteConn, err = dtls.ClientWithContext(dialCtx, remoteConn, client.dtlsConfig)
+	if err != nil {
+		log.Printf("DTL handshake with remote server failed: %v", err)
+		return
+	}
+
+	util.PairConn(conn, remoteConn, client.idleTimeout)
+}
+
+func (client *Client) contextMaker() (context.Context, func()) {
+	return context.WithTimeout(client.baseCtx, client.timeout)
+}
+
+func (client *Client) Close() error {
+	client.cancelCtx()
+	return client.listener.Close()
+}

+ 29 - 0
client/config.go

@@ -0,0 +1,29 @@
+package client
+
+import (
+	"context"
+	"time"
+)
+
+type Config struct {
+	BindAddress   string
+	RemoteAddress string
+	Timeout       time.Duration
+	IdleTimeout   time.Duration
+	BaseContext   context.Context
+	PSKCallback   func([]byte) ([]byte, error)
+	PSKIdentity   string
+}
+
+func (cfg *Config) populateDefaults() *Config {
+	newCfg := new(Config)
+	*newCfg = *cfg
+	cfg = newCfg
+	if cfg.BaseContext == nil {
+		cfg.BaseContext = context.Background()
+	}
+	if cfg.IdleTimeout == 0 {
+		cfg.IdleTimeout = 90 * time.Second
+	}
+	return cfg
+}

+ 19 - 1
cmd/dtlspipe/main.go

@@ -11,6 +11,7 @@ import (
 	"syscall"
 	"time"
 
+	"github.com/Snawoot/dtlspipe/client"
 	"github.com/Snawoot/dtlspipe/keystore"
 	"github.com/Snawoot/dtlspipe/server"
 	"github.com/Snawoot/dtlspipe/util"
@@ -28,6 +29,7 @@ var (
 	idleTime  = flag.Duration("idle-time", 90*time.Second, "max idle time for UDP session")
 	pskHexOpt = flag.String("psk", "", "hex-encoded pre-shared key. Can be generated with genpsk subcommand")
 	keyLength = flag.Uint("key-length", 16, "generate key with specified length")
+	identity  = flag.String("identity", "", "client identity sent to server")
 )
 
 func usage() {
@@ -64,7 +66,7 @@ func cmdVersion() int {
 }
 
 func cmdClient(bindAddress, remoteAddress string) int {
-	_, err := simpleGetPSK()
+	psk, err := simpleGetPSK()
 	if err != nil {
 		log.Printf("can't get PSK: %v", err)
 		return 2
@@ -75,6 +77,22 @@ func cmdClient(bindAddress, remoteAddress string) int {
 	appCtx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
 	defer cancel()
 
+	cfg := client.Config{
+		BindAddress:   bindAddress,
+		RemoteAddress: remoteAddress,
+		PSKCallback:   keystore.NewStaticKeystore(psk).PSKCallback,
+		PSKIdentity:   *identity,
+		Timeout:       *timeout,
+		IdleTimeout:   *idleTime,
+		BaseContext:   appCtx,
+	}
+
+	clt, err := client.New(&cfg)
+	if err != nil {
+		log.Fatalf("client startup failed: %v", err)
+	}
+	defer clt.Close()
+
 	<-appCtx.Done()
 
 	return 0

+ 2 - 61
server/server.go

@@ -6,20 +6,15 @@ import (
 	"log"
 	"net"
 	"net/netip"
-	"sync"
-	"sync/atomic"
 	"time"
 
+	"github.com/Snawoot/dtlspipe/util"
 	"github.com/pion/dtls/v2"
 	"github.com/pion/dtls/v2/pkg/protocol"
 	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
 	"github.com/pion/transport/v2/udp"
 )
 
-const (
-	MaxPktBuf = 4096
-)
-
 type Server struct {
 	listener    net.Listener
 	dtlsConfig  *dtls.Config
@@ -124,52 +119,7 @@ func (srv *Server) serve(conn net.Conn) {
 	}
 	defer remoteConn.Close()
 
-	var lsn atomic.Int32
-	var wg sync.WaitGroup
-
-	copier := func(dst, src net.Conn) {
-		defer wg.Done()
-		defer dst.Close()
-		buf := make([]byte, MaxPktBuf)
-		for {
-			oldLSN := lsn.Load()
-
-			if err := src.SetReadDeadline(time.Now().Add(srv.idleTimeout)); err != nil {
-				log.Println("can't update deadline for connection: %v", err)
-				break
-			}
-
-			n, err := src.Read(buf)
-			if err != nil {
-				if isTimeout(err) {
-					// hit read deadline
-					if oldLSN != lsn.Load() {
-						// not stale conn
-						continue
-					} else {
-						log.Printf("dropping stale connection %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
-					}
-				} else {
-					// any other error
-					log.Printf("read from %s error: %v", src.RemoteAddr(), err)
-				}
-				break
-			}
-
-			lsn.Add(1)
-
-			_, err = dst.Write(buf[:n])
-			if err != nil {
-				log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
-				break
-			}
-		}
-	}
-
-	wg.Add(2)
-	go copier(conn, remoteConn)
-	go copier(remoteConn, conn)
-	wg.Wait()
+	util.PairConn(conn, remoteConn, srv.idleTimeout)
 }
 
 func (srv *Server) contextMaker() (context.Context, func()) {
@@ -180,12 +130,3 @@ func (srv *Server) Close() error {
 	srv.cancelCtx()
 	return srv.listener.Close()
 }
-
-func isTimeout(err error) bool {
-	if timeoutErr, ok := err.(interface {
-		Timeout() bool
-	}); ok {
-		return timeoutErr.Timeout()
-	}
-	return false
-}

+ 67 - 0
util/util.go

@@ -4,6 +4,11 @@ import (
 	"crypto/rand"
 	"encoding/hex"
 	"fmt"
+	"log"
+	"net"
+	"sync"
+	"sync/atomic"
+	"time"
 )
 
 func GenPSK(length int) ([]byte, error) {
@@ -28,3 +33,65 @@ func GenPSKHex(length int) (string, error) {
 func PSKFromHex(input string) ([]byte, error) {
 	return hex.DecodeString(input)
 }
+
+func isTimeout(err error) bool {
+	if timeoutErr, ok := err.(interface {
+		Timeout() bool
+	}); ok {
+		return timeoutErr.Timeout()
+	}
+	return false
+}
+
+const (
+	MaxPktBuf = 4096
+)
+
+func PairConn(left, right net.Conn, idleTimeout time.Duration) {
+	var lsn atomic.Int32
+	var wg sync.WaitGroup
+
+	copier := func(dst, src net.Conn) {
+		defer wg.Done()
+		defer dst.Close()
+		buf := make([]byte, MaxPktBuf)
+		for {
+			oldLSN := lsn.Load()
+
+			if err := src.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
+				log.Println("can't update deadline for connection: %v", err)
+				break
+			}
+
+			n, err := src.Read(buf)
+			if err != nil {
+				if isTimeout(err) {
+					// hit read deadline
+					if oldLSN != lsn.Load() {
+						// not stale conn
+						continue
+					} else {
+						log.Printf("dropping stale connection %s <=> %s", src.LocalAddr(), src.RemoteAddr())
+					}
+				} else {
+					// any other error
+					log.Printf("read from %s error: %v", src.RemoteAddr(), err)
+				}
+				break
+			}
+
+			lsn.Add(1)
+
+			_, err = dst.Write(buf[:n])
+			if err != nil {
+				log.Printf("write to %s error: %v", dst.RemoteAddr(), err)
+				break
+			}
+		}
+	}
+
+	wg.Add(2)
+	go copier(left, right)
+	go copier(right, left)
+	wg.Wait()
+}