1
0

security_test.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package middleware
  2. import (
  3. "net/http"
  4. "net/http/httptest"
  5. "testing"
  6. "github.com/mhsanaei/3x-ui/v3/web/session"
  7. "github.com/gin-contrib/sessions"
  8. "github.com/gin-contrib/sessions/cookie"
  9. "github.com/gin-gonic/gin"
  10. )
  11. func TestCSRFMiddlewareAllowsSafeMethods(t *testing.T) {
  12. gin.SetMode(gin.TestMode)
  13. router := gin.New()
  14. router.Use(CSRFMiddleware())
  15. router.GET("/safe", func(c *gin.Context) {
  16. c.String(http.StatusOK, "ok")
  17. })
  18. rec := httptest.NewRecorder()
  19. req := httptest.NewRequest(http.MethodGet, "/safe", nil)
  20. router.ServeHTTP(rec, req)
  21. if rec.Code != http.StatusOK {
  22. t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
  23. }
  24. }
  25. func TestCSRFMiddlewareRejectsMissingTokenAndAcceptsValidToken(t *testing.T) {
  26. gin.SetMode(gin.TestMode)
  27. router := gin.New()
  28. store := cookie.NewStore([]byte("01234567890123456789012345678901"))
  29. router.Use(sessions.Sessions("3x-ui", store))
  30. router.GET("/token", func(c *gin.Context) {
  31. token, err := session.EnsureCSRFToken(c)
  32. if err != nil {
  33. t.Fatal(err)
  34. }
  35. c.String(http.StatusOK, token)
  36. })
  37. router.POST("/submit", CSRFMiddleware(), func(c *gin.Context) {
  38. c.String(http.StatusOK, "ok")
  39. })
  40. tokenRec := httptest.NewRecorder()
  41. tokenReq := httptest.NewRequest(http.MethodGet, "/token", nil)
  42. router.ServeHTTP(tokenRec, tokenReq)
  43. if tokenRec.Code != http.StatusOK {
  44. t.Fatalf("token status = %d, want %d", tokenRec.Code, http.StatusOK)
  45. }
  46. cookies := tokenRec.Result().Cookies()
  47. token := tokenRec.Body.String()
  48. missingRec := httptest.NewRecorder()
  49. missingReq := httptest.NewRequest(http.MethodPost, "/submit", nil)
  50. for _, cookie := range cookies {
  51. missingReq.AddCookie(cookie)
  52. }
  53. router.ServeHTTP(missingRec, missingReq)
  54. if missingRec.Code != http.StatusForbidden {
  55. t.Fatalf("missing token status = %d, want %d", missingRec.Code, http.StatusForbidden)
  56. }
  57. validRec := httptest.NewRecorder()
  58. validReq := httptest.NewRequest(http.MethodPost, "/submit", nil)
  59. for _, cookie := range cookies {
  60. validReq.AddCookie(cookie)
  61. }
  62. validReq.Header.Set(session.CSRFHeaderName, token)
  63. router.ServeHTTP(validRec, validReq)
  64. if validRec.Code != http.StatusOK {
  65. t.Fatalf("valid token status = %d, want %d", validRec.Code, http.StatusOK)
  66. }
  67. }
  68. func TestSecurityHeadersMiddleware(t *testing.T) {
  69. gin.SetMode(gin.TestMode)
  70. router := gin.New()
  71. router.Use(SecurityHeadersMiddleware(true))
  72. router.GET("/", func(c *gin.Context) {
  73. c.String(http.StatusOK, "ok")
  74. })
  75. rec := httptest.NewRecorder()
  76. req := httptest.NewRequest(http.MethodGet, "/", nil)
  77. router.ServeHTTP(rec, req)
  78. headers := rec.Result().Header
  79. if got := headers.Get("X-Content-Type-Options"); got != "nosniff" {
  80. t.Fatalf("X-Content-Type-Options = %q", got)
  81. }
  82. if got := headers.Get("X-Frame-Options"); got != "DENY" {
  83. t.Fatalf("X-Frame-Options = %q", got)
  84. }
  85. if got := headers.Get("Referrer-Policy"); got != "no-referrer" {
  86. t.Fatalf("Referrer-Policy = %q", got)
  87. }
  88. if got := headers.Get("Strict-Transport-Security"); got == "" {
  89. t.Fatal("Strict-Transport-Security should be set for direct HTTPS")
  90. }
  91. }
  92. func TestSecurityHeadersMiddlewareSkipsHSTSWithoutDirectHTTPS(t *testing.T) {
  93. gin.SetMode(gin.TestMode)
  94. router := gin.New()
  95. router.Use(SecurityHeadersMiddleware(false))
  96. router.GET("/", func(c *gin.Context) {
  97. c.String(http.StatusOK, "ok")
  98. })
  99. rec := httptest.NewRecorder()
  100. req := httptest.NewRequest(http.MethodGet, "/", nil)
  101. router.ServeHTTP(rec, req)
  102. if got := rec.Result().Header.Get("Strict-Transport-Security"); got != "" {
  103. t.Fatalf("Strict-Transport-Security = %q, want empty", got)
  104. }
  105. }