Просмотр исходного кода

fix(job): batch ip-limit per-email lookups and persistence

processObserved paid four round-trips per observed email every 10s scan:
an inbound-resolving join, a tracking-row read, an autocommit Save (one
fsync each under synchronous=FULL), and — worst of all — a full JSON
parse of the owning inbound's settings blob just to read that one
client's limitIp. On a big single inbound that parse alone made a scan
cost ~1.5s per online client.

The scan now front-loads three chunked batch queries (clients.limit_ip,
email->inbound through the client_inbounds relation keeping the lowest
inbound id like the old First(), and the tracking rows) and writes every
inbound_client_ips change inside one transaction, so M observed emails
cost a handful of queries and a single fsync. The per-email LIKE fallback
remains for emails missing from the relation, preserving the #4963
stale-email cleanup. limitIp now comes from the clients table (same
source B3 gates on) instead of the settings blob, and xray disconnects
for banned clients run after the commit so their network round-trips
never extend the write transaction node syncs contend with.
MHSanaei 1 день назад
Родитель
Сommit
c0d17e132d

+ 220 - 86
internal/web/job/check_client_ip_job.go

@@ -124,32 +124,187 @@ func (j *CheckClientIpJob) hasLimitIp() bool {
 	return err == nil && probe > 0
 }
 
+const ipScanChunk = 400
+
+func chunkEmails(s []string, size int) [][]string {
+	if len(s) == 0 {
+		return nil
+	}
+	chunks := make([][]string, 0, (len(s)+size-1)/size)
+	for size < len(s) {
+		s, chunks = s[size:], append(chunks, s[:size])
+	}
+	return append(chunks, s)
+}
+
+// loadClientLimits maps each observed email to its clients.limit_ip in a few
+// chunked queries, replacing the per-email settings-JSON parse that previously
+// resolved the limit.
+func (j *CheckClientIpJob) loadClientLimits(emails []string) map[string]int {
+	db := database.GetDB()
+	out := make(map[string]int, len(emails))
+	for _, batch := range chunkEmails(emails, ipScanChunk) {
+		var rows []struct {
+			Email   string
+			LimitIp int
+		}
+		if err := db.Model(&model.ClientRecord{}).
+			Select("email, limit_ip").
+			Where("email IN ?", batch).
+			Scan(&rows).Error; err != nil {
+			j.checkError(err)
+			continue
+		}
+		for _, r := range rows {
+			out[r.Email] = r.LimitIp
+		}
+	}
+	return out
+}
+
+// loadInboundsByEmails resolves each email's owning inbound through the
+// clients/client_inbounds relation in chunked queries. Like the old per-email
+// First() it keeps the lowest inbound id when a client spans several inbounds.
+func (j *CheckClientIpJob) loadInboundsByEmails(emails []string) map[string]*model.Inbound {
+	db := database.GetDB()
+	minInboundByEmail := make(map[string]int, len(emails))
+	for _, batch := range chunkEmails(emails, ipScanChunk) {
+		var pairs []struct {
+			Email     string
+			InboundId int
+		}
+		if err := db.Table("client_inbounds").
+			Select("clients.email AS email, client_inbounds.inbound_id AS inbound_id").
+			Joins("JOIN clients ON clients.id = client_inbounds.client_id").
+			Where("clients.email IN ?", batch).
+			Scan(&pairs).Error; err != nil {
+			j.checkError(err)
+			return nil
+		}
+		for _, p := range pairs {
+			if cur, ok := minInboundByEmail[p.Email]; !ok || p.InboundId < cur {
+				minInboundByEmail[p.Email] = p.InboundId
+			}
+		}
+	}
+	if len(minInboundByEmail) == 0 {
+		return nil
+	}
+
+	idSet := make(map[int]struct{}, len(minInboundByEmail))
+	ids := make([]int, 0, len(minInboundByEmail))
+	for _, id := range minInboundByEmail {
+		if _, seen := idSet[id]; !seen {
+			idSet[id] = struct{}{}
+			ids = append(ids, id)
+		}
+	}
+	sort.Ints(ids)
+	inboundsById := make(map[int]*model.Inbound, len(ids))
+	for lo := 0; lo < len(ids); lo += ipScanChunk {
+		hi := min(lo+ipScanChunk, len(ids))
+		var page []*model.Inbound
+		if err := db.Model(&model.Inbound{}).Where("id IN ?", ids[lo:hi]).Find(&page).Error; err != nil {
+			j.checkError(err)
+			return nil
+		}
+		for _, ib := range page {
+			inboundsById[ib.Id] = ib
+		}
+	}
+
+	out := make(map[string]*model.Inbound, len(minInboundByEmail))
+	for email, id := range minInboundByEmail {
+		if ib, ok := inboundsById[id]; ok {
+			out[email] = ib
+		}
+	}
+	return out
+}
+
+func (j *CheckClientIpJob) loadClientIpRows(emails []string) map[string]*model.InboundClientIps {
+	db := database.GetDB()
+	out := make(map[string]*model.InboundClientIps, len(emails))
+	for _, batch := range chunkEmails(emails, ipScanChunk) {
+		var rows []model.InboundClientIps
+		if err := db.Where("client_email IN ?", batch).Find(&rows).Error; err != nil {
+			j.checkError(err)
+			continue
+		}
+		for i := range rows {
+			out[rows[i].ClientEmail] = &rows[i]
+		}
+	}
+	return out
+}
+
 // processObserved runs collection + enforcement for one scan's observations
 // (email -> ip -> last-seen unix seconds). observedAreLive marks the
 // observations as live connections, which bypass the stale cutoff: a connection
 // that opened hours ago is still live even though its timestamp is old. The
 // online-stats API always reports live connections, so the job passes true.
