tls_client_test.go 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. package runtime
  2. import (
  3. "context"
  4. "crypto/sha256"
  5. "encoding/base64"
  6. "encoding/hex"
  7. "net/http"
  8. "net/http/httptest"
  9. "net/url"
  10. "strconv"
  11. "strings"
  12. "testing"
  13. "github.com/mhsanaei/3x-ui/v3/internal/database/model"
  14. )
  15. // nodeForServer builds a node pointing at a loopback test server (loopback is
  16. // SSRF-blocked, so AllowPrivateAddress is set for the guarded dialer).
  17. func nodeForServer(t *testing.T, srv *httptest.Server, mode, pin string) *model.Node {
  18. t.Helper()
  19. u, err := url.Parse(srv.URL)
  20. if err != nil {
  21. t.Fatalf("parse server url: %v", err)
  22. }
  23. port, err := strconv.Atoi(u.Port())
  24. if err != nil {
  25. t.Fatalf("parse server port: %v", err)
  26. }
  27. return &model.Node{
  28. Id: 1,
  29. Name: "n1",
  30. Scheme: "https",
  31. Address: u.Hostname(),
  32. Port: port,
  33. BasePath: "/",
  34. ApiToken: "token",
  35. Enable: true,
  36. AllowPrivateAddress: true,
  37. TlsVerifyMode: mode,
  38. PinnedCertSha256: pin,
  39. }
  40. }
  41. func leafPinBase64(srv *httptest.Server) string {
  42. sum := sha256.Sum256(srv.Certificate().Raw)
  43. return base64.StdEncoding.EncodeToString(sum[:])
  44. }
  45. // A self-signed node must be reachable by Remote ops under skip/pin and
  46. // rejected under verify — the split issue #5264 reported.
  47. func TestRemoteHonorsTLSVerifyMode(t *testing.T) {
  48. srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
  49. w.Header().Set("Content-Type", "application/json")
  50. _, _ = w.Write([]byte(`{"success":true,"obj":[]}`))
  51. }))
  52. defer srv.Close()
  53. goodPin := leafPinBase64(srv)
  54. wrongPin := base64.StdEncoding.EncodeToString(make([]byte, sha256.Size))
  55. cases := []struct {
  56. name string
  57. mode string
  58. pin string
  59. wantErr bool
  60. }{
  61. {"verify rejects self-signed", "verify", "", true},
  62. {"skip accepts self-signed", "skip", "", false},
  63. {"pin accepts matching cert", "pin", goodPin, false},
  64. {"pin rejects mismatched cert", "pin", wrongPin, true},
  65. }
  66. for _, c := range cases {
  67. t.Run(c.name, func(t *testing.T) {
  68. r := NewRemote(nodeForServer(t, srv, c.mode, c.pin), nil)
  69. _, err := r.ListInboundOptions(context.Background())
  70. if c.wantErr && err == nil {
  71. t.Fatalf("mode %q: expected error, got nil", c.mode)
  72. }
  73. if !c.wantErr && err != nil {
  74. t.Fatalf("mode %q: unexpected error: %v", c.mode, err)
  75. }
  76. })
  77. }
  78. }
  79. // The lazily-built client is cached for the Remote's lifetime so repeated
  80. // operations reuse one pooled transport rather than rebuilding TLS each call.
  81. func TestRemoteClientCached(t *testing.T) {
  82. r := NewRemote(&model.Node{Scheme: "https", TlsVerifyMode: "skip"}, nil)
  83. c1, err1 := r.httpClient()
  84. c2, err2 := r.httpClient()
  85. if err1 != nil || err2 != nil {
  86. t.Fatalf("httpClient errors: %v %v", err1, err2)
  87. }
  88. if c1 != c2 {
  89. t.Fatal("expected the same cached client across calls")
  90. }
  91. }
  92. func TestHTTPClientForNodeVerifyShared(t *testing.T) {
  93. // verify mode and plain http both reuse the shared default client.
  94. for _, n := range []*model.Node{
  95. {Scheme: "https", TlsVerifyMode: "verify"},
  96. {Scheme: "https", TlsVerifyMode: ""},
  97. {Scheme: "http", TlsVerifyMode: "skip"},
  98. } {
  99. c, err := HTTPClientForNode(n, "")
  100. if err != nil {
  101. t.Fatalf("HTTPClientForNode(%+v): %v", n, err)
  102. }
  103. if c != defaultNodeHTTPClient {
  104. t.Fatalf("HTTPClientForNode(%+v) = %p, want shared default %p", n, c, defaultNodeHTTPClient)
  105. }
  106. }
  107. }
  108. func TestHTTPClientForNodePinInvalid(t *testing.T) {
  109. // pin mode must fail closed, and with a specific error per cause — not merely
  110. // "some error" (which a bug anywhere in the build path would also satisfy).
  111. cases := []struct {
  112. name string
  113. pin string
  114. wantErr string
  115. }{
  116. {"garbage pin", "not-a-pin", "must be a SHA-256 hash"},
  117. {"empty pin", "", "certificate pin is empty"},
  118. }
  119. for _, c := range cases {
  120. t.Run(c.name, func(t *testing.T) {
  121. _, err := HTTPClientForNode(&model.Node{Scheme: "https", TlsVerifyMode: "pin", PinnedCertSha256: c.pin}, "")
  122. if err == nil {
  123. t.Fatalf("expected error for pin %q", c.pin)
  124. }
  125. if !strings.Contains(err.Error(), c.wantErr) {
  126. t.Fatalf("error = %q, want it to contain %q", err.Error(), c.wantErr)
  127. }
  128. })
  129. }
  130. }
  131. // TestHTTPClientForNode_ProxyPinPreservesPinEnforcement covers the proxy+pin branch
  132. // (tls_client.go:43-52): when a node uses a proxy AND pin mode, the proxy client's
  133. // transport must carry the pinning tls.Config (the `transport.TLSClientConfig = tlsCfg`
  134. // line). Dropping it would silently disable certificate pinning whenever a proxy is set.
  135. func TestHTTPClientForNode_ProxyPinPreservesPinEnforcement(t *testing.T) {
  136. pin := base64.StdEncoding.EncodeToString(make([]byte, sha256.Size))
  137. n := &model.Node{Scheme: "https", TlsVerifyMode: "pin", PinnedCertSha256: pin}
  138. c, err := HTTPClientForNode(n, "socks5://127.0.0.1:1080")
  139. if err != nil {
  140. t.Fatalf("HTTPClientForNode: %v", err)
  141. }
  142. if c == defaultNodeHTTPClient {
  143. t.Fatal("proxy client must not be the shared default client")
  144. }
  145. tr, ok := c.Transport.(*http.Transport)
  146. if !ok {
  147. t.Fatalf("transport is %T, want *http.Transport", c.Transport)
  148. }
  149. if tr.TLSClientConfig == nil || tr.TLSClientConfig.VerifyConnection == nil {
  150. t.Fatal("pin mode over a proxy must install a pinning tls.Config (VerifyConnection); pin enforcement was dropped")
  151. }
  152. }
  153. // TestHTTPClientForNode_ProxyVerifyNoPin covers the proxy+verify branch
  154. // (tls_client.go:40-42): verify mode over a proxy returns the proxy client as-is,
  155. // using system-CA verification and NOT a pin VerifyConnection.
  156. func TestHTTPClientForNode_ProxyVerifyNoPin(t *testing.T) {
  157. n := &model.Node{Scheme: "https", TlsVerifyMode: "verify"}
  158. c, err := HTTPClientForNode(n, "socks5://127.0.0.1:1080")
  159. if err != nil {
  160. t.Fatalf("HTTPClientForNode: %v", err)
  161. }
  162. if c == defaultNodeHTTPClient {
  163. t.Fatal("proxy client must not be the shared default client")
  164. }
  165. if tr, ok := c.Transport.(*http.Transport); ok && tr.TLSClientConfig != nil && tr.TLSClientConfig.VerifyConnection != nil {
  166. t.Fatal("verify mode must not install a pin VerifyConnection")
  167. }
  168. }
  169. func TestDecodeCertPin(t *testing.T) {
  170. raw := sha256.Sum256([]byte("cert"))
  171. hexColon := strings.ToUpper(hex.EncodeToString(raw[:]))
  172. // reinsert colons in openssl -fingerprint style
  173. var withColons strings.Builder
  174. for i := 0; i < len(hexColon); i += 2 {
  175. if i > 0 {
  176. withColons.WriteByte(':')
  177. }
  178. withColons.WriteString(hexColon[i : i+2])
  179. }
  180. cases := []struct {
  181. name string
  182. in string
  183. wantErr bool
  184. }{
  185. {"base64 std", base64.StdEncoding.EncodeToString(raw[:]), false},
  186. {"base64 raw url", base64.RawURLEncoding.EncodeToString(raw[:]), false},
  187. {"hex bare", hex.EncodeToString(raw[:]), false},
  188. {"hex colon openssl", withColons.String(), false},
  189. {"empty", "", true},
  190. {"garbage", "not-a-pin", true},
  191. }
  192. for _, c := range cases {
  193. t.Run(c.name, func(t *testing.T) {
  194. got, err := DecodeCertPin(c.in)
  195. if c.wantErr {
  196. if err == nil {
  197. t.Fatalf("expected error for %q", c.in)
  198. }
  199. return
  200. }
  201. if err != nil {
  202. t.Fatalf("unexpected error for %q: %v", c.in, err)
  203. }
  204. if string(got) != string(raw[:]) {
  205. t.Fatalf("decoded bytes mismatch for %q", c.in)
  206. }
  207. })
  208. }
  209. }