1
0
Vladislav Yarmak 1 жил өмнө
parent
commit
462f08c58b

+ 64 - 0
ciphers/ciphers.go

@@ -0,0 +1,64 @@
+package ciphers
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/pion/dtls/v2"
+)
+
+type CipherList = []dtls.CipherSuiteID
+
+var FullList = CipherList{
+	dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256,
+	dtls.TLS_PSK_WITH_AES_128_CCM,
+	dtls.TLS_PSK_WITH_AES_128_CCM_8,
+	dtls.TLS_PSK_WITH_AES_256_CCM_8,
+	dtls.TLS_PSK_WITH_AES_128_GCM_SHA256,
+	dtls.TLS_PSK_WITH_AES_128_CBC_SHA256,
+}
+
+var DefaultList = FullList
+var DefaultListString = ListToString(DefaultList)
+var NameToID map[string]dtls.CipherSuiteID
+
+func init() {
+	NameToID = make(map[string]dtls.CipherSuiteID)
+	for _, id := range FullList {
+		NameToID[dtls.CipherSuiteName(id)] = id
+	}
+}
+
+func IDToString(id dtls.CipherSuiteID) string {
+	return dtls.CipherSuiteName(id)
+}
+
+func ListToString(lst CipherList) string {
+	var b strings.Builder
+	var firstPrinted bool
+	for _, id := range lst {
+		if firstPrinted {
+			b.WriteByte(':')
+		} else {
+			firstPrinted = true
+		}
+		b.WriteString(dtls.CipherSuiteName(id))
+	}
+	return b.String()
+}
+
+func StringToList(str string) (CipherList, error) {
+	if str == "" {
+		return nil, nil
+	}
+	parts := strings.Split(str, ":")
+	var res CipherList
+	for _, name := range parts {
+		if id, ok := NameToID[name]; ok {
+			res = append(res, id)
+		} else {
+			return nil, fmt.Errorf("unknown ciphersuite: %q", name)
+		}
+	}
+	return res, nil
+}

+ 32 - 0
cmd/dtlspipe/main.go

@@ -12,6 +12,7 @@ import (
 	"syscall"
 	"time"
 
+	"github.com/Snawoot/dtlspipe/ciphers"
 	"github.com/Snawoot/dtlspipe/client"
 	"github.com/Snawoot/dtlspipe/keystore"
 	"github.com/Snawoot/dtlspipe/server"
@@ -23,6 +24,23 @@ const (
 	PSKEnvVarKey = "DTLSPIPE_PSK"
 )
 
+type cipherlistArg struct {
+	Value ciphers.CipherList
+}
+
+func (l *cipherlistArg) String() string {
+	return ciphers.ListToString(l.Value)
+}
+
+func (l *cipherlistArg) Set(s string) error {
+	parsed, err := ciphers.StringToList(s)
+	if err != nil {
+		return fmt.Errorf("can't parse cipher list: %w", err)
+	}
+	l.Value = parsed
+	return nil
+}
+
 var (
 	version = "undefined"
 
@@ -34,8 +52,13 @@ var (
 	mtu             = flag.Int("mtu", 1400, "MTU used for DTLS fragments")
 	cpuprofile      = flag.String("cpuprofile", "", "write cpu profile to file")
 	skipHelloVerify = flag.Bool("skip-hello-verify", false, "(server only) skip hello verify request. Useful to workaround DPI")
+	ciphersuites    = cipherlistArg{}
 )
 
+func init() {
+	flag.Var(&ciphersuites, "ciphers", "comma-separated list of ciphers to use")
+}
+
 func usage() {
 	out := flag.CommandLine.Output()
 	fmt.Fprintln(out, "Usage:")
@@ -136,6 +159,13 @@ func cmdServer(bindAddress, remoteAddress string) int {
 	return 0
 }
 
+func cmdCiphers() int {
+	for _, id := range ciphers.FullList {
+		fmt.Println(ciphers.IDToString(id))
+	}
+	return 0
+}
+
 func run() int {
 	flag.CommandLine.Usage = usage
 	flag.Parse()
@@ -155,6 +185,8 @@ func run() int {
 		switch args[0] {
 		case "genpsk":
 			return cmdGenPSK()
+		case "ciphers":
+			return cmdCiphers()
 		case "version":
 			return cmdVersion()
 		}