custom_geo_test.go 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. package service
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "os"
  9. "path/filepath"
  10. "testing"
  11. "github.com/mhsanaei/3x-ui/v2/database/model"
  12. )
  13. // disableSSRFCheck disables the SSRF guard for the duration of a test,
  14. // allowing httptest servers on localhost. It restores the original on cleanup.
  15. func disableSSRFCheck(t *testing.T) {
  16. t.Helper()
  17. orig := checkSSRF
  18. checkSSRF = func(_ context.Context, _ string) error { return nil }
  19. t.Cleanup(func() { checkSSRF = orig })
  20. }
  21. func TestNormalizeAliasKey(t *testing.T) {
  22. if got := NormalizeAliasKey("GeoIP-IR"); got != "geoip_ir" {
  23. t.Fatalf("got %q", got)
  24. }
  25. if got := NormalizeAliasKey("a-b_c"); got != "a_b_c" {
  26. t.Fatalf("got %q", got)
  27. }
  28. }
  29. func TestNewCustomGeoService(t *testing.T) {
  30. s := NewCustomGeoService()
  31. if err := s.validateAlias("ok_alias-1"); err != nil {
  32. t.Fatal(err)
  33. }
  34. }
  35. func TestTriggerUpdateAllAllSuccess(t *testing.T) {
  36. s := CustomGeoService{}
  37. s.updateAllGetAll = func() ([]model.CustomGeoResource, error) {
  38. return []model.CustomGeoResource{
  39. {Id: 1, Alias: "a"},
  40. {Id: 2, Alias: "b"},
  41. }, nil
  42. }
  43. s.updateAllApply = func(id int, onStartup bool) (string, error) {
  44. return fmt.Sprintf("geo_%d.dat", id), nil
  45. }
  46. restartCalls := 0
  47. s.updateAllRestart = func() error {
  48. restartCalls++
  49. return nil
  50. }
  51. res, err := s.TriggerUpdateAll()
  52. if err != nil {
  53. t.Fatal(err)
  54. }
  55. if len(res.Succeeded) != 2 || len(res.Failed) != 0 {
  56. t.Fatalf("unexpected result: %+v", res)
  57. }
  58. if restartCalls != 1 {
  59. t.Fatalf("expected 1 restart, got %d", restartCalls)
  60. }
  61. }
  62. func TestTriggerUpdateAllPartialSuccess(t *testing.T) {
  63. s := CustomGeoService{}
  64. s.updateAllGetAll = func() ([]model.CustomGeoResource, error) {
  65. return []model.CustomGeoResource{
  66. {Id: 1, Alias: "ok"},
  67. {Id: 2, Alias: "bad"},
  68. }, nil
  69. }
  70. s.updateAllApply = func(id int, onStartup bool) (string, error) {
  71. if id == 2 {
  72. return "geo_2.dat", ErrCustomGeoDownload
  73. }
  74. return "geo_1.dat", nil
  75. }
  76. restartCalls := 0
  77. s.updateAllRestart = func() error {
  78. restartCalls++
  79. return nil
  80. }
  81. res, err := s.TriggerUpdateAll()
  82. if err != nil {
  83. t.Fatal(err)
  84. }
  85. if len(res.Succeeded) != 1 || len(res.Failed) != 1 {
  86. t.Fatalf("unexpected result: %+v", res)
  87. }
  88. if restartCalls != 1 {
  89. t.Fatalf("expected 1 restart, got %d", restartCalls)
  90. }
  91. }
  92. func TestTriggerUpdateAllAllFailure(t *testing.T) {
  93. s := CustomGeoService{}
  94. s.updateAllGetAll = func() ([]model.CustomGeoResource, error) {
  95. return []model.CustomGeoResource{
  96. {Id: 1, Alias: "a"},
  97. {Id: 2, Alias: "b"},
  98. }, nil
  99. }
  100. s.updateAllApply = func(id int, onStartup bool) (string, error) {
  101. return fmt.Sprintf("geo_%d.dat", id), ErrCustomGeoDownload
  102. }
  103. restartCalls := 0
  104. s.updateAllRestart = func() error {
  105. restartCalls++
  106. return nil
  107. }
  108. res, err := s.TriggerUpdateAll()
  109. if err != nil {
  110. t.Fatal(err)
  111. }
  112. if len(res.Succeeded) != 0 || len(res.Failed) != 2 {
  113. t.Fatalf("unexpected result: %+v", res)
  114. }
  115. if restartCalls != 0 {
  116. t.Fatalf("expected 0 restart, got %d", restartCalls)
  117. }
  118. }
  119. func TestCustomGeoValidateAlias(t *testing.T) {
  120. s := CustomGeoService{}
  121. if err := s.validateAlias(""); !errors.Is(err, ErrCustomGeoAliasRequired) {
  122. t.Fatal("empty alias")
  123. }
  124. if err := s.validateAlias("Bad"); !errors.Is(err, ErrCustomGeoAliasPattern) {
  125. t.Fatal("uppercase")
  126. }
  127. if err := s.validateAlias("a b"); !errors.Is(err, ErrCustomGeoAliasPattern) {
  128. t.Fatal("space")
  129. }
  130. if err := s.validateAlias("ok_alias-1"); err != nil {
  131. t.Fatal(err)
  132. }
  133. if err := s.validateAlias("geoip"); !errors.Is(err, ErrCustomGeoAliasReserved) {
  134. t.Fatal("reserved")
  135. }
  136. }
  137. func TestCustomGeoValidateURL(t *testing.T) {
  138. s := CustomGeoService{}
  139. if _, err := s.sanitizeURL(""); !errors.Is(err, ErrCustomGeoURLRequired) {
  140. t.Fatal("empty")
  141. }
  142. if _, err := s.sanitizeURL("ftp://x"); !errors.Is(err, ErrCustomGeoURLScheme) {
  143. t.Fatal("ftp")
  144. }
  145. if sanitized, err := s.sanitizeURL("https://example.com/a.dat"); err != nil {
  146. t.Fatal(err)
  147. } else if sanitized != "https://example.com/a.dat" {
  148. t.Fatalf("unexpected sanitized URL: %s", sanitized)
  149. }
  150. }
  151. func TestCustomGeoValidateType(t *testing.T) {
  152. s := CustomGeoService{}
  153. if err := s.validateType("geosite"); err != nil {
  154. t.Fatal(err)
  155. }
  156. if err := s.validateType("x"); !errors.Is(err, ErrCustomGeoInvalidType) {
  157. t.Fatal("bad type")
  158. }
  159. }
  160. func TestCustomGeoDownloadToPath(t *testing.T) {
  161. disableSSRFCheck(t)
  162. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  163. w.Header().Set("X-Test", "1")
  164. if r.Header.Get("If-Modified-Since") != "" {
  165. w.WriteHeader(http.StatusNotModified)
  166. return
  167. }
  168. w.WriteHeader(http.StatusOK)
  169. _, _ = w.Write(make([]byte, minDatBytes+1))
  170. }))
  171. defer ts.Close()
  172. dir := t.TempDir()
  173. t.Setenv("XUI_BIN_FOLDER", dir)
  174. dest := filepath.Join(dir, "geoip_t.dat")
  175. s := CustomGeoService{}
  176. skipped, _, err := s.downloadToPath(ts.URL, dest, "")
  177. if err != nil {
  178. t.Fatal(err)
  179. }
  180. if skipped {
  181. t.Fatal("expected download")
  182. }
  183. st, err := os.Stat(dest)
  184. if err != nil || st.Size() < minDatBytes {
  185. t.Fatalf("file %v", err)
  186. }
  187. skipped2, _, err2 := s.downloadToPath(ts.URL, dest, "")
  188. if err2 != nil || !skipped2 {
  189. t.Fatalf("304 expected skipped=%v err=%v", skipped2, err2)
  190. }
  191. }
  192. func TestCustomGeoDownloadToPath_missingLocalSendsNoIMSFromDB(t *testing.T) {
  193. disableSSRFCheck(t)
  194. lm := "Wed, 21 Oct 2015 07:28:00 GMT"
  195. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  196. if r.Header.Get("If-Modified-Since") != "" {
  197. w.WriteHeader(http.StatusNotModified)
  198. return
  199. }
  200. w.Header().Set("Last-Modified", lm)
  201. w.WriteHeader(http.StatusOK)
  202. _, _ = w.Write(make([]byte, minDatBytes+1))
  203. }))
  204. defer ts.Close()
  205. dir := t.TempDir()
  206. t.Setenv("XUI_BIN_FOLDER", dir)
  207. dest := filepath.Join(dir, "geoip_rebuild.dat")
  208. s := CustomGeoService{}
  209. skipped, _, err := s.downloadToPath(ts.URL, dest, lm)
  210. if err != nil {
  211. t.Fatal(err)
  212. }
  213. if skipped {
  214. t.Fatal("must not treat as not-modified when local file is missing")
  215. }
  216. if _, err := os.Stat(dest); err != nil {
  217. t.Fatal("file should exist after container-style rebuild")
  218. }
  219. }
  220. func TestCustomGeoDownloadToPath_repairSkipsConditional(t *testing.T) {
  221. disableSSRFCheck(t)
  222. lm := "Wed, 21 Oct 2015 07:28:00 GMT"
  223. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  224. if r.Header.Get("If-Modified-Since") != "" {
  225. w.WriteHeader(http.StatusNotModified)
  226. return
  227. }
  228. w.Header().Set("Last-Modified", lm)
  229. w.WriteHeader(http.StatusOK)
  230. _, _ = w.Write(make([]byte, minDatBytes+1))
  231. }))
  232. defer ts.Close()
  233. dir := t.TempDir()
  234. t.Setenv("XUI_BIN_FOLDER", dir)
  235. dest := filepath.Join(dir, "geoip_bad.dat")
  236. if err := os.WriteFile(dest, make([]byte, minDatBytes-1), 0o644); err != nil {
  237. t.Fatal(err)
  238. }
  239. s := CustomGeoService{}
  240. skipped, _, err := s.downloadToPath(ts.URL, dest, lm)
  241. if err != nil {
  242. t.Fatal(err)
  243. }
  244. if skipped {
  245. t.Fatal("corrupt local file must be re-downloaded, not 304")
  246. }
  247. st, err := os.Stat(dest)
  248. if err != nil || st.Size() < minDatBytes {
  249. t.Fatalf("file repaired: %v", err)
  250. }
  251. }
  252. func TestCustomGeoFileNameFor(t *testing.T) {
  253. s := CustomGeoService{}
  254. if s.fileNameFor("geoip", "a") != "geoip_a.dat" {
  255. t.Fatal("geoip name")
  256. }
  257. if s.fileNameFor("geosite", "b") != "geosite_b.dat" {
  258. t.Fatal("geosite name")
  259. }
  260. }
  261. func TestLocalDatFileNeedsRepair(t *testing.T) {
  262. dir := t.TempDir()
  263. t.Setenv("XUI_BIN_FOLDER", dir)
  264. if !localDatFileNeedsRepair(filepath.Join(dir, "missing.dat")) {
  265. t.Fatal("missing")
  266. }
  267. smallPath := filepath.Join(dir, "small.dat")
  268. if err := os.WriteFile(smallPath, make([]byte, minDatBytes-1), 0o644); err != nil {
  269. t.Fatal(err)
  270. }
  271. if !localDatFileNeedsRepair(smallPath) {
  272. t.Fatal("small")
  273. }
  274. okPath := filepath.Join(dir, "ok.dat")
  275. if err := os.WriteFile(okPath, make([]byte, minDatBytes), 0o644); err != nil {
  276. t.Fatal(err)
  277. }
  278. if localDatFileNeedsRepair(okPath) {
  279. t.Fatal("ok size")
  280. }
  281. dirPath := filepath.Join(dir, "isdir.dat")
  282. if err := os.Mkdir(dirPath, 0o755); err != nil {
  283. t.Fatal(err)
  284. }
  285. if !localDatFileNeedsRepair(dirPath) {
  286. t.Fatal("dir should need repair")
  287. }
  288. if !CustomGeoLocalFileNeedsRepair(dirPath) {
  289. t.Fatal("exported wrapper dir")
  290. }
  291. if CustomGeoLocalFileNeedsRepair(okPath) {
  292. t.Fatal("exported wrapper ok file")
  293. }
  294. }
  295. func TestProbeCustomGeoURL_HEADOK(t *testing.T) {
  296. disableSSRFCheck(t)
  297. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  298. if r.Method == http.MethodHead {
  299. w.WriteHeader(http.StatusOK)
  300. return
  301. }
  302. w.WriteHeader(http.StatusOK)
  303. }))
  304. defer ts.Close()
  305. if err := probeCustomGeoURL(ts.URL); err != nil {
  306. t.Fatal(err)
  307. }
  308. }
  309. func TestProbeCustomGeoURL_HEAD405GETRange(t *testing.T) {
  310. disableSSRFCheck(t)
  311. ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  312. if r.Method == http.MethodHead {
  313. w.WriteHeader(http.StatusMethodNotAllowed)
  314. return
  315. }
  316. if r.Method == http.MethodGet && r.Header.Get("Range") != "" {
  317. w.WriteHeader(http.StatusPartialContent)
  318. _, _ = w.Write([]byte{0})
  319. return
  320. }
  321. w.WriteHeader(http.StatusBadRequest)
  322. }))
  323. defer ts.Close()
  324. if err := probeCustomGeoURL(ts.URL); err != nil {
  325. t.Fatal(err)
  326. }
  327. }