Răsfoiți Sursa

fix(job): batch ip-limit per-email lookups and persistence

processObserved paid four round-trips per observed email every 10s scan:
an inbound-resolving join, a tracking-row read, an autocommit Save (one
fsync each under synchronous=FULL), and — worst of all — a full JSON
parse of the owning inbound's settings blob just to read that one
client's limitIp. On a big single inbound that parse alone made a scan
cost ~1.5s per online client.

The scan now front-loads three chunked batch queries (clients.limit_ip,
email->inbound through the client_inbounds relation keeping the lowest
inbound id like the old First(), and the tracking rows) and writes every
inbound_client_ips change inside one transaction, so M observed emails
cost a handful of queries and a single fsync. The per-email LIKE fallback
remains for emails missing from the relation, preserving the #4963
stale-email cleanup. limitIp now comes from the clients table (same
source B3 gates on) instead of the settings blob, and xray disconnects
for banned clients run after the commit so their network round-trips
never extend the write transaction node syncs contend with.
MHSanaei 1 zi în urmă
părinte
comite
c0d17e132d

+ 220 - 86
internal/web/job/check_client_ip_job.go

@@ -124,32 +124,187 @@ func (j *CheckClientIpJob) hasLimitIp() bool {
 	return err == nil && probe > 0
 	return err == nil && probe > 0
 }
 }
 
 
+const ipScanChunk = 400
+
+func chunkEmails(s []string, size int) [][]string {
+	if len(s) == 0 {
+		return nil
+	}
+	chunks := make([][]string, 0, (len(s)+size-1)/size)
+	for size < len(s) {
+		s, chunks = s[size:], append(chunks, s[:size])
+	}
+	return append(chunks, s)
+}
+
+// loadClientLimits maps each observed email to its clients.limit_ip in a few
+// chunked queries, replacing the per-email settings-JSON parse that previously
+// resolved the limit.
+func (j *CheckClientIpJob) loadClientLimits(emails []string) map[string]int {
+	db := database.GetDB()
+	out := make(map[string]int, len(emails))
+	for _, batch := range chunkEmails(emails, ipScanChunk) {
+		var rows []struct {
+			Email   string
+			LimitIp int
+		}
+		if err := db.Model(&model.ClientRecord{}).
+			Select("email, limit_ip").
+			Where("email IN ?", batch).
+			Scan(&rows).Error; err != nil {
+			j.checkError(err)
+			continue
+		}
+		for _, r := range rows {
+			out[r.Email] = r.LimitIp
+		}
+	}
+	return out
+}
+
+// loadInboundsByEmails resolves each email's owning inbound through the
+// clients/client_inbounds relation in chunked queries. Like the old per-email
+// First() it keeps the lowest inbound id when a client spans several inbounds.
+func (j *CheckClientIpJob) loadInboundsByEmails(emails []string) map[string]*model.Inbound {
+	db := database.GetDB()
+	minInboundByEmail := make(map[string]int, len(emails))
+	for _, batch := range chunkEmails(emails, ipScanChunk) {
+		var pairs []struct {
+			Email     string
+			InboundId int
+		}
+		if err := db.Table("client_inbounds").
+			Select("clients.email AS email, client_inbounds.inbound_id AS inbound_id").
+			Joins("JOIN clients ON clients.id = client_inbounds.client_id").
+			Where("clients.email IN ?", batch).
+			Scan(&pairs).Error; err != nil {
+			j.checkError(err)
+			return nil
+		}
+		for _, p := range pairs {
+			if cur, ok := minInboundByEmail[p.Email]; !ok || p.InboundId < cur {
+				minInboundByEmail[p.Email] = p.InboundId
+			}
+		}
+	}
+	if len(minInboundByEmail) == 0 {
+		return nil
+	}
+
+	idSet := make(map[int]struct{}, len(minInboundByEmail))
+	ids := make([]int, 0, len(minInboundByEmail))
+	for _, id := range minInboundByEmail {
+		if _, seen := idSet[id]; !seen {
+			idSet[id] = struct{}{}
+			ids = append(ids, id)
+		}
+	}
+	sort.Ints(ids)
+	inboundsById := make(map[int]*model.Inbound, len(ids))
+	for lo := 0; lo < len(ids); lo += ipScanChunk {
+		hi := min(lo+ipScanChunk, len(ids))
+		var page []*model.Inbound
+		if err := db.Model(&model.Inbound{}).Where("id IN ?", ids[lo:hi]).Find(&page).Error; err != nil {
+			j.checkError(err)
+			return nil
+		}
+		for _, ib := range page {
+			inboundsById[ib.Id] = ib
+		}
+	}
+
+	out := make(map[string]*model.Inbound, len(minInboundByEmail))
+	for email, id := range minInboundByEmail {
+		if ib, ok := inboundsById[id]; ok {
+			out[email] = ib
+		}
+	}
+	return out
+}
+
+func (j *CheckClientIpJob) loadClientIpRows(emails []string) map[string]*model.InboundClientIps {
+	db := database.GetDB()
+	out := make(map[string]*model.InboundClientIps, len(emails))
+	for _, batch := range chunkEmails(emails, ipScanChunk) {
+		var rows []model.InboundClientIps
+		if err := db.Where("client_email IN ?", batch).Find(&rows).Error; err != nil {
+			j.checkError(err)
+			continue
+		}
+		for i := range rows {
+			out[rows[i].ClientEmail] = &rows[i]
+		}
+	}
+	return out
+}
+
 // processObserved runs collection + enforcement for one scan's observations
 // processObserved runs collection + enforcement for one scan's observations
 // (email -> ip -> last-seen unix seconds). observedAreLive marks the
 // (email -> ip -> last-seen unix seconds). observedAreLive marks the
 // observations as live connections, which bypass the stale cutoff: a connection
 // observations as live connections, which bypass the stale cutoff: a connection
 // that opened hours ago is still live even though its timestamp is old. The
 // that opened hours ago is still live even though its timestamp is old. The
 // online-stats API always reports live connections, so the job passes true.
 // online-stats API always reports live connections, so the job passes true.
