|
@@ -11,6 +11,7 @@ import (
|
|
|
"syscall"
|
|
|
"time"
|
|
|
|
|
|
+ "github.com/Snawoot/dtlspipe/keystore"
|
|
|
"github.com/Snawoot/dtlspipe/server"
|
|
|
"github.com/Snawoot/dtlspipe/util"
|
|
|
)
|
|
@@ -25,7 +26,7 @@ var (
|
|
|
|
|
|
timeout = flag.Duration("timeout", 10*time.Second, "network operation timeout")
|
|
|
idleTime = flag.Duration("idle-time", 90*time.Second, "max idle time for UDP session")
|
|
|
- pskHexOpt = flag.String("psk", "", "hex-encoded pre-shared key. Can be generated with `genpsk` subcommand")
|
|
|
+ pskHexOpt = flag.String("psk", "", "hex-encoded pre-shared key. Can be generated with genpsk subcommand")
|
|
|
keyLength = flag.Uint("key-length", 16, "generate key with specified length")
|
|
|
)
|
|
|
|
|
@@ -62,7 +63,12 @@ func cmdVersion() int {
|
|
|
return 0
|
|
|
}
|
|
|
|
|
|
-func cmdClient(bindAddress, remoteAddress string, psk []byte) int {
|
|
|
+func cmdClient(bindAddress, remoteAddress string) int {
|
|
|
+ _, err := simpleGetPSK()
|
|
|
+ if err != nil {
|
|
|
+ log.Printf("can't get PSK: %v", err)
|
|
|
+ return 2
|
|
|
+ }
|
|
|
log.Printf("starting dtlspipe client: %s =[wrap into DTLS]=> %s", bindAddress, remoteAddress)
|
|
|
defer log.Println("dtlspipe client stopped")
|
|
|
|
|
@@ -74,7 +80,12 @@ func cmdClient(bindAddress, remoteAddress string, psk []byte) int {
|
|
|
return 0
|
|
|
}
|
|
|
|
|
|
-func cmdServer(bindAddress, remoteAddress string, psk []byte) int {
|
|
|
+func cmdServer(bindAddress, remoteAddress string) int {
|
|
|
+ psk, err := simpleGetPSK()
|
|
|
+ if err != nil {
|
|
|
+ log.Printf("can't get PSK: %v", err)
|
|
|
+ return 2
|
|
|
+ }
|
|
|
log.Printf("starting dtlspipe server: %s =[unwrap from DTLS]=> %s", bindAddress, remoteAddress)
|
|
|
defer log.Println("dtlspipe server stopped")
|
|
|
|
|
@@ -84,7 +95,7 @@ func cmdServer(bindAddress, remoteAddress string, psk []byte) int {
|
|
|
cfg := server.Config{
|
|
|
BindAddress: bindAddress,
|
|
|
RemoteAddress: remoteAddress,
|
|
|
- PSK: psk,
|
|
|
+ PSKCallback: keystore.NewStaticKeystore(psk).PSKCallback,
|
|
|
Timeout: *timeout,
|
|
|
IdleTimeout: *idleTime,
|
|
|
BaseContext: appCtx,
|
|
@@ -114,30 +125,11 @@ func run() int {
|
|
|
return cmdVersion()
|
|
|
}
|
|
|
case 3:
|
|
|
- pskHex := os.Getenv(PSKEnvVarKey)
|
|
|
- if pskHex == "" {
|
|
|
- os.Unsetenv(PSKEnvVarKey)
|
|
|
- }
|
|
|
- if *pskHexOpt != "" {
|
|
|
- pskHex = *pskHexOpt
|
|
|
- }
|
|
|
- if pskHex == "" {
|
|
|
- fmt.Fprintln(os.Stderr)
|
|
|
- fmt.Fprintf(os.Stderr, "Error: no PSK option provided and neither %s environment variable is set\n", PSKEnvVarKey)
|
|
|
- fmt.Fprintln(os.Stderr)
|
|
|
- return 2
|
|
|
- }
|
|
|
-
|
|
|
- psk, err := util.PSKFromHex(pskHex)
|
|
|
- if err != nil {
|
|
|
- fmt.Fprintf(os.Stderr, "Error: can't hex-decode PSK: %v\n", err)
|
|
|
- return 2
|
|
|
- }
|
|
|
switch args[0] {
|
|
|
case "server":
|
|
|
- return cmdServer(args[1], args[2], psk)
|
|
|
+ return cmdServer(args[1], args[2])
|
|
|
case "client":
|
|
|
- return cmdClient(args[1], args[2], psk)
|
|
|
+ return cmdClient(args[1], args[2])
|
|
|
}
|
|
|
}
|
|
|
usage()
|
|
@@ -149,3 +141,21 @@ func main() {
|
|
|
log.Default().SetPrefix(strings.ToUpper(ProgName) + ": ")
|
|
|
os.Exit(run())
|
|
|
}
|
|
|
+
|
|
|
+func simpleGetPSK() ([]byte, error) {
|
|
|
+ pskHex := os.Getenv(PSKEnvVarKey)
|
|
|
+ if pskHex == "" {
|
|
|
+ os.Unsetenv(PSKEnvVarKey)
|
|
|
+ }
|
|
|
+ if *pskHexOpt != "" {
|
|
|
+ pskHex = *pskHexOpt
|
|
|
+ }
|
|
|
+ if pskHex == "" {
|
|
|
+ return nil, fmt.Errorf("no PSK command line option provided and neither %s environment variable is set", PSKEnvVarKey)
|
|
|
+ }
|
|
|
+ psk, err := util.PSKFromHex(pskHex)
|
|
|
+ if err != nil {
|
|
|
+ return nil, fmt.Errorf("can't hex-decode PSK: %w", err)
|
|
|
+ }
|
|
|
+ return psk, nil
|
|
|
+}
|