|
@@ -3,16 +3,21 @@ package server
|
|
|
import (
|
|
|
"context"
|
|
|
"fmt"
|
|
|
+ "io"
|
|
|
"log"
|
|
|
"net"
|
|
|
"net/netip"
|
|
|
"time"
|
|
|
|
|
|
"github.com/pion/dtls/v2"
|
|
|
+ "github.com/pion/dtls/v2/pkg/protocol"
|
|
|
+ "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
|
|
|
+ "github.com/pion/transport/v2/udp"
|
|
|
)
|
|
|
|
|
|
type Server struct {
|
|
|
listener net.Listener
|
|
|
+ dtlsConfig *dtls.Config
|
|
|
rAddr string
|
|
|
psk []byte
|
|
|
timeout time.Duration
|
|
@@ -28,8 +33,8 @@ func New(cfg *Config) (*Server, error) {
|
|
|
|
|
|
srv := &Server{
|
|
|
rAddr: cfg.RemoteAddress,
|
|
|
- psk: []byte(cfg.Password), // TODO: key derivation
|
|
|
timeout: cfg.Timeout,
|
|
|
+ psk: cfg.PSK,
|
|
|
idleTimeout: cfg.IdleTimeout,
|
|
|
baseCtx: baseCtx,
|
|
|
cancelCtx: cancelCtx,
|
|
@@ -41,7 +46,7 @@ func New(cfg *Config) (*Server, error) {
|
|
|
return nil, fmt.Errorf("can't parse bind address: %w", err)
|
|
|
}
|
|
|
|
|
|
- dtlsConfig := &dtls.Config{
|
|
|
+ srv.dtlsConfig = &dtls.Config{
|
|
|
CipherSuites: []dtls.CipherSuiteID{
|
|
|
dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
|
|
|
dtls.TLS_PSK_WITH_AES_128_CCM,
|
|
@@ -52,11 +57,22 @@ func New(cfg *Config) (*Server, error) {
|
|
|
},
|
|
|
ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
|
|
|
ConnectContextMaker: srv.contextMaker,
|
|
|
- PSK: func(hint []byte) ([]byte, error) {
|
|
|
- return []byte(cfg.Password), nil
|
|
|
+ PSK: srv.getPSK,
|
|
|
+ }
|
|
|
+ lc := udp.ListenConfig{
|
|
|
+ AcceptFilter: func(packet []byte) bool {
|
|
|
+ pkts, err := recordlayer.UnpackDatagram(packet)
|
|
|
+ if err != nil || len(pkts) < 1 {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ h := &recordlayer.Header{}
|
|
|
+ if err := h.Unmarshal(pkts[0]); err != nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ return h.ContentType == protocol.ContentTypeHandshake
|
|
|
},
|
|
|
}
|
|
|
- listener, err := dtls.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort), dtlsConfig)
|
|
|
+ listener, err := lc.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort))
|
|
|
if err != nil {
|
|
|
cancelCtx()
|
|
|
return nil, fmt.Errorf("server listen failed: %w", err)
|
|
@@ -64,30 +80,46 @@ func New(cfg *Config) (*Server, error) {
|
|
|
|
|
|
srv.listener = listener
|
|
|
|
|
|
+ go srv.listen()
|
|
|
+
|
|
|
return srv, nil
|
|
|
}
|
|
|
|
|
|
func (srv *Server) listen() {
|
|
|
+ defer srv.Close()
|
|
|
for srv.baseCtx.Err() == nil {
|
|
|
conn, err := srv.listener.Accept()
|
|
|
if err != nil {
|
|
|
log.Printf("conn accept failed: %v", err)
|
|
|
- return
|
|
|
+ continue
|
|
|
}
|
|
|
|
|
|
- go srv.serve(conn)
|
|
|
+ go func(conn net.Conn) {
|
|
|
+ conn, err := dtls.Server(conn, srv.dtlsConfig)
|
|
|
+ if err != nil {
|
|
|
+ log.Printf("DTLS accept error: %v", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+ srv.serve(conn)
|
|
|
+ }(conn)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
func (srv *Server) serve(conn net.Conn) {
|
|
|
+ log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
|
|
|
+ defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
|
|
|
defer conn.Close()
|
|
|
- conn.Write([]byte("Hello, World!"))
|
|
|
+ io.Copy(conn, conn)
|
|
|
}
|
|
|
|
|
|
func (srv *Server) contextMaker() (context.Context, func()) {
|
|
|
return context.WithTimeout(srv.baseCtx, srv.timeout)
|
|
|
}
|
|
|
|
|
|
+func (srv *Server) getPSK(hint []byte) ([]byte, error) {
|
|
|
+ return srv.psk, nil
|
|
|
+}
|
|
|
+
|
|
|
func (srv *Server) Close() error {
|
|
|
srv.cancelCtx()
|
|
|
return srv.listener.Close()
|