+// Lookups are batched up front and all inbound_client_ips writes share one
+// transaction, so a scan costs a handful of queries and one fsync instead of
+// several per observed email.
 func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64, enforce, observedAreLive bool) bool {
 func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64, enforce, observedAreLive bool) bool {
 	shouldCleanLog := false
 	shouldCleanLog := false
 	now := time.Now().Unix()
 	now := time.Now().Unix()
+
+	emails := make([]string, 0, len(observed))
+	for email := range observed {
+		emails = append(emails, email)
+	}
+	sort.Strings(emails)
+
+	limitByEmail := j.loadClientLimits(emails)
+	inboundByEmail := j.loadInboundsByEmails(emails)
+	ipRowByEmail := j.loadClientIpRows(emails)
+
 	// attribution accumulates this scan's local observations per email so they can
 	// attribution accumulates this scan's local observations per email so they can
 	// be recorded under this panel's own guid for cross-node IP attribution.
 	// be recorded under this panel's own guid for cross-node IP attribution.
 	attribution := make(map[string][]model.ClientIpEntry, len(observed))
 	attribution := make(map[string][]model.ClientIpEntry, len(observed))
-	for email, ipTimestamps := range observed {
+
+	type pendingDisconnect struct {
+		inbound *model.Inbound
+		email   string
+	}
+	var disconnects []pendingDisconnect
+
+	db := database.GetDB()
+	tx := db.Begin()
+	if tx.Error != nil {
+		j.checkError(tx.Error)
+		return false
+	}
+	committed := false
+	defer func() {
+		if !committed {
+			tx.Rollback()
+		}
+	}()
+
+	for _, email := range emails {
+		ipTimestamps := observed[email]
 
 
 		// The observations can still reference a client that was just renamed
 		// The observations can still reference a client that was just renamed
 		// or deleted; its email no longer matches any inbound. Skip it (and
 		// or deleted; its email no longer matches any inbound. Skip it (and
 		// drop any orphaned tracking row) instead of recreating a row and
 		// drop any orphaned tracking row) instead of recreating a row and
-		// logging an ERROR every run (#4963).
-		inbound, err := j.getInboundByEmail(email)
-		if err != nil {
-			if errors.Is(err, gorm.ErrRecordNotFound) {
-				logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email)
-				j.delInboundClientIps(email)
-			} else {
-				j.checkError(err)
+		// logging an ERROR every run (#4963). The batch map resolves through
+		// the clients relation; the per-email fallback keeps its settings LIKE
+		// net for clients not yet present there.
+		inbound, ok := inboundByEmail[email]
+		if !ok {
+			var err error
+			inbound, err = j.getInboundByEmail(email)
+			if err != nil {
+				if errors.Is(err, gorm.ErrRecordNotFound) {
+					logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email)
+					j.delInboundClientIps(tx, email)
+				} else {
+					j.checkError(err)
+				}
+				continue
 			}
 			}
-			continue
 		}
 		}
 
 
 		// Convert to IPWithTimestamp slice
 		// Convert to IPWithTimestamp slice
@@ -170,13 +325,44 @@ func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64,
 			attribution[email] = attrEntries
 			attribution[email] = attrEntries
 		}
 		}
 
 
