Vladislav Yarmak 1 anno fa
parent
commit
d2bd40c953
4 ha cambiato i file con 56 aggiunte e 33 eliminazioni
  1. 35 25
      cmd/dtlspipe/main.go
  2. 17 0
      keystore/static.go
  3. 1 1
      server/config.go
  4. 3 7
      server/server.go

+ 35 - 25
cmd/dtlspipe/main.go

@@ -11,6 +11,7 @@ import (
 	"syscall"
 	"time"
 
+	"github.com/Snawoot/dtlspipe/keystore"
 	"github.com/Snawoot/dtlspipe/server"
 	"github.com/Snawoot/dtlspipe/util"
 )
@@ -25,7 +26,7 @@ var (
 
 	timeout   = flag.Duration("timeout", 10*time.Second, "network operation timeout")
 	idleTime  = flag.Duration("idle-time", 90*time.Second, "max idle time for UDP session")
-	pskHexOpt = flag.String("psk", "", "hex-encoded pre-shared key. Can be generated with `genpsk` subcommand")
+	pskHexOpt = flag.String("psk", "", "hex-encoded pre-shared key. Can be generated with genpsk subcommand")
 	keyLength = flag.Uint("key-length", 16, "generate key with specified length")
 )
 
@@ -62,7 +63,12 @@ func cmdVersion() int {
 	return 0
 }
 
-func cmdClient(bindAddress, remoteAddress string, psk []byte) int {
+func cmdClient(bindAddress, remoteAddress string) int {
+	_, err := simpleGetPSK()
+	if err != nil {
+		log.Printf("can't get PSK: %v", err)
+		return 2
+	}
 	log.Printf("starting dtlspipe client: %s =[wrap into DTLS]=> %s", bindAddress, remoteAddress)
 	defer log.Println("dtlspipe client stopped")
 
@@ -74,7 +80,12 @@ func cmdClient(bindAddress, remoteAddress string, psk []byte) int {
 	return 0
 }
 
-func cmdServer(bindAddress, remoteAddress string, psk []byte) int {
+func cmdServer(bindAddress, remoteAddress string) int {
+	psk, err := simpleGetPSK()
+	if err != nil {
+		log.Printf("can't get PSK: %v", err)
+		return 2
+	}
 	log.Printf("starting dtlspipe server: %s =[unwrap from DTLS]=> %s", bindAddress, remoteAddress)
 	defer log.Println("dtlspipe server stopped")
 
@@ -84,7 +95,7 @@ func cmdServer(bindAddress, remoteAddress string, psk []byte) int {
 	cfg := server.Config{
 		BindAddress:   bindAddress,
 		RemoteAddress: remoteAddress,
-		PSK:           psk,
+		PSKCallback:   keystore.NewStaticKeystore(psk).PSKCallback,
 		Timeout:       *timeout,
 		IdleTimeout:   *idleTime,
 		BaseContext:   appCtx,
@@ -114,30 +125,11 @@ func run() int {
 			return cmdVersion()
 		}
 	case 3:
-		pskHex := os.Getenv(PSKEnvVarKey)
-		if pskHex == "" {
-			os.Unsetenv(PSKEnvVarKey)
-		}
-		if *pskHexOpt != "" {
-			pskHex = *pskHexOpt
-		}
-		if pskHex == "" {
-			fmt.Fprintln(os.Stderr)
-			fmt.Fprintf(os.Stderr, "Error: no PSK option provided and neither %s environment variable is set\n", PSKEnvVarKey)
-			fmt.Fprintln(os.Stderr)
-			return 2
-		}
-
-		psk, err := util.PSKFromHex(pskHex)
-		if err != nil {
-			fmt.Fprintf(os.Stderr, "Error: can't hex-decode PSK: %v\n", err)
-			return 2
-		}
 		switch args[0] {
 		case "server":
-			return cmdServer(args[1], args[2], psk)
+			return cmdServer(args[1], args[2])
 		case "client":
-			return cmdClient(args[1], args[2], psk)
+			return cmdClient(args[1], args[2])
 		}
 	}
 	usage()
@@ -149,3 +141,21 @@ func main() {
 	log.Default().SetPrefix(strings.ToUpper(ProgName) + ": ")
 	os.Exit(run())
 }
+
+func simpleGetPSK() ([]byte, error) {
+	pskHex := os.Getenv(PSKEnvVarKey)
+	if pskHex == "" {
+		os.Unsetenv(PSKEnvVarKey)
+	}
+	if *pskHexOpt != "" {
+		pskHex = *pskHexOpt
+	}
+	if pskHex == "" {
+		return nil, fmt.Errorf("no PSK command line option provided and neither %s environment variable is set", PSKEnvVarKey)
+	}
+	psk, err := util.PSKFromHex(pskHex)
+	if err != nil {
+		return nil, fmt.Errorf("can't hex-decode PSK: %w", err)
+	}
+	return psk, nil
+}

+ 17 - 0
keystore/static.go

@@ -0,0 +1,17 @@
+package keystore
+
+import "bytes"
+
+type StaticKeystore struct {
+	psk []byte
+}
+
+func NewStaticKeystore(psk []byte) *StaticKeystore {
+	return &StaticKeystore{
+		psk: bytes.Clone(psk),
+	}
+}
+
+func (store *StaticKeystore) PSKCallback(hint []byte) ([]byte, error) {
+	return bytes.Clone(store.psk), nil
+}

+ 1 - 1
server/config.go

@@ -8,10 +8,10 @@ import (
 type Config struct {
 	BindAddress   string
 	RemoteAddress string
-	PSK           []byte
 	Timeout       time.Duration
 	IdleTimeout   time.Duration
 	BaseContext   context.Context
+	PSKCallback   func([]byte) ([]byte, error)
 }
 
 func (cfg *Config) populateDefaults() *Config {

+ 3 - 7
server/server.go

@@ -24,7 +24,7 @@ type Server struct {
 	listener    net.Listener
 	dtlsConfig  *dtls.Config
 	rAddr       string
-	psk         []byte
+	psk         func([]byte) ([]byte, error)
 	timeout     time.Duration
 	idleTimeout time.Duration
 	baseCtx     context.Context
@@ -39,7 +39,7 @@ func New(cfg *Config) (*Server, error) {
 	srv := &Server{
 		rAddr:       cfg.RemoteAddress,
 		timeout:     cfg.Timeout,
-		psk:         cfg.PSK,
+		psk:         cfg.PSKCallback,
 		idleTimeout: cfg.IdleTimeout,
 		baseCtx:     baseCtx,
 		cancelCtx:   cancelCtx,
@@ -62,7 +62,7 @@ func New(cfg *Config) (*Server, error) {
 		},
 		ExtendedMasterSecret: dtls.RequireExtendedMasterSecret,
 		ConnectContextMaker:  srv.contextMaker,
-		PSK:                  srv.getPSK,
+		PSK:                  srv.psk,
 	}
 	lc := udp.ListenConfig{
 		AcceptFilter: func(packet []byte) bool {
@@ -176,10 +176,6 @@ func (srv *Server) contextMaker() (context.Context, func()) {
 	return context.WithTimeout(srv.baseCtx, srv.timeout)
 }
 
-func (srv *Server) getPSK(hint []byte) ([]byte, error) {
-	return srv.psk, nil
-}
-
 func (srv *Server) Close() error {
 	srv.cancelCtx()
 	return srv.listener.Close()