node_mtls_test.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. package service
  2. import (
  3. "crypto/x509"
  4. "encoding/pem"
  5. "testing"
  6. "github.com/go-playground/validator/v10"
  7. "github.com/mhsanaei/3x-ui/v3/internal/database/model"
  8. )
  9. func TestNormalizeKeepsMtls(t *testing.T) {
  10. s := &NodeService{}
  11. cases := []struct {
  12. name string
  13. in model.Node
  14. wantMode string
  15. wantErr bool
  16. }{
  17. {"mtls over https preserved", model.Node{Name: "n", Address: "node.example.com", Port: 2053, Scheme: "https", TlsVerifyMode: "mtls"}, "mtls", false},
  18. {"mtls over http rejected", model.Node{Name: "n", Address: "node.example.com", Port: 2053, Scheme: "http", TlsVerifyMode: "mtls"}, "", true},
  19. {"unknown mode clamped to verify", model.Node{Name: "n", Address: "node.example.com", Port: 2053, Scheme: "https", TlsVerifyMode: "bogus"}, "verify", false},
  20. }
  21. for _, c := range cases {
  22. t.Run(c.name, func(t *testing.T) {
  23. n := c.in
  24. err := s.normalize(&n)
  25. if c.wantErr {
  26. if err == nil {
  27. t.Fatal("expected an error")
  28. }
  29. return
  30. }
  31. if err != nil {
  32. t.Fatalf("normalize: %v", err)
  33. }
  34. if n.TlsVerifyMode != c.wantMode {
  35. t.Fatalf("TlsVerifyMode = %q, want %q", n.TlsVerifyMode, c.wantMode)
  36. }
  37. })
  38. }
  39. }
  40. func TestNodeTlsVerifyModeValidatorAcceptsMtls(t *testing.T) {
  41. v := validator.New(validator.WithRequiredStructEnabled())
  42. base := model.Node{Name: "n", Address: "node.example.com", Port: 2053, Scheme: "https", ApiToken: "t"}
  43. for _, m := range []string{"verify", "skip", "pin", "mtls"} {
  44. n := base
  45. n.TlsVerifyMode = m
  46. if err := v.Struct(n); err != nil {
  47. t.Fatalf("validator rejected valid TlsVerifyMode %q: %v", m, err)
  48. }
  49. }
  50. bad := base
  51. bad.TlsVerifyMode = "bogus"
  52. if err := v.Struct(bad); err == nil {
  53. t.Fatal("validator must reject an unknown TlsVerifyMode")
  54. }
  55. }
  56. func TestNodeMtlsCaCert(t *testing.T) {
  57. _ = setupSettingMtlsDB(t)
  58. got, err := (&NodeService{}).NodeMtlsCaCert()
  59. if err != nil {
  60. t.Fatalf("NodeMtlsCaCert: %v", err)
  61. }
  62. block, _ := pem.Decode([]byte(got))
  63. if block == nil || block.Type != "CERTIFICATE" {
  64. t.Fatalf("NodeMtlsCaCert must return a CERTIFICATE PEM, got %q", got)
  65. }
  66. cert, err := x509.ParseCertificate(block.Bytes)
  67. if err != nil {
  68. t.Fatalf("parse returned cert: %v", err)
  69. }
  70. if !cert.IsCA {
  71. t.Fatal("NodeMtlsCaCert must return the CA certificate (IsCA)")
  72. }
  73. }
  74. func TestSetNodeMtlsTrustCA(t *testing.T) {
  75. _ = setupSettingMtlsDB(t)
  76. ns := &NodeService{}
  77. settings := SettingService{}
  78. ca, err := settings.EnsureNodeMtlsCA()
  79. if err != nil {
  80. t.Fatalf("EnsureNodeMtlsCA: %v", err)
  81. }
  82. if err := ns.SetNodeMtlsTrustCA(string(ca.CertPEM)); err != nil {
  83. t.Fatalf("SetNodeMtlsTrustCA(valid): %v", err)
  84. }
  85. pool, err := settings.NodeMtlsClientCAPool()
  86. if err != nil || pool == nil {
  87. t.Fatalf("valid trust CA must persist + build a pool: pool=%v err=%v", pool, err)
  88. }
  89. if err := ns.SetNodeMtlsTrustCA("not a certificate"); err == nil {
  90. t.Fatal("invalid PEM must be rejected (fail closed)")
  91. }
  92. if err := ns.SetNodeMtlsTrustCA(""); err != nil {
  93. t.Fatalf("clearing the trust CA must be allowed: %v", err)
  94. }
  95. pool, _ = settings.NodeMtlsClientCAPool()
  96. if pool != nil {
  97. t.Fatal("cleared trust CA must yield a nil pool (mTLS off)")
  98. }
  99. }