-		clientIpsRecord, err := j.getInboundClientIps(email)
-		if err != nil {
-			_ = j.addInboundClientIps(email, ipsWithTime)
+		clientIpsRecord, ok := ipRowByEmail[email]
+		if !ok {
+			jsonIps, err := json.Marshal(ipsWithTime)
+			if err != nil {
+				j.checkError(err)
+				continue
+			}
+			if err := tx.Save(&model.InboundClientIps{ClientEmail: email, Ips: string(jsonIps)}).Error; err != nil {
+				j.checkError(err)
+			}
 			continue
 			continue
 		}
 		}
 
 
-		shouldCleanLog = j.updateInboundClientIps(clientIpsRecord, inbound, email, ipsWithTime, enforce, observedAreLive) || shouldCleanLog
+		cleaned, banned := j.updateInboundClientIps(tx, clientIpsRecord, inbound, email, limitByEmail[email], ipsWithTime, enforce, observedAreLive)
+		shouldCleanLog = cleaned || shouldCleanLog
+		if banned {
+			disconnects = append(disconnects, pendingDisconnect{inbound: inbound, email: email})
+		}
+	}
+
+	if err := tx.Commit().Error; err != nil {
+		j.checkError(err)
+		return shouldCleanLog
+	}
+	committed = true
+
+	// Xray disconnects run after the commit so their network round-trips never
+	// extend the scan's write transaction (node syncs upsert the same table).
+	clientsCache := make(map[int][]model.Client)
+	for _, d := range disconnects {
+		clients, cached := clientsCache[d.inbound.Id]
+		if !cached {
+			settings := map[string][]model.Client{}
+			_ = json.Unmarshal([]byte(d.inbound.Settings), &settings)
+			clients = settings["clients"]
+			clientsCache[d.inbound.Id] = clients
+		}
+		j.disconnectClientTemporarily(d.inbound, d.email, clients)
 	}
 	}
 
 
 	j.recordLocalAttribution(attribution)
 	j.recordLocalAttribution(attribution)
@@ -275,81 +461,34 @@ func (j *CheckClientIpJob) checkError(e error) {
 	}
 	}
 }
 }
 
 
-func (j *CheckClientIpJob) getInboundClientIps(clientEmail string) (*model.InboundClientIps, error) {
-	db := database.GetDB()
-	InboundClientIps := &model.InboundClientIps{}
-	err := db.Model(model.InboundClientIps{}).Where("client_email = ?", clientEmail).First(InboundClientIps).Error
-	if err != nil {
-		return nil, err
-	}
-	return InboundClientIps, nil
-}
-
-func (j *CheckClientIpJob) addInboundClientIps(clientEmail string, ipsWithTime []IPWithTimestamp) error {
-	inboundClientIps := &model.InboundClientIps{}
-	jsonIps, err := json.Marshal(ipsWithTime)
-	j.checkError(err)
-
-	inboundClientIps.ClientEmail = clientEmail
-	inboundClientIps.Ips = string(jsonIps)
-
-	db := database.GetDB()
-	tx := db.Begin()
-
-	defer func() {
-		if err == nil {
-			tx.Commit()
-		} else {
-			tx.Rollback()
-		}
-	}()
-
-	err = tx.Save(inboundClientIps).Error
-	if err != nil {
-		return err
-	}
-	return nil
-}
-
 // delInboundClientIps drops the inbound_client_ips tracking row for an email
 // delInboundClientIps drops the inbound_client_ips tracking row for an email
 // that no longer maps to any inbound (a renamed or deleted client), so stale
 // that no longer maps to any inbound (a renamed or deleted client), so stale
 // access-log entries don't keep a ghost row alive (#4963).
 // access-log entries don't keep a ghost row alive (#4963).
-func (j *CheckClientIpJob) delInboundClientIps(clientEmail string) {
-	db := database.GetDB()
-	if err := db.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil {
+func (j *CheckClientIpJob) delInboundClientIps(tx *gorm.DB, clientEmail string) {
+	if err := tx.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil {
 		j.checkError(err)
 		j.checkError(err)
 	}
 	}
 }
 }
 
 
