1
0

websocket.go 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. package controller
  2. import (
  3. "net"
  4. "net/http"
  5. "net/url"
  6. "strings"
  7. "time"
  8. "github.com/google/uuid"
  9. "github.com/mhsanaei/3x-ui/v2/logger"
  10. "github.com/mhsanaei/3x-ui/v2/util/common"
  11. "github.com/mhsanaei/3x-ui/v2/web/session"
  12. "github.com/mhsanaei/3x-ui/v2/web/websocket"
  13. "github.com/gin-gonic/gin"
  14. ws "github.com/gorilla/websocket"
  15. )
  16. const (
  17. writeWait = 10 * time.Second
  18. pongWait = 60 * time.Second
  19. pingPeriod = (pongWait * 9) / 10
  20. clientReadLimit = 512
  21. )
  22. var upgrader = ws.Upgrader{
  23. ReadBufferSize: 32768,
  24. WriteBufferSize: 32768,
  25. EnableCompression: true,
  26. CheckOrigin: checkSameOrigin,
  27. }
  28. // checkSameOrigin allows requests with no Origin header (same-origin or non-browser
  29. // clients) and otherwise requires the Origin hostname to match the request hostname.
  30. // Comparison is case-insensitive (RFC 7230 §2.7.3) and ignores port differences
  31. // (the panel often sits behind a reverse proxy on a different port).
  32. func checkSameOrigin(r *http.Request) bool {
  33. origin := r.Header.Get("Origin")
  34. if origin == "" {
  35. return true
  36. }
  37. u, err := url.Parse(origin)
  38. if err != nil || u.Hostname() == "" {
  39. return false
  40. }
  41. host, _, err := net.SplitHostPort(r.Host)
  42. if err != nil {
  43. // IPv6 literals without a port arrive as "[::1]"; net.SplitHostPort
  44. // fails in that case while url.Hostname() returns the address without
  45. // brackets. Strip them so same-origin checks pass for bare IPv6 hosts.
  46. host = r.Host
  47. if len(host) >= 2 && host[0] == '[' && host[len(host)-1] == ']' {
  48. host = host[1 : len(host)-1]
  49. }
  50. }
  51. return strings.EqualFold(u.Hostname(), host)
  52. }
  53. // WebSocketController handles WebSocket connections for real-time updates.
  54. type WebSocketController struct {
  55. BaseController
  56. hub *websocket.Hub
  57. }
  58. // NewWebSocketController creates a new WebSocket controller.
  59. func NewWebSocketController(hub *websocket.Hub) *WebSocketController {
  60. return &WebSocketController{hub: hub}
  61. }
  62. // HandleWebSocket upgrades the HTTP connection and starts the read/write pumps.
  63. func (w *WebSocketController) HandleWebSocket(c *gin.Context) {
  64. if !session.IsLogin(c) {
  65. logger.Warningf("Unauthorized WebSocket connection attempt from %s", getRemoteIp(c))
  66. c.AbortWithStatus(http.StatusUnauthorized)
  67. return
  68. }
  69. conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  70. if err != nil {
  71. logger.Error("Failed to upgrade WebSocket connection:", err)
  72. return
  73. }
  74. client := websocket.NewClient(uuid.New().String())
  75. w.hub.Register(client)
  76. logger.Debugf("WebSocket client %s registered from %s", client.ID, getRemoteIp(c))
  77. go w.writePump(client, conn)
  78. go w.readPump(client, conn)
  79. }
  80. // readPump consumes inbound frames so the gorilla deadline/pong machinery keeps
  81. // running. Clients send no commands today; frames are discarded.
  82. func (w *WebSocketController) readPump(client *websocket.Client, conn *ws.Conn) {
  83. defer func() {
  84. if r := common.Recover("WebSocket readPump panic"); r != nil {
  85. logger.Error("WebSocket readPump panic recovered:", r)
  86. }
  87. w.hub.Unregister(client)
  88. conn.Close()
  89. }()
  90. conn.SetReadLimit(clientReadLimit)
  91. conn.SetReadDeadline(time.Now().Add(pongWait))
  92. conn.SetPongHandler(func(string) error {
  93. return conn.SetReadDeadline(time.Now().Add(pongWait))
  94. })
  95. for {
  96. if _, _, err := conn.ReadMessage(); err != nil {
  97. if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseAbnormalClosure) {
  98. logger.Debugf("WebSocket read error for client %s: %v", client.ID, err)
  99. }
  100. return
  101. }
  102. }
  103. }
  104. // writePump pushes hub messages to the connection and emits keepalive pings.
  105. func (w *WebSocketController) writePump(client *websocket.Client, conn *ws.Conn) {
  106. ticker := time.NewTicker(pingPeriod)
  107. defer func() {
  108. if r := common.Recover("WebSocket writePump panic"); r != nil {
  109. logger.Error("WebSocket writePump panic recovered:", r)
  110. }
  111. ticker.Stop()
  112. conn.Close()
  113. }()
  114. for {
  115. select {
  116. case msg, ok := <-client.Send:
  117. conn.SetWriteDeadline(time.Now().Add(writeWait))
  118. if !ok {
  119. conn.WriteMessage(ws.CloseMessage, []byte{})
  120. return
  121. }
  122. if err := conn.WriteMessage(ws.TextMessage, msg); err != nil {
  123. logger.Debugf("WebSocket write error for client %s: %v", client.ID, err)
  124. return
  125. }
  126. case <-ticker.C:
  127. conn.SetWriteDeadline(time.Now().Add(writeWait))
  128. if err := conn.WriteMessage(ws.PingMessage, nil); err != nil {
  129. logger.Debugf("WebSocket ping error for client %s: %v", client.ID, err)
  130. return
  131. }
  132. }
  133. }
  134. }