sync_scale_postgres_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package service
  2. import (
  3. "errors"
  4. "fmt"
  5. "os"
  6. "strings"
  7. "testing"
  8. "time"
  9. "github.com/google/uuid"
  10. "github.com/mhsanaei/3x-ui/v3/database"
  11. "github.com/mhsanaei/3x-ui/v3/database/model"
  12. "gorm.io/gorm"
  13. )
  14. func syncInboundOld(tx *gorm.DB, inboundId int, clients []model.Client) error {
  15. if tx == nil {
  16. tx = database.GetDB()
  17. }
  18. if err := tx.Where("inbound_id = ?", inboundId).Delete(&model.ClientInbound{}).Error; err != nil {
  19. return err
  20. }
  21. for i := range clients {
  22. c := clients[i]
  23. email := strings.TrimSpace(c.Email)
  24. if email == "" {
  25. continue
  26. }
  27. incoming := c.ToRecord()
  28. row := &model.ClientRecord{}
  29. err := tx.Where("email = ?", email).First(row).Error
  30. if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
  31. return err
  32. }
  33. if errors.Is(err, gorm.ErrRecordNotFound) {
  34. if err := tx.Create(incoming).Error; err != nil {
  35. return err
  36. }
  37. row = incoming
  38. } else {
  39. row.Flow = incoming.Flow
  40. row.SubID = incoming.SubID
  41. row.LimitIP = incoming.LimitIP
  42. row.TotalGB = incoming.TotalGB
  43. row.ExpiryTime = incoming.ExpiryTime
  44. row.Enable = incoming.Enable
  45. row.TgID = incoming.TgID
  46. row.Comment = incoming.Comment
  47. row.Reset = incoming.Reset
  48. preservedUpdatedAt := max(incoming.UpdatedAt, row.UpdatedAt)
  49. row.UpdatedAt = preservedUpdatedAt
  50. if err := tx.Save(row).Error; err != nil {
  51. return err
  52. }
  53. if err := tx.Model(&model.ClientRecord{}).
  54. Where("id = ?", row.Id).
  55. UpdateColumn("updated_at", preservedUpdatedAt).Error; err != nil {
  56. return err
  57. }
  58. }
  59. link := model.ClientInbound{ClientId: row.Id, InboundId: inboundId, FlowOverride: c.Flow}
  60. if err := tx.Create(&link).Error; err != nil {
  61. return err
  62. }
  63. }
  64. return nil
  65. }
  66. func makeScaleClients(n int) []model.Client {
  67. out := make([]model.Client, n)
  68. for i := 0; i < n; i++ {
  69. out[i] = model.Client{
  70. ID: uuid.NewString(),
  71. Email: fmt.Sprintf("user-%07d@scale", i),
  72. SubID: fmt.Sprintf("sub-%07d", i),
  73. Enable: true,
  74. }
  75. }
  76. return out
  77. }
  78. func TestSyncInboundPostgresScale(t *testing.T) {
  79. if strings.TrimSpace(os.Getenv("XUI_DB_DSN")) == "" || os.Getenv("XUI_DB_TYPE") != "postgres" {
  80. t.Skip("set XUI_DB_TYPE=postgres and XUI_DB_DSN to run the postgres scale benchmark")
  81. }
  82. if err := database.InitDB(""); err != nil {
  83. t.Fatalf("InitDB: %v", err)
  84. }
  85. t.Cleanup(func() { _ = database.CloseDB() })
  86. svc := &ClientService{}
  87. sizes := []int{5000, 10000, 20000, 50000, 100000, 200000}
  88. for _, n := range sizes {
  89. t.Run(fmt.Sprintf("N=%d", n), func(t *testing.T) {
  90. db := database.GetDB()
  91. if err := db.Exec("TRUNCATE TABLE inbounds, clients, client_inbounds RESTART IDENTITY CASCADE").Error; err != nil {
  92. t.Fatalf("truncate: %v", err)
  93. }
  94. clients := makeScaleClients(n)
  95. ib := &model.Inbound{
  96. Tag: fmt.Sprintf("scale-%d", n),
  97. Enable: true,
  98. Port: 40000,
  99. Protocol: model.VLESS,
  100. Settings: clientsSettings(t, clients),
  101. }
  102. if err := db.Create(ib).Error; err != nil {
  103. t.Fatalf("create inbound: %v", err)
  104. }
  105. start := time.Now()
  106. if err := svc.SyncInbound(nil, ib.Id, clients); err != nil {
  107. t.Fatalf("seed SyncInbound: %v", err)
  108. }
  109. seed := time.Since(start)
  110. clients[n/2].Enable = !clients[n/2].Enable
  111. start = time.Now()
  112. if err := svc.SyncInbound(nil, ib.Id, clients); err != nil {
  113. t.Fatalf("toggle SyncInbound (new): %v", err)
  114. }
  115. toggleNew := time.Since(start)
  116. start = time.Now()
  117. if err := svc.SyncInbound(nil, ib.Id, clients); err != nil {
  118. t.Fatalf("noop SyncInbound (new): %v", err)
  119. }
  120. noopNew := time.Since(start)
  121. toggleOld := time.Duration(0)
  122. if n <= 10000 {
  123. clients[n/2].Enable = !clients[n/2].Enable
  124. start = time.Now()
  125. if err := syncInboundOld(db, ib.Id, clients); err != nil {
  126. t.Fatalf("toggle SyncInbound (old): %v", err)
  127. }
  128. toggleOld = time.Since(start)
  129. }
  130. var linkCount, recCount int64
  131. db.Model(&model.ClientInbound{}).Where("inbound_id = ?", ib.Id).Count(&linkCount)
  132. db.Model(&model.ClientRecord{}).Count(&recCount)
  133. if int(linkCount) != n || int(recCount) != n {
  134. t.Fatalf("row mismatch: links=%d records=%d want %d", linkCount, recCount, n)
  135. }
  136. oldStr, speedup := "skipped", ""
  137. if toggleOld > 0 {
  138. oldStr = toggleOld.Round(time.Millisecond).String()
  139. speedup = fmt.Sprintf(" speedup=%.0fx", float64(toggleOld)/float64(maxDur(toggleNew, time.Millisecond)))
  140. }
  141. t.Logf("N=%-7d seed=%-10v toggle_new=%-10v noop_new=%-10v toggle_old=%-10s%s",
  142. n, seed.Round(time.Millisecond), toggleNew.Round(time.Millisecond),
  143. noopNew.Round(time.Millisecond), oldStr, speedup)
  144. })
  145. }
  146. }
  147. func maxDur(d, floor time.Duration) time.Duration {
  148. if d < floor {
  149. return floor
  150. }
  151. return d
  152. }
  153. func TestAddDelClientPostgresScale(t *testing.T) {
  154. if strings.TrimSpace(os.Getenv("XUI_DB_DSN")) == "" || os.Getenv("XUI_DB_TYPE") != "postgres" {
  155. t.Skip("set XUI_DB_TYPE=postgres and XUI_DB_DSN to run the postgres scale benchmark")
  156. }
  157. if err := database.InitDB(""); err != nil {
  158. t.Fatalf("InitDB: %v", err)
  159. }
  160. t.Cleanup(func() { _ = database.CloseDB() })
  161. svc := &ClientService{}
  162. inboundSvc := &InboundService{}
  163. sizes := []int{5000, 20000, 50000, 100000, 200000}
  164. for _, n := range sizes {
  165. t.Run(fmt.Sprintf("N=%d", n), func(t *testing.T) {
  166. db := database.GetDB()
  167. if err := db.Exec("TRUNCATE TABLE inbounds, clients, client_inbounds, client_traffics RESTART IDENTITY CASCADE").Error; err != nil {
  168. t.Fatalf("truncate: %v", err)
  169. }
  170. clients := makeScaleClients(n)
  171. ib := &model.Inbound{
  172. Tag: fmt.Sprintf("adddel-%d", n),
  173. Enable: true,
  174. Port: 40000,
  175. Protocol: model.VLESS,
  176. Settings: clientsSettings(t, clients),
  177. }
  178. if err := db.Create(ib).Error; err != nil {
  179. t.Fatalf("create inbound: %v", err)
  180. }
  181. if err := svc.SyncInbound(nil, ib.Id, clients); err != nil {
  182. t.Fatalf("seed SyncInbound: %v", err)
  183. }
  184. newC := model.Client{
  185. ID: uuid.NewString(),
  186. Email: "added-client@scale",
  187. SubID: "added-sub",
  188. Enable: true,
  189. }
  190. addData := &model.Inbound{Id: ib.Id, Protocol: model.VLESS, Settings: clientsSettings(t, []model.Client{newC})}
  191. start := time.Now()
  192. if _, err := svc.AddInboundClient(inboundSvc, addData); err != nil {
  193. t.Fatalf("AddInboundClient: %v", err)
  194. }
  195. addDur := time.Since(start)
  196. delId := clients[n/2].ID
  197. start = time.Now()
  198. if _, err := svc.DelInboundClient(inboundSvc, ib.Id, delId, false); err != nil {
  199. t.Fatalf("DelInboundClient: %v", err)
  200. }
  201. delDur := time.Since(start)
  202. var recCount int64
  203. db.Model(&model.ClientRecord{}).Count(&recCount)
  204. if int(recCount) != n {
  205. t.Fatalf("record count after add+del = %d, want %d", recCount, n)
  206. }
  207. t.Logf("N=%-7d add=%-10v del=%-10v", n, addDur.Round(time.Millisecond), delDur.Round(time.Millisecond))
  208. })
  209. }
  210. }