123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- package main
- import (
- "context"
- "flag"
- "fmt"
- "log"
- "os"
- "os/signal"
- "runtime/pprof"
- "strings"
- "syscall"
- "time"
- "github.com/Snawoot/dtlspipe/ciphers"
- "github.com/Snawoot/dtlspipe/client"
- "github.com/Snawoot/dtlspipe/keystore"
- "github.com/Snawoot/dtlspipe/server"
- "github.com/Snawoot/dtlspipe/util"
- )
- const (
- ProgName = "dtlspipe"
- PSKEnvVarKey = "DTLSPIPE_PSK"
- )
- type cipherlistArg struct {
- Value ciphers.CipherList
- }
- func (l *cipherlistArg) String() string {
- return ciphers.CipherListToString(l.Value)
- }
- func (l *cipherlistArg) Set(s string) error {
- parsed, err := ciphers.StringToCipherList(s)
- if err != nil {
- return fmt.Errorf("can't parse cipher list: %w", err)
- }
- l.Value = parsed
- return nil
- }
- type curvelistArg struct {
- Value ciphers.CurveList
- }
- func (l *curvelistArg) String() string {
- return ciphers.CurveListToString(l.Value)
- }
- func (l *curvelistArg) Set(s string) error {
- parsed, err := ciphers.StringToCurveList(s)
- if err != nil {
- return fmt.Errorf("can't parse curve list: %w", err)
- }
- l.Value = parsed
- return nil
- }
- var (
- version = "undefined"
- timeout = flag.Duration("timeout", 10*time.Second, "network operation timeout")
- idleTime = flag.Duration("idle-time", 30*time.Second, "max idle time for UDP session")
- pskHexOpt = flag.String("psk", "", "hex-encoded pre-shared key. Can be generated with genpsk subcommand")
- keyLength = flag.Uint("key-length", 16, "generate key with specified length")
- identity = flag.String("identity", "", "client identity sent to server")
- mtu = flag.Int("mtu", 1400, "MTU used for DTLS fragments")
- cpuprofile = flag.String("cpuprofile", "", "write cpu profile to file")
- skipHelloVerify = flag.Bool("skip-hello-verify", false, "(server only) skip hello verify request. Useful to workaround DPI")
- ciphersuites = cipherlistArg{}
- curves = curvelistArg{}
- staleMode = util.EitherStale
- timeLimit = flag.Duration("time-limit", 0, "hard time limit for each session")
- )
- func init() {
- flag.Var(&ciphersuites, "ciphers", "colon-separated list of ciphers to use")
- flag.Var(&curves, "curves", "colon-separated list of curves to use")
- flag.Var(&staleMode, "stale-mode", "which stale side of connection makes whole session stale (both, either, left, right)")
- }
- func usage() {
- out := flag.CommandLine.Output()
- fmt.Fprintln(out, "Usage:")
- fmt.Fprintln(out)
- fmt.Fprintf(out, "%s [OPTION]... server <BIND ADDRESS> <REMOTE ADDRESS>\n", ProgName)
- fmt.Fprintf(out, "%s [OPTION]... client <BIND ADDRESS> <REMOTE ADDRESS>\n", ProgName)
- fmt.Fprintf(out, "%s [OPTION]... genpsk\n", ProgName)
- fmt.Fprintf(out, "%s ciphers\n", ProgName)
- fmt.Fprintf(out, "%s curves\n", ProgName)
- fmt.Fprintf(out, "%s version\n", ProgName)
- fmt.Fprintln(out)
- fmt.Fprintln(out, "Options:")
- flag.PrintDefaults()
- }
- func cmdGenPSK() int {
- if *keyLength > 64 {
- fmt.Fprintln(os.Stderr, "key length is too big")
- return 1
- }
- psk, err := util.GenPSKHex(int(*keyLength))
- if err != nil {
- fmt.Fprintf(os.Stderr, "key generation error: %v\n", err)
- return 1
- }
- fmt.Println(psk)
- return 0
- }
- func cmdVersion() int {
- fmt.Println(version)
- return 0
- }
- func cmdClient(bindAddress, remoteAddress string) int {
- psk, err := simpleGetPSK()
- if err != nil {
- log.Printf("can't get PSK: %v", err)
- return 2
- }
- log.Printf("starting dtlspipe client: %s =[wrap into DTLS]=> %s", bindAddress, remoteAddress)
- defer log.Println("dtlspipe client stopped")
- appCtx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
- defer cancel()
- cfg := client.Config{
- BindAddress: bindAddress,
- RemoteAddress: remoteAddress,
- PSKCallback: keystore.NewStaticKeystore(psk).PSKCallback,
- PSKIdentity: *identity,
- Timeout: *timeout,
- IdleTimeout: *idleTime,
- BaseContext: appCtx,
- MTU: *mtu,
- CipherSuites: ciphersuites.Value,
- EllipticCurves: curves.Value,
- StaleMode: staleMode,
- TimeLimit: *timeLimit,
- }
- clt, err := client.New(&cfg)
- if err != nil {
- log.Fatalf("client startup failed: %v", err)
- }
- defer clt.Close()
- <-appCtx.Done()
- return 0
- }
- func cmdServer(bindAddress, remoteAddress string) int {
- psk, err := simpleGetPSK()
- if err != nil {
- log.Printf("can't get PSK: %v", err)
- return 2
- }
- log.Printf("starting dtlspipe server: %s =[unwrap from DTLS]=> %s", bindAddress, remoteAddress)
- defer log.Println("dtlspipe server stopped")
- appCtx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
- defer cancel()
- cfg := server.Config{
- BindAddress: bindAddress,
- RemoteAddress: remoteAddress,
- PSKCallback: keystore.NewStaticKeystore(psk).PSKCallback,
- Timeout: *timeout,
- IdleTimeout: *idleTime,
- BaseContext: appCtx,
- MTU: *mtu,
- SkipHelloVerify: *skipHelloVerify,
- CipherSuites: ciphersuites.Value,
- EllipticCurves: curves.Value,
- StaleMode: staleMode,
- TimeLimit: *timeLimit,
- }
- srv, err := server.New(&cfg)
- if err != nil {
- log.Fatalf("server startup failed: %v", err)
- }
- defer srv.Close()
- <-appCtx.Done()
- return 0
- }
- func cmdCiphers() int {
- for _, id := range ciphers.FullCipherList {
- fmt.Println(ciphers.CipherIDToString(id))
- }
- return 0
- }
- func cmdCurves() int {
- for _, curve := range ciphers.FullCurveList {
- fmt.Println(ciphers.CurveIDToString(curve))
- }
- return 0
- }
- func run() int {
- flag.CommandLine.Usage = usage
- flag.Parse()
- args := flag.Args()
- if *cpuprofile != "" {
- f, err := os.Create(*cpuprofile)
- if err != nil {
- log.Fatal(err)
- }
- pprof.StartCPUProfile(f)
- defer pprof.StopCPUProfile()
- }
- switch len(args) {
- case 1:
- switch args[0] {
- case "genpsk":
- return cmdGenPSK()
- case "ciphers":
- return cmdCiphers()
- case "curves":
- return cmdCurves()
- case "version":
- return cmdVersion()
- }
- case 3:
- switch args[0] {
- case "server":
- return cmdServer(args[1], args[2])
- case "client":
- return cmdClient(args[1], args[2])
- }
- }
- usage()
- return 2
- }
- func main() {
- log.Default().SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile)
- log.Default().SetPrefix(strings.ToUpper(ProgName) + ": ")
- os.Exit(run())
- }
- func simpleGetPSK() ([]byte, error) {
- pskHex := os.Getenv(PSKEnvVarKey)
- if pskHex == "" {
- os.Unsetenv(PSKEnvVarKey)
- }
- if *pskHexOpt != "" {
- pskHex = *pskHexOpt
- }
- if pskHex == "" {
- return nil, fmt.Errorf("no PSK command line option provided and neither %s environment variable is set", PSKEnvVarKey)
- }
- psk, err := util.PSKFromHex(pskHex)
- if err != nil {
- return nil, fmt.Errorf("can't hex-decode PSK: %w", err)
- }
- return psk, nil
- }
|