1
0

client_wireguard_test.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. package service
  2. import (
  3. "testing"
  4. "github.com/mhsanaei/3x-ui/v3/internal/database/model"
  5. wgutil "github.com/mhsanaei/3x-ui/v3/internal/util/wireguard"
  6. )
  7. func TestAllocateWireguardAddress(t *testing.T) {
  8. tests := []struct {
  9. name string
  10. used []string
  11. base string
  12. want string
  13. err bool
  14. }{
  15. {name: "empty starts at .2", used: nil, base: "10.0.0.0/24", want: "10.0.0.2/32"},
  16. {name: "skips used", used: []string{"10.0.0.2/32"}, base: "10.0.0.0/24", want: "10.0.0.3/32"},
  17. {name: "fills gap", used: []string{"10.0.0.3/32", "10.0.0.4/32"}, base: "10.0.0.0/24", want: "10.0.0.2/32"},
  18. {name: "ignores catch-all", used: []string{"0.0.0.0/0", "::/0"}, base: "10.0.0.0/24", want: "10.0.0.2/32"},
  19. {name: "default base when empty", used: nil, base: "", want: "10.0.0.2/32"},
  20. {name: "exhausted /30", used: []string{"10.9.0.2/32", "10.9.0.3/32"}, base: "10.9.0.0/30", err: true},
  21. }
  22. for _, tt := range tests {
  23. t.Run(tt.name, func(t *testing.T) {
  24. got, err := allocateWireguardAddress(tt.used, tt.base)
  25. if tt.err {
  26. if err == nil {
  27. t.Fatalf("expected error, got %q", got)
  28. }
  29. return
  30. }
  31. if err != nil {
  32. t.Fatalf("unexpected error: %v", err)
  33. }
  34. if got != tt.want {
  35. t.Fatalf("got %q, want %q", got, tt.want)
  36. }
  37. })
  38. }
  39. }
  40. func TestDefaultWireguardClientsGeneratesKeypair(t *testing.T) {
  41. clients := []model.Client{{Email: "a@wg"}}
  42. ifaces := []any{map[string]any{"email": "a@wg"}}
  43. if err := defaultWireguardClients(nil, clients, ifaces); err != nil {
  44. t.Fatalf("defaultWireguardClients: %v", err)
  45. }
  46. c := clients[0]
  47. if c.PrivateKey == "" || c.PublicKey == "" {
  48. t.Fatalf("keypair not generated: priv=%q pub=%q", c.PrivateKey, c.PublicKey)
  49. }
  50. if len(c.AllowedIPs) != 1 || c.AllowedIPs[0] != "10.0.0.2/32" {
  51. t.Fatalf("allowedIPs not allocated: %v", c.AllowedIPs)
  52. }
  53. m := ifaces[0].(map[string]any)
  54. if m["privateKey"] != c.PrivateKey || m["publicKey"] != c.PublicKey {
  55. t.Fatalf("interface map not updated: %v", m)
  56. }
  57. }
  58. func TestDefaultWireguardClientsDerivesPublicKey(t *testing.T) {
  59. priv, _, err := wgutil.GenerateWireguardKeypair()
  60. if err != nil {
  61. t.Fatal(err)
  62. }
  63. wantPub, err := wgutil.PublicKeyFromPrivate(priv)
  64. if err != nil {
  65. t.Fatal(err)
  66. }
  67. clients := []model.Client{{Email: "b@wg", PrivateKey: priv}}
  68. ifaces := []any{map[string]any{"email": "b@wg"}}
  69. if err := defaultWireguardClients(nil, clients, ifaces); err != nil {
  70. t.Fatalf("defaultWireguardClients: %v", err)
  71. }
  72. if clients[0].PublicKey != wantPub {
  73. t.Fatalf("derived public key = %q, want %q", clients[0].PublicKey, wantPub)
  74. }
  75. }
  76. func TestDefaultWireguardClientsPreservesProvided(t *testing.T) {
  77. clients := []model.Client{{
  78. Email: "c@wg",
  79. PrivateKey: "keep-priv",
  80. PublicKey: "keep-pub",
  81. AllowedIPs: []string{"10.0.0.50/32"},
  82. }}
  83. ifaces := []any{map[string]any{"email": "c@wg"}}
  84. if err := defaultWireguardClients(nil, clients, ifaces); err != nil {
  85. t.Fatalf("defaultWireguardClients: %v", err)
  86. }
  87. if clients[0].PrivateKey != "keep-priv" || clients[0].PublicKey != "keep-pub" {
  88. t.Fatalf("provided keys were rotated: %+v", clients[0])
  89. }
  90. if clients[0].AllowedIPs[0] != "10.0.0.50/32" {
  91. t.Fatalf("provided allowedIPs changed: %v", clients[0].AllowedIPs)
  92. }
  93. }
  94. func TestWireguardAllocationBase(t *testing.T) {
  95. tests := []struct {
  96. name string
  97. used []string
  98. fallback string
  99. want string
  100. }{
  101. {name: "no peers uses fallback", used: nil, fallback: "10.0.0.0/24", want: "10.0.0.0/24"},
  102. {name: "derives subnet from existing peer", used: []string{"172.16.0.2/32"}, fallback: "10.0.0.0/24", want: "172.16.0.0/24"},
  103. {name: "skips catch-all and ipv6", used: []string{"0.0.0.0/0", "::/0", "fd00::2/128", "192.168.5.7/32"}, fallback: "10.0.0.0/24", want: "192.168.5.0/24"},
  104. }
  105. for _, tt := range tests {
  106. t.Run(tt.name, func(t *testing.T) {
  107. if got := wireguardAllocationBase(tt.used, tt.fallback); got != tt.want {
  108. t.Fatalf("got %q, want %q", got, tt.want)
  109. }
  110. })
  111. }
  112. }
  113. func TestDefaultWireguardClientsHonorsExistingSubnet(t *testing.T) {
  114. existing := []model.Client{{Email: "old@wg", AllowedIPs: []string{"172.16.0.2/32"}}}
  115. clients := []model.Client{{Email: "new@wg"}}
  116. ifaces := []any{map[string]any{"email": "new@wg"}}
  117. if err := defaultWireguardClients(existing, clients, ifaces); err != nil {
  118. t.Fatalf("defaultWireguardClients: %v", err)
  119. }
  120. if got := clients[0].AllowedIPs[0]; got != "172.16.0.3/32" {
  121. t.Fatalf("new client address = %q, want 172.16.0.3/32 in existing subnet", got)
  122. }
  123. }
  124. func TestDefaultWireguardClientsAllocatesDistinctIPs(t *testing.T) {
  125. clients := []model.Client{{Email: "x@wg"}, {Email: "y@wg"}}
  126. ifaces := []any{map[string]any{"email": "x@wg"}, map[string]any{"email": "y@wg"}}
  127. if err := defaultWireguardClients(nil, clients, ifaces); err != nil {
  128. t.Fatalf("defaultWireguardClients: %v", err)
  129. }
  130. if clients[0].AllowedIPs[0] == clients[1].AllowedIPs[0] {
  131. t.Fatalf("two clients got the same address: %v", clients[0].AllowedIPs)
  132. }
  133. }