hub.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. // Package websocket provides WebSocket hub for real-time updates and notifications.
  2. package websocket
  3. import (
  4. "context"
  5. "encoding/json"
  6. "runtime"
  7. "sync"
  8. "time"
  9. "github.com/mhsanaei/3x-ui/v2/logger"
  10. )
  11. // MessageType represents the type of WebSocket message
  12. type MessageType string
  13. const (
  14. MessageTypeStatus MessageType = "status" // Server status update
  15. MessageTypeTraffic MessageType = "traffic" // Traffic statistics update
  16. MessageTypeInbounds MessageType = "inbounds" // Inbounds list update
  17. MessageTypeNotification MessageType = "notification" // System notification
  18. MessageTypeXrayState MessageType = "xray_state" // Xray state change
  19. MessageTypeOutbounds MessageType = "outbounds" // Outbounds list update
  20. )
  21. // Message represents a WebSocket message
  22. type Message struct {
  23. Type MessageType `json:"type"`
  24. Payload any `json:"payload"`
  25. Time int64 `json:"time"`
  26. }
  27. // Client represents a WebSocket client connection
  28. type Client struct {
  29. ID string
  30. Send chan []byte
  31. Hub *Hub
  32. Topics map[MessageType]bool // Subscribed topics
  33. }
  34. // Hub maintains the set of active clients and broadcasts messages to them
  35. type Hub struct {
  36. // Registered clients
  37. clients map[*Client]bool
  38. // Inbound messages from clients
  39. broadcast chan []byte
  40. // Register requests from clients
  41. register chan *Client
  42. // Unregister requests from clients
  43. unregister chan *Client
  44. // Mutex for thread-safe operations
  45. mu sync.RWMutex
  46. // Context for graceful shutdown
  47. ctx context.Context
  48. cancel context.CancelFunc
  49. // Worker pool for parallel broadcasting
  50. workerPoolSize int
  51. broadcastWg sync.WaitGroup
  52. }
  53. // NewHub creates a new WebSocket hub
  54. func NewHub() *Hub {
  55. ctx, cancel := context.WithCancel(context.Background())
  56. // Calculate optimal worker pool size (CPU cores * 2, but max 100)
  57. workerPoolSize := runtime.NumCPU() * 2
  58. if workerPoolSize > 100 {
  59. workerPoolSize = 100
  60. }
  61. if workerPoolSize < 10 {
  62. workerPoolSize = 10
  63. }
  64. return &Hub{
  65. clients: make(map[*Client]bool),
  66. broadcast: make(chan []byte, 2048), // Increased from 256 to 2048 for high load
  67. register: make(chan *Client, 100), // Buffered channel for fast registration
  68. unregister: make(chan *Client, 100), // Buffered channel for fast unregistration
  69. ctx: ctx,
  70. cancel: cancel,
  71. workerPoolSize: workerPoolSize,
  72. }
  73. }
  74. // Run starts the hub's main loop
  75. func (h *Hub) Run() {
  76. defer func() {
  77. if r := recover(); r != nil {
  78. logger.Error("WebSocket hub panic recovered:", r)
  79. // Restart the hub loop
  80. go h.Run()
  81. }
  82. }()
  83. for {
  84. select {
  85. case <-h.ctx.Done():
  86. // Graceful shutdown: close all clients
  87. h.mu.Lock()
  88. for client := range h.clients {
  89. // Safely close channel (avoid double close panic)
  90. select {
  91. case _, stillOpen := <-client.Send:
  92. if stillOpen {
  93. close(client.Send)
  94. }
  95. default:
  96. close(client.Send)
  97. }
  98. }
  99. h.clients = make(map[*Client]bool)
  100. h.mu.Unlock()
  101. // Wait for all broadcast workers to finish
  102. h.broadcastWg.Wait()
  103. logger.Info("WebSocket hub stopped gracefully")
  104. return
  105. case client := <-h.register:
  106. if client == nil {
  107. continue
  108. }
  109. h.mu.Lock()
  110. h.clients[client] = true
  111. count := len(h.clients)
  112. h.mu.Unlock()
  113. logger.Debugf("WebSocket client connected: %s (total: %d)", client.ID, count)
  114. case client := <-h.unregister:
  115. if client == nil {
  116. continue
  117. }
  118. h.mu.Lock()
  119. if _, ok := h.clients[client]; ok {
  120. delete(h.clients, client)
  121. // Safely close channel (avoid double close panic)
  122. // Check if channel is already closed by trying to read from it
  123. select {
  124. case _, stillOpen := <-client.Send:
  125. if stillOpen {
  126. // Channel was open and had data, now it's empty, safe to close
  127. close(client.Send)
  128. }
  129. // If stillOpen is false, channel was already closed, do nothing
  130. default:
  131. // Channel is empty and open, safe to close
  132. close(client.Send)
  133. }
  134. }
  135. count := len(h.clients)
  136. h.mu.Unlock()
  137. logger.Debugf("WebSocket client disconnected: %s (total: %d)", client.ID, count)
  138. case message := <-h.broadcast:
  139. if message == nil {
  140. continue
  141. }
  142. // Optimization: quickly copy client list and release lock
  143. h.mu.RLock()
  144. clientCount := len(h.clients)
  145. if clientCount == 0 {
  146. h.mu.RUnlock()
  147. continue
  148. }
  149. // Pre-allocate memory for client list
  150. clients := make([]*Client, 0, clientCount)
  151. for client := range h.clients {
  152. clients = append(clients, client)
  153. }
  154. h.mu.RUnlock()
  155. // Parallel broadcast using worker pool
  156. h.broadcastParallel(clients, message)
  157. }
  158. }
  159. }
  160. // broadcastParallel sends message to all clients in parallel for maximum performance
  161. func (h *Hub) broadcastParallel(clients []*Client, message []byte) {
  162. if len(clients) == 0 {
  163. return
  164. }
  165. // For small number of clients, use simple parallel sending
  166. if len(clients) < h.workerPoolSize {
  167. var wg sync.WaitGroup
  168. for _, client := range clients {
  169. wg.Add(1)
  170. go func(c *Client) {
  171. defer wg.Done()
  172. defer func() {
  173. if r := recover(); r != nil {
  174. // Channel may be closed, safely ignore
  175. logger.Debugf("WebSocket broadcast panic recovered for client %s: %v", c.ID, r)
  176. }
  177. }()
  178. select {
  179. case c.Send <- message:
  180. default:
  181. // Client's send buffer is full, disconnect
  182. logger.Debugf("WebSocket client %s send buffer full, disconnecting", c.ID)
  183. h.Unregister(c)
  184. }
  185. }(client)
  186. }
  187. wg.Wait()
  188. return
  189. }
  190. // For large number of clients, use worker pool for optimal performance
  191. clientChan := make(chan *Client, len(clients))
  192. for _, client := range clients {
  193. clientChan <- client
  194. }
  195. close(clientChan)
  196. // Start workers for parallel processing
  197. h.broadcastWg.Add(h.workerPoolSize)
  198. for i := 0; i < h.workerPoolSize; i++ {
  199. go func() {
  200. defer h.broadcastWg.Done()
  201. for client := range clientChan {
  202. func() {
  203. defer func() {
  204. if r := recover(); r != nil {
  205. // Channel may be closed, safely ignore
  206. logger.Debugf("WebSocket broadcast panic recovered for client %s: %v", client.ID, r)
  207. }
  208. }()
  209. select {
  210. case client.Send <- message:
  211. default:
  212. // Client's send buffer is full, disconnect
  213. logger.Debugf("WebSocket client %s send buffer full, disconnecting", client.ID)
  214. h.Unregister(client)
  215. }
  216. }()
  217. }
  218. }()
  219. }
  220. // Wait for all workers to finish
  221. h.broadcastWg.Wait()
  222. }
  223. // Broadcast sends a message to all connected clients
  224. func (h *Hub) Broadcast(messageType MessageType, payload any) {
  225. if h == nil {
  226. return
  227. }
  228. if payload == nil {
  229. logger.Warning("Attempted to broadcast nil payload")
  230. return
  231. }
  232. msg := Message{
  233. Type: messageType,
  234. Payload: payload,
  235. Time: getCurrentTimestamp(),
  236. }
  237. data, err := json.Marshal(msg)
  238. if err != nil {
  239. logger.Error("Failed to marshal WebSocket message:", err)
  240. return
  241. }
  242. // Limit message size to prevent memory issues
  243. const maxMessageSize = 1024 * 1024 // 1MB
  244. if len(data) > maxMessageSize {
  245. logger.Warningf("WebSocket message too large: %d bytes, dropping", len(data))
  246. return
  247. }
  248. // Non-blocking send with timeout to prevent delays
  249. select {
  250. case h.broadcast <- data:
  251. case <-time.After(100 * time.Millisecond):
  252. logger.Warning("WebSocket broadcast channel is full, dropping message")
  253. case <-h.ctx.Done():
  254. // Hub is shutting down
  255. }
  256. }
  257. // BroadcastToTopic sends a message only to clients subscribed to the specific topic
  258. func (h *Hub) BroadcastToTopic(messageType MessageType, payload any) {
  259. if h == nil {
  260. return
  261. }
  262. if payload == nil {
  263. logger.Warning("Attempted to broadcast nil payload to topic")
  264. return
  265. }
  266. msg := Message{
  267. Type: messageType,
  268. Payload: payload,
  269. Time: getCurrentTimestamp(),
  270. }
  271. data, err := json.Marshal(msg)
  272. if err != nil {
  273. logger.Error("Failed to marshal WebSocket message:", err)
  274. return
  275. }
  276. // Limit message size to prevent memory issues
  277. const maxMessageSize = 1024 * 1024 // 1MB
  278. if len(data) > maxMessageSize {
  279. logger.Warningf("WebSocket message too large: %d bytes, dropping", len(data))
  280. return
  281. }
  282. h.mu.RLock()
  283. // Filter clients by topics and quickly release lock
  284. subscribedClients := make([]*Client, 0)
  285. for client := range h.clients {
  286. if len(client.Topics) == 0 || client.Topics[messageType] {
  287. subscribedClients = append(subscribedClients, client)
  288. }
  289. }
  290. h.mu.RUnlock()
  291. // Parallel send to subscribed clients
  292. if len(subscribedClients) > 0 {
  293. h.broadcastParallel(subscribedClients, data)
  294. }
  295. }
  296. // GetClientCount returns the number of connected clients
  297. func (h *Hub) GetClientCount() int {
  298. h.mu.RLock()
  299. defer h.mu.RUnlock()
  300. return len(h.clients)
  301. }
  302. // Register registers a new client with the hub
  303. func (h *Hub) Register(client *Client) {
  304. if h == nil || client == nil {
  305. return
  306. }
  307. select {
  308. case h.register <- client:
  309. case <-h.ctx.Done():
  310. // Hub is shutting down
  311. }
  312. }
  313. // Unregister unregisters a client from the hub
  314. func (h *Hub) Unregister(client *Client) {
  315. if h == nil || client == nil {
  316. return
  317. }
  318. select {
  319. case h.unregister <- client:
  320. case <-h.ctx.Done():
  321. // Hub is shutting down
  322. }
  323. }
  324. // Stop gracefully stops the hub and closes all connections
  325. func (h *Hub) Stop() {
  326. if h == nil {
  327. return
  328. }
  329. if h.cancel != nil {
  330. h.cancel()
  331. }
  332. }
  333. // getCurrentTimestamp returns current Unix timestamp in milliseconds
  334. func getCurrentTimestamp() int64 {
  335. return time.Now().UnixMilli()
  336. }