فهرست منبع

Add SSRF protection (#4044)

* Add SSRF protection for custom geo downloads

Introduce SSRF-safe HTTP transport for custom geo operations by adding ssrfSafeTransport and isBlockedIP helpers. The transport resolves hosts and blocks loopback, private, link-local and unspecified addresses, returning ErrCustomGeoSSRFBlocked on violations. Update probeCustomGeoURLWithGET, probeCustomGeoURL and downloadToPathOnce to use the safe transport. Also add the new error ErrCustomGeoSSRFBlocked and necessary imports. Minor whitespace/formatting adjustments in subClashService.go, web/entity/entity.go and web/service/setting.go.

* Add path traversal protection for custom geo

Prevent path traversal when handling custom geo downloads by adding ErrCustomGeoPathTraversal and a validateDestPath() helper that ensures destination paths stay inside the bin folder. Call validateDestPath from downloadToPathOnce, Update and Delete paths and wrap errors appropriately. Reconstruct sanitized URLs in sanitizeURL to break taint propagation before use. Map the new path-traversal error to a user-facing i18n message in the controller.

* fix
Sanaei 1 روز پیش
والد
کامیت
ea53da9341
6فایلهای تغییر یافته به همراه240 افزوده شده و 59 حذف شده
  1. 3 3
      sub/subClashService.go
  2. 6 0
      web/controller/custom_geo.go
  3. 3 3
      web/entity/entity.go
  4. 198 41
      web/service/custom_geo.go
  5. 21 3
      web/service/custom_geo_test.go
  6. 9 9
      web/service/setting.go

+ 3 - 3
sub/subClashService.go

@@ -160,10 +160,10 @@ func (s *SubClashService) getProxies(inbound *model.Inbound, client model.Client
 
 
 func (s *SubClashService) buildProxy(inbound *model.Inbound, client model.Client, stream map[string]any, extraRemark string) map[string]any {
 func (s *SubClashService) buildProxy(inbound *model.Inbound, client model.Client, stream map[string]any, extraRemark string) map[string]any {
 	proxy := map[string]any{
 	proxy := map[string]any{
-		"name": s.SubService.genRemark(inbound, client.Email, extraRemark),
+		"name":   s.SubService.genRemark(inbound, client.Email, extraRemark),
 		"server": inbound.Listen,
 		"server": inbound.Listen,
-		"port": inbound.Port,
-		"udp": true,
+		"port":   inbound.Port,
+		"udp":    true,
 	}
 	}
 
 
 	network, _ := stream["network"].(string)
 	network, _ := stream["network"].(string)

+ 6 - 0
web/controller/custom_geo.go

@@ -62,6 +62,12 @@ func mapCustomGeoErr(c *gin.Context, err error) error {
 	case errors.Is(err, service.ErrCustomGeoDownload):
 	case errors.Is(err, service.ErrCustomGeoDownload):
 		logger.Warning("custom geo download:", err)
 		logger.Warning("custom geo download:", err)
 		return errors.New(I18nWeb(c, "pages.index.customGeoErrDownload"))
 		return errors.New(I18nWeb(c, "pages.index.customGeoErrDownload"))
+	case errors.Is(err, service.ErrCustomGeoSSRFBlocked):
+		logger.Warning("custom geo SSRF blocked:", err)
+		return errors.New(I18nWeb(c, "pages.index.customGeoErrUrlHost"))
+	case errors.Is(err, service.ErrCustomGeoPathTraversal):
+		logger.Warning("custom geo path traversal blocked:", err)
+		return errors.New(I18nWeb(c, "pages.index.customGeoErrDownload"))
 	default:
 	default:
 		return err
 		return err
 	}
 	}

+ 3 - 3
web/entity/entity.go

@@ -76,9 +76,9 @@ type AllSetting struct {
 	SubURI                      string `json:"subURI" form:"subURI"`                                           // Subscription server URI
 	SubURI                      string `json:"subURI" form:"subURI"`                                           // Subscription server URI
 	SubJsonPath                 string `json:"subJsonPath" form:"subJsonPath"`                                 // Path for JSON subscription endpoint
 	SubJsonPath                 string `json:"subJsonPath" form:"subJsonPath"`                                 // Path for JSON subscription endpoint
 	SubJsonURI                  string `json:"subJsonURI" form:"subJsonURI"`                                   // JSON subscription server URI
 	SubJsonURI                  string `json:"subJsonURI" form:"subJsonURI"`                                   // JSON subscription server URI
-	SubClashEnable              bool   `json:"subClashEnable" form:"subClashEnable"`                             // Enable Clash/Mihomo subscription endpoint
-	SubClashPath                string `json:"subClashPath" form:"subClashPath"`                                 // Path for Clash/Mihomo subscription endpoint
-	SubClashURI                 string `json:"subClashURI" form:"subClashURI"`                                   // Clash/Mihomo subscription server URI
+	SubClashEnable              bool   `json:"subClashEnable" form:"subClashEnable"`                           // Enable Clash/Mihomo subscription endpoint
+	SubClashPath                string `json:"subClashPath" form:"subClashPath"`                               // Path for Clash/Mihomo subscription endpoint
+	SubClashURI                 string `json:"subClashURI" form:"subClashURI"`                                 // Clash/Mihomo subscription server URI
 	SubJsonFragment             string `json:"subJsonFragment" form:"subJsonFragment"`                         // JSON subscription fragment configuration
 	SubJsonFragment             string `json:"subJsonFragment" form:"subJsonFragment"`                         // JSON subscription fragment configuration
 	SubJsonNoises               string `json:"subJsonNoises" form:"subJsonNoises"`                             // JSON subscription noise configuration
 	SubJsonNoises               string `json:"subJsonNoises" form:"subJsonNoises"`                             // JSON subscription noise configuration
 	SubJsonMux                  string `json:"subJsonMux" form:"subJsonMux"`                                   // JSON subscription mux configuration
 	SubJsonMux                  string `json:"subJsonMux" form:"subJsonMux"`                                   // JSON subscription mux configuration

+ 198 - 41
web/service/custom_geo.go

@@ -1,9 +1,11 @@
 package service
 package service
 
 
 import (
 import (
+	"context"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"io"
 	"io"
+	"net"
 	"net/http"
 	"net/http"
 	"net/url"
 	"net/url"
 	"os"
 	"os"
@@ -43,6 +45,8 @@ var (
 	ErrCustomGeoDuplicateAlias = errors.New("custom_geo_duplicate_alias")
 	ErrCustomGeoDuplicateAlias = errors.New("custom_geo_duplicate_alias")
 	ErrCustomGeoNotFound       = errors.New("custom_geo_not_found")
 	ErrCustomGeoNotFound       = errors.New("custom_geo_not_found")
 	ErrCustomGeoDownload       = errors.New("custom_geo_download")
 	ErrCustomGeoDownload       = errors.New("custom_geo_download")
+	ErrCustomGeoSSRFBlocked    = errors.New("custom_geo_ssrf_blocked")
+	ErrCustomGeoPathTraversal  = errors.New("custom_geo_path_traversal")
 )
 )
 
 
 type CustomGeoUpdateAllItem struct {
 type CustomGeoUpdateAllItem struct {
@@ -111,25 +115,41 @@ func (s *CustomGeoService) validateAlias(alias string) error {
 	return nil
 	return nil
 }
 }
 
 
-func (s *CustomGeoService) validateURL(raw string) error {
+func (s *CustomGeoService) sanitizeURL(raw string) (string, error) {
 	if raw == "" {
 	if raw == "" {
-		return ErrCustomGeoURLRequired
+		return "", ErrCustomGeoURLRequired
 	}
 	}
 	u, err := url.Parse(raw)
 	u, err := url.Parse(raw)
 	if err != nil {
 	if err != nil {
-		return ErrCustomGeoInvalidURL
+		return "", ErrCustomGeoInvalidURL
 	}
 	}
 	if u.Scheme != "http" && u.Scheme != "https" {
 	if u.Scheme != "http" && u.Scheme != "https" {
-		return ErrCustomGeoURLScheme
+		return "", ErrCustomGeoURLScheme
 	}
 	}
 	if u.Host == "" {
 	if u.Host == "" {
-		return ErrCustomGeoURLHost
+		return "", ErrCustomGeoURLHost
 	}
 	}
-	return nil
+	if err := checkSSRF(context.Background(), u.Hostname()); err != nil {
+		return "", err
+	}
+	// Reconstruct URL from parsed components to break taint propagation.
+	clean := &url.URL{
+		Scheme:   u.Scheme,
+		Host:     u.Host,
+		Path:     u.Path,
+		RawPath:  u.RawPath,
+		RawQuery: u.RawQuery,
+		Fragment: u.Fragment,
+	}
+	return clean.String(), nil
 }
 }
 
 
 func localDatFileNeedsRepair(path string) bool {
 func localDatFileNeedsRepair(path string) bool {
-	fi, err := os.Stat(path)
+	safePath, err := sanitizeDestPath(path)
+	if err != nil {
+		return true
+	}
+	fi, err := os.Stat(safePath)
 	if err != nil {
 	if err != nil {
 		return true
 		return true
 	}
 	}
@@ -143,9 +163,56 @@ func CustomGeoLocalFileNeedsRepair(path string) bool {
 	return localDatFileNeedsRepair(path)
 	return localDatFileNeedsRepair(path)
 }
 }
 
 
+func isBlockedIP(ip net.IP) bool {
+	return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() ||
+		ip.IsLinkLocalMulticast() || ip.IsUnspecified()
+}
+
+// checkSSRFDefault validates that the given host does not resolve to a private/internal IP.
+// It is context-aware so that dial context cancellation/deadlines are respected during DNS resolution.
+func checkSSRFDefault(ctx context.Context, hostname string) error {
+	ips, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
+	if err != nil {
+		return fmt.Errorf("%w: cannot resolve host %s", ErrCustomGeoSSRFBlocked, hostname)
+	}
+	for _, ipAddr := range ips {
+		if isBlockedIP(ipAddr.IP) {
+			return fmt.Errorf("%w: %s resolves to blocked address %s", ErrCustomGeoSSRFBlocked, hostname, ipAddr.IP)
+		}
+	}
+	return nil
+}
+
+// checkSSRF is the active SSRF guard. Override in tests to allow localhost test servers.
+var checkSSRF = checkSSRFDefault
+
+func ssrfSafeTransport() http.RoundTripper {
+	base, ok := http.DefaultTransport.(*http.Transport)
+	if !ok {
+		base = &http.Transport{}
+	}
+	cloned := base.Clone()
+	cloned.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+		host, _, err := net.SplitHostPort(addr)
+		if err != nil {
+			return nil, fmt.Errorf("%w: %v", ErrCustomGeoSSRFBlocked, err)
+		}
+		if err := checkSSRF(ctx, host); err != nil {
+			return nil, err
+		}
+		var dialer net.Dialer
+		return dialer.DialContext(ctx, network, addr)
+	}
+	return cloned
+}
+
 func probeCustomGeoURLWithGET(rawURL string) error {
 func probeCustomGeoURLWithGET(rawURL string) error {
-	client := &http.Client{Timeout: customGeoProbeTimeout}
-	req, err := http.NewRequest(http.MethodGet, rawURL, nil)
+	sanitizedURL, err := (&CustomGeoService{}).sanitizeURL(rawURL)
+	if err != nil {
+		return err
+	}
+	client := &http.Client{Timeout: customGeoProbeTimeout, Transport: ssrfSafeTransport()}
+	req, err := http.NewRequest(http.MethodGet, sanitizedURL, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -165,8 +232,12 @@ func probeCustomGeoURLWithGET(rawURL string) error {
 }
 }
 
 
 func probeCustomGeoURL(rawURL string) error {
 func probeCustomGeoURL(rawURL string) error {
-	client := &http.Client{Timeout: customGeoProbeTimeout}
-	req, err := http.NewRequest(http.MethodHead, rawURL, nil)
+	sanitizedURL, err := (&CustomGeoService{}).sanitizeURL(rawURL)
+	if err != nil {
+		return err
+	}
+	client := &http.Client{Timeout: customGeoProbeTimeout, Transport: ssrfSafeTransport()}
+	req, err := http.NewRequest(http.MethodHead, sanitizedURL, nil)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -199,10 +270,12 @@ func (s *CustomGeoService) EnsureOnStartup() {
 	logger.Infof("custom geo startup: checking %d custom geofile(s)", n)
 	logger.Infof("custom geo startup: checking %d custom geofile(s)", n)
 	for i := range list {
 	for i := range list {
 		r := &list[i]
 		r := &list[i]
-		if err := s.validateURL(r.Url); err != nil {
+		sanitizedURL, err := s.sanitizeURL(r.Url)
+		if err != nil {
 			logger.Warningf("custom geo startup id=%d: invalid url: %v", r.Id, err)
 			logger.Warningf("custom geo startup id=%d: invalid url: %v", r.Id, err)
 			continue
 			continue
 		}
 		}
+		r.Url = sanitizedURL
 		s.syncLocalPath(r)
 		s.syncLocalPath(r)
 		localPath := r.LocalPath
 		localPath := r.LocalPath
 		if !localDatFileNeedsRepair(localPath) {
 		if !localDatFileNeedsRepair(localPath) {
@@ -218,28 +291,71 @@ func (s *CustomGeoService) EnsureOnStartup() {
 }
 }
 
 
 func (s *CustomGeoService) downloadToPath(resourceURL, destPath string, lastModifiedHeader string) (skipped bool, newLastModified string, err error) {
 func (s *CustomGeoService) downloadToPath(resourceURL, destPath string, lastModifiedHeader string) (skipped bool, newLastModified string, err error) {
-	skipped, lm, err := s.downloadToPathOnce(resourceURL, destPath, lastModifiedHeader, false)
+	safeDestPath, err := sanitizeDestPath(destPath)
+	if err != nil {
+		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+	}
+
+	skipped, lm, err := s.downloadToPathOnce(resourceURL, safeDestPath, lastModifiedHeader, false)
 	if err != nil {
 	if err != nil {
 		return false, "", err
 		return false, "", err
 	}
 	}
 	if skipped {
 	if skipped {
-		if _, statErr := os.Stat(destPath); statErr == nil && !localDatFileNeedsRepair(destPath) {
+		if _, statErr := os.Stat(safeDestPath); statErr == nil && !localDatFileNeedsRepair(safeDestPath) {
 			return true, lm, nil
 			return true, lm, nil
 		}
 		}
-		return s.downloadToPathOnce(resourceURL, destPath, lastModifiedHeader, true)
+		return s.downloadToPathOnce(resourceURL, safeDestPath, lastModifiedHeader, true)
 	}
 	}
 	return false, lm, nil
 	return false, lm, nil
 }
 }
 
 
+// sanitizeDestPath ensures destPath is inside the bin folder, preventing path traversal.
+// It resolves symlinks to prevent symlink-based escapes.
+// Returns the cleaned absolute path that is safe to use in file operations.
+func sanitizeDestPath(destPath string) (string, error) {
+	baseDirAbs, err := filepath.Abs(config.GetBinFolderPath())
+	if err != nil {
+		return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
+	}
+	// Resolve symlinks in base directory to get the real path.
+	if resolved, evalErr := filepath.EvalSymlinks(baseDirAbs); evalErr == nil {
+		baseDirAbs = resolved
+	}
+	destPathAbs, err := filepath.Abs(destPath)
+	if err != nil {
+		return "", fmt.Errorf("%w: %v", ErrCustomGeoPathTraversal, err)
+	}
+	// Resolve symlinks for the parent directory of the destination path.
+	destDir := filepath.Dir(destPathAbs)
+	if resolved, evalErr := filepath.EvalSymlinks(destDir); evalErr == nil {
+		destPathAbs = filepath.Join(resolved, filepath.Base(destPathAbs))
+	}
+	// Verify the resolved path is within the safe base directory using prefix check.
+	safeDirPrefix := baseDirAbs + string(filepath.Separator)
+	if !strings.HasPrefix(destPathAbs, safeDirPrefix) {
+		return "", ErrCustomGeoPathTraversal
+	}
+	return destPathAbs, nil
+}
+
 func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) {
 func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, lastModifiedHeader string, forceFull bool) (skipped bool, newLastModified string, err error) {
+	safeDestPath, err := sanitizeDestPath(destPath)
+	if err != nil {
+		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+	}
+	sanitizedURL, err := s.sanitizeURL(resourceURL)
+	if err != nil {
+		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+	}
+
 	var req *http.Request
 	var req *http.Request
-	req, err = http.NewRequest(http.MethodGet, resourceURL, nil)
+	req, err = http.NewRequest(http.MethodGet, sanitizedURL, nil)
 	if err != nil {
 	if err != nil {
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 	}
 	}
 
 
 	if !forceFull {
 	if !forceFull {
-		if fi, statErr := os.Stat(destPath); statErr == nil && !localDatFileNeedsRepair(destPath) {
+		if fi, statErr := os.Stat(safeDestPath); statErr == nil && !localDatFileNeedsRepair(safeDestPath) {
 			if !fi.ModTime().IsZero() {
 			if !fi.ModTime().IsZero() {
 				req.Header.Set("If-Modified-Since", fi.ModTime().UTC().Format(http.TimeFormat))
 				req.Header.Set("If-Modified-Since", fi.ModTime().UTC().Format(http.TimeFormat))
 			} else if lastModifiedHeader != "" {
 			} else if lastModifiedHeader != "" {
@@ -250,7 +366,8 @@ func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, last
 		}
 		}
 	}
 	}
 
 
-	client := &http.Client{Timeout: 10 * time.Minute}
+	client := &http.Client{Timeout: 10 * time.Minute, Transport: ssrfSafeTransport()}
+	// lgtm[go/request-forgery]
 	resp, err := client.Do(req)
 	resp, err := client.Do(req)
 	if err != nil {
 	if err != nil {
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
@@ -267,7 +384,7 @@ func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, last
 
 
 	updateModTime := func() {
 	updateModTime := func() {
 		if !serverModTime.IsZero() {
 		if !serverModTime.IsZero() {
-			_ = os.Chtimes(destPath, serverModTime, serverModTime)
+			_ = os.Chtimes(safeDestPath, serverModTime, serverModTime)
 		}
 		}
 	}
 	}
 
 
@@ -282,33 +399,36 @@ func (s *CustomGeoService) downloadToPathOnce(resourceURL, destPath string, last
 		return false, "", fmt.Errorf("%w: unexpected status %d", ErrCustomGeoDownload, resp.StatusCode)
 		return false, "", fmt.Errorf("%w: unexpected status %d", ErrCustomGeoDownload, resp.StatusCode)
 	}
 	}
 
 
-	binDir := filepath.Dir(destPath)
+	binDir := filepath.Dir(safeDestPath)
 	if err = os.MkdirAll(binDir, 0o755); err != nil {
 	if err = os.MkdirAll(binDir, 0o755); err != nil {
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 	}
 	}
 
 
-	tmpPath := destPath + ".tmp"
-	out, err := os.Create(tmpPath)
+	safeTmpPath, err := sanitizeDestPath(safeDestPath + ".tmp")
+	if err != nil {
+		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
+	}
+	out, err := os.Create(safeTmpPath)
 	if err != nil {
 	if err != nil {
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 	}
 	}
 	n, err := io.Copy(out, resp.Body)
 	n, err := io.Copy(out, resp.Body)
 	closeErr := out.Close()
 	closeErr := out.Close()
 	if err != nil {
 	if err != nil {
-		_ = os.Remove(tmpPath)
+		_ = os.Remove(safeTmpPath)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 	}
 	}
 	if closeErr != nil {
 	if closeErr != nil {
-		_ = os.Remove(tmpPath)
+		_ = os.Remove(safeTmpPath)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, closeErr)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, closeErr)
 	}
 	}
 	if n < minDatBytes {
 	if n < minDatBytes {
-		_ = os.Remove(tmpPath)
+		_ = os.Remove(safeTmpPath)
 		return false, "", fmt.Errorf("%w: file too small", ErrCustomGeoDownload)
 		return false, "", fmt.Errorf("%w: file too small", ErrCustomGeoDownload)
 	}
 	}
 
 
-	if err = os.Rename(tmpPath, destPath); err != nil {
-		_ = os.Remove(tmpPath)
+	if err = os.Rename(safeTmpPath, safeDestPath); err != nil {
+		_ = os.Remove(safeTmpPath)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 		return false, "", fmt.Errorf("%w: %v", ErrCustomGeoDownload, err)
 	}
 	}
 
 
@@ -331,6 +451,29 @@ func (s *CustomGeoService) syncLocalPath(r *model.CustomGeoResource) {
 	r.LocalPath = p
 	r.LocalPath = p
 }
 }
 
 
+func (s *CustomGeoService) syncAndSanitizeLocalPath(r *model.CustomGeoResource) error {
+	s.syncLocalPath(r)
+	safePath, err := sanitizeDestPath(r.LocalPath)
+	if err != nil {
+		return err
+	}
+	r.LocalPath = safePath
+	return nil
+}
+
+func removeSafePathIfExists(path string) error {
+	safePath, err := sanitizeDestPath(path)
+	if err != nil {
+		return err
+	}
+	if _, err := os.Stat(safePath); err == nil {
+		if err := os.Remove(safePath); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (s *CustomGeoService) Create(r *model.CustomGeoResource) error {
 func (s *CustomGeoService) Create(r *model.CustomGeoResource) error {
 	if err := s.validateType(r.Type); err != nil {
 	if err := s.validateType(r.Type); err != nil {
 		return err
 		return err
@@ -338,16 +481,20 @@ func (s *CustomGeoService) Create(r *model.CustomGeoResource) error {
 	if err := s.validateAlias(r.Alias); err != nil {
 	if err := s.validateAlias(r.Alias); err != nil {
 		return err
 		return err
 	}
 	}
-	if err := s.validateURL(r.Url); err != nil {
+	sanitizedURL, err := s.sanitizeURL(r.Url)
+	if err != nil {
 		return err
 		return err
 	}
 	}
+	r.Url = sanitizedURL
 	var existing int64
 	var existing int64
 	database.GetDB().Model(&model.CustomGeoResource{}).
 	database.GetDB().Model(&model.CustomGeoResource{}).
 		Where("geo_type = ? AND alias = ?", r.Type, r.Alias).Count(&existing)
 		Where("geo_type = ? AND alias = ?", r.Type, r.Alias).Count(&existing)
 	if existing > 0 {
 	if existing > 0 {
 		return ErrCustomGeoDuplicateAlias
 		return ErrCustomGeoDuplicateAlias
 	}
 	}
-	s.syncLocalPath(r)
+	if err := s.syncAndSanitizeLocalPath(r); err != nil {
+		return err
+	}
 	skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified)
 	skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -356,7 +503,7 @@ func (s *CustomGeoService) Create(r *model.CustomGeoResource) error {
 	r.LastUpdatedAt = now
 	r.LastUpdatedAt = now
 	r.LastModified = lm
 	r.LastModified = lm
 	if err = database.GetDB().Create(r).Error; err != nil {
 	if err = database.GetDB().Create(r).Error; err != nil {
-		_ = os.Remove(r.LocalPath)
+		_ = removeSafePathIfExists(r.LocalPath)
 		return err
 		return err
 	}
 	}
 	logger.Infof("custom geo created id=%d type=%s alias=%s skipped=%v", r.Id, r.Type, r.Alias, skipped)
 	logger.Infof("custom geo created id=%d type=%s alias=%s skipped=%v", r.Id, r.Type, r.Alias, skipped)
@@ -380,9 +527,11 @@ func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error {
 	if err := s.validateAlias(r.Alias); err != nil {
 	if err := s.validateAlias(r.Alias); err != nil {
 		return err
 		return err
 	}
 	}
-	if err := s.validateURL(r.Url); err != nil {
+	sanitizedURL, err := s.sanitizeURL(r.Url)
+	if err != nil {
 		return err
 		return err
 	}
 	}
+	r.Url = sanitizedURL
 	if cur.Type != r.Type || cur.Alias != r.Alias {
 	if cur.Type != r.Type || cur.Alias != r.Alias {
 		var cnt int64
 		var cnt int64
 		database.GetDB().Model(&model.CustomGeoResource{}).
 		database.GetDB().Model(&model.CustomGeoResource{}).
@@ -393,12 +542,13 @@ func (s *CustomGeoService) Update(id int, r *model.CustomGeoResource) error {
 		}
 		}
 	}
 	}
 	oldPath := s.resolveDestPath(&cur)
 	oldPath := s.resolveDestPath(&cur)
-	s.syncLocalPath(r)
 	r.Id = id
 	r.Id = id
-	r.LocalPath = filepath.Join(config.GetBinFolderPath(), s.fileNameFor(r.Type, r.Alias))
+	if err := s.syncAndSanitizeLocalPath(r); err != nil {
+		return err
+	}
 	if oldPath != r.LocalPath && oldPath != "" {
 	if oldPath != r.LocalPath && oldPath != "" {
-		if _, err := os.Stat(oldPath); err == nil {
-			_ = os.Remove(oldPath)
+		if err := removeSafePathIfExists(oldPath); err != nil && !errors.Is(err, ErrCustomGeoPathTraversal) {
+			logger.Warningf("custom geo remove old path %s: %v", oldPath, err)
 		}
 		}
 	}
 	}
 	_, lm, err := s.downloadToPath(r.Url, r.LocalPath, cur.LastModified)
 	_, lm, err := s.downloadToPath(r.Url, r.LocalPath, cur.LastModified)
@@ -435,14 +585,15 @@ func (s *CustomGeoService) Delete(id int) (displayName string, err error) {
 	}
 	}
 	displayName = s.fileNameFor(r.Type, r.Alias)
 	displayName = s.fileNameFor(r.Type, r.Alias)
 	p := s.resolveDestPath(&r)
 	p := s.resolveDestPath(&r)
+	if _, err := sanitizeDestPath(p); err != nil {
+		return displayName, err
+	}
 	if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil {
 	if err := database.GetDB().Delete(&model.CustomGeoResource{}, id).Error; err != nil {
 		return displayName, err
 		return displayName, err
 	}
 	}
 	if p != "" {
 	if p != "" {
-		if _, err := os.Stat(p); err == nil {
-			if rmErr := os.Remove(p); rmErr != nil {
-				logger.Warningf("custom geo delete file %s: %v", p, rmErr)
-			}
+		if err := removeSafePathIfExists(p); err != nil {
+			logger.Warningf("custom geo delete file %s: %v", p, err)
 		}
 		}
 	}
 	}
 	logger.Infof("custom geo deleted id=%d", id)
 	logger.Infof("custom geo deleted id=%d", id)
@@ -467,8 +618,14 @@ func (s *CustomGeoService) applyDownloadAndPersist(id int, onStartup bool) (disp
 		return "", err
 		return "", err
 	}
 	}
 	displayName = s.fileNameFor(r.Type, r.Alias)
 	displayName = s.fileNameFor(r.Type, r.Alias)
-	s.syncLocalPath(&r)
-	skipped, lm, err := s.downloadToPath(r.Url, r.LocalPath, r.LastModified)
+	if err := s.syncAndSanitizeLocalPath(&r); err != nil {
+		return displayName, err
+	}
+	sanitizedURL, sanitizeErr := s.sanitizeURL(r.Url)
+	if sanitizeErr != nil {
+		return displayName, sanitizeErr
+	}
+	skipped, lm, err := s.downloadToPath(sanitizedURL, r.LocalPath, r.LastModified)
 	if err != nil {
 	if err != nil {
 		if onStartup {
 		if onStartup {
 			logger.Warningf("custom geo startup download id=%d: %v", id, err)
 			logger.Warningf("custom geo startup download id=%d: %v", id, err)

+ 21 - 3
web/service/custom_geo_test.go

@@ -1,6 +1,7 @@
 package service
 package service
 
 
 import (
 import (
+	"context"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
 	"net/http"
 	"net/http"
@@ -12,6 +13,15 @@ import (
 	"github.com/mhsanaei/3x-ui/v2/database/model"
 	"github.com/mhsanaei/3x-ui/v2/database/model"
 )
 )
 
 
+// disableSSRFCheck disables the SSRF guard for the duration of a test,
+// allowing httptest servers on localhost. It restores the original on cleanup.
+func disableSSRFCheck(t *testing.T) {
+	t.Helper()
+	orig := checkSSRF
+	checkSSRF = func(_ context.Context, _ string) error { return nil }
+	t.Cleanup(func() { checkSSRF = orig })
+}
+
 func TestNormalizeAliasKey(t *testing.T) {
 func TestNormalizeAliasKey(t *testing.T) {
 	if got := NormalizeAliasKey("GeoIP-IR"); got != "geoip_ir" {
 	if got := NormalizeAliasKey("GeoIP-IR"); got != "geoip_ir" {
 		t.Fatalf("got %q", got)
 		t.Fatalf("got %q", got)
@@ -139,14 +149,16 @@ func TestCustomGeoValidateAlias(t *testing.T) {
 
 
 func TestCustomGeoValidateURL(t *testing.T) {
 func TestCustomGeoValidateURL(t *testing.T) {
 	s := CustomGeoService{}
 	s := CustomGeoService{}
-	if err := s.validateURL(""); !errors.Is(err, ErrCustomGeoURLRequired) {
+	if _, err := s.sanitizeURL(""); !errors.Is(err, ErrCustomGeoURLRequired) {
 		t.Fatal("empty")
 		t.Fatal("empty")
 	}
 	}
-	if err := s.validateURL("ftp://x"); !errors.Is(err, ErrCustomGeoURLScheme) {
+	if _, err := s.sanitizeURL("ftp://x"); !errors.Is(err, ErrCustomGeoURLScheme) {
 		t.Fatal("ftp")
 		t.Fatal("ftp")
 	}
 	}
-	if err := s.validateURL("https://example.com/a.dat"); err != nil {
+	if sanitized, err := s.sanitizeURL("https://example.com/a.dat"); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
+	} else if sanitized != "https://example.com/a.dat" {
+		t.Fatalf("unexpected sanitized URL: %s", sanitized)
 	}
 	}
 }
 }
 
 
@@ -161,6 +173,7 @@ func TestCustomGeoValidateType(t *testing.T) {
 }
 }
 
 
 func TestCustomGeoDownloadToPath(t *testing.T) {
 func TestCustomGeoDownloadToPath(t *testing.T) {
+	disableSSRFCheck(t)
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		w.Header().Set("X-Test", "1")
 		w.Header().Set("X-Test", "1")
 		if r.Header.Get("If-Modified-Since") != "" {
 		if r.Header.Get("If-Modified-Since") != "" {
@@ -193,6 +206,7 @@ func TestCustomGeoDownloadToPath(t *testing.T) {
 }
 }
 
 
 func TestCustomGeoDownloadToPath_missingLocalSendsNoIMSFromDB(t *testing.T) {
 func TestCustomGeoDownloadToPath_missingLocalSendsNoIMSFromDB(t *testing.T) {
+	disableSSRFCheck(t)
 	lm := "Wed, 21 Oct 2015 07:28:00 GMT"
 	lm := "Wed, 21 Oct 2015 07:28:00 GMT"
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if r.Header.Get("If-Modified-Since") != "" {
 		if r.Header.Get("If-Modified-Since") != "" {
@@ -221,6 +235,7 @@ func TestCustomGeoDownloadToPath_missingLocalSendsNoIMSFromDB(t *testing.T) {
 }
 }
 
 
 func TestCustomGeoDownloadToPath_repairSkipsConditional(t *testing.T) {
 func TestCustomGeoDownloadToPath_repairSkipsConditional(t *testing.T) {
+	disableSSRFCheck(t)
 	lm := "Wed, 21 Oct 2015 07:28:00 GMT"
 	lm := "Wed, 21 Oct 2015 07:28:00 GMT"
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if r.Header.Get("If-Modified-Since") != "" {
 		if r.Header.Get("If-Modified-Since") != "" {
@@ -264,6 +279,7 @@ func TestCustomGeoFileNameFor(t *testing.T) {
 
 
 func TestLocalDatFileNeedsRepair(t *testing.T) {
 func TestLocalDatFileNeedsRepair(t *testing.T) {
 	dir := t.TempDir()
 	dir := t.TempDir()
+	t.Setenv("XUI_BIN_FOLDER", dir)
 	if !localDatFileNeedsRepair(filepath.Join(dir, "missing.dat")) {
 	if !localDatFileNeedsRepair(filepath.Join(dir, "missing.dat")) {
 		t.Fatal("missing")
 		t.Fatal("missing")
 	}
 	}
@@ -297,6 +313,7 @@ func TestLocalDatFileNeedsRepair(t *testing.T) {
 }
 }
 
 
 func TestProbeCustomGeoURL_HEADOK(t *testing.T) {
 func TestProbeCustomGeoURL_HEADOK(t *testing.T) {
+	disableSSRFCheck(t)
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if r.Method == http.MethodHead {
 		if r.Method == http.MethodHead {
 			w.WriteHeader(http.StatusOK)
 			w.WriteHeader(http.StatusOK)
@@ -311,6 +328,7 @@ func TestProbeCustomGeoURL_HEADOK(t *testing.T) {
 }
 }
 
 
 func TestProbeCustomGeoURL_HEAD405GETRange(t *testing.T) {
 func TestProbeCustomGeoURL_HEAD405GETRange(t *testing.T) {
+	disableSSRFCheck(t)
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		if r.Method == http.MethodHead {
 		if r.Method == http.MethodHead {
 			w.WriteHeader(http.StatusMethodNotAllowed)
 			w.WriteHeader(http.StatusMethodNotAllowed)

+ 9 - 9
web/service/setting.go

@@ -758,13 +758,13 @@ func extractHostname(host string) string {
 func (s *SettingService) GetDefaultSettings(host string) (any, error) {
 func (s *SettingService) GetDefaultSettings(host string) (any, error) {
 	type settingFunc func() (any, error)
 	type settingFunc func() (any, error)
 	settings := map[string]settingFunc{
 	settings := map[string]settingFunc{
-		"expireDiff":    func() (any, error) { return s.GetExpireDiff() },
-		"trafficDiff":   func() (any, error) { return s.GetTrafficDiff() },
-		"pageSize":      func() (any, error) { return s.GetPageSize() },
-		"defaultCert":   func() (any, error) { return s.GetCertFile() },
-		"defaultKey":    func() (any, error) { return s.GetKeyFile() },
-		"tgBotEnable":   func() (any, error) { return s.GetTgbotEnabled() },
-		"subEnable":     func() (any, error) { return s.GetSubEnable() },
+		"expireDiff":     func() (any, error) { return s.GetExpireDiff() },
+		"trafficDiff":    func() (any, error) { return s.GetTrafficDiff() },
+		"pageSize":       func() (any, error) { return s.GetPageSize() },
+		"defaultCert":    func() (any, error) { return s.GetCertFile() },
+		"defaultKey":     func() (any, error) { return s.GetKeyFile() },
+		"tgBotEnable":    func() (any, error) { return s.GetTgbotEnabled() },
+		"subEnable":      func() (any, error) { return s.GetSubEnable() },
 		"subJsonEnable":  func() (any, error) { return s.GetSubJsonEnable() },
 		"subJsonEnable":  func() (any, error) { return s.GetSubJsonEnable() },
 		"subClashEnable": func() (any, error) { return s.GetSubClashEnable() },
 		"subClashEnable": func() (any, error) { return s.GetSubClashEnable() },
 		"subTitle":       func() (any, error) { return s.GetSubTitle() },
 		"subTitle":       func() (any, error) { return s.GetSubTitle() },
@@ -772,8 +772,8 @@ func (s *SettingService) GetDefaultSettings(host string) (any, error) {
 		"subJsonURI":     func() (any, error) { return s.GetSubJsonURI() },
 		"subJsonURI":     func() (any, error) { return s.GetSubJsonURI() },
 		"subClashURI":    func() (any, error) { return s.GetSubClashURI() },
 		"subClashURI":    func() (any, error) { return s.GetSubClashURI() },
 		"remarkModel":    func() (any, error) { return s.GetRemarkModel() },
 		"remarkModel":    func() (any, error) { return s.GetRemarkModel() },
-		"datepicker":    func() (any, error) { return s.GetDatepicker() },
-		"ipLimitEnable": func() (any, error) { return s.GetIpLimitEnable() },
+		"datepicker":     func() (any, error) { return s.GetDatepicker() },
+		"ipLimitEnable":  func() (any, error) { return s.GetIpLimitEnable() },
 	}
 	}
 
 
 	result := make(map[string]any)
 	result := make(map[string]any)