main.go 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. package main
  2. import (
  3. "context"
  4. "flag"
  5. "fmt"
  6. "log"
  7. "os"
  8. "os/signal"
  9. "runtime/pprof"
  10. "strings"
  11. "syscall"
  12. "time"
  13. "github.com/Snawoot/dtlspipe/ciphers"
  14. "github.com/Snawoot/dtlspipe/client"
  15. "github.com/Snawoot/dtlspipe/keystore"
  16. "github.com/Snawoot/dtlspipe/server"
  17. "github.com/Snawoot/dtlspipe/util"
  18. )
  19. const (
  20. ProgName = "dtlspipe"
  21. PSKEnvVarKey = "DTLSPIPE_PSK"
  22. )
  23. type cipherlistArg struct {
  24. Value ciphers.CipherList
  25. }
  26. func (l *cipherlistArg) String() string {
  27. return ciphers.CipherListToString(l.Value)
  28. }
  29. func (l *cipherlistArg) Set(s string) error {
  30. parsed, err := ciphers.StringToCipherList(s)
  31. if err != nil {
  32. return fmt.Errorf("can't parse cipher list: %w", err)
  33. }
  34. l.Value = parsed
  35. return nil
  36. }
  37. type curvelistArg struct {
  38. Value ciphers.CurveList
  39. }
  40. func (l *curvelistArg) String() string {
  41. return ciphers.CurveListToString(l.Value)
  42. }
  43. func (l *curvelistArg) Set(s string) error {
  44. parsed, err := ciphers.StringToCurveList(s)
  45. if err != nil {
  46. return fmt.Errorf("can't parse curve list: %w", err)
  47. }
  48. l.Value = parsed
  49. return nil
  50. }
  51. var (
  52. version = "undefined"
  53. timeout = flag.Duration("timeout", 10*time.Second, "network operation timeout")
  54. idleTime = flag.Duration("idle-time", 30*time.Second, "max idle time for UDP session")
  55. pskHexOpt = flag.String("psk", "", "hex-encoded pre-shared key. Can be generated with genpsk subcommand")
  56. keyLength = flag.Uint("key-length", 16, "generate key with specified length")
  57. identity = flag.String("identity", "", "client identity sent to server")
  58. mtu = flag.Int("mtu", 1400, "MTU used for DTLS fragments")
  59. cpuprofile = flag.String("cpuprofile", "", "write cpu profile to file")
  60. skipHelloVerify = flag.Bool("skip-hello-verify", false, "(server only) skip hello verify request. Useful to workaround DPI")
  61. ciphersuites = cipherlistArg{}
  62. curves = curvelistArg{}
  63. staleMode = util.EitherStale
  64. timeLimit = flag.Duration("time-limit", 0, "hard time limit for each session")
  65. )
  66. func init() {
  67. flag.Var(&ciphersuites, "ciphers", "colon-separated list of ciphers to use")
  68. flag.Var(&curves, "curves", "colon-separated list of curves to use")
  69. flag.Var(&staleMode, "stale-mode", "which stale side of connection makes whole session stale (both, either, left, right)")
  70. }
  71. func usage() {
  72. out := flag.CommandLine.Output()
  73. fmt.Fprintln(out, "Usage:")
  74. fmt.Fprintln(out)
  75. fmt.Fprintf(out, "%s [OPTION]... server <BIND ADDRESS> <REMOTE ADDRESS>\n", ProgName)
  76. fmt.Fprintf(out, "%s [OPTION]... client <BIND ADDRESS> <REMOTE ADDRESS>\n", ProgName)
  77. fmt.Fprintf(out, "%s [OPTION]... genpsk\n", ProgName)
  78. fmt.Fprintf(out, "%s ciphers\n", ProgName)
  79. fmt.Fprintf(out, "%s curves\n", ProgName)
  80. fmt.Fprintf(out, "%s version\n", ProgName)
  81. fmt.Fprintln(out)
  82. fmt.Fprintln(out, "Options:")
  83. flag.PrintDefaults()
  84. }
  85. func cmdGenPSK() int {
  86. if *keyLength > 64 {
  87. fmt.Fprintln(os.Stderr, "key length is too big")
  88. return 1
  89. }
  90. psk, err := util.GenPSKHex(int(*keyLength))
  91. if err != nil {
  92. fmt.Fprintf(os.Stderr, "key generation error: %v\n", err)
  93. return 1
  94. }
  95. fmt.Println(psk)
  96. return 0
  97. }
  98. func cmdVersion() int {
  99. fmt.Println(version)
  100. return 0
  101. }
  102. func cmdClient(bindAddress, remoteAddress string) int {
  103. psk, err := simpleGetPSK()
  104. if err != nil {
  105. log.Printf("can't get PSK: %v", err)
  106. return 2
  107. }
  108. log.Printf("starting dtlspipe client: %s =[wrap into DTLS]=> %s", bindAddress, remoteAddress)
  109. defer log.Println("dtlspipe client stopped")
  110. appCtx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
  111. defer cancel()
  112. cfg := client.Config{
  113. BindAddress: bindAddress,
  114. RemoteAddress: remoteAddress,
  115. PSKCallback: keystore.NewStaticKeystore(psk).PSKCallback,
  116. PSKIdentity: *identity,
  117. Timeout: *timeout,
  118. IdleTimeout: *idleTime,
  119. BaseContext: appCtx,
  120. MTU: *mtu,
  121. CipherSuites: ciphersuites.Value,
  122. EllipticCurves: curves.Value,
  123. StaleMode: staleMode,
  124. TimeLimit: *timeLimit,
  125. }
  126. clt, err := client.New(&cfg)
  127. if err != nil {
  128. log.Fatalf("client startup failed: %v", err)
  129. }
  130. defer clt.Close()
  131. <-appCtx.Done()
  132. return 0
  133. }
  134. func cmdServer(bindAddress, remoteAddress string) int {
  135. psk, err := simpleGetPSK()
  136. if err != nil {
  137. log.Printf("can't get PSK: %v", err)
  138. return 2
  139. }
  140. log.Printf("starting dtlspipe server: %s =[unwrap from DTLS]=> %s", bindAddress, remoteAddress)
  141. defer log.Println("dtlspipe server stopped")
  142. appCtx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
  143. defer cancel()
  144. cfg := server.Config{
  145. BindAddress: bindAddress,
  146. RemoteAddress: remoteAddress,
  147. PSKCallback: keystore.NewStaticKeystore(psk).PSKCallback,
  148. Timeout: *timeout,
  149. IdleTimeout: *idleTime,
  150. BaseContext: appCtx,
  151. MTU: *mtu,
  152. SkipHelloVerify: *skipHelloVerify,
  153. CipherSuites: ciphersuites.Value,
  154. EllipticCurves: curves.Value,
  155. StaleMode: staleMode,
  156. TimeLimit: *timeLimit,
  157. }
  158. srv, err := server.New(&cfg)
  159. if err != nil {
  160. log.Fatalf("server startup failed: %v", err)
  161. }
  162. defer srv.Close()
  163. <-appCtx.Done()
  164. return 0
  165. }
  166. func cmdCiphers() int {
  167. for _, id := range ciphers.FullCipherList {
  168. fmt.Println(ciphers.CipherIDToString(id))
  169. }
  170. return 0
  171. }
  172. func cmdCurves() int {
  173. for _, curve := range ciphers.FullCurveList {
  174. fmt.Println(ciphers.CurveIDToString(curve))
  175. }
  176. return 0
  177. }
  178. func run() int {
  179. flag.CommandLine.Usage = usage
  180. flag.Parse()
  181. args := flag.Args()
  182. if *cpuprofile != "" {
  183. f, err := os.Create(*cpuprofile)
  184. if err != nil {
  185. log.Fatal(err)
  186. }
  187. pprof.StartCPUProfile(f)
  188. defer pprof.StopCPUProfile()
  189. }
  190. switch len(args) {
  191. case 1:
  192. switch args[0] {
  193. case "genpsk":
  194. return cmdGenPSK()
  195. case "ciphers":
  196. return cmdCiphers()
  197. case "curves":
  198. return cmdCurves()
  199. case "version":
  200. return cmdVersion()
  201. }
  202. case 3:
  203. switch args[0] {
  204. case "server":
  205. return cmdServer(args[1], args[2])
  206. case "client":
  207. return cmdClient(args[1], args[2])
  208. }
  209. }
  210. usage()
  211. return 2
  212. }
  213. func main() {
  214. log.Default().SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
  215. log.Default().SetPrefix(strings.ToUpper(ProgName) + ": ")
  216. os.Exit(run())
  217. }
  218. func simpleGetPSK() ([]byte, error) {
  219. pskHex := os.Getenv(PSKEnvVarKey)
  220. if pskHex == "" {
  221. os.Unsetenv(PSKEnvVarKey)
  222. }
  223. if *pskHexOpt != "" {
  224. pskHex = *pskHexOpt
  225. }
  226. if pskHex == "" {
  227. return nil, fmt.Errorf("no PSK command line option provided and neither %s environment variable is set", PSKEnvVarKey)
  228. }
  229. psk, err := util.PSKFromHex(pskHex)
  230. if err != nil {
  231. return nil, fmt.Errorf("can't hex-decode PSK: %w", err)
  232. }
  233. return psk, nil
  234. }