Ver Fonte

fix(security): SSRF-guard node and remote HTTP clients

The Node.Probe and Remote.do paths built outbound URLs by string-
formatting admin-controlled fields (Scheme/Address/Port/BasePath)
straight into requests, then dialed the result with the default
transport. CodeQL flagged this as go/request-forgery — an admin
(or anyone who compromises the admin account) could point a node
at internal infrastructure (cloud metadata, RFC1918 ranges, etc.)
and the panel would dutifully fetch it.

Add util/netsafe with a shared TOCTOU-safe DialContext that
resolves the host, rejects private/internal IPs unless the
per-request context whitelists them (per-node AllowPrivateAddress
flag, plumbed through context.Value), and dials the resolved IP
directly so the IP that passed the check is the IP we connect to.
This closes the DNS-rebinding window where a hostname could
resolve to a public IP at check time and a private one at dial.

Also tighten address validation (NormalizeHost rejects anything
that isn't a bare hostname or IP literal — no embedded paths,
userinfo, schemes) and switch URL construction from fmt.Sprintf to
url.URL{} + net.JoinHostPort so admin-supplied values can't smuggle
URL components.

custom_geo.go's isBlockedIP now delegates to netsafe so there's
one source of truth.
MHSanaei há 1 dia atrás
pai
commit
38da210ded
4 ficheiros alterados com 163 adições e 15 exclusões
  1. 101 0
      util/netsafe/netsafe.go
  2. 30 7
      web/runtime/remote.go
  3. 2 2
      web/service/custom_geo.go
  4. 30 6
      web/service/node.go

+ 101 - 0
util/netsafe/netsafe.go

@@ -0,0 +1,101 @@
+// Package netsafe provides SSRF-safe HTTP dialing primitives. A dialer
+// installed via SSRFGuardedDialContext resolves the host, rejects
+// private/internal IPs unless the per-request context whitelists them,
+// and dials the resolved IP directly so the IP checked is the IP used —
+// closing the DNS-rebinding TOCTOU window.
+package netsafe
+
+import (
+	"context"
+	"fmt"
+	"net"
+	"regexp"
+	"strings"
+	"time"
+)
+
+// IsBlockedIP returns true for loopback, RFC1918 private, link-local
+// (including 169.254.169.254 cloud-metadata), and unspecified addresses.
+func IsBlockedIP(ip net.IP) bool {
+	return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
+		ip.IsLinkLocalMulticast() || ip.IsUnspecified()
+}
+
+type allowPrivateCtxKey struct{}
+
+// ContextWithAllowPrivate marks a context as permitting outbound requests
+// to private/internal IPs. Use only for callers (e.g. LAN-resident nodes)
+// where the admin has opted in explicitly.
+func ContextWithAllowPrivate(ctx context.Context, allow bool) context.Context {
+	return context.WithValue(ctx, allowPrivateCtxKey{}, allow)
+}
+
+func AllowPrivateFromContext(ctx context.Context) bool {
+	v, _ := ctx.Value(allowPrivateCtxKey{}).(bool)
+	return v
+}
+
+var defaultDialer = &net.Dialer{Timeout: 10 * time.Second}
+
+// SSRFGuardedDialContext is a net/http Transport.DialContext implementation
+// that enforces IsBlockedIP unless the context opts in via
+// ContextWithAllowPrivate.
+func SSRFGuardedDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+	host, port, err := net.SplitHostPort(addr)
+	if err != nil {
+		return nil, err
+	}
+	allowPrivate := AllowPrivateFromContext(ctx)
+	var ips []net.IPAddr
+	if ip := net.ParseIP(host); ip != nil {
+		ips = []net.IPAddr{{IP: ip}}
+	} else {
+		ips, err = net.DefaultResolver.LookupIPAddr(ctx, host)
+		if err != nil {
+			return nil, err
+		}
+	}
+	var lastErr error
+	for _, ipAddr := range ips {
+		if !allowPrivate && IsBlockedIP(ipAddr.IP) {
+			lastErr = fmt.Errorf("blocked private/internal address %s", ipAddr.IP)
+			continue
+		}
+		conn, derr := defaultDialer.DialContext(ctx, network, net.JoinHostPort(ipAddr.IP.String(), port))
+		if derr == nil {
+			return conn, nil
+		}
+		lastErr = derr
+	}
+	if lastErr == nil {
+		lastErr = fmt.Errorf("no usable address for %s", host)
+	}
+	return nil, lastErr
+}
+
+// hostnamePattern accepts RFC 1123 hostnames (letters, digits, hyphens,
+// dots). Bracketed IPv6 forms ("[::1]") are stripped before this check
+// runs in NormalizeHost.
+var hostnamePattern = regexp.MustCompile(`^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)*$`)
+
+// NormalizeHost validates that addr is a plain hostname or IP literal with
+// no embedded path/userinfo/port/scheme — anything that could be used to
+// smuggle URL components past callers that string-format URLs from user
+// input. Returns the bare host (no brackets); callers wrap IPv6 via
+// net.JoinHostPort as needed.
+func NormalizeHost(addr string) (string, error) {
+	addr = strings.TrimSpace(addr)
+	if addr == "" {
+		return "", fmt.Errorf("address is required")
+	}
+	if strings.HasPrefix(addr, "[") && strings.HasSuffix(addr, "]") {
+		addr = addr[1 : len(addr)-1]
+	}
+	if ip := net.ParseIP(addr); ip != nil {
+		return ip.String(), nil
+	}
+	if len(addr) > 253 || !hostnamePattern.MatchString(addr) {
+		return "", fmt.Errorf("invalid host %q", addr)
+	}
+	return addr, nil
+}