-func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) bool {
+// updateInboundClientIps merges one email's observed IPs into its tracking row
+// and applies the IP limit. limitIp comes from the caller (the clients table);
+// writes go through the caller's transaction. banned=true asks the caller to
+// disconnect the client after the transaction commits.
+func (j *CheckClientIpJob) updateInboundClientIps(tx *gorm.DB, inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, limitIp int, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) (shouldCleanLog, banned bool) {
 	if inbound.Settings == "" {
 	if inbound.Settings == "" {
 		logger.Debug("wrong data:", inbound)
 		logger.Debug("wrong data:", inbound)
-		return false
-	}
-
-	settings := map[string][]model.Client{}
-	_ = json.Unmarshal([]byte(inbound.Settings), &settings)
-	clients := settings["clients"]
-
-	// Find the client's IP limit
-	var limitIp int
-	var clientFound bool
-	for _, client := range clients {
-		if client.Email == clientEmail {
-			limitIp = client.LimitIP
-			clientFound = true
-			break
-		}
+		return false, false
 	}
 	}
 
 
-	if !enforce || !clientFound || limitIp <= 0 || !inbound.Enable {
-		// Nothing to enforce (collection-only run, no limit, client missing, or
-		// inbound disabled): record the observed IPs for the panel and return.
+	if !enforce || limitIp <= 0 || !inbound.Enable {
+		// Nothing to enforce (collection-only run, no limit on the clients row,
+		// or inbound disabled): record the observed IPs for the panel and return.
 		jsonIps, _ := json.Marshal(newIpsWithTime)
 		jsonIps, _ := json.Marshal(newIpsWithTime)
 		inboundClientIps.Ips = string(jsonIps)
 		inboundClientIps.Ips = string(jsonIps)
-		db := database.GetDB()
-		db.Save(inboundClientIps)
-		return false
+		if err := tx.Save(inboundClientIps).Error; err != nil {
+			logger.Error("failed to save inboundClientIps:", err)
+		}
+		return false, false
 	}
 	}
 
 
 	// Parse old IPs from database
 	// Parse old IPs from database
@@ -368,18 +507,18 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
 	}
 	}
 	liveIps, historicalIps := partitionLiveIps(ipMap, observedThisScan)
 	liveIps, historicalIps := partitionLiveIps(ipMap, observedThisScan)
 
 
-	shouldCleanLog := false
 	j.disAllowedIps = []string{}
 	j.disAllowedIps = []string{}
 
 
 	// historical db-only ips are excluded from this count on purpose.
 	// historical db-only ips are excluded from this count on purpose.
 	keptLive, bannedLive := selectIpsToBan(liveIps, limitIp)
 	keptLive, bannedLive := selectIpsToBan(liveIps, limitIp)
 	if len(bannedLive) > 0 {
 	if len(bannedLive) > 0 {
 		shouldCleanLog = true
 		shouldCleanLog = true
+		banned = true
 
 
 		logIpFile, err := os.OpenFile(xray.GetIPLimitLogPath(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
 		logIpFile, err := os.OpenFile(xray.GetIPLimitLogPath(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
 		if err != nil {
 		if err != nil {
 			logger.Errorf("failed to open IP limit log file: %s", err)
 			logger.Errorf("failed to open IP limit log file: %s", err)
-			return false
+			return false, false
 		}
 		}
 		defer logIpFile.Close()
 		defer logIpFile.Close()
 		ipLogger := log.New(logIpFile, "", log.LstdFlags)
 		ipLogger := log.New(logIpFile, "", log.LstdFlags)
@@ -392,9 +531,6 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
 			j.disAllowedIps = append(j.disAllowedIps, ipTime.IP)
 			j.disAllowedIps = append(j.disAllowedIps, ipTime.IP)
 			ipLogger.Printf("[LIMIT_IP] Email = %s || Disconnecting OLD IP = %s || Timestamp = %d", clientEmail, ipTime.IP, ipTime.Timestamp)
 			ipLogger.Printf("[LIMIT_IP] Email = %s || Disconnecting OLD IP = %s || Timestamp = %d", clientEmail, ipTime.IP, ipTime.Timestamp)
 		}
 		}
-
-		// force xray to drop existing connections from banned ips
-		j.disconnectClientTemporarily(inbound, clientEmail, clients)
 	}
 	}
 
 
 	// keep kept-live + historical in the blob so the panel keeps showing
 	// keep kept-live + historical in the blob so the panel keeps showing
@@ -406,18 +542,16 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
 	jsonIps, _ := json.Marshal(dbIps)
 	jsonIps, _ := json.Marshal(dbIps)
 	inboundClientIps.Ips = string(jsonIps)
 	inboundClientIps.Ips = string(jsonIps)
 
 
-	db := database.GetDB()
-	err := db.Save(inboundClientIps).Error
-	if err != nil {
+	if err := tx.Save(inboundClientIps).Error; err != nil {
 		logger.Error("failed to save inboundClientIps:", err)
 		logger.Error("failed to save inboundClientIps:", err)
-		return false
+		return false, banned
 	}
 	}
 
 
 	if len(j.disAllowedIps) > 0 {
 	if len(j.disAllowedIps) > 0 {
 		logger.Infof("[LIMIT_IP] Client %s: Kept %d live IPs, queued %d old IPs for fail2ban", clientEmail, len(keptLive), len(j.disAllowedIps))
 		logger.Infof("[LIMIT_IP] Client %s: Kept %d live IPs, queued %d old IPs for fail2ban", clientEmail, len(keptLive), len(j.disAllowedIps))
 	}
 	}
 
 
-	return shouldCleanLog
+	return shouldCleanLog, banned
 }
 }
 
 
 // disconnectClientTemporarily removes and re-adds a client to force disconnect banned connections
 // disconnectClientTemporarily removes and re-adds a client to force disconnect banned connections

