Kaynağa Gözat

Merge pull request #2 from Snawoot/ciphersuite_config

Ciphersuite config
Snawoot 1 yıl önce
ebeveyn
işleme
c918d220ec
7 değiştirilmiş dosya ile 116 ekleme ve 16 silme
  1. 3 0
      README.md
  2. 64 0
      ciphers/ciphers.go
  3. 1 8
      client/client.go
  4. 6 0
      client/config.go
  5. 35 0
      cmd/dtlspipe/main.go
  6. 6 0
      server/config.go
  7. 1 8
      server/server.go

+ 3 - 0
README.md

@@ -66,9 +66,12 @@ Usage:
 dtlspipe [OPTION]... server <BIND ADDRESS> <REMOTE ADDRESS>
 dtlspipe [OPTION]... client <BIND ADDRESS> <REMOTE ADDRESS>
 dtlspipe [OPTION]... genpsk
+dtlspipe ciphers
 dtlspipe version
 
 Options:
+  -ciphers value
+    	colon-separated list of ciphers to use
   -cpuprofile string
     	write cpu profile to file
   -identity string

+ 64 - 0
ciphers/ciphers.go

@@ -0,0 +1,64 @@
+package ciphers
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/pion/dtls/v2"
+)
+
+type CipherList = []dtls.CipherSuiteID
+
+var FullList = CipherList{
+	dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
+	dtls.TLS_PSK_WITH_AES_128_CCM,
+	dtls.TLS_PSK_WITH_AES_128_CCM_8,
+	dtls.TLS_PSK_WITH_AES_256_CCM_8,
+	dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
+	dtls.TLS_PSK_WITH_AES_128_CBC_SHA256,
+}
+
+var DefaultList = FullList
+var DefaultListString = ListToString(DefaultList)
+var NameToID map[string]dtls.CipherSuiteID
+
+func init() {
+	NameToID = make(map[string]dtls.CipherSuiteID)
+	for _, id := range FullList {
+		NameToID[dtls.CipherSuiteName(id)] = id
+	}
+}
+
+func IDToString(id dtls.CipherSuiteID) string {
+	return dtls.CipherSuiteName(id)
+}
+
+func ListToString(lst CipherList) string {
+	var b strings.Builder
+	var firstPrinted bool
+	for _, id := range lst {
+		if firstPrinted {
+			b.WriteByte(':')
+		} else {
+			firstPrinted = true
+		}
+		b.WriteString(dtls.CipherSuiteName(id))
+	}
+	return b.String()
+}
+
+func StringToList(str string) (CipherList, error) {
+	if str == "" {
+		return nil, nil
+	}
+	parts := strings.Split(str, ":")
+	var res CipherList
+	for _, name := range parts {
+		if id, ok := NameToID[name]; ok {
+			res = append(res, id)
+		} else {
+			return nil, fmt.Errorf("unknown ciphersuite: %q", name)
+		}
+	}
+	return res, nil
+}

+ 1 - 8
client/client.go