+ 30 - 7
web/runtime/remote.go

@@ -7,6 +7,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
 	"net/http"
 	"net/url"
 	"strconv"
@@ -16,6 +17,7 @@ import (
 
 	"github.com/mhsanaei/3x-ui/v3/database/model"
 	"github.com/mhsanaei/3x-ui/v3/logger"
+	"github.com/mhsanaei/3x-ui/v3/util/netsafe"
 )
 
 const remoteHTTPTimeout = 10 * time.Second
@@ -25,6 +27,7 @@ var remoteHTTPClient = &http.Client{
 		MaxIdleConns:        64,
 		MaxIdleConnsPerHost: 4,
 		IdleConnTimeout:     60 * time.Second,
+		DialContext:         netsafe.SSRFGuardedDialContext,
 	},
 }
 
@@ -50,7 +53,18 @@ func NewRemote(n *model.Node) *Remote {
 
 func (r *Remote) Name() string { return "node:" + r.node.Name }
 
-func (r *Remote) baseURL() string {
+func (r *Remote) baseURL() (string, error) {
+	addr, err := netsafe.NormalizeHost(r.node.Address)
+	if err != nil {
+		return "", err
+	}
+	scheme := r.node.Scheme
+	if scheme != "http" && scheme != "https" {
+		scheme = "https"
+	}
+	if r.node.Port <= 0 || r.node.Port > 65535 {
+		return "", fmt.Errorf("invalid node port %d", r.node.Port)
+	}
 	bp := r.node.BasePath
 	if bp == "" {
 		bp = "/"
@@ -58,7 +72,12 @@ func (r *Remote) baseURL() string {
 	if !strings.HasSuffix(bp, "/") {
 		bp += "/"
 	}
-	return fmt.Sprintf("%s://%s:%d%s", r.node.Scheme, r.node.Address, r.node.Port, bp)
+	u := &url.URL{
+		Scheme: scheme,
+		Host:   net.JoinHostPort(addr, strconv.Itoa(r.node.Port)),
+		Path:   bp,
+	}
+	return u.String(), nil
 }
 
 func (r *Remote) do(ctx context.Context, method, path string, body any) (*envelope, error) {
@@ -66,7 +85,11 @@ func (r *Remote) do(ctx context.Context, method, path string, body any) (*envelo
 		return nil, errors.New("node has no API token configured")
 	}
 
-	target := r.baseURL() + strings.TrimPrefix(path, "/")
+	base, err := r.baseURL()
+	if err != nil {
+		return nil, err
+	}
+	target := base + strings.TrimPrefix(path, "/")
 
 	var (
 		reqBody     io.Reader
@@ -78,15 +101,15 @@ func (r *Remote) do(ctx context.Context, method, path string, body any) (*envelo
 		reqBody = strings.NewReader(b.Encode())
 		contentType = "application/x-www-form-urlencoded"
 	default:
-		buf, err := json.Marshal(b)
-		if err != nil {
-			return nil, fmt.Errorf("marshal body: %w", err)
+		buf, jerr := json.Marshal(b)
+		if jerr != nil {
+			return nil, fmt.Errorf("marshal body: %w", jerr)
 		}
 		reqBody = bytes.NewReader(buf)
 		contentType = "application/json"
 	}
 
-	cctx, cancel := context.WithTimeout(ctx, remoteHTTPTimeout)
+	cctx, cancel := context.WithTimeout(netsafe.ContextWithAllowPrivate(ctx, r.node.AllowPrivateAddress), remoteHTTPTimeout)
 	defer cancel()
 	req, err := http.NewRequestWithContext(cctx, method, target, reqBody)
 	if err != nil {

+ 2 - 2
web/service/custom_geo.go

@@ -18,6 +18,7 @@ import (
 	"github.com/mhsanaei/3x-ui/v3/database"
 	"github.com/mhsanaei/3x-ui/v3/database/model"
 	"github.com/mhsanaei/3x-ui/v3/logger"
+	"github.com/mhsanaei/3x-ui/v3/util/netsafe"
 )
 
 const (
@@ -164,8 +165,7 @@ func CustomGeoLocalFileNeedsRepair(path string) bool {
 }
 
 func isBlockedIP(ip net.IP) bool {
-	return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
-		ip.IsLinkLocalMulticast() || ip.IsUnspecified()
+	return netsafe.IsBlockedIP(ip)
 }
 
 // checkSSRFDefault validates that the given host does not resolve to a private/internal IP.

+ 30 - 6
web/service/node.go

@@ -5,7 +5,9 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"net"
 	"net/http"
+	"net/url"
 	"strconv"
 	"strings"
 	"time"
@@ -13,6 +15,7 @@ import (
 	"github.com/mhsanaei/3x-ui/v3/database"
 	"github.com/mhsanaei/3x-ui/v3/database/model"
 	"github.com/mhsanaei/3x-ui/v3/util/common"
+	"github.com/mhsanaei/3x-ui/v3/util/netsafe"
 	"github.com/mhsanaei/3x-ui/v3/web/runtime"
 )
 
@@ -34,6 +37,7 @@ var nodeHTTPClient = &http.Client{
 		MaxIdleConns:        64,
 		MaxIdleConnsPerHost: 4,
 		IdleConnTimeout:     60 * time.Second,
+		DialContext:         netsafe.SSRFGuardedDialContext,
 	},
 }
 
@@ -69,14 +73,15 @@ func normalizeBasePath(p string) string {
 
 func (s *NodeService) normalize(n *model.Node) error {
 	n.Name = strings.TrimSpace(n.Name)
-	n.Address = strings.TrimSpace(n.Address)
 	n.ApiToken = strings.TrimSpace(n.ApiToken)
 	if n.Name == "" {
 		return common.NewError("node name is required")
 	}
-	if n.Address == "" {
-		return common.NewError("node address is required")
+	addr, err := netsafe.NormalizeHost(n.Address)
+	if err != nil {
+		return common.NewError(err.Error())
 	}
+	n.Address = addr
 	if n.Port <= 0 || n.Port > 65535 {
 		return common.NewError("node port must be 1-65535")
 	}
@@ -175,10 +180,29 @@ func (s *NodeService) AggregateNodeMetric(id int, metric string, bucketSeconds i
 
 func (s *NodeService) Probe(ctx context.Context, n *model.Node) (HeartbeatPatch, error) {
 	patch := HeartbeatPatch{LastHeartbeat: time.Now().Unix()}
-	url := fmt.Sprintf("%s://%s:%d%spanel/api/server/status",
-		n.Scheme, n.Address, n.Port, normalizeBasePath(n.BasePath))
 
-	req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
+	addr, err := netsafe.NormalizeHost(n.Address)
+	if err != nil {
+		patch.LastError = err.Error()
+		return patch, err
+	}
+	scheme := n.Scheme
+	if scheme != "http" && scheme != "https" {
+		scheme = "https"
+	}
+	if n.Port <= 0 || n.Port > 65535 {
+		patch.LastError = "node port must be 1-65535"
+		return patch, errors.New(patch.LastError)
+	}
+	probeURL := &url.URL{
+		Scheme: scheme,
+		Host:   net.JoinHostPort(addr, strconv.Itoa(n.Port)),
+		Path:   normalizeBasePath(n.BasePath) + "panel/api/server/status",
+	}
+
+	req, err := http.NewRequestWithContext(
+		netsafe.ContextWithAllowPrivate(ctx, n.AllowPrivateAddress),
+		http.MethodGet, probeURL.String(), nil)
 	if err != nil {
 		patch.LastError = err.Error()
 		return patch, err