+// Lookups are batched up front and all inbound_client_ips writes share one
+// transaction, so a scan costs a handful of queries and one fsync instead of
+// several per observed email.
 func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64, enforce, observedAreLive bool) bool {
 	shouldCleanLog := false
 	now := time.Now().Unix()
+
+	emails := make([]string, 0, len(observed))
+	for email := range observed {
+		emails = append(emails, email)
+	}
+	sort.Strings(emails)
+
+	limitByEmail := j.loadClientLimits(emails)
+	inboundByEmail := j.loadInboundsByEmails(emails)
+	ipRowByEmail := j.loadClientIpRows(emails)
+
 	// attribution accumulates this scan's local observations per email so they can
 	// be recorded under this panel's own guid for cross-node IP attribution.
 	attribution := make(map[string][]model.ClientIpEntry, len(observed))
-	for email, ipTimestamps := range observed {
+
+	type pendingDisconnect struct {
+		inbound *model.Inbound
+		email   string
+	}
+	var disconnects []pendingDisconnect
+
+	db := database.GetDB()
+	tx := db.Begin()
+	if tx.Error != nil {
+		j.checkError(tx.Error)
+		return false
+	}
+	committed := false
+	defer func() {
+		if !committed {
+			tx.Rollback()
+		}
+	}()
+
+	for _, email := range emails {
+		ipTimestamps := observed[email]
 
 		// The observations can still reference a client that was just renamed
 		// or deleted; its email no longer matches any inbound. Skip it (and
 		// drop any orphaned tracking row) instead of recreating a row and
-		// logging an ERROR every run (#4963).
-		inbound, err := j.getInboundByEmail(email)
-		if err != nil {
-			if errors.Is(err, gorm.ErrRecordNotFound) {
-				logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email)
-				j.delInboundClientIps(email)
-			} else {
-				j.checkError(err)
+		// logging an ERROR every run (#4963). The batch map resolves through
+		// the clients relation; the per-email fallback keeps its settings LIKE
+		// net for clients not yet present there.
+		inbound, ok := inboundByEmail[email]
+		if !ok {
+			var err error
+			inbound, err = j.getInboundByEmail(email)
+			if err != nil {
+				if errors.Is(err, gorm.ErrRecordNotFound) {
+					logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email)
+					j.delInboundClientIps(tx, email)
+				} else {
+					j.checkError(err)
+				}
+				continue
 			}
-			continue
 		}
 
 		// Convert to IPWithTimestamp slice
@@ -170,13 +325,44 @@ func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64,
 			attribution[email] = attrEntries
 		}
 
-		clientIpsRecord, err := j.getInboundClientIps(email)
-		if err != nil {
-			_ = j.addInboundClientIps(email, ipsWithTime)
+		clientIpsRecord, ok := ipRowByEmail[email]
+		if !ok {
+			jsonIps, err := json.Marshal(ipsWithTime)
+			if err != nil {
+				j.checkError(err)
+				continue
+			}
+			if err := tx.Save(&model.InboundClientIps{ClientEmail: email, Ips: string(jsonIps)}).Error; err != nil {
+				j.checkError(err)
+			}
 			continue
 		}
 
-		shouldCleanLog = j.updateInboundClientIps(clientIpsRecord, inbound, email, ipsWithTime, enforce, observedAreLive) || shouldCleanLog
+		cleaned, banned := j.updateInboundClientIps(tx, clientIpsRecord, inbound, email, limitByEmail[email], ipsWithTime, enforce, observedAreLive)
+		shouldCleanLog = cleaned || shouldCleanLog
+		if banned {
+			disconnects = append(disconnects, pendingDisconnect{inbound: inbound, email: email})
+		}
+	}
+
+	if err := tx.Commit().Error; err != nil {
+		j.checkError(err)
+		return shouldCleanLog
+	}
+	committed = true
+
+	// Xray disconnects run after the commit so their network round-trips never
+	// extend the scan's write transaction (node syncs upsert the same table).
+	clientsCache := make(map[int][]model.Client)
+	for _, d := range disconnects {
+		clients, cached := clientsCache[d.inbound.Id]
+		if !cached {
+			settings := map[string][]model.Client{}
+			_ = json.Unmarshal([]byte(d.inbound.Settings), &settings)
+			clients = settings["clients"]
+			clientsCache[d.inbound.Id] = clients
+		}
+		j.disconnectClientTemporarily(d.inbound, d.email, clients)
 	}
 
 	j.recordLocalAttribution(attribution)
@@ -275,81 +461,34 @@ func (j *CheckClientIpJob) checkError(e error) {
 	}
 }
 
-func (j *CheckClientIpJob) getInboundClientIps(clientEmail string) (*model.InboundClientIps, error) {
-	db := database.GetDB()
-	InboundClientIps := &model.InboundClientIps{}
-	err := db.Model(model.InboundClientIps{}).Where("client_email = ?", clientEmail).First(InboundClientIps).Error
-	if err != nil {
-		return nil, err
-	}
-	return InboundClientIps, nil
-}
-
-func (j *CheckClientIpJob) addInboundClientIps(clientEmail string, ipsWithTime []IPWithTimestamp) error {
-	inboundClientIps := &model.InboundClientIps{}
-	jsonIps, err := json.Marshal(ipsWithTime)
-	j.checkError(err)
-
-	inboundClientIps.ClientEmail = clientEmail
-	inboundClientIps.Ips = string(jsonIps)
-
-	db := database.GetDB()
-	tx := db.Begin()
-
-	defer func() {
-		if err == nil {
-			tx.Commit()
-		} else {
-			tx.Rollback()
-		}
-	}()
-
-	err = tx.Save(inboundClientIps).Error
-	if err != nil {
-		return err
-	}
-	return nil
-}
-
 // delInboundClientIps drops the inbound_client_ips tracking row for an email
 // that no longer maps to any inbound (a renamed or deleted client), so stale
 // access-log entries don't keep a ghost row alive (#4963).
-func (j *CheckClientIpJob) delInboundClientIps(clientEmail string) {
-	db := database.GetDB()
-	if err := db.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil {
+func (j *CheckClientIpJob) delInboundClientIps(tx *gorm.DB, clientEmail string) {
+	if err := tx.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil {
 		j.checkError(err)
 	}
 }
 
-func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) bool {
+// updateInboundClientIps merges one email's observed IPs into its tracking row
+// and applies the IP limit. limitIp comes from the caller (the clients table);
+// writes go through the caller's transaction. banned=true asks the caller to
+// disconnect the client after the transaction commits.
+func (j *CheckClientIpJob) updateInboundClientIps(tx *gorm.DB, inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, limitIp int, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) (shouldCleanLog, banned bool) {
 	if inbound.Settings == "" {
 		logger.Debug("wrong data:", inbound)
-		return false
-	}
-
-	settings := map[string][]model.Client{}
-	_ = json.Unmarshal([]byte(inbound.Settings), &settings)
-	clients := settings["clients"]
-
-	// Find the client's IP limit
-	var limitIp int
-	var clientFound bool
-	for _, client := range clients {
-		if client.Email == clientEmail {
-			limitIp = client.LimitIP
-			clientFound = true
-			break
-		}
+		return false, false
 	}
 
-	if !enforce || !clientFound || limitIp <= 0 || !inbound.Enable {
-		// Nothing to enforce (collection-only run, no limit, client missing, or
-		// inbound disabled): record the observed IPs for the panel and return.
+	if !enforce || limitIp <= 0 || !inbound.Enable {
+		// Nothing to enforce (collection-only run, no limit on the clients row,
+		// or inbound disabled): record the observed IPs for the panel and return.
 		jsonIps, _ := json.Marshal(newIpsWithTime)
 		inboundClientIps.Ips = string(jsonIps)
-		db := database.GetDB()
-		db.Save(inboundClientIps)
-		return false
+		if err := tx.Save(inboundClientIps).Error; err != nil {
+			logger.Error("failed to save inboundClientIps:", err)
+		}
+		return false, false
 	}
 
 	// Parse old IPs from database
@@ -368,18 +507,18 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
 	}
 	liveIps, historicalIps := partitionLiveIps(ipMap, observedThisScan)
 
-	shouldCleanLog := false
 	j.disAllowedIps = []string{}
 
 	// historical db-only ips are excluded from this count on purpose.
 	keptLive, bannedLive := selectIpsToBan(liveIps, limitIp)
 	if len(bannedLive) > 0 {
 		shouldCleanLog = true
+		banned = true
 
 		logIpFile, err := os.OpenFile(xray.GetIPLimitLogPath(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
 		if err != nil {
 			logger.Errorf("failed to open IP limit log file: %s", err)
-			return false
+			return false, false
 		}
 		defer logIpFile.Close()
 		ipLogger := log.New(logIpFile, "", log.LstdFlags)
@@ -392,9 +531,6 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
 			j.disAllowedIps = append(j.disAllowedIps, ipTime.IP)
 			ipLogger.Printf("[LIMIT_IP] Email = %s || Disconnecting OLD IP = %s || Timestamp = %d", clientEmail, ipTime.IP, ipTime.Timestamp)
 		}
-
-		// force xray to drop existing connections from banned ips
-		j.disconnectClientTemporarily(inbound, clientEmail, clients)
 	}
 
 	// keep kept-live + historical in the blob so the panel keeps showing
@@ -406,18 +542,16 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
 	jsonIps, _ := json.Marshal(dbIps)
 	inboundClientIps.Ips = string(jsonIps)
 
-	db := database.GetDB()
-	err := db.Save(inboundClientIps).Error
-	if err != nil {
+	if err := tx.Save(inboundClientIps).Error; err != nil {
 		logger.Error("failed to save inboundClientIps:", err)
-		return false
+		return false, banned
 	}
 
 	if len(j.disAllowedIps) > 0 {
 		logger.Infof("[LIMIT_IP] Client %s: Kept %d live IPs, queued %d old IPs for fail2ban", clientEmail, len(keptLive), len(j.disAllowedIps))
 	}
 
-	return shouldCleanLog
+	return shouldCleanLog, banned
 }
 
 // disconnectClientTemporarily removes and re-adds a client to force disconnect banned connections

+ 9 - 3
internal/web/job/check_client_ip_job_integration_test.go

@@ -95,7 +95,7 @@ func seedInboundOnlyWithClient(t *testing.T, tag, email string, limitIp int) *mo
 func seedLinkedInboundWithClient(t *testing.T, tag, email string, limitIp int) *model.Inbound {
 	t.Helper()
 	inbound := seedInboundOnlyWithClient(t, tag, email, limitIp)
-	client := &model.ClientRecord{Email: email}
+	client := &model.ClientRecord{Email: email, LimitIP: limitIp}
 	if err := database.GetDB().Create(client).Error; err != nil {
 		t.Fatalf("seed client record: %v", err)
 	}
@@ -206,11 +206,14 @@ func TestUpdateInboundClientIps_LiveIpNotBannedByStillFreshHistoricals(t *testin
 	if err != nil {
 		t.Fatalf("getInboundByEmail: %v", err)
 	}
-	shouldCleanLog := j.updateInboundClientIps(row, inbound, email, live, true, false)
+	shouldCleanLog, banned := j.updateInboundClientIps(database.GetDB(), row, inbound, email, 3, live, true, false)
 
 	if shouldCleanLog {
 		t.Fatalf("shouldCleanLog must be false, nothing should have been banned with 1 live ip under limit 3")
 	}
+	if banned {
+		t.Fatalf("banned must be false with 1 live ip under limit 3")
+	}
 	if len(j.disAllowedIps) != 0 {
 		t.Fatalf("disAllowedIps must be empty, got %v", j.disAllowedIps)
 	}
@@ -259,11 +262,14 @@ func TestUpdateInboundClientIps_ExcessLiveIpIsStillBanned(t *testing.T) {
 	if err != nil {
 		t.Fatalf("getInboundByEmail: %v", err)
 	}
-	shouldCleanLog := j.updateInboundClientIps(row, inbound, email, live, true, false)
+	shouldCleanLog, banned := j.updateInboundClientIps(database.GetDB(), row, inbound, email, 1, live, true, false)
 
 	if !shouldCleanLog {
 		t.Fatalf("shouldCleanLog must be true when the live set exceeds the limit")
 	}
+	if !banned {
+		t.Fatalf("banned must be true when the live set exceeds the limit")
+	}
 	if len(j.disAllowedIps) != 1 || j.disAllowedIps[0] != "10.1.0.1" {
 		t.Fatalf("expected 10.1.0.1 to be banned; disAllowedIps = %v", j.disAllowedIps)
 	}