Prechádzať zdrojové kódy

perf(clients): make SyncInbound bulk to fix large-inbound timeouts (#4885)

Every client mutation funnels through SyncInbound, which ran O(n) DB
round-trips per call: one SELECT per client, a Save+UpdateColumn per
client, and a per-row junction INSERT. Toggling a single client on a
large inbound issued thousands of queries and timed out, badly so on
PostgreSQL where each round-trip pays TCP latency.

SyncInbound now:
- loads existing records with a single chunked SELECT ... email IN (...)
  instead of one query per client
- writes only the records that actually changed (skips no-op Saves), so
  toggling/editing one client writes one row, not all of them
- batch-creates new records and batch-inserts the junction rows

Merge and sticky-field semantics are unchanged. Measured on PostgreSQL
16: a single-client toggle on a 50k-client inbound drops from ~8m54s to
~0.9s, and seeding 50k clients from ~2m48s to ~1.6s; 200k clients sync
in seconds.

A skip-gated benchmark (web/service/sync_scale_postgres_test.go, run
with XUI_DB_TYPE=postgres) reproduces and verifies the scaling.
MHSanaei 16 hodín pred
rodič
commit
756746dbca
2 zmenil súbory, kde vykonal 351 pridanie a 56 odobranie
  1. 117 56
      web/service/client.go
  2. 234 0
      web/service/sync_scale_postgres_test.go

+ 117 - 56
web/service/client.go

@@ -196,73 +196,134 @@ func (s *ClientService) SyncInbound(tx *gorm.DB, inboundId int, clients []model.
 		return err
 	}
 
+	emails := make([]string, 0, len(clients))
+	seen := make(map[string]struct{}, len(clients))
 	for i := range clients {
-		c := clients[i]
-		email := strings.TrimSpace(c.Email)
+		email := strings.TrimSpace(clients[i].Email)
 		if email == "" {
 			continue
 		}
+		if _, ok := seen[email]; ok {
+			continue
+		}
+		seen[email] = struct{}{}
+		emails = append(emails, email)
+	}
 
-		incoming := c.ToRecord()
-		row := &model.ClientRecord{}
-		err := tx.Where("email = ?", email).First(row).Error
-		if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
+	existing := make(map[string]*model.ClientRecord, len(emails))
+	const selectChunk = 400
+	for start := 0; start < len(emails); start += selectChunk {
+		end := min(start+selectChunk, len(emails))
+		var rows []model.ClientRecord
+		if err := tx.Where("email IN ?", emails[start:end]).Find(&rows).Error; err != nil {
 			return err
 		}
-		if errors.Is(err, gorm.ErrRecordNotFound) {
-			if err := tx.Create(incoming).Error; err != nil {
-				return err
-			}
-			row = incoming
-		} else {
-			if incoming.UUID != "" {
-				row.UUID = incoming.UUID
-			}
-			if incoming.Password != "" {
-				row.Password = incoming.Password
-			}
-			if incoming.Auth != "" {
-				row.Auth = incoming.Auth
-			}
-			row.Flow = incoming.Flow
-			if incoming.Security != "" {
-				row.Security = incoming.Security
-			}
-			if incoming.Reverse != "" {
-				row.Reverse = incoming.Reverse
-			}
-			row.SubID = incoming.SubID
-			row.LimitIP = incoming.LimitIP
-			row.TotalGB = incoming.TotalGB
-			row.ExpiryTime = incoming.ExpiryTime
-			row.Enable = incoming.Enable
-			row.TgID = incoming.TgID
-			if incoming.Group != "" {
-				row.Group = incoming.Group
-			}
-			row.Comment = incoming.Comment
-			row.Reset = incoming.Reset
-			if incoming.CreatedAt > 0 && (row.CreatedAt == 0 || incoming.CreatedAt < row.CreatedAt) {
-				row.CreatedAt = incoming.CreatedAt
-			}
-			preservedUpdatedAt := max(incoming.UpdatedAt, row.UpdatedAt)
-			row.UpdatedAt = preservedUpdatedAt
-			if err := tx.Save(row).Error; err != nil {
-				return err
-			}
-			if err := tx.Model(&model.ClientRecord{}).
-				Where("id = ?", row.Id).
-				UpdateColumn("updated_at", preservedUpdatedAt).Error; err != nil {
-				return err
+		for i := range rows {
+			r := rows[i]
+			existing[r.Email] = &r
+		}
+	}
+
+	idByEmail := make(map[string]int, len(emails))
+	pending := make(map[string]*model.ClientRecord, len(emails))
+	toCreate := make([]*model.ClientRecord, 0, len(emails))
+	for i := range clients {
+		email := strings.TrimSpace(clients[i].Email)
+		if email == "" {
+			continue
+		}
+
+		incoming := clients[i].ToRecord()
+		row, ok := existing[email]
+		if !ok {
+			if _, dup := pending[email]; !dup {
+				pending[email] = incoming
+				toCreate = append(toCreate, incoming)
 			}
+			continue
 		}
 
-		link := model.ClientInbound{
-			ClientId:     row.Id,
-			InboundId:    inboundId,
-			FlowOverride: c.Flow,
+		before := *row
+		if incoming.UUID != "" {
+			row.UUID = incoming.UUID
+		}
+		if incoming.Password != "" {
+			row.Password = incoming.Password
+		}
+		if incoming.Auth != "" {
+			row.Auth = incoming.Auth
+		}
+		row.Flow = incoming.Flow
+		if incoming.Security != "" {
+			row.Security = incoming.Security
+		}
+		if incoming.Reverse != "" {
+			row.Reverse = incoming.Reverse
 		}
-		if err := tx.Create(&link).Error; err != nil {
+		row.SubID = incoming.SubID
+		row.LimitIP = incoming.LimitIP
+		row.TotalGB = incoming.TotalGB
+		row.ExpiryTime = incoming.ExpiryTime
+		row.Enable = incoming.Enable
+		row.TgID = incoming.TgID
+		if incoming.Group != "" {
+			row.Group = incoming.Group
+		}
+		row.Comment = incoming.Comment
+		row.Reset = incoming.Reset
+		if incoming.CreatedAt > 0 && (row.CreatedAt == 0 || incoming.CreatedAt < row.CreatedAt) {
+			row.CreatedAt = incoming.CreatedAt
+		}
+		preservedUpdatedAt := max(incoming.UpdatedAt, row.UpdatedAt)
+		row.UpdatedAt = preservedUpdatedAt
+
+		idByEmail[email] = row.Id
+
+		if *row == before {
+			continue
+		}
+		if err := tx.Save(row).Error; err != nil {
+			return err
+		}
+		if err := tx.Model(&model.ClientRecord{}).
+			Where("id = ?", row.Id).
+			UpdateColumn("updated_at", preservedUpdatedAt).Error; err != nil {
+			return err
+		}
+	}
+
+	if len(toCreate) > 0 {
+		if err := tx.CreateInBatches(toCreate, 200).Error; err != nil {
+			return err
+		}
+		for _, rec := range toCreate {
+			idByEmail[rec.Email] = rec.Id
+		}
+	}
+
+	links := make([]model.ClientInbound, 0, len(clients))
+	linked := make(map[int]struct{}, len(clients))
+	for i := range clients {
+		email := strings.TrimSpace(clients[i].Email)
+		if email == "" {
+			continue
+		}
+		id, ok := idByEmail[email]
+		if !ok {
+			continue
+		}
+		if _, dup := linked[id]; dup {
+			continue
+		}
+		linked[id] = struct{}{}
+		links = append(links, model.ClientInbound{
+			ClientId:     id,
+			InboundId:    inboundId,
+			FlowOverride: clients[i].Flow,
+		})
+	}
+	if len(links) > 0 {
+		if err := tx.CreateInBatches(links, 200).Error; err != nil {
 			return err
 		}
 	}

+ 234 - 0
web/service/sync_scale_postgres_test.go

@@ -0,0 +1,234 @@
+package service
+
+import (
+	"errors"
+	"fmt"
+	"os"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/google/uuid"
+	"github.com/mhsanaei/3x-ui/v3/database"
+	"github.com/mhsanaei/3x-ui/v3/database/model"
+
+	"gorm.io/gorm"
+)
+
+func syncInboundOld(tx *gorm.DB, inboundId int, clients []model.Client) error {
+	if tx == nil {
+		tx = database.GetDB()
+	}
+	if err := tx.Where("inbound_id = ?", inboundId).Delete(&model.ClientInbound{}).Error; err != nil {
+		return err
+	}
+	for i := range clients {
+		c := clients[i]
+		email := strings.TrimSpace(c.Email)
+		if email == "" {
+			continue
+		}
+		incoming := c.ToRecord()
+		row := &model.ClientRecord{}
+		err := tx.Where("email = ?", email).First(row).Error
+		if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
+			return err
+		}
+		if errors.Is(err, gorm.ErrRecordNotFound) {
+			if err := tx.Create(incoming).Error; err != nil {
+				return err
+			}
+			row = incoming
+		} else {
+			row.Flow = incoming.Flow
+			row.SubID = incoming.SubID
+			row.LimitIP = incoming.LimitIP
+			row.TotalGB = incoming.TotalGB
+			row.ExpiryTime = incoming.ExpiryTime
+			row.Enable = incoming.Enable
+			row.TgID = incoming.TgID
+			row.Comment = incoming.Comment
+			row.Reset = incoming.Reset
+			preservedUpdatedAt := max(incoming.UpdatedAt, row.UpdatedAt)
+			row.UpdatedAt = preservedUpdatedAt
+			if err := tx.Save(row).Error; err != nil {
+				return err
+			}
+			if err := tx.Model(&model.ClientRecord{}).
+				Where("id = ?", row.Id).
+				UpdateColumn("updated_at", preservedUpdatedAt).Error; err != nil {
+				return err
+			}
+		}
+		link := model.ClientInbound{ClientId: row.Id, InboundId: inboundId, FlowOverride: c.Flow}
+		if err := tx.Create(&link).Error; err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func makeScaleClients(n int) []model.Client {
+	out := make([]model.Client, n)
+	for i := 0; i < n; i++ {
+		out[i] = model.Client{
+			ID:     uuid.NewString(),
+			Email:  fmt.Sprintf("user-%07d@scale", i),
+			SubID:  fmt.Sprintf("sub-%07d", i),
+			Enable: true,
+		}
+	}
+	return out
+}
+
+func TestSyncInboundPostgresScale(t *testing.T) {
+	if strings.TrimSpace(os.Getenv("XUI_DB_DSN")) == "" || os.Getenv("XUI_DB_TYPE") != "postgres" {
+		t.Skip("set XUI_DB_TYPE=postgres and XUI_DB_DSN to run the postgres scale benchmark")
+	}
+	if err := database.InitDB(""); err != nil {
+		t.Fatalf("InitDB: %v", err)
+	}
+	t.Cleanup(func() { _ = database.CloseDB() })
+
+	svc := &ClientService{}
+	sizes := []int{5000, 10000, 20000, 50000, 100000, 200000}
+
+	for _, n := range sizes {
+		t.Run(fmt.Sprintf("N=%d", n), func(t *testing.T) {
+			db := database.GetDB()
+			if err := db.Exec("TRUNCATE TABLE inbounds, clients, client_inbounds RESTART IDENTITY CASCADE").Error; err != nil {
+				t.Fatalf("truncate: %v", err)
+			}
+
+			clients := makeScaleClients(n)
+			ib := &model.Inbound{
+				Tag:      fmt.Sprintf("scale-%d", n),
+				Enable:   true,
+				Port:     40000,
+				Protocol: model.VLESS,
+				Settings: clientsSettings(t, clients),
+			}
+			if err := db.Create(ib).Error; err != nil {
+				t.Fatalf("create inbound: %v", err)
+			}
+
+			start := time.Now()
+			if err := svc.SyncInbound(nil, ib.Id, clients); err != nil {
+				t.Fatalf("seed SyncInbound: %v", err)
+			}
+			seed := time.Since(start)
+
+			clients[n/2].Enable = !clients[n/2].Enable
+			start = time.Now()
+			if err := svc.SyncInbound(nil, ib.Id, clients); err != nil {
+				t.Fatalf("toggle SyncInbound (new): %v", err)
+			}
+			toggleNew := time.Since(start)
+
+			start = time.Now()
+			if err := svc.SyncInbound(nil, ib.Id, clients); err != nil {
+				t.Fatalf("noop SyncInbound (new): %v", err)
+			}
+			noopNew := time.Since(start)
+
+			toggleOld := time.Duration(0)
+			if n <= 10000 {
+				clients[n/2].Enable = !clients[n/2].Enable
+				start = time.Now()
+				if err := syncInboundOld(db, ib.Id, clients); err != nil {
+					t.Fatalf("toggle SyncInbound (old): %v", err)
+				}
+				toggleOld = time.Since(start)
+			}
+
+			var linkCount, recCount int64
+			db.Model(&model.ClientInbound{}).Where("inbound_id = ?", ib.Id).Count(&linkCount)
+			db.Model(&model.ClientRecord{}).Count(&recCount)
+			if int(linkCount) != n || int(recCount) != n {
+				t.Fatalf("row mismatch: links=%d records=%d want %d", linkCount, recCount, n)
+			}
+
+			oldStr, speedup := "skipped", ""
+			if toggleOld > 0 {
+				oldStr = toggleOld.Round(time.Millisecond).String()
+				speedup = fmt.Sprintf("  speedup=%.0fx", float64(toggleOld)/float64(maxDur(toggleNew, time.Millisecond)))
+			}
+			t.Logf("N=%-7d seed=%-10v toggle_new=%-10v noop_new=%-10v toggle_old=%-10s%s",
+				n, seed.Round(time.Millisecond), toggleNew.Round(time.Millisecond),
+				noopNew.Round(time.Millisecond), oldStr, speedup)
+		})
+	}
+}
+
+func maxDur(d, floor time.Duration) time.Duration {
+	if d < floor {
+		return floor
+	}
+	return d
+}
+
+func TestAddDelClientPostgresScale(t *testing.T) {
+	if strings.TrimSpace(os.Getenv("XUI_DB_DSN")) == "" || os.Getenv("XUI_DB_TYPE") != "postgres" {
+		t.Skip("set XUI_DB_TYPE=postgres and XUI_DB_DSN to run the postgres scale benchmark")
+	}
+	if err := database.InitDB(""); err != nil {
+		t.Fatalf("InitDB: %v", err)
+	}
+	t.Cleanup(func() { _ = database.CloseDB() })
+
+	svc := &ClientService{}
+	inboundSvc := &InboundService{}
+	sizes := []int{5000, 20000, 50000, 100000, 200000}
+
+	for _, n := range sizes {
+		t.Run(fmt.Sprintf("N=%d", n), func(t *testing.T) {
+			db := database.GetDB()
+			if err := db.Exec("TRUNCATE TABLE inbounds, clients, client_inbounds, client_traffics RESTART IDENTITY CASCADE").Error; err != nil {
+				t.Fatalf("truncate: %v", err)
+			}
+
+			clients := makeScaleClients(n)
+			ib := &model.Inbound{
+				Tag:      fmt.Sprintf("adddel-%d", n),
+				Enable:   true,
+				Port:     40000,
+				Protocol: model.VLESS,
+				Settings: clientsSettings(t, clients),
+			}
+			if err := db.Create(ib).Error; err != nil {
+				t.Fatalf("create inbound: %v", err)
+			}
+			if err := svc.SyncInbound(nil, ib.Id, clients); err != nil {
+				t.Fatalf("seed SyncInbound: %v", err)
+			}
+
+			newC := model.Client{
+				ID:     uuid.NewString(),
+				Email:  "added-client@scale",
+				SubID:  "added-sub",
+				Enable: true,
+			}
+			addData := &model.Inbound{Id: ib.Id, Protocol: model.VLESS, Settings: clientsSettings(t, []model.Client{newC})}
+			start := time.Now()
+			if _, err := svc.AddInboundClient(inboundSvc, addData); err != nil {
+				t.Fatalf("AddInboundClient: %v", err)
+			}
+			addDur := time.Since(start)
+
+			delId := clients[n/2].ID
+			start = time.Now()
+			if _, err := svc.DelInboundClient(inboundSvc, ib.Id, delId, false); err != nil {
+				t.Fatalf("DelInboundClient: %v", err)
+			}
+			delDur := time.Since(start)
+
+			var recCount int64
+			db.Model(&model.ClientRecord{}).Count(&recCount)
+			if int(recCount) != n {
+				t.Fatalf("record count after add+del = %d, want %d", recCount, n)
+			}
+
+			t.Logf("N=%-7d add=%-10v del=%-10v", n, addDur.Round(time.Millisecond), delDur.Round(time.Millisecond))
+		})
+	}
+}