websocket.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package controller
  2. import (
  3. "net/http"
  4. "strings"
  5. "time"
  6. "github.com/google/uuid"
  7. "github.com/mhsanaei/3x-ui/v2/logger"
  8. "github.com/mhsanaei/3x-ui/v2/util/common"
  9. "github.com/mhsanaei/3x-ui/v2/web/session"
  10. "github.com/mhsanaei/3x-ui/v2/web/websocket"
  11. "github.com/gin-gonic/gin"
  12. ws "github.com/gorilla/websocket"
  13. )
  14. const (
  15. // Time allowed to write a message to the peer
  16. writeWait = 10 * time.Second
  17. // Time allowed to read the next pong message from the peer
  18. pongWait = 60 * time.Second
  19. // Send pings to peer with this period (must be less than pongWait)
  20. pingPeriod = (pongWait * 9) / 10
  21. // Maximum message size allowed from peer
  22. maxMessageSize = 512
  23. )
  24. var upgrader = ws.Upgrader{
  25. ReadBufferSize: 32768,
  26. WriteBufferSize: 32768,
  27. EnableCompression: true, // Negotiate permessage-deflate compression if the client supports it
  28. CheckOrigin: func(r *http.Request) bool {
  29. // Check origin for security
  30. origin := r.Header.Get("Origin")
  31. if origin == "" {
  32. // Allow connections without Origin header (same-origin requests)
  33. return true
  34. }
  35. // Get the host from the request
  36. host := r.Host
  37. // Extract scheme and host from origin
  38. originURL := origin
  39. // Simple check: origin should match the request host
  40. // This prevents cross-origin WebSocket hijacking
  41. if strings.HasPrefix(originURL, "http://") || strings.HasPrefix(originURL, "https://") {
  42. // Extract host from origin
  43. originHost := strings.TrimPrefix(strings.TrimPrefix(originURL, "http://"), "https://")
  44. if idx := strings.Index(originHost, "/"); idx != -1 {
  45. originHost = originHost[:idx]
  46. }
  47. if idx := strings.Index(originHost, ":"); idx != -1 {
  48. originHost = originHost[:idx]
  49. }
  50. // Compare hosts (without port)
  51. requestHost := host
  52. if idx := strings.Index(requestHost, ":"); idx != -1 {
  53. requestHost = requestHost[:idx]
  54. }
  55. return originHost == requestHost || originHost == "" || requestHost == ""
  56. }
  57. return false
  58. },
  59. }
  60. // WebSocketController handles WebSocket connections for real-time updates
  61. type WebSocketController struct {
  62. BaseController
  63. hub *websocket.Hub
  64. }
  65. // NewWebSocketController creates a new WebSocket controller
  66. func NewWebSocketController(hub *websocket.Hub) *WebSocketController {
  67. return &WebSocketController{
  68. hub: hub,
  69. }
  70. }
  71. // HandleWebSocket handles WebSocket connections
  72. func (w *WebSocketController) HandleWebSocket(c *gin.Context) {
  73. // Check authentication
  74. if !session.IsLogin(c) {
  75. logger.Warningf("Unauthorized WebSocket connection attempt from %s", getRemoteIp(c))
  76. c.AbortWithStatus(http.StatusUnauthorized)
  77. return
  78. }
  79. // Upgrade connection to WebSocket
  80. conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  81. if err != nil {
  82. logger.Error("Failed to upgrade WebSocket connection:", err)
  83. return
  84. }
  85. // Create client
  86. clientID := uuid.New().String()
  87. client := &websocket.Client{
  88. ID: clientID,
  89. Hub: w.hub,
  90. Send: make(chan []byte, 512), // Increased from 256 to 512 to prevent overflow
  91. Topics: make(map[websocket.MessageType]bool),
  92. }
  93. // Register client
  94. w.hub.Register(client)
  95. logger.Debugf("WebSocket client %s registered from %s", clientID, getRemoteIp(c))
  96. // Start goroutines for reading and writing
  97. go w.writePump(client, conn)
  98. go w.readPump(client, conn)
  99. }
  100. // readPump pumps messages from the WebSocket connection to the hub
  101. func (w *WebSocketController) readPump(client *websocket.Client, conn *ws.Conn) {
  102. defer func() {
  103. if r := common.Recover("WebSocket readPump panic"); r != nil {
  104. logger.Error("WebSocket readPump panic recovered:", r)
  105. }
  106. w.hub.Unregister(client)
  107. conn.Close()
  108. }()
  109. conn.SetReadDeadline(time.Now().Add(pongWait))
  110. conn.SetPongHandler(func(string) error {
  111. conn.SetReadDeadline(time.Now().Add(pongWait))
  112. return nil
  113. })
  114. conn.SetReadLimit(maxMessageSize)
  115. for {
  116. _, message, err := conn.ReadMessage()
  117. if err != nil {
  118. if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseAbnormalClosure) {
  119. logger.Debugf("WebSocket read error for client %s: %v", client.ID, err)
  120. }
  121. break
  122. }
  123. // Validate message size
  124. if len(message) > maxMessageSize {
  125. logger.Warningf("WebSocket message from client %s exceeds max size: %d bytes", client.ID, len(message))
  126. continue
  127. }
  128. // Handle incoming messages (e.g., subscription requests)
  129. // For now, we'll just log them
  130. logger.Debugf("Received WebSocket message from client %s: %s", client.ID, string(message))
  131. }
  132. }
  133. // writePump pumps messages from the hub to the WebSocket connection
  134. func (w *WebSocketController) writePump(client *websocket.Client, conn *ws.Conn) {
  135. ticker := time.NewTicker(pingPeriod)
  136. defer func() {
  137. if r := common.Recover("WebSocket writePump panic"); r != nil {
  138. logger.Error("WebSocket writePump panic recovered:", r)
  139. }
  140. ticker.Stop()
  141. conn.Close()
  142. }()
  143. for {
  144. select {
  145. case message, ok := <-client.Send:
  146. conn.SetWriteDeadline(time.Now().Add(writeWait))
  147. if !ok {
  148. // Hub closed the channel
  149. conn.WriteMessage(ws.CloseMessage, []byte{})
  150. return
  151. }
  152. // Send each message individually (no batching)
  153. // This ensures each JSON message is sent separately and can be parsed correctly
  154. if err := conn.WriteMessage(ws.TextMessage, message); err != nil {
  155. logger.Debugf("WebSocket write error for client %s: %v", client.ID, err)
  156. return
  157. }
  158. case <-ticker.C:
  159. conn.SetWriteDeadline(time.Now().Add(writeWait))
  160. if err := conn.WriteMessage(ws.PingMessage, nil); err != nil {
  161. logger.Debugf("WebSocket ping error for client %s: %v", client.ID, err)
  162. return
  163. }
  164. }
  165. }
  166. }