1
0

server.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "net"
  7. "net/netip"
  8. "time"
  9. "github.com/pion/dtls/v2"
  10. )
  11. type Server struct {
  12. listener net.Listener
  13. rAddr string
  14. psk []byte
  15. timeout time.Duration
  16. idleTimeout time.Duration
  17. baseCtx context.Context
  18. cancelCtx func()
  19. }
  20. func New(cfg *Config) (*Server, error) {
  21. cfg = cfg.populateDefaults()
  22. baseCtx, cancelCtx := context.WithCancel(cfg.BaseContext)
  23. srv := &Server{
  24. rAddr: cfg.RemoteAddress,
  25. psk: []byte(cfg.Password), // TODO: key derivation
  26. timeout: cfg.Timeout,
  27. idleTimeout: cfg.IdleTimeout,
  28. baseCtx: baseCtx,
  29. cancelCtx: cancelCtx,
  30. }
  31. lAddrPort, err := netip.ParseAddrPort(cfg.BindAddress)
  32. if err != nil {
  33. cancelCtx()
  34. return nil, fmt.Errorf("can't parse bind address: %w", err)
  35. }
  36. dtlsConfig := &dtls.Config{
  37. CipherSuites: []dtls.CipherSuiteID{
  38. dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
  39. dtls.TLS_PSK_WITH_AES_128_CCM,
  40. dtls.TLS_PSK_WITH_AES_128_CCM_8,
  41. dtls.TLS_PSK_WITH_AES_256_CCM_8,
  42. dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
  43. dtls.TLS_PSK_WITH_AES_128_CBC_SHA256,
  44. },
  45. ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
  46. ConnectContextMaker: srv.contextMaker,
  47. PSK: func(hint []byte) ([]byte, error) {
  48. return []byte(cfg.Password), nil
  49. },
  50. }
  51. listener, err := dtls.Listen("udp", net.UDPAddrFromAddrPort(lAddrPort), dtlsConfig)
  52. if err != nil {
  53. cancelCtx()
  54. return nil, fmt.Errorf("server listen failed: %w", err)
  55. }
  56. srv.listener = listener
  57. return srv, nil
  58. }
  59. func (srv *Server) listen() {
  60. for srv.baseCtx.Err() == nil {
  61. conn, err := srv.listener.Accept()
  62. if err != nil {
  63. log.Printf("conn accept failed: %v", err)
  64. return
  65. }
  66. go srv.serve(conn)
  67. }
  68. }
  69. func (srv *Server) serve(conn net.Conn) {
  70. defer conn.Close()
  71. conn.Write([]byte("Hello, World!"))
  72. }
  73. func (srv *Server) contextMaker() (context.Context, func()) {
  74. return context.WithTimeout(srv.baseCtx, srv.timeout)
  75. }
  76. func (srv *Server) Close() error {
  77. srv.cancelCtx()
  78. return srv.listener.Close()
  79. }