wireguard_migration_test.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. package database
  2. import (
  3. "encoding/json"
  4. "path/filepath"
  5. "testing"
  6. "github.com/mhsanaei/3x-ui/v3/internal/database/model"
  7. )
  8. func initWGMigrationDB(t *testing.T) {
  9. t.Helper()
  10. dbDir := t.TempDir()
  11. t.Setenv("XUI_DB_FOLDER", dbDir)
  12. if err := InitDB(filepath.Join(dbDir, "x-ui.db")); err != nil {
  13. t.Fatalf("InitDB failed: %v", err)
  14. }
  15. t.Cleanup(func() { _ = CloseDB() })
  16. }
  17. func createWGInbound(t *testing.T, remark string, port int, peers []any) *model.Inbound {
  18. t.Helper()
  19. settings, err := json.Marshal(map[string]any{
  20. "secretKey": "c2VjcmV0LWtleS1iYXNlNjQtMzJieXRlcy1wbGFjZWg=",
  21. "mtu": 1420,
  22. "peers": peers,
  23. })
  24. if err != nil {
  25. t.Fatalf("marshal settings: %v", err)
  26. }
  27. in := &model.Inbound{
  28. UserId: 1,
  29. Remark: remark,
  30. Port: port,
  31. Protocol: model.WireGuard,
  32. Settings: string(settings),
  33. Tag: remark,
  34. }
  35. if err := db.Create(in).Error; err != nil {
  36. t.Fatalf("create wg inbound: %v", err)
  37. }
  38. return in
  39. }
  40. func clearWGMigrationHistory(t *testing.T) {
  41. t.Helper()
  42. if err := db.Where("seeder_name = ?", "WireguardPeersToClients").Delete(&model.HistoryOfSeeders{}).Error; err != nil {
  43. t.Fatalf("clear history: %v", err)
  44. }
  45. }
  46. func reloadInboundSettings(t *testing.T, id int) map[string]any {
  47. t.Helper()
  48. var in model.Inbound
  49. if err := db.First(&in, id).Error; err != nil {
  50. t.Fatalf("reload inbound: %v", err)
  51. }
  52. var settings map[string]any
  53. if err := json.Unmarshal([]byte(in.Settings), &settings); err != nil {
  54. t.Fatalf("unmarshal settings: %v", err)
  55. }
  56. return settings
  57. }
  58. func wgPeer(comment, priv, pub, ip string, keepAlive int) any {
  59. m := map[string]any{
  60. "privateKey": priv,
  61. "publicKey": pub,
  62. "allowedIPs": []any{ip},
  63. "keepAlive": keepAlive,
  64. }
  65. if comment != "" {
  66. m["comment"] = comment
  67. }
  68. return m
  69. }
  70. func TestSeedWireguardPeersToClientsCreatesClients(t *testing.T) {
  71. initWGMigrationDB(t)
  72. in := createWGInbound(t, "wg-server", 51820, []any{
  73. wgPeer("laptop", "priv-1", "pub-1", "10.0.0.2/32", 25),
  74. })
  75. clearWGMigrationHistory(t)
  76. if err := seedWireguardPeersToClients(); err != nil {
  77. t.Fatalf("seedWireguardPeersToClients: %v", err)
  78. }
  79. var rec model.ClientRecord
  80. if err := db.Where("email = ?", "wg-server-laptop").First(&rec).Error; err != nil {
  81. t.Fatalf("migrated client not found: %v", err)
  82. }
  83. if rec.PrivateKey != "priv-1" || rec.PublicKey != "pub-1" || rec.AllowedIPs != "10.0.0.2/32" {
  84. t.Fatalf("wg columns not migrated: %+v", rec)
  85. }
  86. var linkCount int64
  87. db.Model(&model.ClientInbound{}).Where("inbound_id = ? AND client_id = ?", in.Id, rec.Id).Count(&linkCount)
  88. if linkCount != 1 {
  89. t.Fatalf("expected 1 client_inbounds link, got %d", linkCount)
  90. }
  91. settings := reloadInboundSettings(t, in.Id)
  92. if _, ok := settings["peers"]; ok {
  93. t.Fatalf("peers key must be removed from stored settings")
  94. }
  95. clients, ok := settings["clients"].([]any)
  96. if !ok || len(clients) != 1 {
  97. t.Fatalf("settings.clients not written: %v", settings["clients"])
  98. }
  99. if settings["secretKey"] == nil || settings["mtu"] == nil {
  100. t.Fatalf("server fields not preserved: %v", settings)
  101. }
  102. }
  103. func TestSeedWireguardPeersToClientsIdempotent(t *testing.T) {
  104. initWGMigrationDB(t)
  105. in := createWGInbound(t, "wg-idem", 51823, []any{
  106. wgPeer("", "priv-a", "pub-a", "10.0.0.2/32", 0),
  107. })
  108. clearWGMigrationHistory(t)
  109. if err := seedWireguardPeersToClients(); err != nil {
  110. t.Fatalf("first run: %v", err)
  111. }
  112. if err := seedWireguardPeersToClients(); err != nil {
  113. t.Fatalf("second run (history gate): %v", err)
  114. }
  115. clearWGMigrationHistory(t)
  116. if err := seedWireguardPeersToClients(); err != nil {
  117. t.Fatalf("third run (linkCount gate): %v", err)
  118. }
  119. var clientCount int64
  120. db.Model(&model.ClientInbound{}).Where("inbound_id = ?", in.Id).Count(&clientCount)
  121. if clientCount != 1 {
  122. t.Fatalf("expected exactly 1 link after repeated runs, got %d", clientCount)
  123. }
  124. }
  125. func TestSeedWireguardPeersToClientsSkipsNonWireguard(t *testing.T) {
  126. initWGMigrationDB(t)
  127. vless := &model.Inbound{UserId: 1, Port: 41001, Protocol: model.VLESS, Tag: "vless-x", Settings: `{"clients":[]}`}
  128. if err := db.Create(vless).Error; err != nil {
  129. t.Fatalf("create vless: %v", err)
  130. }
  131. clearWGMigrationHistory(t)
  132. if err := seedWireguardPeersToClients(); err != nil {
  133. t.Fatalf("seed: %v", err)
  134. }
  135. var linkCount int64
  136. db.Model(&model.ClientInbound{}).Where("inbound_id = ?", vless.Id).Count(&linkCount)
  137. if linkCount != 0 {
  138. t.Fatalf("vless inbound must be untouched, got %d links", linkCount)
  139. }
  140. }
  141. func TestSeedWireguardPeersToClientsMultiplePeers(t *testing.T) {
  142. initWGMigrationDB(t)
  143. in := createWGInbound(t, "wg-multi", 51824, []any{
  144. wgPeer("alpha", "p1", "pub1", "10.0.0.2/32", 0),
  145. wgPeer("beta", "p2", "pub2", "10.0.0.3/32", 0),
  146. })
  147. clearWGMigrationHistory(t)
  148. if err := seedWireguardPeersToClients(); err != nil {
  149. t.Fatalf("seed: %v", err)
  150. }
  151. var links []model.ClientInbound
  152. if err := db.Where("inbound_id = ?", in.Id).Find(&links).Error; err != nil {
  153. t.Fatalf("load links: %v", err)
  154. }
  155. if len(links) != 2 {
  156. t.Fatalf("expected 2 links, got %d", len(links))
  157. }
  158. settings := reloadInboundSettings(t, in.Id)
  159. clients := settings["clients"].([]any)
  160. ips := map[string]bool{}
  161. emails := map[string]bool{}
  162. for _, c := range clients {
  163. m := c.(map[string]any)
  164. emails[m["email"].(string)] = true
  165. ip := m["allowedIPs"].([]any)[0].(string)
  166. ips[ip] = true
  167. }
  168. if len(ips) != 2 || len(emails) != 2 {
  169. t.Fatalf("expected distinct emails/ips, got emails=%v ips=%v", emails, ips)
  170. }
  171. }