csrf.go 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. package session
  2. import (
  3. "crypto/rand"
  4. "crypto/subtle"
  5. "encoding/base64"
  6. "io"
  7. "github.com/gin-contrib/sessions"
  8. "github.com/gin-gonic/gin"
  9. )
  10. const csrfTokenKey = "CSRF_TOKEN"
  11. // CSRFHeaderName is the request header used by browser clients for unsafe methods.
  12. const CSRFHeaderName = "X-CSRF-Token"
  13. // EnsureCSRFToken returns the current session CSRF token or creates one.
  14. func EnsureCSRFToken(c *gin.Context) (string, error) {
  15. s := sessions.Default(c)
  16. if token, ok := s.Get(csrfTokenKey).(string); ok && token != "" {
  17. return token, nil
  18. }
  19. token, err := newCSRFToken()
  20. if err != nil {
  21. return "", err
  22. }
  23. s.Set(csrfTokenKey, token)
  24. return token, s.Save()
  25. }
  26. // ValidateCSRFToken checks the submitted CSRF token against the session token.
  27. func ValidateCSRFToken(c *gin.Context) bool {
  28. s := sessions.Default(c)
  29. expected, ok := s.Get(csrfTokenKey).(string)
  30. if !ok || expected == "" {
  31. return false
  32. }
  33. actual := c.GetHeader(CSRFHeaderName)
  34. if actual == "" {
  35. actual = c.PostForm("_csrf")
  36. }
  37. if len(actual) != len(expected) {
  38. return false
  39. }
  40. return subtle.ConstantTimeCompare([]byte(actual), []byte(expected)) == 1
  41. }
  42. func newCSRFToken() (string, error) {
  43. buf := make([]byte, 32)
  44. if _, err := io.ReadFull(rand.Reader, buf); err != nil {
  45. return "", err
  46. }
  47. return base64.RawURLEncoding.EncodeToString(buf), nil
  48. }