Browse Source

some DTLS srv implementation

Vladislav Yarmak 1 year ago
parent
commit
b1b4adb5ef
3 changed files with 59 additions and 10 deletions
  1. 18 1
      cmd/dtlspipe/main.go
  2. 1 1
      server/config.go
  3. 40 8
      server/server.go

+ 18 - 1
cmd/dtlspipe/main.go

@@ -11,13 +11,15 @@ import (
 	"syscall"
 	"time"
 
+	"github.com/Snawoot/dtlspipe/server"
 	"github.com/Snawoot/dtlspipe/util"
 )
 
 const (
-	ProgName = "dtlspipe"
+	ProgName     = "dtlspipe"
 	PSKEnvVarKey = "DTLSPIPE_PSK"
 )
+
 var (
 	version = "undefined"
 
@@ -79,6 +81,21 @@ func cmdServer(bindAddress, remoteAddress string, psk []byte) int {
 	appCtx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
 	defer cancel()
 
+	cfg := server.Config{
+		BindAddress:   bindAddress,
+		RemoteAddress: remoteAddress,
+		PSK:           psk,
+		Timeout:       *timeout,
+		IdleTimeout:   *idleTime,
+		BaseContext:   appCtx,
+	}
+
+	srv, err := server.New(&cfg)
+	if err != nil {
+		log.Fatalf("server startup failed: %v", err)
+	}
+	defer srv.Close()
+
 	<-appCtx.Done()
 	return 0
 }

+ 1 - 1
server/config.go

@@ -8,7 +8,7 @@ import (
 type Config struct {
 	BindAddress   string
 	RemoteAddress string
-	Password      string
+	PSK           []byte
 	Timeout       time.Duration
 	IdleTimeout   time.Duration
 	BaseContext   context.Context

+ 40 - 8
server/server.go

@@ -3,16 +3,21 @@ package server
 import (
 	"context"
 	"fmt"
+	"io"
 	"log"
 	"net"
 	"net/netip"
 	"time"
 
 	"github.com/pion/dtls/v2"
+	"github.com/pion/dtls/v2/pkg/protocol"
+	"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
+	"github.com/pion/transport/v2/udp"
 )
 
 type Server struct {
 	listener    net.Listener
+	dtlsConfig  *dtls.Config
 	rAddr       string
 	psk         []byte
 	timeout     time.Duration
@@ -28,8 +33,8 @@ func New(cfg *Config) (*Server, error) {
 
 	srv := &Server{
 		rAddr:       cfg.RemoteAddress,
-		psk:         []byte(cfg.Password), // TODO: key derivation
 		timeout:     cfg.Timeout,
+		psk:         cfg.PSK,
 		idleTimeout: cfg.IdleTimeout,
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
@@ -41,7 +46,7 @@ func New(cfg *Config) (*Server, error) {
 		return nil, fmt.Errorf("can't parse bind address: %w", err)
 	}
 
-	dtlsConfig := &dtls.Config{
+	srv.dtlsConfig = &dtls.Config{
 		CipherSuites: []dtls.CipherSuiteID{
 			dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
 			dtls.TLS_PSK_WITH_AES_128_CCM,
@@ -52,11 +57,22 @@ func New(cfg *Config) (*Server, error) {
 		},
 		ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
 		ConnectContextMaker:  srv.contextMaker,
-		PSK: func(hint []byte) ([]byte, error) {
-			return []byte(cfg.Password), nil
+		PSK:                  srv.getPSK,
+	}
+	lc := udp.ListenConfig{
+		AcceptFilter: func(packet []byte) bool {
+			pkts, err := recordlayer.UnpackDatagram(packet)
+			if err != nil || len(pkts) < 1 {
+				return false
+			}
+			h := &recordlayer.Header{}
+			if err := h.Unmarshal(pkts[0]); err != nil {
+				return false
+			}
+			return h.ContentType == protocol.ContentTypeHandshake
 		},
 	}
-	listener, err := dtls.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort), dtlsConfig)
+	listener, err := lc.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort))
 	if err != nil {
 		cancelCtx()
 		return nil, fmt.Errorf("server listen failed: %w", err)
@@ -64,30 +80,46 @@ func New(cfg *Config) (*Server, error) {
 
 	srv.listener = listener
 
+	go srv.listen()
+
 	return srv, nil
 }
 
 func (srv *Server) listen() {
+	defer srv.Close()
 	for srv.baseCtx.Err() == nil {
 		conn, err := srv.listener.Accept()
 		if err != nil {
 			log.Printf("conn accept failed: %v", err)
-			return
+			continue
 		}
 
-		go srv.serve(conn)
+		go func(conn net.Conn) {
+			conn, err := dtls.Server(conn, srv.dtlsConfig)
+			if err != nil {
+				log.Printf("DTLS accept error: %v", err)
+				return
+			}
+			srv.serve(conn)
+		}(conn)
 	}
 }
 
 func (srv *Server) serve(conn net.Conn) {
+	log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
+	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 	defer conn.Close()
-	conn.Write([]byte("Hello, World!"))
+	io.Copy(conn, conn)
 }
 
 func (srv *Server) contextMaker() (context.Context, func()) {
 	return context.WithTimeout(srv.baseCtx, srv.timeout)
 }
 
+func (srv *Server) getPSK(hint []byte) ([]byte, error) {
+	return srv.psk, nil
+}
+
 func (srv *Server) Close() error {
 	srv.cancelCtx()
 	return srv.listener.Close()