1
0

login_limiter.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. package controller
  2. import (
  3. "strings"
  4. "sync"
  5. "time"
  6. )
  7. const (
  8. loginLimitMaxFailures = 5
  9. loginLimitWindow = 5 * time.Minute
  10. loginLimitCooldown = 15 * time.Minute
  11. )
  12. var defaultLoginLimiter = newLoginLimiter(loginLimitMaxFailures, loginLimitWindow, loginLimitCooldown)
  13. type loginLimiter struct {
  14. mu sync.Mutex
  15. now func() time.Time
  16. maxFailures int
  17. window time.Duration
  18. cooldown time.Duration
  19. attempts map[string]*loginLimitRecord
  20. }
  21. type loginLimitRecord struct {
  22. failures []time.Time
  23. blockedUntil time.Time
  24. }
  25. func newLoginLimiter(maxFailures int, window, cooldown time.Duration) *loginLimiter {
  26. return &loginLimiter{
  27. now: time.Now,
  28. maxFailures: maxFailures,
  29. window: window,
  30. cooldown: cooldown,
  31. attempts: make(map[string]*loginLimitRecord),
  32. }
  33. }
  34. func (l *loginLimiter) allow(ip, username string) (time.Time, bool) {
  35. l.mu.Lock()
  36. defer l.mu.Unlock()
  37. key := loginLimitKey(ip, username)
  38. record := l.attempts[key]
  39. if record == nil {
  40. return time.Time{}, true
  41. }
  42. now := l.now()
  43. if now.Before(record.blockedUntil) {
  44. return record.blockedUntil, false
  45. }
  46. record.blockedUntil = time.Time{}
  47. record.failures = pruneLoginFailures(record.failures, now.Add(-l.window))
  48. if len(record.failures) == 0 {
  49. delete(l.attempts, key)
  50. }
  51. return time.Time{}, true
  52. }
  53. func (l *loginLimiter) registerFailure(ip, username string) (time.Time, bool) {
  54. l.mu.Lock()
  55. defer l.mu.Unlock()
  56. key := loginLimitKey(ip, username)
  57. record := l.attempts[key]
  58. if record == nil {
  59. record = &loginLimitRecord{}
  60. l.attempts[key] = record
  61. }
  62. now := l.now()
  63. record.failures = pruneLoginFailures(record.failures, now.Add(-l.window))
  64. record.failures = append(record.failures, now)
  65. if len(record.failures) >= l.maxFailures {
  66. record.failures = nil
  67. record.blockedUntil = now.Add(l.cooldown)
  68. return record.blockedUntil, true
  69. }
  70. return time.Time{}, false
  71. }
  72. func (l *loginLimiter) registerSuccess(ip, username string) {
  73. l.mu.Lock()
  74. defer l.mu.Unlock()
  75. delete(l.attempts, loginLimitKey(ip, username))
  76. }
  77. func loginLimitKey(ip, username string) string {
  78. return strings.TrimSpace(ip) + "\x00" + strings.ToLower(strings.TrimSpace(username))
  79. }
  80. func pruneLoginFailures(failures []time.Time, cutoff time.Time) []time.Time {
  81. keepFrom := 0
  82. for keepFrom < len(failures) && failures[keepFrom].Before(cutoff) {
  83. keepFrom++
  84. }
  85. return failures[keepFrom:]
  86. }