|  | @@ -0,0 +1,122 @@
 | 
	
		
			
				|  |  | +package client
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +import (
 | 
	
		
			
				|  |  | +	"context"
 | 
	
		
			
				|  |  | +	"fmt"
 | 
	
		
			
				|  |  | +	"log"
 | 
	
		
			
				|  |  | +	"net"
 | 
	
		
			
				|  |  | +	"net/netip"
 | 
	
		
			
				|  |  | +	"time"
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	"github.com/Snawoot/dtlspipe/util"
 | 
	
		
			
				|  |  | +	"github.com/pion/dtls/v2"
 | 
	
		
			
				|  |  | +	"github.com/pion/transport/v2/udp"
 | 
	
		
			
				|  |  | +)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +const (
 | 
	
		
			
				|  |  | +	MaxPktBuf = 4096
 | 
	
		
			
				|  |  | +)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +type Client struct {
 | 
	
		
			
				|  |  | +	listener    net.Listener
 | 
	
		
			
				|  |  | +	dtlsConfig  *dtls.Config
 | 
	
		
			
				|  |  | +	rAddr       string
 | 
	
		
			
				|  |  | +	psk         func([]byte) ([]byte, error)
 | 
	
		
			
				|  |  | +	timeout     time.Duration
 | 
	
		
			
				|  |  | +	idleTimeout time.Duration
 | 
	
		
			
				|  |  | +	baseCtx     context.Context
 | 
	
		
			
				|  |  | +	cancelCtx   func()
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func New(cfg *Config) (*Client, error) {
 | 
	
		
			
				|  |  | +	cfg = cfg.populateDefaults()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	client := &Client{
 | 
	
		
			
				|  |  | +		rAddr:       cfg.RemoteAddress,
 | 
	
		
			
				|  |  | +		timeout:     cfg.Timeout,
 | 
	
		
			
				|  |  | +		psk:         cfg.PSKCallback,
 | 
	
		
			
				|  |  | +		idleTimeout: cfg.IdleTimeout,
 | 
	
		
			
				|  |  | +		baseCtx:     baseCtx,
 | 
	
		
			
				|  |  | +		cancelCtx:   cancelCtx,
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		cancelCtx()
 | 
	
		
			
				|  |  | +		return nil, fmt.Errorf("can't parse bind address: %w", err)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	client.dtlsConfig = &dtls.Config{
 | 
	
		
			
				|  |  | +		CipherSuites: []dtls.CipherSuiteID{
 | 
	
		
			
				|  |  | +			dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
 | 
	
		
			
				|  |  | +			dtls.TLS_PSK_WITH_AES_128_CCM,
 | 
	
		
			
				|  |  | +			dtls.TLS_PSK_WITH_AES_128_CCM_8,
 | 
	
		
			
				|  |  | +			dtls.TLS_PSK_WITH_AES_256_CCM_8,
 | 
	
		
			
				|  |  | +			dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
 | 
	
		
			
				|  |  | +			dtls.TLS_PSK_WITH_AES_128_CBC_SHA256,
 | 
	
		
			
				|  |  | +		},
 | 
	
		
			
				|  |  | +		ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
 | 
	
		
			
				|  |  | +		ConnectContextMaker:  client.contextMaker,
 | 
	
		
			
				|  |  | +		PSK:                  client.psk,
 | 
	
		
			
				|  |  | +		PSKIdentityHint:      []byte(cfg.PSKIdentity),
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	lc := udp.ListenConfig{}
 | 
	
		
			
				|  |  | +	listener, err := lc.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort))
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		cancelCtx()
 | 
	
		
			
				|  |  | +		return nil, fmt.Errorf("client listen failed: %w", err)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	client.listener = listener
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	go client.listen()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	return client, nil
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (client *Client) listen() {
 | 
	
		
			
				|  |  | +	defer client.Close()
 | 
	
		
			
				|  |  | +	for client.baseCtx.Err() == nil {
 | 
	
		
			
				|  |  | +		conn, err := client.listener.Accept()
 | 
	
		
			
				|  |  | +		if err != nil {
 | 
	
		
			
				|  |  | +			log.Printf("conn accept failed: %v", err)
 | 
	
		
			
				|  |  | +			continue
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		go client.serve(conn)
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (client *Client) serve(conn net.Conn) {
 | 
	
		
			
				|  |  | +	log.Printf("[+] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 | 
	
		
			
				|  |  | +	defer log.Printf("[-] conn %s <=> %s", conn.LocalAddr(), conn.RemoteAddr())
 | 
	
		
			
				|  |  | +	defer conn.Close()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	dialCtx, cancel := context.WithTimeout(client.baseCtx, client.timeout)
 | 
	
		
			
				|  |  | +	defer cancel()
 | 
	
		
			
				|  |  | +	remoteConn, err := (&net.Dialer{}).DialContext(dialCtx, "udp", client.rAddr)
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		log.Printf("remote dial failed: %v", err)
 | 
	
		
			
				|  |  | +		return
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +	defer remoteConn.Close()
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	remoteConn, err = dtls.ClientWithContext(dialCtx, remoteConn, client.dtlsConfig)
 | 
	
		
			
				|  |  | +	if err != nil {
 | 
	
		
			
				|  |  | +		log.Printf("DTL handshake with remote server failed: %v", err)
 | 
	
		
			
				|  |  | +		return
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	util.PairConn(conn, remoteConn, client.idleTimeout)
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (client *Client) contextMaker() (context.Context, func()) {
 | 
	
		
			
				|  |  | +	return context.WithTimeout(client.baseCtx, client.timeout)
 | 
	
		
			
				|  |  | +}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +func (client *Client) Close() error {
 | 
	
		
			
				|  |  | +	client.cancelCtx()
 | 
	
		
			
				|  |  | +	return client.listener.Close()
 | 
	
		
			
				|  |  | +}
 |