@@ -50,19 +50,12 @@ func New(cfg *Config) (*Client, error) {
 	}
 
 	client.dtlsConfig = &dtls.Config{
-		CipherSuites: []dtls.CipherSuiteID{
-			dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
-			dtls.TLS_PSK_WITH_AES_128_CCM,
-			dtls.TLS_PSK_WITH_AES_128_CCM_8,
-			dtls.TLS_PSK_WITH_AES_256_CCM_8,
-			dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
-			dtls.TLS_PSK_WITH_AES_128_CBC_SHA256,
-		},
 		ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
 		ConnectContextMaker:  client.contextMaker,
 		PSK:                  client.psk,
 		PSKIdentityHint:      []byte(cfg.PSKIdentity),
 		MTU:                  cfg.MTU,
+		CipherSuites:         cfg.CipherSuites,
 	}
 	lc := udp.ListenConfig{
 		Backlog: Backlog,

+ 6 - 0
client/config.go

@@ -3,6 +3,8 @@ package client
 import (
 	"context"
 	"time"
+
+	"github.com/Snawoot/dtlspipe/ciphers"
 )
 
 type Config struct {
@@ -14,6 +16,7 @@ type Config struct {
 	PSKCallback   func([]byte) ([]byte, error)
 	PSKIdentity   string
 	MTU           int
+	CipherSuites  ciphers.CipherList
 }
 
 func (cfg *Config) populateDefaults() *Config {
@@ -26,5 +29,8 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.IdleTimeout == 0 {
 		cfg.IdleTimeout = 90 * time.Second
 	}
+	if cfg.CipherSuites == nil {
+		cfg.CipherSuites = ciphers.DefaultList
+	}
 	return cfg
 }

+ 35 - 0
cmd/dtlspipe/main.go

@@ -12,6 +12,7 @@ import (
 	"syscall"
 	"time"
 
+	"github.com/Snawoot/dtlspipe/ciphers"
 	"github.com/Snawoot/dtlspipe/client"
 	"github.com/Snawoot/dtlspipe/keystore"
 	"github.com/Snawoot/dtlspipe/server"
@@ -23,6 +24,23 @@ const (
 	PSKEnvVarKey = "DTLSPIPE_PSK"
 )
 
+type cipherlistArg struct {
+	Value ciphers.CipherList
+}
+
+func (l *cipherlistArg) String() string {
+	return ciphers.ListToString(l.Value)
+}
+
+func (l *cipherlistArg) Set(s string) error {
+	parsed, err := ciphers.StringToList(s)
+	if err != nil {
+		return fmt.Errorf("can't parse cipher list: %w", err)
+	}
+	l.Value = parsed
+	return nil
+}
+
 var (
 	version = "undefined"
 
@@ -34,8 +52,13 @@ var (
 	mtu             = flag.Int("mtu", 1400, "MTU used for DTLS fragments")
 	cpuprofile      = flag.String("cpuprofile", "", "write cpu profile to file")
 	skipHelloVerify = flag.Bool("skip-hello-verify", false, "(server only) skip hello verify request. Useful to workaround DPI")
+	ciphersuites    = cipherlistArg{}
 )
 
+func init() {
+	flag.Var(&ciphersuites, "ciphers", "colon-separated list of ciphers to use")
+}
+
 func usage() {
 	out := flag.CommandLine.Output()
 	fmt.Fprintln(out, "Usage:")
@@ -43,6 +66,7 @@ func usage() {
 	fmt.Fprintf(out, "%s [OPTION]... server <BIND ADDRESS> <REMOTE ADDRESS>\n", ProgName)
 	fmt.Fprintf(out, "%s [OPTION]... client <BIND ADDRESS> <REMOTE ADDRESS>\n", ProgName)
 	fmt.Fprintf(out, "%s [OPTION]... genpsk\n", ProgName)
+	fmt.Fprintf(out, "%s ciphers\n", ProgName)
 	fmt.Fprintf(out, "%s version\n", ProgName)
 	fmt.Fprintln(out)
 	fmt.Fprintln(out, "Options:")
@@ -90,6 +114,7 @@ func cmdClient(bindAddress, remoteAddress string) int {
 		IdleTimeout:   *idleTime,
 		BaseContext:   appCtx,
 		MTU:           *mtu,
+		CipherSuites:  ciphersuites.Value,
 	}
 
 	clt, err := client.New(&cfg)
@@ -124,6 +149,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		BaseContext:     appCtx,
 		MTU:             *mtu,
 		SkipHelloVerify: *skipHelloVerify,
+		CipherSuites:    ciphersuites.Value,
 	}
 
 	srv, err := server.New(&cfg)
@@ -136,6 +162,13 @@ func cmdServer(bindAddress, remoteAddress string) int {
 	return 0
 }
 
+func cmdCiphers() int {
+	for _, id := range ciphers.FullList {
+		fmt.Println(ciphers.IDToString(id))
+	}
+	return 0
+}
+
 func run() int {
 	flag.CommandLine.Usage = usage
 	flag.Parse()
@@ -155,6 +188,8 @@ func run() int {
 		switch args[0] {
 		case "genpsk":
 			return cmdGenPSK()
+		case "ciphers":
+			return cmdCiphers()
 		case "version":
 			return cmdVersion()
 		}

+ 6 - 0
server/config.go

@@ -3,6 +3,8 @@ package server
 import (
 	"context"
 	"time"
+
+	"github.com/Snawoot/dtlspipe/ciphers"
 )
 
 type Config struct {
@@ -14,6 +16,7 @@ type Config struct {
 	PSKCallback     func([]byte) ([]byte, error)
 	MTU             int
 	SkipHelloVerify bool
+	CipherSuites    ciphers.CipherList
 }
 
 func (cfg *Config) populateDefaults() *Config {
@@ -26,5 +29,8 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.IdleTimeout == 0 {
 		cfg.IdleTimeout = 90 * time.Second
 	}
+	if cfg.CipherSuites == nil {
+		cfg.CipherSuites = ciphers.DefaultList
+	}
 	return cfg
 }

+ 1 - 8
server/server.go

@@ -51,19 +51,12 @@ func New(cfg *Config) (*Server, error) {
 	}
 
 	srv.dtlsConfig = &dtls.Config{
-		CipherSuites: []dtls.CipherSuiteID{
-			dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
-			dtls.TLS_PSK_WITH_AES_128_CCM,
-			dtls.TLS_PSK_WITH_AES_128_CCM_8,
-			dtls.TLS_PSK_WITH_AES_256_CCM_8,
-			dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
-			dtls.TLS_PSK_WITH_AES_128_CBC_SHA256,
-		},
 		ExtendedMasterSecret:    dtls.RequireExtendedMasterSecret,
 		ConnectContextMaker:     srv.contextMaker,
 		PSK:                     srv.psk,
 		MTU:                     cfg.MTU,
 		InsecureSkipVerifyHello: cfg.SkipHelloVerify,
+		CipherSuites:            cfg.CipherSuites,
 	}
 	lc := udp.ListenConfig{
 		AcceptFilter: func(packet []byte) bool {