Browse Source

feat: Enhance host extraction from headers (#2292)

- Refactor SUBController subs and subJsons methods to extract host from X-Forwarded-Host header, falling back to X-Real-IP header and then to the request host if unavailable.
- Update html function to extract host from X-Forwarded-Host header, falling back to X-Real-IP header and then to the request host if unavailable.
- Update DomainValidatorMiddleware to first attempt to extract host from X-Forwarded-Host header, falling back to X-Real-IP header and then to the request host.

Fixes: #2284

Signed-off-by: Ahmad Thoriq Najahi <[email protected]>
Ahmad Thoriq Najahi 9 tháng trước cách đây
mục cha
commit
d070a82b3d
3 tập tin đã thay đổi với 46 bổ sung9 xóa
  1. 24 2
      sub/subController.go
  2. 12 1
      web/controller/util.go
  3. 10 6
      web/middleware/domainValidator.go

+ 24 - 2
sub/subController.go

@@ -54,7 +54,18 @@ func (a *SUBController) initRouter(g *gin.RouterGroup) {
 
 func (a *SUBController) subs(c *gin.Context) {
 	subId := c.Param("subid")
-	host, _, _ := net.SplitHostPort(c.Request.Host)
+	host := c.GetHeader("X-Forwarded-Host")
+	if host == "" {
+		host = c.GetHeader("X-Real-IP")
+	}
+	if host == "" {
+		var err error
+		host, _, err = net.SplitHostPort(c.Request.Host)
+		if err != nil {
+			host = c.Request.Host
+		}
+	}
+	host = host
 	subs, header, err := a.subService.GetSubs(subId, host)
 	if err != nil || len(subs) == 0 {
 		c.String(400, "Error!")
@@ -79,7 +90,18 @@ func (a *SUBController) subs(c *gin.Context) {
 
 func (a *SUBController) subJsons(c *gin.Context) {
 	subId := c.Param("subid")
-	host, _, _ := net.SplitHostPort(c.Request.Host)
+	host := c.GetHeader("X-Forwarded-Host")
+	if host == "" {
+		host = c.GetHeader("X-Real-IP")
+	}
+	if host == "" {
+		var err error
+		host, _, err = net.SplitHostPort(c.Request.Host)
+		if err != nil {
+			host = c.Request.Host
+		}
+	}
+	host = host
 	jsonSub, header, err := a.subJsonService.GetJson(subId, host)
 	if err != nil || len(jsonSub) == 0 {
 		c.String(400, "Error!")

+ 12 - 1
web/controller/util.go

@@ -64,7 +64,18 @@ func html(c *gin.Context, name string, title string, data gin.H) {
 		data = gin.H{}
 	}
 	data["title"] = title
-	data["host"], _, _ = net.SplitHostPort(c.Request.Host)
+	host := c.GetHeader("X-Forwarded-Host")
+	if host == "" {
+		host = c.GetHeader("X-Real-IP")
+	}
+	if host == "" {
+		var err error
+		host, _, err = net.SplitHostPort(c.Request.Host)
+		if err != nil {
+			host = c.Request.Host
+		}
+	}
+	data["host"] = host
 	data["request_uri"] = c.Request.RequestURI
 	data["base_path"] = c.GetString("base_path")
 	c.HTML(http.StatusOK, name, getContext(data))

+ 10 - 6
web/middleware/domainValidator.go

@@ -9,13 +9,17 @@ import (
 
 func DomainValidatorMiddleware(domain string) gin.HandlerFunc {
 	return func(c *gin.Context) {
-		host, _, _ := net.SplitHostPort(c.Request.Host)
-
-		if host != domain {
-			c.AbortWithStatus(http.StatusForbidden)
-			return
+		host := c.GetHeader("X-Forwarded-Host")
+		if host == "" {
+			host = c.GetHeader("X-Real-IP")
 		}
-
+		if host == "" {
+			host, _, _ := net.SplitHostPort(c.Request.Host)
+			if host != domain {
+				c.AbortWithStatus(http.StatusForbidden)
+				return
+			}
 		c.Next()
+		}
 	}
 }