Explorar el Código

configurable elliptic curves

Vladislav Yarmak hace 1 año
padre
commit
9cc5660d13
Se han modificado 7 ficheros con 121 adiciones y 19 borrados
  1. 1 1
      ciphers/ciphers.go
  2. 61 0
      ciphers/curves.go
  3. 1 0
      client/client.go
  4. 13 9
      client/config.go
  5. 40 9
      cmd/dtlspipe/main.go
  6. 4 0
      server/config.go
  7. 1 0
      server/server.go

+ 1 - 1
ciphers/ciphers.go

@@ -19,7 +19,7 @@ var FullCipherList = CipherList{
 }
 
 var DefaultCipherList = FullCipherList
-var DefaultListString = CipherListToString(DefaultCipherList)
+var DefaultCipherListString = CipherListToString(DefaultCipherList)
 var CipherNameToID map[string]dtls.CipherSuiteID
 
 func init() {

+ 61 - 0
ciphers/curves.go

@@ -0,0 +1,61 @@
+package ciphers
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/pion/dtls/v2/pkg/crypto/elliptic"
+)
+
+type CurveList = []elliptic.Curve
+
+var FullCurveList = CurveList{
+	elliptic.X25519,
+	elliptic.P256,
+	elliptic.P384,
+}
+
+var DefaultCurveList = FullCurveList
+var DefaultCurveListString = CurveListToString(DefaultCurveList)
+var CurveNameToID map[string]elliptic.Curve
+
+func init() {
+	CurveNameToID = make(map[string]elliptic.Curve)
+	for _, curve := range FullCurveList {
+		CurveNameToID[curve.String()] = curve
+	}
+}
+
+func CurveIDToString(curve elliptic.Curve) string {
+	return curve.String()
+}
+
+func CurveListToString(lst CurveList) string {
+	var b strings.Builder
+	var firstPrinted bool
+	for _, curve := range lst {
+		if firstPrinted {
+			b.WriteByte(':')
+		} else {
+			firstPrinted = true
+		}
+		b.WriteString(curve.String())
+	}
+	return b.String()
+}
+
+func StringToCurveList(str string) (CurveList, error) {
+	if str == "" {
+		return CurveList{}, nil
+	}
+	parts := strings.Split(str, ":")
+	var res CurveList
+	for _, name := range parts {
+		if id, ok := CurveNameToID[name]; ok {
+			res = append(res, id)
+		} else {
+			return nil, fmt.Errorf("unknown curve: %q", name)
+		}
+	}
+	return res, nil
+}

+ 1 - 0
client/client.go