+ 9 - 3
internal/web/job/check_client_ip_job_integration_test.go

@@ -95,7 +95,7 @@ func seedInboundOnlyWithClient(t *testing.T, tag, email string, limitIp int) *mo
 func seedLinkedInboundWithClient(t *testing.T, tag, email string, limitIp int) *model.Inbound {
 func seedLinkedInboundWithClient(t *testing.T, tag, email string, limitIp int) *model.Inbound {
 	t.Helper()
 	t.Helper()
 	inbound := seedInboundOnlyWithClient(t, tag, email, limitIp)
 	inbound := seedInboundOnlyWithClient(t, tag, email, limitIp)
-	client := &model.ClientRecord{Email: email}
+	client := &model.ClientRecord{Email: email, LimitIP: limitIp}
 	if err := database.GetDB().Create(client).Error; err != nil {
 	if err := database.GetDB().Create(client).Error; err != nil {
 		t.Fatalf("seed client record: %v", err)
 		t.Fatalf("seed client record: %v", err)
 	}
 	}
@@ -206,11 +206,14 @@ func TestUpdateInboundClientIps_LiveIpNotBannedByStillFreshHistoricals(t *testin
 	if err != nil {
 	if err != nil {
 		t.Fatalf("getInboundByEmail: %v", err)
 		t.Fatalf("getInboundByEmail: %v", err)
 	}
 	}
-	shouldCleanLog := j.updateInboundClientIps(row, inbound, email, live, true, false)
+	shouldCleanLog, banned := j.updateInboundClientIps(database.GetDB(), row, inbound, email, 3, live, true, false)
 
 
 	if shouldCleanLog {
 	if shouldCleanLog {
 		t.Fatalf("shouldCleanLog must be false, nothing should have been banned with 1 live ip under limit 3")
 		t.Fatalf("shouldCleanLog must be false, nothing should have been banned with 1 live ip under limit 3")
 	}
 	}
+	if banned {
+		t.Fatalf("banned must be false with 1 live ip under limit 3")
+	}
 	if len(j.disAllowedIps) != 0 {
 	if len(j.disAllowedIps) != 0 {
 		t.Fatalf("disAllowedIps must be empty, got %v", j.disAllowedIps)
 		t.Fatalf("disAllowedIps must be empty, got %v", j.disAllowedIps)
 	}
 	}
@@ -259,11 +262,14 @@ func TestUpdateInboundClientIps_ExcessLiveIpIsStillBanned(t *testing.T) {
 	if err != nil {
 	if err != nil {
 		t.Fatalf("getInboundByEmail: %v", err)
 		t.Fatalf("getInboundByEmail: %v", err)
 	}
 	}
-	shouldCleanLog := j.updateInboundClientIps(row, inbound, email, live, true, false)
+	shouldCleanLog, banned := j.updateInboundClientIps(database.GetDB(), row, inbound, email, 1, live, true, false)
 
 
 	if !shouldCleanLog {
 	if !shouldCleanLog {
 		t.Fatalf("shouldCleanLog must be true when the live set exceeds the limit")
 		t.Fatalf("shouldCleanLog must be true when the live set exceeds the limit")
 	}
 	}
+	if !banned {
+		t.Fatalf("banned must be true when the live set exceeds the limit")
+	}
 	if len(j.disAllowedIps) != 1 || j.disAllowedIps[0] != "10.1.0.1" {
 	if len(j.disAllowedIps) != 1 || j.disAllowedIps[0] != "10.1.0.1" {
 		t.Fatalf("expected 10.1.0.1 to be banned; disAllowedIps = %v", j.disAllowedIps)
 		t.Fatalf("expected 10.1.0.1 to be banned; disAllowedIps = %v", j.disAllowedIps)
 	}
 	}