hub.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. // Package websocket provides WebSocket hub for real-time updates and notifications.
  2. package websocket
  3. import (
  4. "context"
  5. "encoding/json"
  6. "runtime"
  7. "sync"
  8. "time"
  9. "github.com/mhsanaei/3x-ui/v2/logger"
  10. )
  11. // MessageType represents the type of WebSocket message
  12. type MessageType string
  13. const (
  14. MessageTypeStatus MessageType = "status" // Server status update
  15. MessageTypeTraffic MessageType = "traffic" // Traffic statistics update
  16. MessageTypeInbounds MessageType = "inbounds" // Inbounds list update
  17. MessageTypeNotification MessageType = "notification" // System notification
  18. MessageTypeXrayState MessageType = "xray_state" // Xray state change
  19. MessageTypeOutbounds MessageType = "outbounds" // Outbounds list update
  20. MessageTypeInvalidate MessageType = "invalidate" // Lightweight signal telling frontend to re-fetch data via REST
  21. )
  22. // Message represents a WebSocket message
  23. type Message struct {
  24. Type MessageType `json:"type"`
  25. Payload any `json:"payload"`
  26. Time int64 `json:"time"`
  27. }
  28. // Client represents a WebSocket client connection
  29. type Client struct {
  30. ID string
  31. Send chan []byte
  32. Hub *Hub
  33. Topics map[MessageType]bool // Subscribed topics
  34. closeOnce sync.Once // Ensures Send channel is closed exactly once
  35. }
  36. // Hub maintains the set of active clients and broadcasts messages to them
  37. type Hub struct {
  38. // Registered clients
  39. clients map[*Client]bool
  40. // Inbound messages from clients
  41. broadcast chan []byte
  42. // Register requests from clients
  43. register chan *Client
  44. // Unregister requests from clients
  45. unregister chan *Client
  46. // Mutex for thread-safe operations
  47. mu sync.RWMutex
  48. // Context for graceful shutdown
  49. ctx context.Context
  50. cancel context.CancelFunc
  51. // Worker pool for parallel broadcasting
  52. workerPoolSize int
  53. }
  54. // NewHub creates a new WebSocket hub
  55. func NewHub() *Hub {
  56. ctx, cancel := context.WithCancel(context.Background())
  57. // Calculate optimal worker pool size (CPU cores * 2, but max 100)
  58. workerPoolSize := runtime.NumCPU() * 2
  59. if workerPoolSize > 100 {
  60. workerPoolSize = 100
  61. }
  62. if workerPoolSize < 10 {
  63. workerPoolSize = 10
  64. }
  65. return &Hub{
  66. clients: make(map[*Client]bool),
  67. broadcast: make(chan []byte, 2048), // Increased from 256 to 2048 for high load
  68. register: make(chan *Client, 100), // Buffered channel for fast registration
  69. unregister: make(chan *Client, 100), // Buffered channel for fast unregistration
  70. ctx: ctx,
  71. cancel: cancel,
  72. workerPoolSize: workerPoolSize,
  73. }
  74. }
  75. // Run starts the hub's main loop
  76. func (h *Hub) Run() {
  77. defer func() {
  78. if r := recover(); r != nil {
  79. logger.Error("WebSocket hub panic recovered:", r)
  80. // Restart the hub loop
  81. go h.Run()
  82. }
  83. }()
  84. for {
  85. select {
  86. case <-h.ctx.Done():
  87. // Graceful shutdown: close all clients
  88. h.mu.Lock()
  89. for client := range h.clients {
  90. client.closeOnce.Do(func() {
  91. close(client.Send)
  92. })
  93. }
  94. h.clients = make(map[*Client]bool)
  95. h.mu.Unlock()
  96. logger.Info("WebSocket hub stopped gracefully")
  97. return
  98. case client := <-h.register:
  99. if client == nil {
  100. continue
  101. }
  102. h.mu.Lock()
  103. h.clients[client] = true
  104. count := len(h.clients)
  105. h.mu.Unlock()
  106. logger.Debugf("WebSocket client connected: %s (total: %d)", client.ID, count)
  107. case client := <-h.unregister:
  108. if client == nil {
  109. continue
  110. }
  111. h.mu.Lock()
  112. if _, ok := h.clients[client]; ok {
  113. delete(h.clients, client)
  114. client.closeOnce.Do(func() {
  115. close(client.Send)
  116. })
  117. }
  118. count := len(h.clients)
  119. h.mu.Unlock()
  120. logger.Debugf("WebSocket client disconnected: %s (total: %d)", client.ID, count)
  121. case message := <-h.broadcast:
  122. if message == nil {
  123. continue
  124. }
  125. // Optimization: quickly copy client list and release lock
  126. h.mu.RLock()
  127. clientCount := len(h.clients)
  128. if clientCount == 0 {
  129. h.mu.RUnlock()
  130. continue
  131. }
  132. // Pre-allocate memory for client list
  133. clients := make([]*Client, 0, clientCount)
  134. for client := range h.clients {
  135. clients = append(clients, client)
  136. }
  137. h.mu.RUnlock()
  138. // Parallel broadcast using worker pool
  139. h.broadcastParallel(clients, message)
  140. }
  141. }
  142. }
  143. // broadcastParallel sends message to all clients in parallel for maximum performance
  144. func (h *Hub) broadcastParallel(clients []*Client, message []byte) {
  145. if len(clients) == 0 {
  146. return
  147. }
  148. // For small number of clients, use simple parallel sending
  149. if len(clients) < h.workerPoolSize {
  150. var wg sync.WaitGroup
  151. for _, client := range clients {
  152. wg.Add(1)
  153. go func(c *Client) {
  154. defer wg.Done()
  155. defer func() {
  156. if r := recover(); r != nil {
  157. // Channel may be closed, safely ignore
  158. logger.Debugf("WebSocket broadcast panic recovered for client %s: %v", c.ID, r)
  159. }
  160. }()
  161. select {
  162. case c.Send <- message:
  163. default:
  164. // Client's send buffer is full, disconnect
  165. logger.Debugf("WebSocket client %s send buffer full, disconnecting", c.ID)
  166. h.Unregister(c)
  167. }
  168. }(client)
  169. }
  170. wg.Wait()
  171. return
  172. }
  173. // For large number of clients, use worker pool for optimal performance
  174. clientChan := make(chan *Client, len(clients))
  175. for _, client := range clients {
  176. clientChan <- client
  177. }
  178. close(clientChan)
  179. // Use a local WaitGroup to avoid blocking hub shutdown
  180. var wg sync.WaitGroup
  181. wg.Add(h.workerPoolSize)
  182. for i := 0; i < h.workerPoolSize; i++ {
  183. go func() {
  184. defer wg.Done()
  185. for client := range clientChan {
  186. func() {
  187. defer func() {
  188. if r := recover(); r != nil {
  189. // Channel may be closed, safely ignore
  190. logger.Debugf("WebSocket broadcast panic recovered for client %s: %v", client.ID, r)
  191. }
  192. }()
  193. select {
  194. case client.Send <- message:
  195. default:
  196. // Client's send buffer is full, disconnect
  197. logger.Debugf("WebSocket client %s send buffer full, disconnecting", client.ID)
  198. h.Unregister(client)
  199. }
  200. }()
  201. }
  202. }()
  203. }
  204. // Wait for all workers to finish
  205. wg.Wait()
  206. }
  207. // Broadcast sends a message to all connected clients
  208. func (h *Hub) Broadcast(messageType MessageType, payload any) {
  209. if h == nil {
  210. return
  211. }
  212. if payload == nil {
  213. logger.Warning("Attempted to broadcast nil payload")
  214. return
  215. }
  216. // Skip all work if no clients are connected
  217. if h.GetClientCount() == 0 {
  218. return
  219. }
  220. msg := Message{
  221. Type: messageType,
  222. Payload: payload,
  223. Time: getCurrentTimestamp(),
  224. }
  225. data, err := json.Marshal(msg)
  226. if err != nil {
  227. logger.Error("Failed to marshal WebSocket message:", err)
  228. return
  229. }
  230. // If message exceeds size limit, send a lightweight invalidate notification
  231. // instead of dropping it entirely — the frontend will re-fetch via REST API
  232. const maxMessageSize = 10 * 1024 * 1024 // 10MB
  233. if len(data) > maxMessageSize {
  234. logger.Debugf("WebSocket message too large (%d bytes) for type %s, sending invalidate signal", len(data), messageType)
  235. h.broadcastInvalidate(messageType)
  236. return
  237. }
  238. // Non-blocking send with timeout to prevent delays
  239. select {
  240. case h.broadcast <- data:
  241. case <-time.After(100 * time.Millisecond):
  242. logger.Warning("WebSocket broadcast channel is full, dropping message")
  243. case <-h.ctx.Done():
  244. // Hub is shutting down
  245. }
  246. }
  247. // BroadcastToTopic sends a message only to clients subscribed to the specific topic
  248. func (h *Hub) BroadcastToTopic(messageType MessageType, payload any) {
  249. if h == nil {
  250. return
  251. }
  252. if payload == nil {
  253. logger.Warning("Attempted to broadcast nil payload to topic")
  254. return
  255. }
  256. // Skip all work if no clients are connected
  257. if h.GetClientCount() == 0 {
  258. return
  259. }
  260. msg := Message{
  261. Type: messageType,
  262. Payload: payload,
  263. Time: getCurrentTimestamp(),
  264. }
  265. data, err := json.Marshal(msg)
  266. if err != nil {
  267. logger.Error("Failed to marshal WebSocket message:", err)
  268. return
  269. }
  270. // If message exceeds size limit, send a lightweight invalidate notification
  271. const maxMessageSize = 10 * 1024 * 1024 // 10MB
  272. if len(data) > maxMessageSize {
  273. logger.Debugf("WebSocket message too large (%d bytes) for type %s, sending invalidate signal", len(data), messageType)
  274. h.broadcastInvalidate(messageType)
  275. return
  276. }
  277. h.mu.RLock()
  278. // Filter clients by topics and quickly release lock
  279. subscribedClients := make([]*Client, 0)
  280. for client := range h.clients {
  281. if len(client.Topics) == 0 || client.Topics[messageType] {
  282. subscribedClients = append(subscribedClients, client)
  283. }
  284. }
  285. h.mu.RUnlock()
  286. // Parallel send to subscribed clients
  287. if len(subscribedClients) > 0 {
  288. h.broadcastParallel(subscribedClients, data)
  289. }
  290. }
  291. // GetClientCount returns the number of connected clients
  292. func (h *Hub) GetClientCount() int {
  293. h.mu.RLock()
  294. defer h.mu.RUnlock()
  295. return len(h.clients)
  296. }
  297. // Register registers a new client with the hub
  298. func (h *Hub) Register(client *Client) {
  299. if h == nil || client == nil {
  300. return
  301. }
  302. select {
  303. case h.register <- client:
  304. case <-h.ctx.Done():
  305. // Hub is shutting down
  306. }
  307. }
  308. // Unregister unregisters a client from the hub
  309. func (h *Hub) Unregister(client *Client) {
  310. if h == nil || client == nil {
  311. return
  312. }
  313. select {
  314. case h.unregister <- client:
  315. case <-h.ctx.Done():
  316. // Hub is shutting down
  317. }
  318. }
  319. // Stop gracefully stops the hub and closes all connections
  320. func (h *Hub) Stop() {
  321. if h == nil {
  322. return
  323. }
  324. if h.cancel != nil {
  325. h.cancel()
  326. }
  327. }
  328. // broadcastInvalidate sends a lightweight invalidate message to all clients,
  329. // telling them to re-fetch the specified data type via REST API.
  330. // This is used when the full payload exceeds the WebSocket message size limit.
  331. func (h *Hub) broadcastInvalidate(originalType MessageType) {
  332. msg := Message{
  333. Type: MessageTypeInvalidate,
  334. Payload: map[string]string{"type": string(originalType)},
  335. Time: getCurrentTimestamp(),
  336. }
  337. data, err := json.Marshal(msg)
  338. if err != nil {
  339. logger.Error("Failed to marshal invalidate message:", err)
  340. return
  341. }
  342. // Non-blocking send with timeout
  343. select {
  344. case h.broadcast <- data:
  345. case <-time.After(100 * time.Millisecond):
  346. logger.Warning("WebSocket broadcast channel is full, dropping invalidate message")
  347. case <-h.ctx.Done():
  348. }
  349. }
  350. // getCurrentTimestamp returns current Unix timestamp in milliseconds
  351. func getCurrentTimestamp() int64 {
  352. return time.Now().UnixMilli()
  353. }