|
|
@@ -124,32 +124,187 @@ func (j *CheckClientIpJob) hasLimitIp() bool {
|
|
|
return err == nil && probe > 0
|
|
|
}
|
|
|
|
|
|
+const ipScanChunk = 400
|
|
|
+
|
|
|
+func chunkEmails(s []string, size int) [][]string {
|
|
|
+ if len(s) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ chunks := make([][]string, 0, (len(s)+size-1)/size)
|
|
|
+ for size < len(s) {
|
|
|
+ s, chunks = s[size:], append(chunks, s[:size])
|
|
|
+ }
|
|
|
+ return append(chunks, s)
|
|
|
+}
|
|
|
+
|
|
|
+// loadClientLimits maps each observed email to its clients.limit_ip in a few
|
|
|
+// chunked queries, replacing the per-email settings-JSON parse that previously
|
|
|
+// resolved the limit.
|
|
|
+func (j *CheckClientIpJob) loadClientLimits(emails []string) map[string]int {
|
|
|
+ db := database.GetDB()
|
|
|
+ out := make(map[string]int, len(emails))
|
|
|
+ for _, batch := range chunkEmails(emails, ipScanChunk) {
|
|
|
+ var rows []struct {
|
|
|
+ Email string
|
|
|
+ LimitIp int
|
|
|
+ }
|
|
|
+ if err := db.Model(&model.ClientRecord{}).
|
|
|
+ Select("email, limit_ip").
|
|
|
+ Where("email IN ?", batch).
|
|
|
+ Scan(&rows).Error; err != nil {
|
|
|
+ j.checkError(err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ for _, r := range rows {
|
|
|
+ out[r.Email] = r.LimitIp
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return out
|
|
|
+}
|
|
|
+
|
|
|
+// loadInboundsByEmails resolves each email's owning inbound through the
|
|
|
+// clients/client_inbounds relation in chunked queries. Like the old per-email
|
|
|
+// First() it keeps the lowest inbound id when a client spans several inbounds.
|
|
|
+func (j *CheckClientIpJob) loadInboundsByEmails(emails []string) map[string]*model.Inbound {
|
|
|
+ db := database.GetDB()
|
|
|
+ minInboundByEmail := make(map[string]int, len(emails))
|
|
|
+ for _, batch := range chunkEmails(emails, ipScanChunk) {
|
|
|
+ var pairs []struct {
|
|
|
+ Email string
|
|
|
+ InboundId int
|
|
|
+ }
|
|
|
+ if err := db.Table("client_inbounds").
|
|
|
+ Select("clients.email AS email, client_inbounds.inbound_id AS inbound_id").
|
|
|
+ Joins("JOIN clients ON clients.id = client_inbounds.client_id").
|
|
|
+ Where("clients.email IN ?", batch).
|
|
|
+ Scan(&pairs).Error; err != nil {
|
|
|
+ j.checkError(err)
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ for _, p := range pairs {
|
|
|
+ if cur, ok := minInboundByEmail[p.Email]; !ok || p.InboundId < cur {
|
|
|
+ minInboundByEmail[p.Email] = p.InboundId
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if len(minInboundByEmail) == 0 {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ idSet := make(map[int]struct{}, len(minInboundByEmail))
|
|
|
+ ids := make([]int, 0, len(minInboundByEmail))
|
|
|
+ for _, id := range minInboundByEmail {
|
|
|
+ if _, seen := idSet[id]; !seen {
|
|
|
+ idSet[id] = struct{}{}
|
|
|
+ ids = append(ids, id)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ sort.Ints(ids)
|
|
|
+ inboundsById := make(map[int]*model.Inbound, len(ids))
|
|
|
+ for lo := 0; lo < len(ids); lo += ipScanChunk {
|
|
|
+ hi := min(lo+ipScanChunk, len(ids))
|
|
|
+ var page []*model.Inbound
|
|
|
+ if err := db.Model(&model.Inbound{}).Where("id IN ?", ids[lo:hi]).Find(&page).Error; err != nil {
|
|
|
+ j.checkError(err)
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ for _, ib := range page {
|
|
|
+ inboundsById[ib.Id] = ib
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ out := make(map[string]*model.Inbound, len(minInboundByEmail))
|
|
|
+ for email, id := range minInboundByEmail {
|
|
|
+ if ib, ok := inboundsById[id]; ok {
|
|
|
+ out[email] = ib
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return out
|
|
|
+}
|
|
|
+
|
|
|
+func (j *CheckClientIpJob) loadClientIpRows(emails []string) map[string]*model.InboundClientIps {
|
|
|
+ db := database.GetDB()
|
|
|
+ out := make(map[string]*model.InboundClientIps, len(emails))
|
|
|
+ for _, batch := range chunkEmails(emails, ipScanChunk) {
|
|
|
+ var rows []model.InboundClientIps
|
|
|
+ if err := db.Where("client_email IN ?", batch).Find(&rows).Error; err != nil {
|
|
|
+ j.checkError(err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ for i := range rows {
|
|
|
+ out[rows[i].ClientEmail] = &rows[i]
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return out
|
|
|
+}
|
|
|
+
|
|
|
// processObserved runs collection + enforcement for one scan's observations
|
|
|
// (email -> ip -> last-seen unix seconds). observedAreLive marks the
|
|
|
// observations as live connections, which bypass the stale cutoff: a connection
|
|
|
// that opened hours ago is still live even though its timestamp is old. The
|
|
|
// online-stats API always reports live connections, so the job passes true.
|
|
|
+// Lookups are batched up front and all inbound_client_ips writes share one
|
|
|
+// transaction, so a scan costs a handful of queries and one fsync instead of
|
|
|
+// several per observed email.
|
|
|
func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64, enforce, observedAreLive bool) bool {
|
|
|
shouldCleanLog := false
|
|
|
now := time.Now().Unix()
|
|
|
+
|
|
|
+ emails := make([]string, 0, len(observed))
|
|
|
+ for email := range observed {
|
|
|
+ emails = append(emails, email)
|
|
|
+ }
|
|
|
+ sort.Strings(emails)
|
|
|
+
|
|
|
+ limitByEmail := j.loadClientLimits(emails)
|
|
|
+ inboundByEmail := j.loadInboundsByEmails(emails)
|
|
|
+ ipRowByEmail := j.loadClientIpRows(emails)
|
|
|
+
|
|
|
// attribution accumulates this scan's local observations per email so they can
|
|
|
// be recorded under this panel's own guid for cross-node IP attribution.
|
|
|
attribution := make(map[string][]model.ClientIpEntry, len(observed))
|
|
|
- for email, ipTimestamps := range observed {
|
|
|
+
|
|
|
+ type pendingDisconnect struct {
|
|
|
+ inbound *model.Inbound
|
|
|
+ email string
|
|
|
+ }
|
|
|
+ var disconnects []pendingDisconnect
|
|
|
+
|
|
|
+ db := database.GetDB()
|
|
|
+ tx := db.Begin()
|
|
|
+ if tx.Error != nil {
|
|
|
+ j.checkError(tx.Error)
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ committed := false
|
|
|
+ defer func() {
|
|
|
+ if !committed {
|
|
|
+ tx.Rollback()
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ for _, email := range emails {
|
|
|
+ ipTimestamps := observed[email]
|
|
|
|
|
|
// The observations can still reference a client that was just renamed
|
|
|
// or deleted; its email no longer matches any inbound. Skip it (and
|
|
|
// drop any orphaned tracking row) instead of recreating a row and
|
|
|
- // logging an ERROR every run (#4963).
|
|
|
- inbound, err := j.getInboundByEmail(email)
|
|
|
- if err != nil {
|
|
|
- if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
- logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email)
|
|
|
- j.delInboundClientIps(email)
|
|
|
- } else {
|
|
|
- j.checkError(err)
|
|
|
+ // logging an ERROR every run (#4963). The batch map resolves through
|
|
|
+ // the clients relation; the per-email fallback keeps its settings LIKE
|
|
|
+ // net for clients not yet present there.
|
|
|
+ inbound, ok := inboundByEmail[email]
|
|
|
+ if !ok {
|
|
|
+ var err error
|
|
|
+ inbound, err = j.getInboundByEmail(email)
|
|
|
+ if err != nil {
|
|
|
+ if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
+ logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email)
|
|
|
+ j.delInboundClientIps(tx, email)
|
|
|
+ } else {
|
|
|
+ j.checkError(err)
|
|
|
+ }
|
|
|
+ continue
|
|
|
}
|
|
|
- continue
|
|
|
}
|
|
|
|
|
|
// Convert to IPWithTimestamp slice
|
|
|
@@ -170,13 +325,44 @@ func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64,
|
|
|
attribution[email] = attrEntries
|
|
|
}
|
|
|
|
|
|
- clientIpsRecord, err := j.getInboundClientIps(email)
|
|
|
- if err != nil {
|
|
|
- _ = j.addInboundClientIps(email, ipsWithTime)
|
|
|
+ clientIpsRecord, ok := ipRowByEmail[email]
|
|
|
+ if !ok {
|
|
|
+ jsonIps, err := json.Marshal(ipsWithTime)
|
|
|
+ if err != nil {
|
|
|
+ j.checkError(err)
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if err := tx.Save(&model.InboundClientIps{ClientEmail: email, Ips: string(jsonIps)}).Error; err != nil {
|
|
|
+ j.checkError(err)
|
|
|
+ }
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- shouldCleanLog = j.updateInboundClientIps(clientIpsRecord, inbound, email, ipsWithTime, enforce, observedAreLive) || shouldCleanLog
|
|
|
+ cleaned, banned := j.updateInboundClientIps(tx, clientIpsRecord, inbound, email, limitByEmail[email], ipsWithTime, enforce, observedAreLive)
|
|
|
+ shouldCleanLog = cleaned || shouldCleanLog
|
|
|
+ if banned {
|
|
|
+ disconnects = append(disconnects, pendingDisconnect{inbound: inbound, email: email})
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := tx.Commit().Error; err != nil {
|
|
|
+ j.checkError(err)
|
|
|
+ return shouldCleanLog
|
|
|
+ }
|
|
|
+ committed = true
|
|
|
+
|
|
|
+ // Xray disconnects run after the commit so their network round-trips never
|
|
|
+ // extend the scan's write transaction (node syncs upsert the same table).
|
|
|
+ clientsCache := make(map[int][]model.Client)
|
|
|
+ for _, d := range disconnects {
|
|
|
+ clients, cached := clientsCache[d.inbound.Id]
|
|
|
+ if !cached {
|
|
|
+ settings := map[string][]model.Client{}
|
|
|
+ _ = json.Unmarshal([]byte(d.inbound.Settings), &settings)
|
|
|
+ clients = settings["clients"]
|
|
|
+ clientsCache[d.inbound.Id] = clients
|
|
|
+ }
|
|
|
+ j.disconnectClientTemporarily(d.inbound, d.email, clients)
|
|
|
}
|
|
|
|
|
|
j.recordLocalAttribution(attribution)
|
|
|
@@ -275,81 +461,34 @@ func (j *CheckClientIpJob) checkError(e error) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (j *CheckClientIpJob) getInboundClientIps(clientEmail string) (*model.InboundClientIps, error) {
|
|
|
- db := database.GetDB()
|
|
|
- InboundClientIps := &model.InboundClientIps{}
|
|
|
- err := db.Model(model.InboundClientIps{}).Where("client_email = ?", clientEmail).First(InboundClientIps).Error
|
|
|
- if err != nil {
|
|
|
- return nil, err
|
|
|
- }
|
|
|
- return InboundClientIps, nil
|
|
|
-}
|
|
|
-
|
|
|
-func (j *CheckClientIpJob) addInboundClientIps(clientEmail string, ipsWithTime []IPWithTimestamp) error {
|
|
|
- inboundClientIps := &model.InboundClientIps{}
|
|
|
- jsonIps, err := json.Marshal(ipsWithTime)
|
|
|
- j.checkError(err)
|
|
|
-
|
|
|
- inboundClientIps.ClientEmail = clientEmail
|
|
|
- inboundClientIps.Ips = string(jsonIps)
|
|
|
-
|
|
|
- db := database.GetDB()
|
|
|
- tx := db.Begin()
|
|
|
-
|
|
|
- defer func() {
|
|
|
- if err == nil {
|
|
|
- tx.Commit()
|
|
|
- } else {
|
|
|
- tx.Rollback()
|
|
|
- }
|
|
|
- }()
|
|
|
-
|
|
|
- err = tx.Save(inboundClientIps).Error
|
|
|
- if err != nil {
|
|
|
- return err
|
|
|
- }
|
|
|
- return nil
|
|
|
-}
|
|
|
-
|
|
|
// delInboundClientIps drops the inbound_client_ips tracking row for an email
|
|
|
// that no longer maps to any inbound (a renamed or deleted client), so stale
|
|
|
// access-log entries don't keep a ghost row alive (#4963).
|
|
|
-func (j *CheckClientIpJob) delInboundClientIps(clientEmail string) {
|
|
|
- db := database.GetDB()
|
|
|
- if err := db.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil {
|
|
|
+func (j *CheckClientIpJob) delInboundClientIps(tx *gorm.DB, clientEmail string) {
|
|
|
+ if err := tx.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil {
|
|
|
j.checkError(err)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) bool {
|
|
|
+// updateInboundClientIps merges one email's observed IPs into its tracking row
|
|
|
+// and applies the IP limit. limitIp comes from the caller (the clients table);
|
|
|
+// writes go through the caller's transaction. banned=true asks the caller to
|
|
|
+// disconnect the client after the transaction commits.
|
|
|
+func (j *CheckClientIpJob) updateInboundClientIps(tx *gorm.DB, inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, limitIp int, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) (shouldCleanLog, banned bool) {
|
|
|
if inbound.Settings == "" {
|
|
|
logger.Debug("wrong data:", inbound)
|
|
|
- return false
|
|
|
- }
|
|
|
-
|
|
|
- settings := map[string][]model.Client{}
|
|
|
- _ = json.Unmarshal([]byte(inbound.Settings), &settings)
|
|
|
- clients := settings["clients"]
|
|
|
-
|
|
|
- // Find the client's IP limit
|
|
|
- var limitIp int
|
|
|
- var clientFound bool
|
|
|
- for _, client := range clients {
|
|
|
- if client.Email == clientEmail {
|
|
|
- limitIp = client.LimitIP
|
|
|
- clientFound = true
|
|
|
- break
|
|
|
- }
|
|
|
+ return false, false
|
|
|
}
|
|
|
|
|
|
- if !enforce || !clientFound || limitIp <= 0 || !inbound.Enable {
|
|
|
- // Nothing to enforce (collection-only run, no limit, client missing, or
|
|
|
- // inbound disabled): record the observed IPs for the panel and return.
|
|
|
+ if !enforce || limitIp <= 0 || !inbound.Enable {
|
|
|
+ // Nothing to enforce (collection-only run, no limit on the clients row,
|
|
|
+ // or inbound disabled): record the observed IPs for the panel and return.
|
|
|
jsonIps, _ := json.Marshal(newIpsWithTime)
|
|
|
inboundClientIps.Ips = string(jsonIps)
|
|
|
- db := database.GetDB()
|
|
|
- db.Save(inboundClientIps)
|
|
|
- return false
|
|
|
+ if err := tx.Save(inboundClientIps).Error; err != nil {
|
|
|
+ logger.Error("failed to save inboundClientIps:", err)
|
|
|
+ }
|
|
|
+ return false, false
|
|
|
}
|
|
|
|
|
|
// Parse old IPs from database
|
|
|
@@ -368,18 +507,18 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
|
|
|
}
|
|
|
liveIps, historicalIps := partitionLiveIps(ipMap, observedThisScan)
|
|
|
|
|
|
- shouldCleanLog := false
|
|
|
j.disAllowedIps = []string{}
|
|
|
|
|
|
// historical db-only ips are excluded from this count on purpose.
|
|
|
keptLive, bannedLive := selectIpsToBan(liveIps, limitIp)
|
|
|
if len(bannedLive) > 0 {
|
|
|
shouldCleanLog = true
|
|
|
+ banned = true
|
|
|
|
|
|
logIpFile, err := os.OpenFile(xray.GetIPLimitLogPath(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
|
|
|
if err != nil {
|
|
|
logger.Errorf("failed to open IP limit log file: %s", err)
|
|
|
- return false
|
|
|
+ return false, false
|
|
|
}
|
|
|
defer logIpFile.Close()
|
|
|
ipLogger := log.New(logIpFile, "", log.LstdFlags)
|
|
|
@@ -392,9 +531,6 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
|
|
|
j.disAllowedIps = append(j.disAllowedIps, ipTime.IP)
|
|
|
ipLogger.Printf("[LIMIT_IP] Email = %s || Disconnecting OLD IP = %s || Timestamp = %d", clientEmail, ipTime.IP, ipTime.Timestamp)
|
|
|
}
|
|
|
-
|
|
|
- // force xray to drop existing connections from banned ips
|
|
|
- j.disconnectClientTemporarily(inbound, clientEmail, clients)
|
|
|
}
|
|
|
|
|
|
// keep kept-live + historical in the blob so the panel keeps showing
|
|
|
@@ -406,18 +542,16 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun
|
|
|
jsonIps, _ := json.Marshal(dbIps)
|
|
|
inboundClientIps.Ips = string(jsonIps)
|
|
|
|
|
|
- db := database.GetDB()
|
|
|
- err := db.Save(inboundClientIps).Error
|
|
|
- if err != nil {
|
|
|
+ if err := tx.Save(inboundClientIps).Error; err != nil {
|
|
|
logger.Error("failed to save inboundClientIps:", err)
|
|
|
- return false
|
|
|
+ return false, banned
|
|
|
}
|
|
|
|
|
|
if len(j.disAllowedIps) > 0 {
|
|
|
logger.Infof("[LIMIT_IP] Client %s: Kept %d live IPs, queued %d old IPs for fail2ban", clientEmail, len(keptLive), len(j.disAllowedIps))
|
|
|
}
|
|
|
|
|
|
- return shouldCleanLog
|
|
|
+ return shouldCleanLog, banned
|
|
|
}
|
|
|
|
|
|
// disconnectClientTemporarily removes and re-adds a client to force disconnect banned connections
|