hub.go 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. // Package websocket provides WebSocket hub for real-time updates and notifications.
  2. package websocket
  3. import (
  4. "context"
  5. "encoding/json"
  6. "runtime"
  7. "sync"
  8. "time"
  9. "github.com/mhsanaei/3x-ui/v2/logger"
  10. )
  11. // MessageType represents the type of WebSocket message
  12. type MessageType string
  13. const (
  14. MessageTypeStatus MessageType = "status" // Server status update
  15. MessageTypeTraffic MessageType = "traffic" // Traffic statistics update
  16. MessageTypeInbounds MessageType = "inbounds" // Inbounds list update
  17. MessageTypeNotification MessageType = "notification" // System notification
  18. MessageTypeXrayState MessageType = "xray_state" // Xray state change
  19. )
  20. // Message represents a WebSocket message
  21. type Message struct {
  22. Type MessageType `json:"type"`
  23. Payload interface{} `json:"payload"`
  24. Time int64 `json:"time"`
  25. }
  26. // Client represents a WebSocket client connection
  27. type Client struct {
  28. ID string
  29. Send chan []byte
  30. Hub *Hub
  31. Topics map[MessageType]bool // Subscribed topics
  32. }
  33. // Hub maintains the set of active clients and broadcasts messages to them
  34. type Hub struct {
  35. // Registered clients
  36. clients map[*Client]bool
  37. // Inbound messages from clients
  38. broadcast chan []byte
  39. // Register requests from clients
  40. register chan *Client
  41. // Unregister requests from clients
  42. unregister chan *Client
  43. // Mutex for thread-safe operations
  44. mu sync.RWMutex
  45. // Context for graceful shutdown
  46. ctx context.Context
  47. cancel context.CancelFunc
  48. // Worker pool for parallel broadcasting
  49. workerPoolSize int
  50. broadcastWg sync.WaitGroup
  51. }
  52. // NewHub creates a new WebSocket hub
  53. func NewHub() *Hub {
  54. ctx, cancel := context.WithCancel(context.Background())
  55. // Calculate optimal worker pool size (CPU cores * 2, but max 100)
  56. workerPoolSize := runtime.NumCPU() * 2
  57. if workerPoolSize > 100 {
  58. workerPoolSize = 100
  59. }
  60. if workerPoolSize < 10 {
  61. workerPoolSize = 10
  62. }
  63. return &Hub{
  64. clients: make(map[*Client]bool),
  65. broadcast: make(chan []byte, 2048), // Increased from 256 to 2048 for high load
  66. register: make(chan *Client, 100), // Buffered channel for fast registration
  67. unregister: make(chan *Client, 100), // Buffered channel for fast unregistration
  68. ctx: ctx,
  69. cancel: cancel,
  70. workerPoolSize: workerPoolSize,
  71. }
  72. }
  73. // Run starts the hub's main loop
  74. func (h *Hub) Run() {
  75. defer func() {
  76. if r := recover(); r != nil {
  77. logger.Error("WebSocket hub panic recovered:", r)
  78. // Restart the hub loop
  79. go h.Run()
  80. }
  81. }()
  82. for {
  83. select {
  84. case <-h.ctx.Done():
  85. // Graceful shutdown: close all clients
  86. h.mu.Lock()
  87. for client := range h.clients {
  88. // Safely close channel (avoid double close panic)
  89. select {
  90. case _, stillOpen := <-client.Send:
  91. if stillOpen {
  92. close(client.Send)
  93. }
  94. default:
  95. close(client.Send)
  96. }
  97. }
  98. h.clients = make(map[*Client]bool)
  99. h.mu.Unlock()
  100. // Wait for all broadcast workers to finish
  101. h.broadcastWg.Wait()
  102. logger.Info("WebSocket hub stopped gracefully")
  103. return
  104. case client := <-h.register:
  105. if client == nil {
  106. continue
  107. }
  108. h.mu.Lock()
  109. h.clients[client] = true
  110. count := len(h.clients)
  111. h.mu.Unlock()
  112. logger.Debugf("WebSocket client connected: %s (total: %d)", client.ID, count)
  113. case client := <-h.unregister:
  114. if client == nil {
  115. continue
  116. }
  117. h.mu.Lock()
  118. if _, ok := h.clients[client]; ok {
  119. delete(h.clients, client)
  120. // Safely close channel (avoid double close panic)
  121. // Check if channel is already closed by trying to read from it
  122. select {
  123. case _, stillOpen := <-client.Send:
  124. if stillOpen {
  125. // Channel was open and had data, now it's empty, safe to close
  126. close(client.Send)
  127. }
  128. // If stillOpen is false, channel was already closed, do nothing
  129. default:
  130. // Channel is empty and open, safe to close
  131. close(client.Send)
  132. }
  133. }
  134. count := len(h.clients)
  135. h.mu.Unlock()
  136. logger.Debugf("WebSocket client disconnected: %s (total: %d)", client.ID, count)
  137. case message := <-h.broadcast:
  138. if message == nil {
  139. continue
  140. }
  141. // Optimization: quickly copy client list and release lock
  142. h.mu.RLock()
  143. clientCount := len(h.clients)
  144. if clientCount == 0 {
  145. h.mu.RUnlock()
  146. continue
  147. }
  148. // Pre-allocate memory for client list
  149. clients := make([]*Client, 0, clientCount)
  150. for client := range h.clients {
  151. clients = append(clients, client)
  152. }
  153. h.mu.RUnlock()
  154. // Parallel broadcast using worker pool
  155. h.broadcastParallel(clients, message)
  156. }
  157. }
  158. }
  159. // broadcastParallel sends message to all clients in parallel for maximum performance
  160. func (h *Hub) broadcastParallel(clients []*Client, message []byte) {
  161. if len(clients) == 0 {
  162. return
  163. }
  164. // For small number of clients, use simple parallel sending
  165. if len(clients) < h.workerPoolSize {
  166. var wg sync.WaitGroup
  167. for _, client := range clients {
  168. wg.Add(1)
  169. go func(c *Client) {
  170. defer wg.Done()
  171. defer func() {
  172. if r := recover(); r != nil {
  173. // Channel may be closed, safely ignore
  174. logger.Debugf("WebSocket broadcast panic recovered for client %s: %v", c.ID, r)
  175. }
  176. }()
  177. select {
  178. case c.Send <- message:
  179. default:
  180. // Client's send buffer is full, disconnect
  181. logger.Debugf("WebSocket client %s send buffer full, disconnecting", c.ID)
  182. h.Unregister(c)
  183. }
  184. }(client)
  185. }
  186. wg.Wait()
  187. return
  188. }
  189. // For large number of clients, use worker pool for optimal performance
  190. clientChan := make(chan *Client, len(clients))
  191. for _, client := range clients {
  192. clientChan <- client
  193. }
  194. close(clientChan)
  195. // Start workers for parallel processing
  196. h.broadcastWg.Add(h.workerPoolSize)
  197. for i := 0; i < h.workerPoolSize; i++ {
  198. go func() {
  199. defer h.broadcastWg.Done()
  200. for client := range clientChan {
  201. func() {
  202. defer func() {
  203. if r := recover(); r != nil {
  204. // Channel may be closed, safely ignore
  205. logger.Debugf("WebSocket broadcast panic recovered for client %s: %v", client.ID, r)
  206. }
  207. }()
  208. select {
  209. case client.Send <- message:
  210. default:
  211. // Client's send buffer is full, disconnect
  212. logger.Debugf("WebSocket client %s send buffer full, disconnecting", client.ID)
  213. h.Unregister(client)
  214. }
  215. }()
  216. }
  217. }()
  218. }
  219. // Wait for all workers to finish
  220. h.broadcastWg.Wait()
  221. }
  222. // Broadcast sends a message to all connected clients
  223. func (h *Hub) Broadcast(messageType MessageType, payload interface{}) {
  224. if h == nil {
  225. return
  226. }
  227. if payload == nil {
  228. logger.Warning("Attempted to broadcast nil payload")
  229. return
  230. }
  231. msg := Message{
  232. Type: messageType,
  233. Payload: payload,
  234. Time: getCurrentTimestamp(),
  235. }
  236. data, err := json.Marshal(msg)
  237. if err != nil {
  238. logger.Error("Failed to marshal WebSocket message:", err)
  239. return
  240. }
  241. // Limit message size to prevent memory issues
  242. const maxMessageSize = 1024 * 1024 // 1MB
  243. if len(data) > maxMessageSize {
  244. logger.Warningf("WebSocket message too large: %d bytes, dropping", len(data))
  245. return
  246. }
  247. // Non-blocking send with timeout to prevent delays
  248. select {
  249. case h.broadcast <- data:
  250. case <-time.After(100 * time.Millisecond):
  251. logger.Warning("WebSocket broadcast channel is full, dropping message")
  252. case <-h.ctx.Done():
  253. // Hub is shutting down
  254. }
  255. }
  256. // BroadcastToTopic sends a message only to clients subscribed to the specific topic
  257. func (h *Hub) BroadcastToTopic(messageType MessageType, payload interface{}) {
  258. if h == nil {
  259. return
  260. }
  261. if payload == nil {
  262. logger.Warning("Attempted to broadcast nil payload to topic")
  263. return
  264. }
  265. msg := Message{
  266. Type: messageType,
  267. Payload: payload,
  268. Time: getCurrentTimestamp(),
  269. }
  270. data, err := json.Marshal(msg)
  271. if err != nil {
  272. logger.Error("Failed to marshal WebSocket message:", err)
  273. return
  274. }
  275. // Limit message size to prevent memory issues
  276. const maxMessageSize = 1024 * 1024 // 1MB
  277. if len(data) > maxMessageSize {
  278. logger.Warningf("WebSocket message too large: %d bytes, dropping", len(data))
  279. return
  280. }
  281. h.mu.RLock()
  282. // Filter clients by topics and quickly release lock
  283. subscribedClients := make([]*Client, 0)
  284. for client := range h.clients {
  285. if len(client.Topics) == 0 || client.Topics[messageType] {
  286. subscribedClients = append(subscribedClients, client)
  287. }
  288. }
  289. h.mu.RUnlock()
  290. // Parallel send to subscribed clients
  291. if len(subscribedClients) > 0 {
  292. h.broadcastParallel(subscribedClients, data)
  293. }
  294. }
  295. // GetClientCount returns the number of connected clients
  296. func (h *Hub) GetClientCount() int {
  297. h.mu.RLock()
  298. defer h.mu.RUnlock()
  299. return len(h.clients)
  300. }
  301. // Register registers a new client with the hub
  302. func (h *Hub) Register(client *Client) {
  303. if h == nil || client == nil {
  304. return
  305. }
  306. select {
  307. case h.register <- client:
  308. case <-h.ctx.Done():
  309. // Hub is shutting down
  310. }
  311. }
  312. // Unregister unregisters a client from the hub
  313. func (h *Hub) Unregister(client *Client) {
  314. if h == nil || client == nil {
  315. return
  316. }
  317. select {
  318. case h.unregister <- client:
  319. case <-h.ctx.Done():
  320. // Hub is shutting down
  321. }
  322. }
  323. // Stop gracefully stops the hub and closes all connections
  324. func (h *Hub) Stop() {
  325. if h == nil {
  326. return
  327. }
  328. if h.cancel != nil {
  329. h.cancel()
  330. }
  331. }
  332. // getCurrentTimestamp returns current Unix timestamp in milliseconds
  333. func getCurrentTimestamp() int64 {
  334. return time.Now().UnixMilli()
  335. }