|
@@ -3,6 +3,7 @@ package middleware
|
|
import (
|
|
import (
|
|
"net"
|
|
"net"
|
|
"net/http"
|
|
"net/http"
|
|
|
|
+ "strings"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
)
|
|
@@ -14,12 +15,17 @@ func DomainValidatorMiddleware(domain string) gin.HandlerFunc {
|
|
host = c.GetHeader("X-Real-IP")
|
|
host = c.GetHeader("X-Real-IP")
|
|
}
|
|
}
|
|
if host == "" {
|
|
if host == "" {
|
|
- host, _, _ := net.SplitHostPort(c.Request.Host)
|
|
|
|
- if host != domain {
|
|
|
|
- c.AbortWithStatus(http.StatusForbidden)
|
|
|
|
- return
|
|
|
|
|
|
+ host = c.Request.Host
|
|
|
|
+ if colonIndex := strings.LastIndex(host, ":"); colonIndex != -1 {
|
|
|
|
+ host, _, _ = net.SplitHostPort(host)
|
|
}
|
|
}
|
|
- c.Next()
|
|
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ if host != domain {
|
|
|
|
+ c.AbortWithStatus(http.StatusForbidden)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ c.Next()
|
|
}
|
|
}
|
|
}
|
|
}
|