Prechádzať zdrojové kódy

fix(sub): error instead of silently truncating oversized subscription (#5495)

The external subscription fetcher read the remote body with a plain
io.LimitReader, silently truncating at 2 MiB and decoding whatever
prefix arrived (possibly a half share link). Detect the overflow with
the established N+1 pattern and return an error so the caller serves the
last cached value instead of a corrupted partial list.

Co-authored-by: Sanaei <[email protected]>
n0ctal 1 deň pred
rodič
commit
67344cae6f

+ 8 - 2
internal/sub/external_subscription.go

@@ -78,14 +78,20 @@ func doFetchSubscriptionLinks(rawURL string) ([]string, error) {
 	if resp.StatusCode < 200 || resp.StatusCode >= 300 {
 		return nil, errBadStatus
 	}
-	body, err := io.ReadAll(io.LimitReader(resp.Body, subscriptionMaxBytes))
+	body, err := io.ReadAll(io.LimitReader(resp.Body, subscriptionMaxBytes+1))
 	if err != nil {
 		return nil, err
 	}
+	if len(body) > subscriptionMaxBytes {
+		return nil, errSubscriptionBodyTooLarge
+	}
 	return decodeSubscriptionBody(body), nil
 }
 
-var errBadStatus = &subError{"non-2xx subscription response"}
+var (
+	errBadStatus                = &subError{"non-2xx subscription response"}
+	errSubscriptionBodyTooLarge = &subError{"subscription response body exceeds size limit"}
+)
 
 type subError struct{ msg string }
 

+ 43 - 0
internal/sub/external_subscription_test.go

@@ -0,0 +1,43 @@
+package sub
+
+import (
+	"net/http"
+	"net/http/httptest"
+	"strings"
+	"testing"
+)
+
+func TestDoFetchSubscriptionLinks_RejectsOversizedBody(t *testing.T) {
+	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		_, _ = w.Write([]byte(strings.Repeat("a", subscriptionMaxBytes+1)))
+	}))
+	defer srv.Close()
+
+	links, err := doFetchSubscriptionLinks(srv.URL)
+	if err != errSubscriptionBodyTooLarge {
+		t.Fatalf("err = %v, want errSubscriptionBodyTooLarge", err)
+	}
+	if links != nil {
+		t.Fatalf("links = %v, want nil", links)
+	}
+}
+
+func TestDoFetchSubscriptionLinks_AcceptsBodyAtLimit(t *testing.T) {
+	link := "vless://example"
+	body := link + "\n" + strings.Repeat("#", subscriptionMaxBytes-len(link)-1)
+	if len(body) != subscriptionMaxBytes {
+		t.Fatalf("fixture size = %d, want %d", len(body), subscriptionMaxBytes)
+	}
+	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		_, _ = w.Write([]byte(body))
+	}))
+	defer srv.Close()
+
+	links, err := doFetchSubscriptionLinks(srv.URL)
+	if err != nil {
+		t.Fatalf("unexpected err: %v", err)
+	}
+	if len(links) != 1 || links[0] != link {
+		t.Fatalf("links = %v, want [%q]", links, link)
+	}
+}