瀏覽代碼

iplimit: dont count idle db-only ips toward the per-client limit

after #4083 the staleness window is 30 minutes, which still lets an ip
that stopped connecting a few minutes ago sit in the db blob and keep
the protected slot on the ascending sort. the ip that is actually
connecting right now gets classified as excess and sent to fail2ban,
and never lands in inbound_client_ips.ips so the panel doesnt show it
until you clear the log by hand.

only count ips observed in the current scan toward the limit. db-only
entries stay in the blob for display but dont participate in the ban
decision. live subset still uses the "protect oldest, ban newcomer"
rule.

closes #4091. followup to #4077.
pwnnex 2 天之前
父節點
當前提交
5f7c7c5f3d
共有 3 個文件被更改,包括 380 次插入28 次删除
  1. 61 28
      web/job/check_client_ip_job.go
  2. 250 0
      web/job/check_client_ip_job_integration_test.go
  3. 69 0
      web/job/check_client_ip_job_test.go

+ 61 - 28
web/job/check_client_ip_job.go

@@ -246,6 +246,37 @@ func mergeClientIps(old, new []IPWithTimestamp, staleCutoff int64) map[string]in
 	return ipMap
 }
 
+// partitionLiveIps splits the merged ip map into live (seen in the
+// current scan) and historical (only in the db blob, still inside the
+// staleness window).
+//
+// only live ips count toward the per-client limit. historical ones stay
+// in the db so the panel keeps showing them, but they must not take a
+// protected slot. the 30min cutoff alone isn't tight enough: an ip that
+// stopped connecting a few minutes ago still looks fresh to
+// mergeClientIps, and since the over-limit picker sorts ascending and
+// keeps the oldest, those idle entries used to win the slot while the
+// ip actually connecting got classified as excess and sent to fail2ban
+// every tick. see #4077 / #4091.
+//
+// live is sorted ascending so the "protect original, ban newcomer"
+// rule still holds when several ips are really connecting at once.
+func partitionLiveIps(ipMap map[string]int64, observedThisScan map[string]bool) (live, historical []IPWithTimestamp) {
+	live = make([]IPWithTimestamp, 0, len(observedThisScan))
+	historical = make([]IPWithTimestamp, 0, len(ipMap))
+	for ip, ts := range ipMap {
+		entry := IPWithTimestamp{IP: ip, Timestamp: ts}
+		if observedThisScan[ip] {
+			live = append(live, entry)
+		} else {
+			historical = append(historical, entry)
+		}
+	}
+	sort.Slice(live, func(i, j int) bool { return live[i].Timestamp < live[j].Timestamp })
+	sort.Slice(historical, func(i, j int) bool { return historical[i].Timestamp < historical[j].Timestamp })
+	return live, historical
+}
+
 func (j *CheckClientIpJob) checkFail2BanInstalled() bool {
 	cmd := "fail2ban-client"
 	args := []string{"-h"}
@@ -358,15 +389,13 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
 	// re-observed in a while. See mergeClientIps / #4077 for why.
 	ipMap := mergeClientIps(oldIpsWithTime, newIpsWithTime, time.Now().Unix()-ipStaleAfterSeconds)
 
-	// Convert back to slice and sort by timestamp (oldest first)
-	// This ensures we always protect the original/current connections and ban new excess ones.
-	allIps := make([]IPWithTimestamp, 0, len(ipMap))
-	for ip, timestamp := range ipMap {
-		allIps = append(allIps, IPWithTimestamp{IP: ip, Timestamp: timestamp})
+	// only ips seen in this scan count toward the limit. see
+	// partitionLiveIps.
+	observedThisScan := make(map[string]bool, len(newIpsWithTime))
+	for _, ipTime := range newIpsWithTime {
+		observedThisScan[ipTime.IP] = true
 	}
-	sort.Slice(allIps, func(i, j int) bool {
-		return allIps[i].Timestamp < allIps[j].Timestamp // Ascending order (oldest first)
-	})
+	liveIps, historicalIps := partitionLiveIps(ipMap, observedThisScan)
 
 	shouldCleanLog := false
 	j.disAllowedIps = []string{}
@@ -381,35 +410,39 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
 	log.SetOutput(logIpFile)
 	log.SetFlags(log.LstdFlags)
 
-	// Check if we exceed the limit
-	if len(allIps) > limitIp {
+	// historical db-only ips are excluded from this count on purpose.
+	var keptLive []IPWithTimestamp
+	if len(liveIps) > limitIp {
 		shouldCleanLog = true
 
-		// Keep the oldest IPs (currently active connections) and ban the new excess ones.
-		keptIps := allIps[:limitIp]
-		bannedIps := allIps[limitIp:]
+		// protect the oldest live ip, ban newcomers.
+		keptLive = liveIps[:limitIp]
+		bannedLive := liveIps[limitIp:]
 
-		// Log banned IPs in the format fail2ban filters expect: [LIMIT_IP] Email = X || Disconnecting OLD IP = Y || Timestamp = Z
-		for _, ipTime := range bannedIps {
+		// log format is load-bearing: x-ui.sh create_iplimit_jails builds
+		// filter.d/3x-ipl.conf with
+		//   failregex = \[LIMIT_IP\]\s*Email\s*=\s*<F-USER>.+</F-USER>\s*\|\|\s*Disconnecting OLD IP\s*=\s*<ADDR>\s*\|\|\s*Timestamp\s*=\s*\d+
+		// don't change the wording.
+		for _, ipTime := range bannedLive {
 			j.disAllowedIps = append(j.disAllowedIps, ipTime.IP)
 			log.Printf("[LIMIT_IP] Email = %s || Disconnecting OLD IP = %s || Timestamp = %d", clientEmail, ipTime.IP, ipTime.Timestamp)
 		}
 
-		// Actually disconnect banned IPs by temporarily removing and re-adding user
-		// This forces Xray to drop existing connections from banned IPs
-		if len(bannedIps) > 0 {
-			j.disconnectClientTemporarily(inbound, clientEmail, clients)
-		}
-
-		// Update database with only the currently active (kept) IPs
-		jsonIps, _ := json.Marshal(keptIps)
-		inboundClientIps.Ips = string(jsonIps)
+		// force xray to drop existing connections from banned ips
+		j.disconnectClientTemporarily(inbound, clientEmail, clients)
 	} else {
-		// Under limit, save all IPs
-		jsonIps, _ := json.Marshal(allIps)
-		inboundClientIps.Ips = string(jsonIps)
+		keptLive = liveIps
 	}
 
+	// keep kept-live + historical in the blob so the panel keeps showing
+	// recently seen ips. banned live ips are already in the fail2ban log
+	// and will reappear in the next scan if they reconnect.
+	dbIps := make([]IPWithTimestamp, 0, len(keptLive)+len(historicalIps))
+	dbIps = append(dbIps, keptLive...)
+	dbIps = append(dbIps, historicalIps...)
+	jsonIps, _ := json.Marshal(dbIps)
+	inboundClientIps.Ips = string(jsonIps)
+
 	db := database.GetDB()
 	err = db.Save(inboundClientIps).Error
 	if err != nil {
@@ -418,7 +451,7 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
 	}
 
 	if len(j.disAllowedIps) > 0 {
-		logger.Infof("[LIMIT_IP] Client %s: Kept %d current IPs, queued %d new IPs for fail2ban", clientEmail, limitIp, len(j.disAllowedIps))
+		logger.Infof("[LIMIT_IP] Client %s: Kept %d live IPs, queued %d new IPs for fail2ban", clientEmail, len(keptLive), len(j.disAllowedIps))
 	}
 
 	return shouldCleanLog

+ 250 - 0
web/job/check_client_ip_job_integration_test.go

@@ -0,0 +1,250 @@
+package job
+
+import (
+	"encoding/json"
+	"log"
+	"os"
+	"path/filepath"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/mhsanaei/3x-ui/v2/database"
+	"github.com/mhsanaei/3x-ui/v2/database/model"
+	xuilogger "github.com/mhsanaei/3x-ui/v2/logger"
+	"github.com/op/go-logging"
+)
+
+// 3x-ui logger must be initialised once before any code path that can
+// log a warning. otherwise log.Warningf panics on a nil logger.
+var loggerInitOnce sync.Once
+
+// setupIntegrationDB wires a temp sqlite db and log folder so
+// updateInboundClientIps can run end to end. closes the db before
+// TempDir cleanup so windows doesn't complain about the file being in
+// use.
+func setupIntegrationDB(t *testing.T) {
+	t.Helper()
+
+	loggerInitOnce.Do(func() {
+		xuilogger.InitLogger(logging.ERROR)
+	})
+
+	dbDir := t.TempDir()
+	logDir := t.TempDir()
+
+	t.Setenv("XUI_DB_FOLDER", dbDir)
+	t.Setenv("XUI_LOG_FOLDER", logDir)
+
+	// updateInboundClientIps calls log.SetOutput on the package global,
+	// which would leak to other tests in the same binary.
+	origLogWriter := log.Writer()
+	origLogFlags := log.Flags()
+	t.Cleanup(func() {
+		log.SetOutput(origLogWriter)
+		log.SetFlags(origLogFlags)
+	})
+
+	if err := database.InitDB(filepath.Join(dbDir, "3x-ui.db")); err != nil {
+		t.Fatalf("database.InitDB failed: %v", err)
+	}
+	// LIFO cleanup order: this runs before t.TempDir's own cleanup.
+	t.Cleanup(func() {
+		if err := database.CloseDB(); err != nil {
+			t.Logf("database.CloseDB warning: %v", err)
+		}
+	})
+}
+
+// seed an inbound whose settings json has a single client with the
+// given email and ip limit.
+func seedInboundWithClient(t *testing.T, tag, email string, limitIp int) {
+	t.Helper()
+	settings := map[string]any{
+		"clients": []map[string]any{
+			{
+				"email":   email,
+				"limitIp": limitIp,
+				"enable":  true,
+			},
+		},
+	}
+	settingsJSON, err := json.Marshal(settings)
+	if err != nil {
+		t.Fatalf("marshal settings: %v", err)
+	}
+	inbound := &model.Inbound{
+		Tag:      tag,
+		Enable:   true,
+		Protocol: model.VLESS,
+		Port:     4321,
+		Settings: string(settingsJSON),
+	}
+	if err := database.GetDB().Create(inbound).Error; err != nil {
+		t.Fatalf("seed inbound: %v", err)
+	}
+}
+
+// seed an InboundClientIps row with the given blob.
+func seedClientIps(t *testing.T, email string, ips []IPWithTimestamp) *model.InboundClientIps {
+	t.Helper()
+	blob, err := json.Marshal(ips)
+	if err != nil {
+		t.Fatalf("marshal ips: %v", err)
+	}
+	row := &model.InboundClientIps{
+		ClientEmail: email,
+		Ips:         string(blob),
+	}
+	if err := database.GetDB().Create(row).Error; err != nil {
+		t.Fatalf("seed InboundClientIps: %v", err)
+	}
+	return row
+}
+
+// read the persisted blob and parse it back.
+func readClientIps(t *testing.T, email string) []IPWithTimestamp {
+	t.Helper()
+	row := &model.InboundClientIps{}
+	if err := database.GetDB().Where("client_email = ?", email).First(row).Error; err != nil {
+		t.Fatalf("read InboundClientIps for %s: %v", email, err)
+	}
+	if row.Ips == "" {
+		return nil
+	}
+	var out []IPWithTimestamp
+	if err := json.Unmarshal([]byte(row.Ips), &out); err != nil {
+		t.Fatalf("unmarshal Ips blob %q: %v", row.Ips, err)
+	}
+	return out
+}
+
+// make a lookup map so asserts don't depend on slice order.
+func ipSet(entries []IPWithTimestamp) map[string]int64 {
+	out := make(map[string]int64, len(entries))
+	for _, e := range entries {
+		out[e.IP] = e.Timestamp
+	}
+	return out
+}
+
+// #4091 repro: client has limit=3, db still holds 3 idle ips from a
+// few minutes ago, only one live ip is actually connecting. pre-fix:
+// live ip got banned every tick and never appeared in the panel.
+// post-fix: no ban, live ip persisted, historical ips still visible.
+func TestUpdateInboundClientIps_LiveIpNotBannedByStillFreshHistoricals(t *testing.T) {
+	setupIntegrationDB(t)
+
+	const email = "pr4091-repro"
+	seedInboundWithClient(t, "inbound-pr4091", email, 3)
+
+	now := time.Now().Unix()
+	// idle but still within the 30min staleness window.
+	row := seedClientIps(t, email, []IPWithTimestamp{
+		{IP: "10.0.0.1", Timestamp: now - 20*60},
+		{IP: "10.0.0.2", Timestamp: now - 15*60},
+		{IP: "10.0.0.3", Timestamp: now - 10*60},
+	})
+
+	j := NewCheckClientIpJob()
+	// the one that's actually connecting (user's 128.71.x.x).
+	live := []IPWithTimestamp{
+		{IP: "128.71.1.1", Timestamp: now},
+	}
+
+	shouldCleanLog := j.updateInboundClientIps(row, email, live)
+
+	if shouldCleanLog {
+		t.Fatalf("shouldCleanLog must be false, nothing should have been banned with 1 live ip under limit 3")
+	}
+	if len(j.disAllowedIps) != 0 {
+		t.Fatalf("disAllowedIps must be empty, got %v", j.disAllowedIps)
+	}
+
+	persisted := ipSet(readClientIps(t, email))
+	for _, want := range []string{"128.71.1.1", "10.0.0.1", "10.0.0.2", "10.0.0.3"} {
+		if _, ok := persisted[want]; !ok {
+			t.Errorf("expected %s to be persisted in inbound_client_ips.ips; got %v", want, persisted)
+		}
+	}
+	if got := persisted["128.71.1.1"]; got != now {
+		t.Errorf("live ip timestamp should match the scan timestamp %d, got %d", now, got)
+	}
+
+	// 3xipl.log must not contain a ban line.
+	if info, err := os.Stat(readIpLimitLogPath()); err == nil && info.Size() > 0 {
+		body, _ := os.ReadFile(readIpLimitLogPath())
+		t.Fatalf("3xipl.log should be empty when no ips are banned, got:\n%s", body)
+	}
+}
+
+// opposite invariant: when several ips are actually live and exceed
+// the limit, the newcomer still gets banned.
+func TestUpdateInboundClientIps_ExcessLiveIpIsStillBanned(t *testing.T) {
+	setupIntegrationDB(t)
+
+	const email = "pr4091-abuse"
+	seedInboundWithClient(t, "inbound-pr4091-abuse", email, 1)
+
+	now := time.Now().Unix()
+	row := seedClientIps(t, email, []IPWithTimestamp{
+		{IP: "10.1.0.1", Timestamp: now - 60}, // original connection
+	})
+
+	j := NewCheckClientIpJob()
+	// both live, limit=1. use distinct timestamps so sort-by-timestamp
+	// is deterministic: 10.1.0.1 is the original (older), 192.0.2.9
+	// joined later and must get banned.
+	live := []IPWithTimestamp{
+		{IP: "10.1.0.1", Timestamp: now - 5},
+		{IP: "192.0.2.9", Timestamp: now},
+	}
+
+	shouldCleanLog := j.updateInboundClientIps(row, email, live)
+
+	if !shouldCleanLog {
+		t.Fatalf("shouldCleanLog must be true when the live set exceeds the limit")
+	}
+	if len(j.disAllowedIps) != 1 || j.disAllowedIps[0] != "192.0.2.9" {
+		t.Fatalf("expected 192.0.2.9 to be banned; disAllowedIps = %v", j.disAllowedIps)
+	}
+
+	persisted := ipSet(readClientIps(t, email))
+	if _, ok := persisted["10.1.0.1"]; !ok {
+		t.Errorf("original IP 10.1.0.1 must still be persisted; got %v", persisted)
+	}
+	if _, ok := persisted["192.0.2.9"]; ok {
+		t.Errorf("banned IP 192.0.2.9 must NOT be persisted; got %v", persisted)
+	}
+
+	// 3xipl.log must contain the ban line in the exact fail2ban format.
+	body, err := os.ReadFile(readIpLimitLogPath())
+	if err != nil {
+		t.Fatalf("read 3xipl.log: %v", err)
+	}
+	wantSubstr := "[LIMIT_IP] Email = pr4091-abuse || Disconnecting OLD IP = 192.0.2.9"
+	if !contains(string(body), wantSubstr) {
+		t.Fatalf("3xipl.log missing expected ban line %q\nfull log:\n%s", wantSubstr, body)
+	}
+}
+
+// readIpLimitLogPath reads the 3xipl.log path the same way the job
+// does via xray.GetIPLimitLogPath but without importing xray here
+// just for the path helper (which would pull a lot more deps into the
+// test binary). The env-derived log folder is deterministic.
+func readIpLimitLogPath() string {
+	folder := os.Getenv("XUI_LOG_FOLDER")
+	if folder == "" {
+		folder = filepath.Join(".", "log")
+	}
+	return filepath.Join(folder, "3xipl.log")
+}
+
+func contains(haystack, needle string) bool {
+	for i := 0; i+len(needle) <= len(haystack); i++ {
+		if haystack[i:i+len(needle)] == needle {
+			return true
+		}
+	}
+	return false
+}

+ 69 - 0
web/job/check_client_ip_job_test.go

@@ -75,3 +75,72 @@ func TestMergeClientIps_NoStaleCutoffStillWorks(t *testing.T) {
 		t.Fatalf("zero cutoff should keep everything\ngot:  %v\nwant: %v", got, want)
 	}
 }
+
+func collectIps(entries []IPWithTimestamp) []string {
+	out := make([]string, 0, len(entries))
+	for _, e := range entries {
+		out = append(out, e.IP)
+	}
+	return out
+}
+
+func TestPartitionLiveIps_SingleLiveNotStarvedByStillFreshHistoricals(t *testing.T) {
+	// #4091: db holds A, B, C from minutes ago (still in the 30min
+	// window) but they're not connecting anymore. only D is. old code
+	// merged all four, sorted ascending, kept [A,B,C] and banned D
+	// every tick. pin the new rule: only live ips count toward the limit.
+	ipMap := map[string]int64{
+		"A": 1000,
+		"B": 1100,
+		"C": 1200,
+		"D": 2000,
+	}
+	observed := map[string]bool{"D": true}
+
+	live, historical := partitionLiveIps(ipMap, observed)
+
+	if got := collectIps(live); !reflect.DeepEqual(got, []string{"D"}) {
+		t.Fatalf("live set should only contain the ip observed this scan\ngot:  %v\nwant: [D]", got)
+	}
+	if got := collectIps(historical); !reflect.DeepEqual(got, []string{"A", "B", "C"}) {
+		t.Fatalf("historical set should contain db-only ips in ascending order\ngot:  %v\nwant: [A B C]", got)
+	}
+}
+
+func TestPartitionLiveIps_ConcurrentLiveIpsStillBanNewcomers(t *testing.T) {
+	// keep the "protect original, ban newcomer" policy when several ips
+	// are really live. with limit=1, A must stay and B must be banned.
+	ipMap := map[string]int64{
+		"A": 5000,
+		"B": 5500,
+	}
+	observed := map[string]bool{"A": true, "B": true}
+
+	live, historical := partitionLiveIps(ipMap, observed)
+
+	if got := collectIps(live); !reflect.DeepEqual(got, []string{"A", "B"}) {
+		t.Fatalf("both live ips should be in the live set, ascending\ngot:  %v\nwant: [A B]", got)
+	}
+	if len(historical) != 0 {
+		t.Fatalf("no historical ips expected, got %v", historical)
+	}
+}
+
+func TestPartitionLiveIps_EmptyScanLeavesDbIntact(t *testing.T) {
+	// quiet tick: nothing observed => nothing live. everything merged
+	// is historical. keeps the panel from wiping recent-but-idle ips.
+	ipMap := map[string]int64{
+		"A": 1000,
+		"B": 1100,
+	}
+	observed := map[string]bool{}
+
+	live, historical := partitionLiveIps(ipMap, observed)
+
+	if len(live) != 0 {
+		t.Fatalf("no live ips expected, got %v", live)
+	}
+	if got := collectIps(historical); !reflect.DeepEqual(got, []string{"A", "B"}) {
+		t.Fatalf("all merged entries should flow to historical\ngot:  %v\nwant: [A B]", got)
+	}
+}