@@ -56,6 +56,7 @@ func New(cfg *Config) (*Client, error) {
 		PSKIdentityHint:      []byte(cfg.PSKIdentity),
 		MTU:                  cfg.MTU,
 		CipherSuites:         cfg.CipherSuites,
+		EllipticCurves:       cfg.EllipticCurves,
 	}
 	lc := udp.ListenConfig{
 		Backlog: Backlog,

+ 13 - 9
client/config.go

@@ -8,15 +8,16 @@ import (
 )
 
 type Config struct {
-	BindAddress   string
-	RemoteAddress string
-	Timeout       time.Duration
-	IdleTimeout   time.Duration
-	BaseContext   context.Context
-	PSKCallback   func([]byte) ([]byte, error)
-	PSKIdentity   string
-	MTU           int
-	CipherSuites  ciphers.CipherList
+	BindAddress    string
+	RemoteAddress  string
+	Timeout        time.Duration
+	IdleTimeout    time.Duration
+	BaseContext    context.Context
+	PSKCallback    func([]byte) ([]byte, error)
+	PSKIdentity    string
+	MTU            int
+	CipherSuites   ciphers.CipherList
+	EllipticCurves ciphers.CurveList
 }
 
 func (cfg *Config) populateDefaults() *Config {
@@ -32,5 +33,8 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.CipherSuites == nil {
 		cfg.CipherSuites = ciphers.DefaultCipherList
 	}
+	if cfg.EllipticCurves == nil {
+		cfg.EllipticCurves = ciphers.DefaultCurveList
+	}
 	return cfg
 }

+ 40 - 9
cmd/dtlspipe/main.go

@@ -41,6 +41,23 @@ func (l *cipherlistArg) Set(s string) error {
 	return nil
 }
 
+type curvelistArg struct {
+	Value ciphers.CurveList
+}
+
+func (l *curvelistArg) String() string {
+	return ciphers.CurveListToString(l.Value)
+}
+
+func (l *curvelistArg) Set(s string) error {
+	parsed, err := ciphers.StringToCurveList(s)
+	if err != nil {
+		return fmt.Errorf("can't parse curve list: %w", err)
+	}
+	l.Value = parsed
+	return nil
+}
+
 var (
 	version = "undefined"
 
@@ -53,10 +70,12 @@ var (
 	cpuprofile      = flag.String("cpuprofile", "", "write cpu profile to file")
 	skipHelloVerify = flag.Bool("skip-hello-verify", false, "(server only) skip hello verify request. Useful to workaround DPI")
 	ciphersuites    = cipherlistArg{}
+	curves          = curvelistArg{}
 )
 
 func init() {
 	flag.Var(&ciphersuites, "ciphers", "colon-separated list of ciphers to use")
+	flag.Var(&curves, "curves", "colon-separated list of curves to use")
 }
 
 func usage() {
@@ -67,6 +86,7 @@ func usage() {
 	fmt.Fprintf(out, "%s [OPTION]... client <BIND ADDRESS> <REMOTE ADDRESS>\n", ProgName)
 	fmt.Fprintf(out, "%s [OPTION]... genpsk\n", ProgName)
 	fmt.Fprintf(out, "%s ciphers\n", ProgName)
+	fmt.Fprintf(out, "%s curves\n", ProgName)
 	fmt.Fprintf(out, "%s version\n", ProgName)
 	fmt.Fprintln(out)
 	fmt.Fprintln(out, "Options:")
@@ -106,15 +126,16 @@ func cmdClient(bindAddress, remoteAddress string) int {
 	defer cancel()
 
 	cfg := client.Config{
-		BindAddress:   bindAddress,
-		RemoteAddress: remoteAddress,
-		PSKCallback:   keystore.NewStaticKeystore(psk).PSKCallback,
-		PSKIdentity:   *identity,
-		Timeout:       *timeout,
-		IdleTimeout:   *idleTime,
-		BaseContext:   appCtx,
-		MTU:           *mtu,
-		CipherSuites:  ciphersuites.Value,
+		BindAddress:    bindAddress,
+		RemoteAddress:  remoteAddress,
+		PSKCallback:    keystore.NewStaticKeystore(psk).PSKCallback,
+		PSKIdentity:    *identity,
+		Timeout:        *timeout,
+		IdleTimeout:    *idleTime,
+		BaseContext:    appCtx,
+		MTU:            *mtu,
+		CipherSuites:   ciphersuites.Value,
+		EllipticCurves: curves.Value,
 	}
 
 	clt, err := client.New(&cfg)
@@ -150,6 +171,7 @@ func cmdServer(bindAddress, remoteAddress string) int {
 		MTU:             *mtu,
 		SkipHelloVerify: *skipHelloVerify,
 		CipherSuites:    ciphersuites.Value,
+		EllipticCurves:  curves.Value,
 	}
 
 	srv, err := server.New(&cfg)
@@ -169,6 +191,13 @@ func cmdCiphers() int {
 	return 0
 }
 
+func cmdCurves() int {
+	for _, curve := range ciphers.FullCurveList {
+		fmt.Println(ciphers.CurveIDToString(curve))
+	}
+	return 0
+}
+
 func run() int {
 	flag.CommandLine.Usage = usage
 	flag.Parse()
@@ -190,6 +219,8 @@ func run() int {
 			return cmdGenPSK()
 		case "ciphers":
 			return cmdCiphers()
+		case "curves":
+			return cmdCurves()
 		case "version":
 			return cmdVersion()
 		}

+ 4 - 0
server/config.go

@@ -17,6 +17,7 @@ type Config struct {
 	MTU             int
 	SkipHelloVerify bool
 	CipherSuites    ciphers.CipherList
+	EllipticCurves  ciphers.CurveList
 }
 
 func (cfg *Config) populateDefaults() *Config {
@@ -32,5 +33,8 @@ func (cfg *Config) populateDefaults() *Config {
 	if cfg.CipherSuites == nil {
 		cfg.CipherSuites = ciphers.DefaultCipherList
 	}
+	if cfg.EllipticCurves == nil {
+		cfg.EllipticCurves = ciphers.DefaultCurveList
+	}
 	return cfg
 }

+ 1 - 0
server/server.go

@@ -57,6 +57,7 @@ func New(cfg *Config) (*Server, error) {
 		MTU:                     cfg.MTU,
 		InsecureSkipVerifyHello: cfg.SkipHelloVerify,
 		CipherSuites:            cfg.CipherSuites,
+		EllipticCurves:          cfg.EllipticCurves,
 	}
 	lc := udp.ListenConfig{
 		AcceptFilter: func(packet []byte) bool {