websocket.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package controller
  2. import (
  3. "net/http"
  4. "strings"
  5. "time"
  6. "github.com/google/uuid"
  7. "github.com/mhsanaei/3x-ui/v2/logger"
  8. "github.com/mhsanaei/3x-ui/v2/util/common"
  9. "github.com/mhsanaei/3x-ui/v2/web/session"
  10. "github.com/mhsanaei/3x-ui/v2/web/websocket"
  11. "github.com/gin-gonic/gin"
  12. ws "github.com/gorilla/websocket"
  13. )
  14. const (
  15. // Time allowed to write a message to the peer
  16. writeWait = 10 * time.Second
  17. // Time allowed to read the next pong message from the peer
  18. pongWait = 60 * time.Second
  19. // Send pings to peer with this period (must be less than pongWait)
  20. pingPeriod = (pongWait * 9) / 10
  21. // Maximum message size allowed from peer
  22. maxMessageSize = 512
  23. )
  24. var upgrader = ws.Upgrader{
  25. ReadBufferSize: 4096, // Increased from 1024 for better performance
  26. WriteBufferSize: 4096, // Increased from 1024 for better performance
  27. CheckOrigin: func(r *http.Request) bool {
  28. // Check origin for security
  29. origin := r.Header.Get("Origin")
  30. if origin == "" {
  31. // Allow connections without Origin header (same-origin requests)
  32. return true
  33. }
  34. // Get the host from the request
  35. host := r.Host
  36. // Extract scheme and host from origin
  37. originURL := origin
  38. // Simple check: origin should match the request host
  39. // This prevents cross-origin WebSocket hijacking
  40. if strings.HasPrefix(originURL, "http://") || strings.HasPrefix(originURL, "https://") {
  41. // Extract host from origin
  42. originHost := strings.TrimPrefix(strings.TrimPrefix(originURL, "http://"), "https://")
  43. if idx := strings.Index(originHost, "/"); idx != -1 {
  44. originHost = originHost[:idx]
  45. }
  46. if idx := strings.Index(originHost, ":"); idx != -1 {
  47. originHost = originHost[:idx]
  48. }
  49. // Compare hosts (without port)
  50. requestHost := host
  51. if idx := strings.Index(requestHost, ":"); idx != -1 {
  52. requestHost = requestHost[:idx]
  53. }
  54. return originHost == requestHost || originHost == "" || requestHost == ""
  55. }
  56. return false
  57. },
  58. }
  59. // WebSocketController handles WebSocket connections for real-time updates
  60. type WebSocketController struct {
  61. BaseController
  62. hub *websocket.Hub
  63. }
  64. // NewWebSocketController creates a new WebSocket controller
  65. func NewWebSocketController(hub *websocket.Hub) *WebSocketController {
  66. return &WebSocketController{
  67. hub: hub,
  68. }
  69. }
  70. // HandleWebSocket handles WebSocket connections
  71. func (w *WebSocketController) HandleWebSocket(c *gin.Context) {
  72. // Check authentication
  73. if !session.IsLogin(c) {
  74. logger.Warningf("Unauthorized WebSocket connection attempt from %s", getRemoteIp(c))
  75. c.AbortWithStatus(http.StatusUnauthorized)
  76. return
  77. }
  78. // Upgrade connection to WebSocket
  79. conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
  80. if err != nil {
  81. logger.Error("Failed to upgrade WebSocket connection:", err)
  82. return
  83. }
  84. // Create client
  85. clientID := uuid.New().String()
  86. client := &websocket.Client{
  87. ID: clientID,
  88. Hub: w.hub,
  89. Send: make(chan []byte, 512), // Increased from 256 to 512 to prevent overflow
  90. Topics: make(map[websocket.MessageType]bool),
  91. }
  92. // Register client
  93. w.hub.Register(client)
  94. logger.Debugf("WebSocket client %s registered from %s", clientID, getRemoteIp(c))
  95. // Start goroutines for reading and writing
  96. go w.writePump(client, conn)
  97. go w.readPump(client, conn)
  98. }
  99. // readPump pumps messages from the WebSocket connection to the hub
  100. func (w *WebSocketController) readPump(client *websocket.Client, conn *ws.Conn) {
  101. defer func() {
  102. if r := common.Recover("WebSocket readPump panic"); r != nil {
  103. logger.Error("WebSocket readPump panic recovered:", r)
  104. }
  105. w.hub.Unregister(client)
  106. conn.Close()
  107. }()
  108. conn.SetReadDeadline(time.Now().Add(pongWait))
  109. conn.SetPongHandler(func(string) error {
  110. conn.SetReadDeadline(time.Now().Add(pongWait))
  111. return nil
  112. })
  113. conn.SetReadLimit(maxMessageSize)
  114. for {
  115. _, message, err := conn.ReadMessage()
  116. if err != nil {
  117. if ws.IsUnexpectedCloseError(err, ws.CloseGoingAway, ws.CloseAbnormalClosure) {
  118. logger.Debugf("WebSocket read error for client %s: %v", client.ID, err)
  119. }
  120. break
  121. }
  122. // Validate message size
  123. if len(message) > maxMessageSize {
  124. logger.Warningf("WebSocket message from client %s exceeds max size: %d bytes", client.ID, len(message))
  125. continue
  126. }
  127. // Handle incoming messages (e.g., subscription requests)
  128. // For now, we'll just log them
  129. logger.Debugf("Received WebSocket message from client %s: %s", client.ID, string(message))
  130. }
  131. }
  132. // writePump pumps messages from the hub to the WebSocket connection
  133. func (w *WebSocketController) writePump(client *websocket.Client, conn *ws.Conn) {
  134. ticker := time.NewTicker(pingPeriod)
  135. defer func() {
  136. if r := common.Recover("WebSocket writePump panic"); r != nil {
  137. logger.Error("WebSocket writePump panic recovered:", r)
  138. }
  139. ticker.Stop()
  140. conn.Close()
  141. }()
  142. for {
  143. select {
  144. case message, ok := <-client.Send:
  145. conn.SetWriteDeadline(time.Now().Add(writeWait))
  146. if !ok {
  147. // Hub closed the channel
  148. conn.WriteMessage(ws.CloseMessage, []byte{})
  149. return
  150. }
  151. // Send each message individually (no batching)
  152. // This ensures each JSON message is sent separately and can be parsed correctly
  153. if err := conn.WriteMessage(ws.TextMessage, message); err != nil {
  154. logger.Debugf("WebSocket write error for client %s: %v", client.ID, err)
  155. return
  156. }
  157. case <-ticker.C:
  158. conn.SetWriteDeadline(time.Now().Add(writeWait))
  159. if err := conn.WriteMessage(ws.PingMessage, nil); err != nil {
  160. logger.Debugf("WebSocket ping error for client %s: %v", client.ID, err)
  161. return
  162. }
  163. }
  164. }
  165. }